diff --git a/.gitignore b/.gitignore index 73c615f495e411a8b0f83d29460d4279dce974d0..1dc6a3997d333ca448b9226e8e5cf2d9ef63557a 100644 --- a/.gitignore +++ b/.gitignore @@ -120,6 +120,7 @@ venv.bak/ # IDEs .idea +.vscode # seed project lightning_logs/ diff --git a/requirements.txt b/requirements.txt index 299248718de256390ef419166ee3f51be197d3c0..260eeda4423b12fd41ed697dc29276f2eef18c50 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -pytorch-lightning >= 0.7.5 \ No newline at end of file +pytorch-lightning >= 0.9.0 \ No newline at end of file diff --git a/src/research_mnist/mnist.py b/src/research_mnist/mnist.py index d441101d3af99b6910c5965ae2f8abd1faadba9c..dffe47f0dd3a702986b2a215e5fc0118cad95751 100644 --- a/src/research_mnist/mnist.py +++ b/src/research_mnist/mnist.py @@ -29,52 +29,54 @@ class CoolSystem(pl.LightningModule): y_hat = self.forward(x) loss = F.cross_entropy(y_hat, y) - tensorboard_logs = {'train_loss': loss} + result = pl.TrainResult(minimize=loss) + result.log('train_loss', loss, prog_bar=True) - return {'loss': loss, 'log': tensorboard_logs} + return result def validation_step(self, batch, batch_idx): # OPTIONAL x, y = batch y_hat = self.forward(x) - return {'val_loss': F.cross_entropy(y_hat, y)} + loss = F.cross_entropy(y_hat, y) + result = pl.EvalResult() + result.valid_batch_loss = loss + result.log('valid_loss', loss, on_epoch=True, prog_bar=True) + + return result def validation_epoch_end(self, outputs): # OPTIONAL - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + avg_loss = outputs.valid_batch_loss.mean() + result = pl.EvalResult(checkpoint_on=avg_loss) + result.log('valid_loss', avg_loss, on_epoch=True, prog_bar=True) - tensorboard_logs = {'avg_val_loss': avg_loss} - return {'val_loss': avg_loss, 'log': tensorboard_logs} + return result def test_step(self, batch, batch_idx): # OPTIONAL x, y = batch y_hat = self.forward(x) - return {'test_loss': F.cross_entropy(y_hat, y)} + loss = F.cross_entropy(y_hat, y) + result = pl.EvalResult() + result.test_batch_loss = loss + result.log('test_loss', loss, on_epoch=True) + + return result def test_epoch_end(self, outputs): # OPTIONAL - avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() + avg_loss = outputs.test_batch_loss.mean() - tensorboard_logs = {'test_val_loss': avg_loss} - return {'test_loss': avg_loss, 'log': tensorboard_logs} + result = pl.EvalResult() + result.log('test_loss', avg_loss, on_epoch=True) + return result def configure_optimizers(self): # REQUIRED # can return multiple optimizers and learning_rate schedulers return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - def train_dataloader(self): - # REQUIRED - return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size) - - def val_dataloader(self): - # OPTIONAL - return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size) - - def test_dataloader(self): - # OPTIONAL - return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size) @staticmethod def add_model_specific_args(parent_parser): @@ -84,7 +86,6 @@ class CoolSystem(pl.LightningModule): # MODEL specific parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--learning_rate', default=0.02, type=float) - parser.add_argument('--batch_size', default=32, type=int) # training specific (for this model) parser.add_argument('--max_nb_epochs', default=2, type=int) diff --git a/src/research_mnist/mnist_data_module.py b/src/research_mnist/mnist_data_module.py new file mode 100644 index 0000000000000000000000000000000000000000..c4f7c256383bd287bdd8732b8da548e8617b02af --- /dev/null +++ b/src/research_mnist/mnist_data_module.py @@ -0,0 +1,92 @@ +from argparse import ArgumentParser + +import pytorch_lightning as pl +from pytorch_lightning.metrics.functional import accuracy +import torch +from torch import nn +import torch.nn.functional as F +from torch.utils.data import random_split, DataLoader + +# Note - you must have torchvision installed for this example +from torchvision.datasets import MNIST, CIFAR10 +from torchvision import transforms + + +class MNISTDataModule(pl.LightningDataModule): + def __init__(self, hparams): + + super().__init__() + + self.hparams = hparams + + self.data_dir = self.hparams.data_dir + self.batch_size = self.hparams.batch_size + + # We hardcode dataset specific stuff here. + self.num_classes = 10 + self.dims = (1, 28, 28) + self.transform = transforms.Compose([transforms.ToTensor(),]) + + # Basic test that parameters passed are sensible. + assert ( + self.hparams.train_size + self.hparams.valid_size == 60_000 + ), "Invalid Train and Valid Split, make sure they add up to 60,000" + + def prepare_data(self): + # download the dataset + MNIST(self.data_dir, train=True, download=True) + MNIST(self.data_dir, train=False, download=True) + + def setup(self, stage=None): + + # Assign train/val datasets for use in dataloaders + if stage == "fit" or stage is None: + mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) + self.mnist_train, self.mnist_val = random_split( + mnist_full, [self.hparams.train_size, self.hparams.valid_size] + ) + + # Assign test dataset for use in dataloader(s) + if stage == "test" or stage is None: + self.mnist_test = MNIST( + self.data_dir, train=False, transform=self.transform + ) + + def train_dataloader(self): + # REQUIRED + return DataLoader( + self.mnist_train, + batch_size=self.batch_size, + num_workers=self.hparams.workers, + ) + + def val_dataloader(self): + # OPTIONAL + return DataLoader( + self.mnist_val, batch_size=self.batch_size, num_workers=self.hparams.workers + ) + + def test_dataloader(self): + # OPTIONAL + return DataLoader( + self.mnist_test, + batch_size=self.batch_size, + num_workers=self.hparams.workers, + ) + + @staticmethod + def add_data_specific_args(parent_parser): + """ + Specify the hyperparams for this LightningModule + """ + # Dataset specific + parser = ArgumentParser(parents=[parent_parser], add_help=False) + parser.add_argument("--batch_size", default=32, type=int) + parser.add_argument("--data_dir", default="./", type=str) + + # training specific + parser.add_argument("--train_size", default=55_000, type=int) + parser.add_argument("--valid_size", default=5_000, type=int) + parser.add_argument("--workers", default=8, type=int) + + return parser diff --git a/src/research_mnist/mnist_trainer.py b/src/research_mnist/mnist_trainer.py index 74396b93af5ba4ca8f3d0c0b2ccc14c6420ddaa8..6e05f059cad4305de31998bfb39f5316379b9ad3 100644 --- a/src/research_mnist/mnist_trainer.py +++ b/src/research_mnist/mnist_trainer.py @@ -4,18 +4,20 @@ This file runs the main training/val loop, etc... using Lightning Trainer from pytorch_lightning import Trainer, seed_everything from argparse import ArgumentParser from src.research_mnist.mnist import CoolSystem +from src.research_mnist.mnist_data_module import MNISTDataModule # sets seeds for numpy, torch, etc... # must do for DDP to work well seed_everything(123) def main(args): - # init module + # init modules + dm = MNISTDataModule(hparams=args) model = CoolSystem(hparams=args) # most basic trainer, uses good defaults trainer = Trainer.from_argparse_args(args) - trainer.fit(model) + trainer.fit(model, dm) trainer.test() @@ -29,6 +31,8 @@ if __name__ == '__main__': # give the module a chance to add own params # good practice to define LightningModule speficic params in the module parser = CoolSystem.add_model_specific_args(parser) + # same goes for data modules + parser = MNISTDataModule.add_data_specific_args(parser) # parse params args = parser.parse_args()