Skip to content
Snippets Groups Projects
Select Git revision
  • 25ac97fcb9e0dfa4c0e345d6c2e6f05a66fe5fbb
  • main default protected
2 results

start_script.py

Blame
  • Forked from Ockenden, Samuel / CiS Projekt
    Source project has a limited visibility.
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    mnist.py 2.44 KiB
    """
    This file defines the core research contribution   
    """
    import os
    import torch
    from torch.nn import functional as F
    from torch.utils.data import DataLoader
    from torchvision.datasets import MNIST
    import torchvision.transforms as transforms
    from argparse import ArgumentParser
    
    import pytorch_lightning as pl
    
    
    class CoolSystem(pl.LightningModule):
    
        def __init__(self, hparams):
            super(CoolSystem, self).__init__()
            # not the best model...
            self.hparams = hparams
            self.l1 = torch.nn.Linear(28 * 28, 10)
    
        def forward(self, x):
            return torch.relu(self.l1(x.view(x.size(0), -1)))
    
        def training_step(self, batch, batch_idx):
            # REQUIRED
            x, y = batch
            y_hat = self.forward(x)
            return {'loss': F.cross_entropy(y_hat, y)}
    
        def validation_step(self, batch, batch_idx):
            # OPTIONAL
            x, y = batch
            y_hat = self.forward(x)
            return {'val_loss': F.cross_entropy(y_hat, y)}
    
        def validation_end(self, outputs):
            # OPTIONAL
            avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
            return {'avg_val_loss': avg_loss}
    
        def configure_optimizers(self):
            # REQUIRED
            # can return multiple optimizers and learning_rate schedulers
            return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
    
        @pl.data_loader
        def train_dataloader(self):
            # REQUIRED
            return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
    
        @pl.data_loader
        def val_dataloader(self):
            # OPTIONAL
            return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
    
        @pl.data_loader
        def test_dataloader(self):
            # OPTIONAL
            return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
    
        @staticmethod
        def add_model_specific_args(parent_parser):
            """
            Specify the hyperparams for this LightningModule
            """
            # MODEL specific
            parser = ArgumentParser(parents=[parent_parser])
            parser.add_argument('--learning_rate', default=0.02, type=float)
            parser.add_argument('--batch_size', default=32, type=int)
    
            # training specific (for this model)
            parser.add_argument('--max_nb_epochs', default=2, type=int)
    
            return parser