Skip to content
Snippets Groups Projects
Select Git revision
  • e5b8d156bfee461b4e3fbbd91bf36f213ae9fc51
  • master default protected
  • Dawit
  • maike-patrick-first-pipeline
  • Jonas
  • Kamal
  • Maike
  • Patrick
  • Uni-Bremen
  • update-setup
10 results

mnist_baseline_trainer.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    mnist_baseline_trainer.py 928 B
    """
    This file runs the main training/val loop, etc... using Lightning Trainer    
    """
    from pytorch_lightning import Trainer
    from argparse import ArgumentParser
    from research_seed.mnist.mnist import CoolSystem
    
    
    def main(hparams):
        # init module
        model = CoolSystem(hparams)
    
        # most basic trainer, uses good defaults
        trainer = Trainer(
            max_nb_epochs=hparams.max_nb_epochs,
            gpus=hparams.gpus,
            nb_gpu_nodes=hparams.nodes,
        )
        trainer.fit(model)
    
    
    if __name__ == '__main__':
        parser = ArgumentParser(add_help=False)
        parser.add_argument('--gpus', type=str, default=None)
        parser.add_argument('--nodes', type=int, default=1)
    
        # give the module a chance to add own params
        # good practice to define LightningModule speficic params in the module
        parser = CoolSystem.add_model_specific_args(parser)
    
        # parse params
        hparams = parser.parse_args()
    
        main(hparams)