Skip to content
Snippets Groups Projects
Unverified Commit 56de532c authored by William Falcon's avatar William Falcon Committed by GitHub
Browse files

Merge pull request #9 from asvskartheek/master

Update with Data Module and Loggers
parents 06f5a89b 75b954bc
No related branches found
No related tags found
No related merge requests found
...@@ -120,6 +120,7 @@ venv.bak/ ...@@ -120,6 +120,7 @@ venv.bak/
# IDEs # IDEs
.idea .idea
.vscode
# seed project # seed project
lightning_logs/ lightning_logs/
......
pytorch-lightning >= 0.7.5 pytorch-lightning >= 0.9.0
\ No newline at end of file \ No newline at end of file
...@@ -29,52 +29,54 @@ class CoolSystem(pl.LightningModule): ...@@ -29,52 +29,54 @@ class CoolSystem(pl.LightningModule):
y_hat = self.forward(x) y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {'train_loss': loss} result = pl.TrainResult(minimize=loss)
result.log('train_loss', loss, prog_bar=True)
return {'loss': loss, 'log': tensorboard_logs} return result
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# OPTIONAL # OPTIONAL
x, y = batch x, y = batch
y_hat = self.forward(x) y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)} loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult()
result.valid_batch_loss = loss
result.log('valid_loss', loss, on_epoch=True, prog_bar=True)
return result
def validation_epoch_end(self, outputs): def validation_epoch_end(self, outputs):
# OPTIONAL # OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() avg_loss = outputs.valid_batch_loss.mean()
result = pl.EvalResult(checkpoint_on=avg_loss)
result.log('valid_loss', avg_loss, on_epoch=True, prog_bar=True)
tensorboard_logs = {'avg_val_loss': avg_loss} return result
return {'val_loss': avg_loss, 'log': tensorboard_logs}
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
# OPTIONAL # OPTIONAL
x, y = batch x, y = batch
y_hat = self.forward(x) y_hat = self.forward(x)
return {'test_loss': F.cross_entropy(y_hat, y)} loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult()
result.test_batch_loss = loss
result.log('test_loss', loss, on_epoch=True)
return result
def test_epoch_end(self, outputs): def test_epoch_end(self, outputs):
# OPTIONAL # OPTIONAL
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() avg_loss = outputs.test_batch_loss.mean()
tensorboard_logs = {'test_val_loss': avg_loss} result = pl.EvalResult()
return {'test_loss': avg_loss, 'log': tensorboard_logs} result.log('test_loss', avg_loss, on_epoch=True)
return result
def configure_optimizers(self): def configure_optimizers(self):
# REQUIRED # REQUIRED
# can return multiple optimizers and learning_rate schedulers # can return multiple optimizers and learning_rate schedulers
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
def train_dataloader(self):
# REQUIRED
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
def val_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
def test_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
@staticmethod @staticmethod
def add_model_specific_args(parent_parser): def add_model_specific_args(parent_parser):
...@@ -84,7 +86,6 @@ class CoolSystem(pl.LightningModule): ...@@ -84,7 +86,6 @@ class CoolSystem(pl.LightningModule):
# MODEL specific # MODEL specific
parser = ArgumentParser(parents=[parent_parser], add_help=False) parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--learning_rate', default=0.02, type=float) parser.add_argument('--learning_rate', default=0.02, type=float)
parser.add_argument('--batch_size', default=32, type=int)
# training specific (for this model) # training specific (for this model)
parser.add_argument('--max_nb_epochs', default=2, type=int) parser.add_argument('--max_nb_epochs', default=2, type=int)
......
from argparse import ArgumentParser
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.data_dir = self.hparams.data_dir
self.batch_size = self.hparams.batch_size
# We hardcode dataset specific stuff here.
self.num_classes = 10
self.dims = (1, 28, 28)
self.transform = transforms.Compose([transforms.ToTensor(),])
# Basic test that parameters passed are sensible.
assert (
self.hparams.train_size + self.hparams.valid_size == 60_000
), "Invalid Train and Valid Split, make sure they add up to 60,000"
def prepare_data(self):
# download the dataset
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [self.hparams.train_size, self.hparams.valid_size]
)
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(
self.data_dir, train=False, transform=self.transform
)
def train_dataloader(self):
# REQUIRED
return DataLoader(
self.mnist_train,
batch_size=self.batch_size,
num_workers=self.hparams.workers,
)
def val_dataloader(self):
# OPTIONAL
return DataLoader(
self.mnist_val, batch_size=self.batch_size, num_workers=self.hparams.workers
)
def test_dataloader(self):
# OPTIONAL
return DataLoader(
self.mnist_test,
batch_size=self.batch_size,
num_workers=self.hparams.workers,
)
@staticmethod
def add_data_specific_args(parent_parser):
"""
Specify the hyperparams for this LightningModule
"""
# Dataset specific
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--data_dir", default="./", type=str)
# training specific
parser.add_argument("--train_size", default=55_000, type=int)
parser.add_argument("--valid_size", default=5_000, type=int)
parser.add_argument("--workers", default=8, type=int)
return parser
...@@ -4,18 +4,20 @@ This file runs the main training/val loop, etc... using Lightning Trainer ...@@ -4,18 +4,20 @@ This file runs the main training/val loop, etc... using Lightning Trainer
from pytorch_lightning import Trainer, seed_everything from pytorch_lightning import Trainer, seed_everything
from argparse import ArgumentParser from argparse import ArgumentParser
from src.research_mnist.mnist import CoolSystem from src.research_mnist.mnist import CoolSystem
from src.research_mnist.mnist_data_module import MNISTDataModule
# sets seeds for numpy, torch, etc... # sets seeds for numpy, torch, etc...
# must do for DDP to work well # must do for DDP to work well
seed_everything(123) seed_everything(123)
def main(args): def main(args):
# init module # init modules
dm = MNISTDataModule(hparams=args)
model = CoolSystem(hparams=args) model = CoolSystem(hparams=args)
# most basic trainer, uses good defaults # most basic trainer, uses good defaults
trainer = Trainer.from_argparse_args(args) trainer = Trainer.from_argparse_args(args)
trainer.fit(model) trainer.fit(model, dm)
trainer.test() trainer.test()
...@@ -29,6 +31,8 @@ if __name__ == '__main__': ...@@ -29,6 +31,8 @@ if __name__ == '__main__':
# give the module a chance to add own params # give the module a chance to add own params
# good practice to define LightningModule speficic params in the module # good practice to define LightningModule speficic params in the module
parser = CoolSystem.add_model_specific_args(parser) parser = CoolSystem.add_model_specific_args(parser)
# same goes for data modules
parser = MNISTDataModule.add_data_specific_args(parser)
# parse params # parse params
args = parser.parse_args() args = parser.parse_args()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment