diff --git a/research_seed/baselines/__init__.py b/research_seed/baselines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/research_seed/baselines/mnist_baseline/README.md b/research_seed/baselines/mnist_baseline/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f9202487238ece1e2e4ba36ededde4a9de6ce8fe --- /dev/null +++ b/research_seed/baselines/mnist_baseline/README.md @@ -0,0 +1,17 @@ +## MNIST Baseline +In this readme, give instructions on how to run your code. + +#### CPU +```python +python mnist_baseline_trainer.py +``` + +#### Multiple-GPUs +```python +python mnist_baseline_trainer.py --gpus '0,1,2,3' +``` + +#### On multiple nodes +```python +python mnist_baseline_trainer.py --gpus '0,1,2,3' --nodes 4 +``` diff --git a/research_seed/baselines/mnist_baseline/__init__.py b/research_seed/baselines/mnist_baseline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/research_seed/baselines/mnist_baseline/mnist_baseline.py b/research_seed/baselines/mnist_baseline/mnist_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..c48ea849cfefb6e143e9bba6a769ec5ad27a9db9 --- /dev/null +++ b/research_seed/baselines/mnist_baseline/mnist_baseline.py @@ -0,0 +1,77 @@ +""" +This file defines the core research contribution +""" +import os +import torch +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torchvision.datasets import MNIST +import torchvision.transforms as transforms +from argparse import ArgumentParser + +import pytorch_lightning as pl + + +class CoolSystem(pl.LightningModule): + + def __init__(self, hparams): + super(CoolSystem, self).__init__() + # not the best model... + self.hparams = hparams + self.l1 = torch.nn.Linear(28 * 28, 10) + + def forward(self, x): + return torch.relu(self.l1(x.view(x.size(0), -1))) + + def training_step(self, batch, batch_nb): + # REQUIRED + x, y = batch + y_hat = self.forward(x) + return {'loss': F.cross_entropy(y_hat, y)} + + def validation_step(self, batch, batch_nb): + # OPTIONAL + x, y = batch + y_hat = self.forward(x) + return {'val_loss': F.cross_entropy(y_hat, y)} + + def validation_end(self, outputs): + # OPTIONAL + avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + return {'avg_val_loss': avg_loss} + + def configure_optimizers(self): + # REQUIRED + # can return multiple optimizers and learning_rate schedulers + return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + @pl.data_loader + def tng_dataloader(self): + # REQUIRED + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size) + + @pl.data_loader + def val_dataloader(self): + # OPTIONAL + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size) + + @pl.data_loader + def test_dataloader(self): + # OPTIONAL + return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size) + + @staticmethod + def add_model_specific_args(parent_parser): + """ + Specify the hyperparams for this LightningModule + """ + # MODEL specific + parser = ArgumentParser(parents=[parent_parser]) + parser.add_argument('--learning_rate', default=0.02, type=float) + parser.add_argument('--batch_size', default=32, type=int) + + # training specific (for this model) + parser.add_argument('--max_nb_epochs', default=2, type=int) + + return parser + diff --git a/research_seed/baselines/mnist_baseline/mnist_baseline_trainer.py b/research_seed/baselines/mnist_baseline/mnist_baseline_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..152ae4d1caaf6c7bee5fc6e496baa17917518290 --- /dev/null +++ b/research_seed/baselines/mnist_baseline/mnist_baseline_trainer.py @@ -0,0 +1,34 @@ +""" +This file runs the main training/val loop, etc... using Lightning Trainer +""" +from pytorch_lightning import Trainer +from argparse import ArgumentParser +from research_seed.mnist.mnist import CoolSystem + + +def main(hparams): + # init module + model = CoolSystem(hparams) + + # most basic trainer, uses good defaults + trainer = Trainer( + max_nb_epochs=hparams.max_nb_epochs, + gpus=hparams.gpus, + nb_gpu_nodes=hparams.nodes, + ) + trainer.fit(model) + + +if __name__ == '__main__': + parser = ArgumentParser(add_help=False) + parser.add_argument('--gpus', type=str, default=None) + parser.add_argument('--nodes', type=int, default=1) + + # give the module a chance to add own params + # good practice to define LightningModule speficic params in the module + parser = CoolSystem.add_model_specific_args(parser) + + # parse params + hparams = parser.parse_args() + + main(hparams)