Skip to content
Snippets Groups Projects
Select Git revision
  • 70d11575aef4c73f95b542c57a71cffb688aab2c
  • master default protected
  • Dawit
  • maike-patrick-first-pipeline
  • Jonas
  • Kamal
  • Maike
  • Patrick
  • Uni-Bremen
  • update-setup
10 results

test_classifier.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    test_classifier.py 473 B
    from project.datasets.mnist import mnist
    from project.lit_classifier_main import LitClassifier
    from pytorch_lightning import Trainer, seed_everything
    
    
    def test_lit_classifier():
        seed_everything(1234)
    
        model = LitClassifier()
        train, val, test = mnist()
        trainer = Trainer(limit_train_batches=50, limit_val_batches=20, max_epochs=2)
        trainer.fit(model, train, val)
    
        results = trainer.test(test_dataloaders=test)
        assert results[0]['test_acc'] > 0.7