From faef5bbf4f3c392b9e4d71e606edaab5ce4d0aaf Mon Sep 17 00:00:00 2001
From: William Falcon <waf2107@columbia.edu>
Date: Thu, 8 Oct 2020 23:36:54 -0400
Subject: [PATCH] updated template to PL 1.0

---
 project/lit_autoencoder.py      |  88 ++++++++++++++++++++++++++
 project/lit_classifier_main.py  |  83 ------------------------
 project/lit_image_classifier.py | 109 ++++++++++++++++++++++++++++++++
 project/lit_mnist.py            |  96 ++++++++++++++++++++++++++++
 requirements.txt                |   2 +-
 tests/test_classifier.py        |   4 +-
 6 files changed, 296 insertions(+), 86 deletions(-)
 create mode 100644 project/lit_autoencoder.py
 delete mode 100644 project/lit_classifier_main.py
 create mode 100644 project/lit_image_classifier.py
 create mode 100644 project/lit_mnist.py

diff --git a/project/lit_autoencoder.py b/project/lit_autoencoder.py
new file mode 100644
index 0000000..3f9ff0c
--- /dev/null
+++ b/project/lit_autoencoder.py
@@ -0,0 +1,88 @@
+from argparse import ArgumentParser
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import pytorch_lightning as pl
+from torch.utils.data import random_split
+
+from torchvision.datasets.mnist import MNIST
+from torchvision import transforms
+
+
+class LitAutoEncoder(pl.LightningModule):
+
+    def __init__(self):
+        super().__init__()
+        self.encoder = nn.Sequential(
+            nn.Linear(28 * 28, 64),
+            nn.ReLU(),
+            nn.Linear(64, 3)
+        )
+        self.decoder = nn.Sequential(
+            nn.Linear(3, 64),
+            nn.ReLU(),
+            nn.Linear(64, 28 * 28)
+        )
+
+    def forward(self, x):
+        # in lightning, forward defines the prediction/inference actions
+        embedding = self.encoder(x)
+        return embedding
+
+    def training_step(self, batch, batch_idx):
+        x, y = batch
+        x = x.view(x.size(0), -1)
+        z = self.encoder(x)
+        x_hat = self.decoder(z)
+        loss = F.mse_loss(x_hat, x)
+        return loss
+
+    def configure_optimizers(self):
+        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
+        return optimizer
+
+
+def cli_main():
+    pl.seed_everything(1234)
+
+    # ------------
+    # args
+    # ------------
+    parser = ArgumentParser()
+    parser.add_argument('--batch_size', default=32, type=int)
+    parser.add_argument('--hidden_dim', type=int, default=128)
+    parser = pl.Trainer.add_argparse_args(parser)
+    args = parser.parse_args()
+
+    # ------------
+    # data
+    # ------------
+    dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
+    mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
+    mnist_train, mnist_val = random_split(dataset, [55000, 5000])
+
+    train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
+    val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
+    test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
+
+    # ------------
+    # model
+    # ------------
+    model = LitAutoEncoder()
+
+    # ------------
+    # training
+    # ------------
+    trainer = pl.Trainer.from_argparse_args(args)
+    trainer.fit(model, train_loader, val_loader)
+
+    # ------------
+    # testing
+    # ------------
+    result = trainer.test(test_dataloaders=test_loader)
+    print(result)
+
+
+if __name__ == '__main__':
+    cli_main()
diff --git a/project/lit_classifier_main.py b/project/lit_classifier_main.py
deleted file mode 100644
index 020b72c..0000000
--- a/project/lit_classifier_main.py
+++ /dev/null
@@ -1,83 +0,0 @@
-from argparse import ArgumentParser
-
-import torch
-import pytorch_lightning as pl
-from pytorch_lightning.metrics.functional import accuracy
-from torch.nn import functional as F
-
-
-class LitClassifier(pl.LightningModule):
-    def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, **kwargs):
-        super().__init__()
-        self.save_hyperparameters()
-
-        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
-        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
-
-        self.mnist_train = None
-        self.mnist_val = None
-
-    def forward(self, x):
-        x = x.view(x.size(0), -1)
-        x = torch.relu(self.l1(x))
-        x = torch.relu(self.l2(x))
-        return x
-
-    def training_step(self, batch, batch_idx):
-        x, y = batch
-        y_hat = self(x)
-        loss = F.cross_entropy(y_hat, y)
-        result = pl.TrainResult(minimize=loss)
-        result.log('train_loss', loss)
-        return result
-
-    def validation_step(self, batch, batch_idx):
-        x, y = batch
-        y_hat = self(x)
-        loss = F.cross_entropy(y_hat, y)
-        result = pl.EvalResult(checkpoint_on=loss)
-        result.log('val_loss', loss)
-        result.log('val_acc', accuracy(y_hat, y))
-        return result
-
-    def test_step(self, batch, batch_idx):
-        x, y = batch
-        y_hat = self(x)
-        loss = F.cross_entropy(y_hat, y)
-        result = pl.EvalResult(checkpoint_on=loss)
-        result.log('test_loss', loss)
-        result.log('test_acc', accuracy(y_hat, y))
-        return result
-
-    def configure_optimizers(self):
-        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
-
-
-def cli_main():
-    from project.datasets.mnist import mnist
-    pl.seed_everything(1234)
-
-    # args
-    parser = ArgumentParser()
-    parser.add_argument('--gpus', default=0, type=int)
-
-    # optional... automatically add all the params
-    # parser = pl.Trainer.add_argparse_args(parser)
-    # parser = MNISTDataModule.add_argparse_args(parser)
-    args = parser.parse_args()
-
-    # data
-    mnist_train, mnist_val, test_dataset = mnist()
-
-    # model
-    model = LitClassifier(**vars(args))
-
-    # training
-    trainer = pl.Trainer(gpus=args.gpus, max_epochs=2, limit_train_batches=200)
-    trainer.fit(model, mnist_train, mnist_val)
-
-    trainer.test(test_dataloaders=test_dataset)
-
-
-if __name__ == '__main__':  # pragma: no cover
-    cli_main()
diff --git a/project/lit_image_classifier.py b/project/lit_image_classifier.py
new file mode 100644
index 0000000..1296a3f
--- /dev/null
+++ b/project/lit_image_classifier.py
@@ -0,0 +1,109 @@
+from argparse import ArgumentParser
+
+import torch
+import pytorch_lightning as pl
+from torch.nn import functional as F
+from torch.utils.data import DataLoader, random_split
+
+from torchvision.datasets.mnist import MNIST
+from torchvision import transforms
+
+
+class Backbone(torch.nn.Module):
+    def __init__(self, hidden_dim=128):
+        super().__init__()
+        self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
+        self.l2 = torch.nn.Linear(hidden_dim, 10)
+
+    def forward(self, x):
+        x = x.view(x.size(0), -1)
+        x = torch.relu(self.l1(x))
+        x = torch.relu(self.l2(x))
+        return x
+
+
+class LitClassifier(pl.LightningModule):
+    def __init__(self, backbone, learning_rate=1e-3):
+        super().__init__()
+        self.save_hyperparameters()
+        self.backbone = backbone
+
+    def forward(self, x):
+        # use forward for inference/predictions
+        embedding = self.backbone(x)
+        return embedding
+
+    def training_step(self, batch, batch_idx):
+        x, y = batch
+        y_hat = self.backbone(x)
+        loss = F.cross_entropy(y_hat, y)
+        self.log('train_loss', loss, on_epoch=True)
+        return loss
+
+    def validation_step(self, batch, batch_idx):
+        x, y = batch
+        y_hat = self.backbone(x)
+        loss = F.cross_entropy(y_hat, y)
+        self.log('valid_loss', loss, on_step=True)
+
+    def test_step(self, batch, batch_idx):
+        x, y = batch
+        y_hat = self.backbone(x)
+        loss = F.cross_entropy(y_hat, y)
+        self.log('test_loss', loss)
+
+    def configure_optimizers(self):
+        # self.hparams available because we called self.save_hyperparameters()
+        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
+
+    @staticmethod
+    def add_model_specific_args(parent_parser):
+        parser = ArgumentParser(parents=[parent_parser], add_help=False)
+        parser.add_argument('--learning_rate', type=float, default=0.0001)
+        return parser
+
+
+def cli_main():
+    pl.seed_everything(1234)
+
+    # ------------
+    # args
+    # ------------
+    parser = ArgumentParser()
+    parser.add_argument('--batch_size', default=32, type=int)
+    parser.add_argument('--hidden_dim', type=int, default=128)
+    parser = pl.Trainer.add_argparse_args(parser)
+    parser = LitClassifier.add_model_specific_args(parser)
+    args = parser.parse_args()
+
+    # ------------
+    # data
+    # ------------
+    dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
+    mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
+    mnist_train, mnist_val = random_split(dataset, [55000, 5000])
+
+    train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
+    val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
+    test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
+
+    # ------------
+    # model
+    # ------------
+    model = LitClassifier(Backbone(hidden_dim=args.hidden_dim), args.learning_rate)
+
+    # ------------
+    # training
+    # ------------
+    trainer = pl.Trainer.from_argparse_args(args)
+    trainer.fit(model, train_loader, val_loader)
+
+    # ------------
+    # testing
+    # ------------
+    result = trainer.test(test_dataloaders=test_loader)
+    print(result)
+
+
+if __name__ == '__main__':
+    cli_main()
diff --git a/project/lit_mnist.py b/project/lit_mnist.py
new file mode 100644
index 0000000..8733378
--- /dev/null
+++ b/project/lit_mnist.py
@@ -0,0 +1,96 @@
+from argparse import ArgumentParser
+
+import torch
+import pytorch_lightning as pl
+from torch.nn import functional as F
+from torch.utils.data import DataLoader, random_split
+
+from torchvision.datasets.mnist import MNIST
+from torchvision import transforms
+
+
+class LitClassifier(pl.LightningModule):
+    def __init__(self, hidden_dim=128, learning_rate=1e-3):
+        super().__init__()
+        self.save_hyperparameters()
+
+        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
+        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
+
+    def forward(self, x):
+        x = x.view(x.size(0), -1)
+        x = torch.relu(self.l1(x))
+        x = torch.relu(self.l2(x))
+        return x
+
+    def training_step(self, batch, batch_idx):
+        x, y = batch
+        y_hat = self(x)
+        loss = F.cross_entropy(y_hat, y)
+        return loss
+
+    def validation_step(self, batch, batch_idx):
+        x, y = batch
+        y_hat = self(x)
+        loss = F.cross_entropy(y_hat, y)
+        self.log('valid_loss', loss)
+
+    def test_step(self, batch, batch_idx):
+        x, y = batch
+        y_hat = self(x)
+        loss = F.cross_entropy(y_hat, y)
+        self.log('test_loss', loss)
+
+    def configure_optimizers(self):
+        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
+
+    @staticmethod
+    def add_model_specific_args(parent_parser):
+        parser = ArgumentParser(parents=[parent_parser], add_help=False)
+        parser.add_argument('--hidden_dim', type=int, default=128)
+        parser.add_argument('--learning_rate', type=float, default=0.0001)
+        return parser
+
+
+def cli_main():
+    pl.seed_everything(1234)
+
+    # ------------
+    # args
+    # ------------
+    parser = ArgumentParser()
+    parser.add_argument('--batch_size', default=32, type=int)
+    parser = pl.Trainer.add_argparse_args(parser)
+    parser = LitClassifier.add_model_specific_args(parser)
+    args = parser.parse_args()
+
+    # ------------
+    # data
+    # ------------
+    dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
+    mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
+    mnist_train, mnist_val = random_split(dataset, [55000, 5000])
+
+    train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
+    val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
+    test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
+
+    # ------------
+    # model
+    # ------------
+    model = LitClassifier(args.hidden_dim, args.learning_rate)
+
+    # ------------
+    # training
+    # ------------
+    trainer = pl.Trainer.from_argparse_args(args)
+    trainer.fit(model, train_loader, val_loader)
+
+    # ------------
+    # testing
+    # ------------
+    trainer.test(test_dataloaders=test_loader)
+
+
+if __name__ == '__main__':
+    cli_main()
diff --git a/requirements.txt b/requirements.txt
index fb1b543..30840f0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,3 +1,3 @@
-pytorch-lightning >= 0.9.0
+pytorch-lightning >= 1.0.0rc2
 torch >= 1.3.0
 torchvision >= 0.6.0
diff --git a/tests/test_classifier.py b/tests/test_classifier.py
index 96be92b..e173fd5 100644
--- a/tests/test_classifier.py
+++ b/tests/test_classifier.py
@@ -1,6 +1,6 @@
 from pytorch_lightning import Trainer, seed_everything
-
-from project.lit_classifier_main import LitClassifier
+from project.lit_mnist import LitClassifier
+from project.datasets.mnist import mnist
 
 
 def test_lit_classifier():
-- 
GitLab