Skip to content
Snippets Groups Projects
Commit 1a3b828b authored by William Falcon's avatar William Falcon
Browse files

clean up sample project

parent 67a75aa1
No related branches found
No related tags found
No related merge requests found
......@@ -52,10 +52,10 @@ pip install -r requirements.txt
Next, navigate to [Your Main Contribution (MNIST here)] and run it.
```bash
# module folder
cd research_mnist/
cd src
# run module (example: mnist as your main contribution)
python simplest_mnist.py
python mnist_classifier.py
```
## Main Contribution
......
## Research Seed Folder
Create a folder for each contribution (ie: MNIST, BERT, etc...).
Each folder will have:
##### contribution_name_trainer.py
Runs your LightningModule. Abstracts training loop, distributed training, etc...
##### contribution_name.py
Holds your main contribution
## Example
The folder here gives an example for mnist.
### MNIST
In this readme, give instructions on how to run your code.
#### CPU
```bash
python mnist_trainer.py
```
#### Multiple-GPUs
```bash
python mnist_trainer.py --gpus 4
```
or specific GPUs
```bash
python mnist_trainer.py --gpus '0,3'
```
#### On multiple nodes
```bash
python mnist_trainer.py --gpus 4 --nodes 4 --precision 16
```
from argparse import ArgumentParser
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
from torchvision import transforms
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.data_dir = self.hparams.data_dir
self.batch_size = self.hparams.batch_size
# We hardcode dataset specific stuff here.
self.num_classes = 10
self.dims = (1, 28, 28)
self.transform = transforms.Compose([transforms.ToTensor(), ])
# Basic test that parameters passed are sensible.
assert (
self.hparams.train_size + self.hparams.valid_size == 60_000
), "Invalid Train and Valid Split, make sure they add up to 60,000"
def prepare_data(self):
# download the dataset
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [self.hparams.train_size, self.hparams.valid_size]
)
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(
self.data_dir, train=False, transform=self.transform
)
def train_dataloader(self):
# REQUIRED
return DataLoader(
self.mnist_train,
batch_size=self.batch_size,
num_workers=self.hparams.workers,
)
def val_dataloader(self):
# OPTIONAL
return DataLoader(
self.mnist_val, batch_size=self.batch_size, num_workers=self.hparams.workers
)
def test_dataloader(self):
# OPTIONAL
return DataLoader(
self.mnist_test,
batch_size=self.batch_size,
num_workers=self.hparams.workers,
)
@staticmethod
def add_data_specific_args(parent_parser):
"""
Specify the hyperparams for this LightningModule
"""
# Dataset specific
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--data_dir", default="./", type=str)
# training specific
parser.add_argument("--train_size", default=55_000, type=int)
parser.add_argument("--valid_size", default=5_000, type=int)
parser.add_argument("--workers", default=8, type=int)
return parser
"""
This file runs the main training/val loop, etc. using Lightning Trainer.
"""
from argparse import ArgumentParser
from pytorch_lightning import Trainer, seed_everything
from research_mnist.simplest_mnist import CoolSystem
from research_mnist.mnist_data_module import MNISTDataModule
def main(args):
# init modules
dm = MNISTDataModule(hparams=args)
model = CoolSystem(hparams=args)
# most basic trainer, uses good defaults
trainer = Trainer.from_argparse_args(args)
trainer.fit(model, dm)
trainer.test()
def main_cli():
# sets seeds for numpy, torch, etc...
# must do for DDP to work well
seed_everything(123)
parser = ArgumentParser(add_help=False)
# add args from trainer
parser = Trainer.add_argparse_args(parser)
# give the module a chance to add own params
# good practice to define LightningModule speficic params in the module
parser = CoolSystem.add_model_specific_args(parser)
# same goes for data modules
parser = MNISTDataModule.add_data_specific_args(parser)
# parse params
args = parser.parse_args()
main(args)
if __name__ == '__main__':
main_cli()
"""
This file defines the core research contribution.
"""
from argparse import ArgumentParser
import pytorch_lightning as pl
import torch
from torch.nn import functional as F
class CoolSystem(pl.LightningModule):
def __init__(self, hparams):
super(CoolSystem, self).__init__()
# not the best model...
self.hparams = hparams
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_idx):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
result = pl.TrainResult(minimize=loss)
result.log('train_loss', loss, prog_bar=True)
return result
def validation_step(self, batch, batch_idx):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult()
result.valid_batch_loss = loss
result.log('valid_loss', loss, on_epoch=True, prog_bar=True)
return result
def validation_epoch_end(self, outputs):
# OPTIONAL
avg_loss = outputs.valid_batch_loss.mean()
result = pl.EvalResult(checkpoint_on=avg_loss)
result.log('valid_loss', avg_loss, on_epoch=True, prog_bar=True)
return result
def test_step(self, batch, batch_idx):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult()
result.test_batch_loss = loss
result.log('test_loss', loss, on_epoch=True)
return result
def test_epoch_end(self, outputs):
# OPTIONAL
avg_loss = outputs.test_batch_loss.mean()
result = pl.EvalResult()
result.log('test_loss', avg_loss, on_epoch=True)
return result
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@staticmethod
def add_model_specific_args(parent_parser):
"""
Specify the hyperparams for this LightningModule
"""
# MODEL specific
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--learning_rate', default=0.02, type=float)
# training specific (for this model)
parser.add_argument('--max_nb_epochs', default=2, type=int)
return parser
......@@ -3,7 +3,7 @@
from setuptools import setup, find_packages
setup(
name='research_mnist',
name='src',
version='0.0.0',
description='Describe Your Cool Project',
author='',
......
File moved
import os
from argparse import ArgumentParser
import torch
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, **kwargs):
super().__init__()
self.save_hyperparameters()
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
self.mnist_train = None
self.mnist_val = None
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
result = pl.TrainResult(minimize=loss)
result.log('train_loss', loss)
return result
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log('val_loss', loss)
result.log('val_acc', accuracy(y_hat, y))
return result
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log('test_loss', loss)
result.log('test_acc', accuracy(y_hat, y))
return result
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
def cli_main():
pl.seed_everything(1234)
# args
parser = ArgumentParser()
parser.add_argument('--gpus', default=0, type=int)
# optional... automatically add all the params
# parser = pl.Trainer.add_argparse_args(parser)
# parser = MNISTDataModule.add_argparse_args(parser)
args = parser.parse_args()
# data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000], generator=torch.Generator().manual_seed(1234))
mnist_train = DataLoader(mnist_train, batch_size=32)
mnist_val = DataLoader(mnist_val, batch_size=32)
test_dataset = DataLoader(MNIST('', train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
# model
model = LitClassifier(**vars(args))
# training
trainer = pl.Trainer(gpus=args.gpus, max_epochs=2, limit_train_batches=200)
trainer.fit(model, mnist_train, mnist_val)
trainer.test(test_dataloaders=test_dataset)
if __name__ == '__main__': # pragma: no cover
cli_main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment