Skip to content
Snippets Groups Projects
Commit e5b8d156 authored by William Falcon's avatar William Falcon
Browse files

added baselines folder

parent aec0edf4
Branches
No related tags found
No related merge requests found
## MNIST Baseline
In this readme, give instructions on how to run your code.
#### CPU
```python
python mnist_baseline_trainer.py
```
#### Multiple-GPUs
```python
python mnist_baseline_trainer.py --gpus '0,1,2,3'
```
#### On multiple nodes
```python
python mnist_baseline_trainer.py --gpus '0,1,2,3' --nodes 4
```
"""
This file defines the core research contribution
"""
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from argparse import ArgumentParser
import pytorch_lightning as pl
class CoolSystem(pl.LightningModule):
def __init__(self, hparams):
super(CoolSystem, self).__init__()
# not the best model...
self.hparams = hparams
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)}
def validation_step(self, batch, batch_nb):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)}
def validation_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'avg_val_loss': avg_loss}
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@pl.data_loader
def tng_dataloader(self):
# REQUIRED
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
@pl.data_loader
def val_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
@pl.data_loader
def test_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
@staticmethod
def add_model_specific_args(parent_parser):
"""
Specify the hyperparams for this LightningModule
"""
# MODEL specific
parser = ArgumentParser(parents=[parent_parser])
parser.add_argument('--learning_rate', default=0.02, type=float)
parser.add_argument('--batch_size', default=32, type=int)
# training specific (for this model)
parser.add_argument('--max_nb_epochs', default=2, type=int)
return parser
"""
This file runs the main training/val loop, etc... using Lightning Trainer
"""
from pytorch_lightning import Trainer
from argparse import ArgumentParser
from research_seed.mnist.mnist import CoolSystem
def main(hparams):
# init module
model = CoolSystem(hparams)
# most basic trainer, uses good defaults
trainer = Trainer(
max_nb_epochs=hparams.max_nb_epochs,
gpus=hparams.gpus,
nb_gpu_nodes=hparams.nodes,
)
trainer.fit(model)
if __name__ == '__main__':
parser = ArgumentParser(add_help=False)
parser.add_argument('--gpus', type=str, default=None)
parser.add_argument('--nodes', type=int, default=1)
# give the module a chance to add own params
# good practice to define LightningModule speficic params in the module
parser = CoolSystem.add_model_specific_args(parser)
# parse params
hparams = parser.parse_args()
main(hparams)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment