Skip to content
Snippets Groups Projects
Commit 75b954bc authored by Kartheek Akella's avatar Kartheek Akella
Browse files

Add data module for handling dataset specific tasks.

parent 6c2f52b5
Branches
No related tags found
No related merge requests found
......@@ -29,52 +29,54 @@ class CoolSystem(pl.LightningModule):
y_hat = self.forward(x)
loss = F.cross_entropy(y_hat, y)
tensorboard_logs = {'train_loss': loss}
result = pl.TrainResult(minimize=loss)
result.log('train_loss', loss, prog_bar=True)
return {'loss': loss, 'log': tensorboard_logs}
return result
def validation_step(self, batch, batch_idx):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'val_loss': F.cross_entropy(y_hat, y)}
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult()
result.valid_batch_loss = loss
result.log('valid_loss', loss, on_epoch=True, prog_bar=True)
return result
def validation_epoch_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_loss = outputs.valid_batch_loss.mean()
result = pl.EvalResult(checkpoint_on=avg_loss)
result.log('valid_loss', avg_loss, on_epoch=True, prog_bar=True)
tensorboard_logs = {'avg_val_loss': avg_loss}
return {'val_loss': avg_loss, 'log': tensorboard_logs}
return result
def test_step(self, batch, batch_idx):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'test_loss': F.cross_entropy(y_hat, y)}
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult()
result.test_batch_loss = loss
result.log('test_loss', loss, on_epoch=True)
return result
def test_epoch_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
avg_loss = outputs.test_batch_loss.mean()
tensorboard_logs = {'test_val_loss': avg_loss}
return {'test_loss': avg_loss, 'log': tensorboard_logs}
result = pl.EvalResult()
result.log('test_loss', avg_loss, on_epoch=True)
return result
def configure_optimizers(self):
# REQUIRED
# can return multiple optimizers and learning_rate schedulers
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
def train_dataloader(self):
# REQUIRED
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
def val_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
def test_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
@staticmethod
def add_model_specific_args(parent_parser):
......@@ -84,7 +86,6 @@ class CoolSystem(pl.LightningModule):
# MODEL specific
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--learning_rate', default=0.02, type=float)
parser.add_argument('--batch_size', default=32, type=int)
# training specific (for this model)
parser.add_argument('--max_nb_epochs', default=2, type=int)
......
from argparse import ArgumentParser
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, hparams):
super().__init__()
self.hparams = hparams
self.data_dir = self.hparams.data_dir
self.batch_size = self.hparams.batch_size
# We hardcode dataset specific stuff here.
self.num_classes = 10
self.dims = (1, 28, 28)
self.transform = transforms.Compose([transforms.ToTensor(),])
# Basic test that parameters passed are sensible.
assert (
self.hparams.train_size + self.hparams.valid_size == 60_000
), "Invalid Train and Valid Split, make sure they add up to 60,000"
def prepare_data(self):
# download the dataset
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [self.hparams.train_size, self.hparams.valid_size]
)
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(
self.data_dir, train=False, transform=self.transform
)
def train_dataloader(self):
# REQUIRED
return DataLoader(
self.mnist_train,
batch_size=self.batch_size,
num_workers=self.hparams.workers,
)
def val_dataloader(self):
# OPTIONAL
return DataLoader(
self.mnist_val, batch_size=self.batch_size, num_workers=self.hparams.workers
)
def test_dataloader(self):
# OPTIONAL
return DataLoader(
self.mnist_test,
batch_size=self.batch_size,
num_workers=self.hparams.workers,
)
@staticmethod
def add_data_specific_args(parent_parser):
"""
Specify the hyperparams for this LightningModule
"""
# Dataset specific
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--data_dir", default="./", type=str)
# training specific
parser.add_argument("--train_size", default=55_000, type=int)
parser.add_argument("--valid_size", default=5_000, type=int)
parser.add_argument("--workers", default=8, type=int)
return parser
......@@ -4,18 +4,20 @@ This file runs the main training/val loop, etc... using Lightning Trainer
from pytorch_lightning import Trainer, seed_everything
from argparse import ArgumentParser
from src.research_mnist.mnist import CoolSystem
from src.research_mnist.mnist_data_module import MNISTDataModule
# sets seeds for numpy, torch, etc...
# must do for DDP to work well
seed_everything(123)
def main(args):
# init module
# init modules
dm = MNISTDataModule(hparams=args)
model = CoolSystem(hparams=args)
# most basic trainer, uses good defaults
trainer = Trainer.from_argparse_args(args)
trainer.fit(model)
trainer.fit(model, dm)
trainer.test()
......@@ -29,6 +31,8 @@ if __name__ == '__main__':
# give the module a chance to add own params
# good practice to define LightningModule speficic params in the module
parser = CoolSystem.add_model_specific_args(parser)
# same goes for data modules
parser = MNISTDataModule.add_data_specific_args(parser)
# parse params
args = parser.parse_args()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment