Skip to content
Snippets Groups Projects
Select Git revision
  • 1a31a19e4dfabd30d94517520974c3a428c29e1f
  • main default protected
2 results

test.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    test_classifier.py 433 B
    from pytorch_lightning import Trainer, seed_everything
    
    from project.lit_classifier_main import LitClassifier
    
    
    def test_lit_classifier():
        seed_everything(1234)
    
        model = LitClassifier()
        train, val, test = mnist()
        trainer = Trainer(limit_train_batches=50, limit_val_batches=20, max_epochs=2)
        trainer.fit(model, train, val)
    
        results = trainer.test(test_dataloaders=test)
        assert results[0]['test_acc'] > 0.7