Skip to content
Snippets Groups Projects
Commit faef5bbf authored by William Falcon's avatar William Falcon
Browse files

updated template to PL 1.0

parent a7d3627c
No related branches found
No related tags found
No related merge requests found
from argparse import ArgumentParser
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torch.utils.data import random_split
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
class LitAutoEncoder(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 64),
nn.ReLU(),
nn.Linear(64, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 64),
nn.ReLU(),
nn.Linear(64, 28 * 28)
)
def forward(self, x):
# in lightning, forward defines the prediction/inference actions
embedding = self.encoder(x)
return embedding
def training_step(self, batch, batch_idx):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
loss = F.mse_loss(x_hat, x)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
def cli_main():
pl.seed_everything(1234)
# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--hidden_dim', type=int, default=128)
parser = pl.Trainer.add_argparse_args(parser)
args = parser.parse_args()
# ------------
# data
# ------------
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
# ------------
# model
# ------------
model = LitAutoEncoder()
# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)
# ------------
# testing
# ------------
result = trainer.test(test_dataloaders=test_loader)
print(result)
if __name__ == '__main__':
cli_main()
from argparse import ArgumentParser
import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
class Backbone(torch.nn.Module):
def __init__(self, hidden_dim=128):
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
self.l2 = torch.nn.Linear(hidden_dim, 10)
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
class LitClassifier(pl.LightningModule):
def __init__(self, backbone, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.backbone = backbone
def forward(self, x):
# use forward for inference/predictions
embedding = self.backbone(x)
return embedding
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('train_loss', loss, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss, on_step=True)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self.backbone(x)
loss = F.cross_entropy(y_hat, y)
self.log('test_loss', loss)
def configure_optimizers(self):
# self.hparams available because we called self.save_hyperparameters()
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser
def cli_main():
pl.seed_everything(1234)
# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--hidden_dim', type=int, default=128)
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
args = parser.parse_args()
# ------------
# data
# ------------
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
# ------------
# model
# ------------
model = LitClassifier(Backbone(hidden_dim=args.hidden_dim), args.learning_rate)
# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)
# ------------
# testing
# ------------
result = trainer.test(test_dataloaders=test_loader)
print(result)
if __name__ == '__main__':
cli_main()
......@@ -2,21 +2,21 @@ from argparse import ArgumentParser
import torch
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, **kwargs):
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
self.mnist_train = None
self.mnist_val = None
def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
......@@ -27,57 +27,70 @@ class LitClassifier(pl.LightningModule):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
result = pl.TrainResult(minimize=loss)
result.log('train_loss', loss)
return result
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log('val_loss', loss)
result.log('val_acc', accuracy(y_hat, y))
return result
self.log('valid_loss', loss)
def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
result = pl.EvalResult(checkpoint_on=loss)
result.log('test_loss', loss)
result.log('test_acc', accuracy(y_hat, y))
return result
self.log('test_loss', loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser
def cli_main():
from project.datasets.mnist import mnist
pl.seed_everything(1234)
# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--gpus', default=0, type=int)
# optional... automatically add all the params
# parser = pl.Trainer.add_argparse_args(parser)
# parser = MNISTDataModule.add_argparse_args(parser)
parser.add_argument('--batch_size', default=32, type=int)
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
args = parser.parse_args()
# ------------
# data
mnist_train, mnist_val, test_dataset = mnist()
# ------------
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
# ------------
# model
model = LitClassifier(**vars(args))
# ------------
model = LitClassifier(args.hidden_dim, args.learning_rate)
# ------------
# training
trainer = pl.Trainer(gpus=args.gpus, max_epochs=2, limit_train_batches=200)
trainer.fit(model, mnist_train, mnist_val)
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)
trainer.test(test_dataloaders=test_dataset)
# ------------
# testing
# ------------
trainer.test(test_dataloaders=test_loader)
if __name__ == '__main__': # pragma: no cover
if __name__ == '__main__':
cli_main()
pytorch-lightning >= 0.9.0
pytorch-lightning >= 1.0.0rc2
torch >= 1.3.0
torchvision >= 0.6.0
from pytorch_lightning import Trainer, seed_everything
from project.lit_classifier_main import LitClassifier
from project.lit_mnist import LitClassifier
from project.datasets.mnist import mnist
def test_lit_classifier():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment