Skip to content
Snippets Groups Projects
Select Git revision
  • 329a4195454cb180831104c01563bdc61c3720d0
  • master default protected
  • develop
  • feat-features-update
  • esra-current
  • v0.1.2.1
6 results

apollon.commands.rst

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    mnist.py 2.63 KiB
    """
    This file defines the core research contribution   
    """
    import os
    import torch
    from torch.nn import functional as F
    from torch.utils.data import DataLoader
    from torchvision.datasets import MNIST
    import torchvision.transforms as transforms
    from argparse import ArgumentParser
    
    import pytorch_lightning as pl
    
    
    class CoolSystem(pl.LightningModule):
    
        def __init__(self, hparams):
            super(CoolSystem, self).__init__()
            # not the best model...
            self.hparams = hparams
            self.l1 = torch.nn.Linear(28 * 28, 10)
    
        def forward(self, x):
            return torch.relu(self.l1(x.view(x.size(0), -1)))
    
        def training_step(self, batch, batch_idx):
            # REQUIRED
            x, y = batch
            y_hat = self.forward(x)
            loss = F.cross_entropy(y_hat, y)
    
            result = pl.TrainResult(minimize=loss)
            result.log('train_loss', loss, prog_bar=True)
    
            return result
    
        def validation_step(self, batch, batch_idx):
            # OPTIONAL
            x, y = batch
            y_hat = self.forward(x)
            loss = F.cross_entropy(y_hat, y)
            result = pl.EvalResult()
            result.valid_batch_loss = loss
            result.log('valid_loss', loss, on_epoch=True, prog_bar=True)
    
            return result
    
        def validation_epoch_end(self, outputs):
            # OPTIONAL
            avg_loss = outputs.valid_batch_loss.mean()
            result = pl.EvalResult(checkpoint_on=avg_loss)
            result.log('valid_loss', avg_loss, on_epoch=True, prog_bar=True)
    
            return result
    
        def test_step(self, batch, batch_idx):
            # OPTIONAL
            x, y = batch
            y_hat = self.forward(x)
            loss = F.cross_entropy(y_hat, y)
            result = pl.EvalResult()
            result.test_batch_loss = loss
            result.log('test_loss', loss, on_epoch=True)
    
            return result
    
        def test_epoch_end(self, outputs):
            # OPTIONAL
            avg_loss = outputs.test_batch_loss.mean()
    
            result = pl.EvalResult()
            result.log('test_loss', avg_loss, on_epoch=True)
            return result
    
        def configure_optimizers(self):
            # REQUIRED
            # can return multiple optimizers and learning_rate schedulers
            return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
    
    
        @staticmethod
        def add_model_specific_args(parent_parser):
            """
            Specify the hyperparams for this LightningModule
            """
            # MODEL specific
            parser = ArgumentParser(parents=[parent_parser], add_help=False)
            parser.add_argument('--learning_rate', default=0.02, type=float)
    
            # training specific (for this model)
            parser.add_argument('--max_nb_epochs', default=2, type=int)
    
            return parser