Skip to content
Snippets Groups Projects
Commit faef5bbf authored by William Falcon's avatar William Falcon
Browse files

updated template to PL 1.0

parent a7d3627c
No related branches found
No related tags found
No related merge requests found
from argparse import ArgumentParser
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torch.utils.data import random_split
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 64),
nn.ReLU(),
nn.Linear(64, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28 * 28)
)
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def cli_main():
pl.seed_everything(1234)
# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--hidden_dim', type=int, default=128)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# ------------
# data
# ------------
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
# ------------
# model
# ------------
model = LitAutoEncoder()
# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)
# ------------
# testing
# ------------
result = trainer.test(test_dataloaders=test_loader)
print(result)
if __name__ == '__main__':
cli_main()
from argparse import ArgumentParser
import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
class Backbone(torch.nn.Module):
def __init__(self, hidden_dim=128):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
self.l2 = torch.nn.Linear(hidden_dim, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
class LitClassifier(pl.LightningModule):
def __init__(self, backbone, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.backbone = backbone
def forward(self, x):
# use forward for inference/predictions
embedding = self.backbone(x)
return embedding
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss, on_step=True)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('test_loss', loss)
def configure_optimizers(self):
# self.hparams available because we called self.save_hyperparameters()
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser
def cli_main():
pl.seed_everything(1234)
# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--hidden_dim', type=int, default=128)
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
args = parser.parse_args()
# ------------
# data
# ------------
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
# ------------
# model
# ------------
model = LitClassifier(Backbone(hidden_dim=args.hidden_dim), args.learning_rate)
# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)
# ------------
# testing
# ------------
result = trainer.test(test_dataloaders=test_loader)
print(result)
if __name__ == '__main__':
cli_main()
...@@ -2,21 +2,21 @@ from argparse import ArgumentParser ...@@ -2,21 +2,21 @@ from argparse import ArgumentParser
import torch import torch
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from torch.nn import functional as F from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
class LitClassifier(pl.LightningModule): class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, **kwargs): def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__() super().__init__()
self.save_hyperparameters() self.save_hyperparameters()
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
self.mnist_train = None
self.mnist_val = None
def forward(self, x): def forward(self, x):
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x)) x = torch.relu(self.l1(x))
...@@ -27,57 +27,70 @@ class LitClassifier(pl.LightningModule): ...@@ -27,57 +27,70 @@ class LitClassifier(pl.LightningModule):
x, y = batch x, y = batch
y_hat = self(x) y_hat = self(x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
result = pl.TrainResult(minimize=loss) return loss
result.log('train_loss', loss)
return result
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
x, y = batch x, y = batch
y_hat = self(x) y_hat = self(x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss) self.log('valid_loss', loss)
result.log('val_loss', loss)
result.log('val_acc', accuracy(y_hat, y))
return result
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
x, y = batch x, y = batch
y_hat = self(x) y_hat = self(x)
loss = F.cross_entropy(y_hat, y) loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss) self.log('test_loss', loss)
result.log('test_loss', loss)
result.log('test_acc', accuracy(y_hat, y))
return result
def configure_optimizers(self): def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser
def cli_main(): def cli_main():
from project.datasets.mnist import mnist
pl.seed_everything(1234) pl.seed_everything(1234)
# ------------
# args # args
# ------------
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('--gpus', default=0, type=int) parser.add_argument('--batch_size', default=32, type=int)
parser = pl.Trainer.add_argparse_args(parser)
# optional... automatically add all the params parser = LitClassifier.add_model_specific_args(parser)
# parser = pl.Trainer.add_argparse_args(parser)
# parser = MNISTDataModule.add_argparse_args(parser)
args = parser.parse_args() args = parser.parse_args()
# ------------
# data # data
mnist_train, mnist_val, test_dataset = mnist() # ------------
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
# ------------
# model # model
model = LitClassifier(**vars(args)) # ------------
model = LitClassifier(args.hidden_dim, args.learning_rate)
# ------------
# training # training
trainer = pl.Trainer(gpus=args.gpus, max_epochs=2, limit_train_batches=200) # ------------
trainer.fit(model, mnist_train, mnist_val) trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)
trainer.test(test_dataloaders=test_dataset) # ------------
# testing
# ------------
trainer.test(test_dataloaders=test_loader)
if __name__ == '__main__': # pragma: no cover if __name__ == '__main__':
cli_main() cli_main()
pytorch-lightning >= 0.9.0 pytorch-lightning >= 1.0.0rc2
torch >= 1.3.0 torch >= 1.3.0
torchvision >= 0.6.0 torchvision >= 0.6.0
from pytorch_lightning import Trainer, seed_everything from pytorch_lightning import Trainer, seed_everything
from project.lit_mnist import LitClassifier
from project.lit_classifier_main import LitClassifier from project.datasets.mnist import mnist
def test_lit_classifier(): def test_lit_classifier():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment