diff --git a/project/lit_autoencoder.py b/project/lit_autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3f9ff0c6c9994d0561b464c3303bd2faa614f445 --- /dev/null +++ b/project/lit_autoencoder.py @@ -0,0 +1,88 @@ +from argparse import ArgumentParser +import torch +from torch import nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +import pytorch_lightning as pl +from torch.utils.data import random_split + +from torchvision.datasets.mnist import MNIST +from torchvision import transforms + + +class LitAutoEncoder(pl.LightningModule): + + def __init__(self): + super().__init__() + self.encoder = nn.Sequential( + nn.Linear(28 * 28, 64), + nn.ReLU(), + nn.Linear(64, 3) + ) + self.decoder = nn.Sequential( + nn.Linear(3, 64), + nn.ReLU(), + nn.Linear(64, 28 * 28) + ) + + def forward(self, x): + # in lightning, forward defines the prediction/inference actions + embedding = self.encoder(x) + return embedding + + def training_step(self, batch, batch_idx): + x, y = batch + x = x.view(x.size(0), -1) + z = self.encoder(x) + x_hat = self.decoder(z) + loss = F.mse_loss(x_hat, x) + return loss + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer + + +def cli_main(): + pl.seed_everything(1234) + + # ------------ + # args + # ------------ + parser = ArgumentParser() + parser.add_argument('--batch_size', default=32, type=int) + parser.add_argument('--hidden_dim', type=int, default=128) + parser = pl.Trainer.add_argparse_args(parser) + args = parser.parse_args() + + # ------------ + # data + # ------------ + dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor()) + mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor()) + mnist_train, mnist_val = random_split(dataset, [55000, 5000]) + + train_loader = DataLoader(mnist_train, batch_size=args.batch_size) + val_loader = DataLoader(mnist_val, batch_size=args.batch_size) + test_loader = DataLoader(mnist_test, batch_size=args.batch_size) + + # ------------ + # model + # ------------ + model = LitAutoEncoder() + + # ------------ + # training + # ------------ + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model, train_loader, val_loader) + + # ------------ + # testing + # ------------ + result = trainer.test(test_dataloaders=test_loader) + print(result) + + +if __name__ == '__main__': + cli_main() diff --git a/project/lit_classifier_main.py b/project/lit_classifier_main.py deleted file mode 100644 index 020b72c6c0e843b1166d4a95824b196a31eec45f..0000000000000000000000000000000000000000 --- a/project/lit_classifier_main.py +++ /dev/null @@ -1,83 +0,0 @@ -from argparse import ArgumentParser - -import torch -import pytorch_lightning as pl -from pytorch_lightning.metrics.functional import accuracy -from torch.nn import functional as F - - -class LitClassifier(pl.LightningModule): - def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, **kwargs): - super().__init__() - self.save_hyperparameters() - - self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) - self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) - - self.mnist_train = None - self.mnist_val = None - - def forward(self, x): - x = x.view(x.size(0), -1) - x = torch.relu(self.l1(x)) - x = torch.relu(self.l2(x)) - return x - - def training_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - result = pl.TrainResult(minimize=loss) - result.log('train_loss', loss) - return result - - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - result = pl.EvalResult(checkpoint_on=loss) - result.log('val_loss', loss) - result.log('val_acc', accuracy(y_hat, y)) - return result - - def test_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - result = pl.EvalResult(checkpoint_on=loss) - result.log('test_loss', loss) - result.log('test_acc', accuracy(y_hat, y)) - return result - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - - -def cli_main(): - from project.datasets.mnist import mnist - pl.seed_everything(1234) - - # args - parser = ArgumentParser() - parser.add_argument('--gpus', default=0, type=int) - - # optional... automatically add all the params - # parser = pl.Trainer.add_argparse_args(parser) - # parser = MNISTDataModule.add_argparse_args(parser) - args = parser.parse_args() - - # data - mnist_train, mnist_val, test_dataset = mnist() - - # model - model = LitClassifier(**vars(args)) - - # training - trainer = pl.Trainer(gpus=args.gpus, max_epochs=2, limit_train_batches=200) - trainer.fit(model, mnist_train, mnist_val) - - trainer.test(test_dataloaders=test_dataset) - - -if __name__ == '__main__': # pragma: no cover - cli_main() diff --git a/project/lit_image_classifier.py b/project/lit_image_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..1296a3f126eaf69628fc20a9ec58bf95043d6668 --- /dev/null +++ b/project/lit_image_classifier.py @@ -0,0 +1,109 @@ +from argparse import ArgumentParser + +import torch +import pytorch_lightning as pl +from torch.nn import functional as F +from torch.utils.data import DataLoader, random_split + +from torchvision.datasets.mnist import MNIST +from torchvision import transforms + + +class Backbone(torch.nn.Module): + def __init__(self, hidden_dim=128): + super().__init__() + self.l1 = torch.nn.Linear(28 * 28, hidden_dim) + self.l2 = torch.nn.Linear(hidden_dim, 10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = torch.relu(self.l1(x)) + x = torch.relu(self.l2(x)) + return x + + +class LitClassifier(pl.LightningModule): + def __init__(self, backbone, learning_rate=1e-3): + super().__init__() + self.save_hyperparameters() + self.backbone = backbone + + def forward(self, x): + # use forward for inference/predictions + embedding = self.backbone(x) + return embedding + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.backbone(x) + loss = F.cross_entropy(y_hat, y) + self.log('train_loss', loss, on_epoch=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self.backbone(x) + loss = F.cross_entropy(y_hat, y) + self.log('valid_loss', loss, on_step=True) + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self.backbone(x) + loss = F.cross_entropy(y_hat, y) + self.log('test_loss', loss) + + def configure_optimizers(self): + # self.hparams available because we called self.save_hyperparameters() + return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--learning_rate', type=float, default=0.0001) + return parser + + +def cli_main(): + pl.seed_everything(1234) + + # ------------ + # args + # ------------ + parser = ArgumentParser() + parser.add_argument('--batch_size', default=32, type=int) + parser.add_argument('--hidden_dim', type=int, default=128) + parser = pl.Trainer.add_argparse_args(parser) + parser = LitClassifier.add_model_specific_args(parser) + args = parser.parse_args() + + # ------------ + # data + # ------------ + dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor()) + mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor()) + mnist_train, mnist_val = random_split(dataset, [55000, 5000]) + + train_loader = DataLoader(mnist_train, batch_size=args.batch_size) + val_loader = DataLoader(mnist_val, batch_size=args.batch_size) + test_loader = DataLoader(mnist_test, batch_size=args.batch_size) + + # ------------ + # model + # ------------ + model = LitClassifier(Backbone(hidden_dim=args.hidden_dim), args.learning_rate) + + # ------------ + # training + # ------------ + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model, train_loader, val_loader) + + # ------------ + # testing + # ------------ + result = trainer.test(test_dataloaders=test_loader) + print(result) + + +if __name__ == '__main__': + cli_main() diff --git a/project/lit_mnist.py b/project/lit_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..873337858fed5e3c816dc63984cd350ccf303c22 --- /dev/null +++ b/project/lit_mnist.py @@ -0,0 +1,96 @@ +from argparse import ArgumentParser + +import torch +import pytorch_lightning as pl +from torch.nn import functional as F +from torch.utils.data import DataLoader, random_split + +from torchvision.datasets.mnist import MNIST +from torchvision import transforms + + +class LitClassifier(pl.LightningModule): + def __init__(self, hidden_dim=128, learning_rate=1e-3): + super().__init__() + self.save_hyperparameters() + + self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) + self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = torch.relu(self.l1(x)) + x = torch.relu(self.l2(x)) + return x + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('valid_loss', loss) + + def test_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + self.log('test_loss', loss) + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + @staticmethod + def add_model_specific_args(parent_parser): + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument('--hidden_dim', type=int, default=128) + parser.add_argument('--learning_rate', type=float, default=0.0001) + return parser + + +def cli_main(): + pl.seed_everything(1234) + + # ------------ + # args + # ------------ + parser = ArgumentParser() + parser.add_argument('--batch_size', default=32, type=int) + parser = pl.Trainer.add_argparse_args(parser) + parser = LitClassifier.add_model_specific_args(parser) + args = parser.parse_args() + + # ------------ + # data + # ------------ + dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor()) + mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor()) + mnist_train, mnist_val = random_split(dataset, [55000, 5000]) + + train_loader = DataLoader(mnist_train, batch_size=args.batch_size) + val_loader = DataLoader(mnist_val, batch_size=args.batch_size) + test_loader = DataLoader(mnist_test, batch_size=args.batch_size) + + # ------------ + # model + # ------------ + model = LitClassifier(args.hidden_dim, args.learning_rate) + + # ------------ + # training + # ------------ + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model, train_loader, val_loader) + + # ------------ + # testing + # ------------ + trainer.test(test_dataloaders=test_loader) + + +if __name__ == '__main__': + cli_main() diff --git a/requirements.txt b/requirements.txt index fb1b543e039eb6f5b4f299be59b010161258dc8a..30840f0ce8f203f5935035cf3c2d7701d13e9f03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -pytorch-lightning >= 0.9.0 +pytorch-lightning >= 1.0.0rc2 torch >= 1.3.0 torchvision >= 0.6.0 diff --git a/tests/test_classifier.py b/tests/test_classifier.py index 96be92bdd6c6ba8a7b00f7ab07b1a04772638b04..e173fd518d28d57f5c0d0ed46b3ab73a42cbeec4 100644 --- a/tests/test_classifier.py +++ b/tests/test_classifier.py @@ -1,6 +1,6 @@ from pytorch_lightning import Trainer, seed_everything - -from project.lit_classifier_main import LitClassifier +from project.lit_mnist import LitClassifier +from project.datasets.mnist import mnist def test_lit_classifier():