Skip to content
Snippets Groups Projects
Unverified Commit 56de532c authored by William Falcon's avatar William Falcon Committed by GitHub
Browse files

Merge pull request #9 from asvskartheek/master

Update with Data Module and Loggers
parents 06f5a89b 75b954bc
No related branches found
No related tags found
No related merge requests found
......@@ -120,6 +120,7 @@ venv.bak/
# IDEs
.idea
.vscode
# seed project
lightning_logs/
......
pytorch-lightning >= 0.7.5
\ No newline at end of file
pytorch-lightning >= 0.9.0
\ No newline at end of file
......@@ -29,52 +29,54 @@ class CoolSystem(pl.LightningModule):
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {'train_loss': loss}
result = pl.TrainResult(minimize=loss)
result.log('train_loss', loss, prog_bar=True)
return {'loss': loss, 'log': tensorboard_logs}
return result
def validation_step(self, batch, batch_idx):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)}
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult()
result.valid_batch_loss = loss
result.log('valid_loss', loss, on_epoch=True, prog_bar=True)
return result
def validation_epoch_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_loss = outputs.valid_batch_loss.mean()
result = pl.EvalResult(checkpoint_on=avg_loss)
result.log('valid_loss', avg_loss, on_epoch=True, prog_bar=True)
tensorboard_logs = {'avg_val_loss': avg_loss}
return {'val_loss': avg_loss, 'log': tensorboard_logs}
return result
def test_step(self, batch, batch_idx):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'test_loss': F.cross_entropy(y_hat, y)}
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult()
result.test_batch_loss = loss
result.log('test_loss', loss, on_epoch=True)
return result
def test_epoch_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
avg_loss = outputs.test_batch_loss.mean()
tensorboard_logs = {'test_val_loss': avg_loss}
return {'test_loss': avg_loss, 'log': tensorboard_logs}
result = pl.EvalResult()
result.log('test_loss', avg_loss, on_epoch=True)
return result
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
def train_dataloader(self):
# REQUIRED
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
def val_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
def test_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
@staticmethod
def add_model_specific_args(parent_parser):
......@@ -84,7 +86,6 @@ class CoolSystem(pl.LightningModule):
# MODEL specific
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--learning_rate', default=0.02, type=float)
parser.add_argument('--batch_size', default=32, type=int)
# training specific (for this model)
parser.add_argument('--max_nb_epochs', default=2, type=int)
......
from argparse import ArgumentParser
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.data_dir = self.hparams.data_dir
self.batch_size = self.hparams.batch_size
# We hardcode dataset specific stuff here.
self.num_classes = 10
self.dims = (1, 28, 28)
self.transform = transforms.Compose([transforms.ToTensor(),])
# Basic test that parameters passed are sensible.
assert (
self.hparams.train_size + self.hparams.valid_size == 60_000
), "Invalid Train and Valid Split, make sure they add up to 60,000"
def prepare_data(self):
# download the dataset
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [self.hparams.train_size, self.hparams.valid_size]
)
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(
self.data_dir, train=False, transform=self.transform
)
def train_dataloader(self):
# REQUIRED
return DataLoader(
self.mnist_train,
batch_size=self.batch_size,
num_workers=self.hparams.workers,
)
def val_dataloader(self):
# OPTIONAL
return DataLoader(
self.mnist_val, batch_size=self.batch_size, num_workers=self.hparams.workers
)
def test_dataloader(self):
# OPTIONAL
return DataLoader(
self.mnist_test,
batch_size=self.batch_size,
num_workers=self.hparams.workers,
)
@staticmethod
def add_data_specific_args(parent_parser):
"""
Specify the hyperparams for this LightningModule
"""
# Dataset specific
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--data_dir", default="./", type=str)
# training specific
parser.add_argument("--train_size", default=55_000, type=int)
parser.add_argument("--valid_size", default=5_000, type=int)
parser.add_argument("--workers", default=8, type=int)
return parser
......@@ -4,18 +4,20 @@ This file runs the main training/val loop, etc... using Lightning Trainer
from pytorch_lightning import Trainer, seed_everything
from argparse import ArgumentParser
from src.research_mnist.mnist import CoolSystem
from src.research_mnist.mnist_data_module import MNISTDataModule
# sets seeds for numpy, torch, etc...
# must do for DDP to work well
seed_everything(123)
def main(args):
# init module
# init modules
dm = MNISTDataModule(hparams=args)
model = CoolSystem(hparams=args)
# most basic trainer, uses good defaults
trainer = Trainer.from_argparse_args(args)
trainer.fit(model)
trainer.fit(model, dm)
trainer.test()
......@@ -29,6 +31,8 @@ if __name__ == '__main__':
# give the module a chance to add own params
# good practice to define LightningModule speficic params in the module
parser = CoolSystem.add_model_specific_args(parser)
# same goes for data modules
parser = MNISTDataModule.add_data_specific_args(parser)
# parse params
args = parser.parse_args()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment