Skip to content
Snippets Groups Projects
Select Git revision
  • 75b954bce5f8e8672612957ea909e078531d61fc
  • master default protected
  • Dawit
  • maike-patrick-first-pipeline
  • Jonas
  • Kamal
  • Maike
  • Patrick
  • Uni-Bremen
  • update-setup
10 results

mnist_trainer.py

Blame
  • user avatar
    Kartheek Akella authored
    75b954bc
    History
    Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    mnist_trainer.py 1.10 KiB
    """
    This file runs the main training/val loop, etc... using Lightning Trainer    
    """
    from pytorch_lightning import Trainer, seed_everything
    from argparse import ArgumentParser
    from src.research_mnist.mnist import CoolSystem
    from src.research_mnist.mnist_data_module import MNISTDataModule
    
    # sets seeds for numpy, torch, etc...
    # must do for DDP to work well
    seed_everything(123)
    
    def main(args):
        # init modules
        dm = MNISTDataModule(hparams=args)
        model = CoolSystem(hparams=args)
    
        # most basic trainer, uses good defaults
        trainer = Trainer.from_argparse_args(args)
        trainer.fit(model, dm)
    
        trainer.test()
    
    
    if __name__ == '__main__':
        parser = ArgumentParser(add_help=False)
    
        # add args from trainer
        parser = Trainer.add_argparse_args(parser)
    
        # give the module a chance to add own params
        # good practice to define LightningModule speficic params in the module
        parser = CoolSystem.add_model_specific_args(parser)
        # same goes for data modules
        parser = MNISTDataModule.add_data_specific_args(parser)
    
        # parse params
        args = parser.parse_args()
    
        main(args)