Skip to content
Snippets Groups Projects
Select Git revision
  • fe8071f80c365af74d318b5430a0f81f8ee8a559
  • development default
  • production protected
3 results

start_celery_worker.sh

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    test_classifier.py 433 B
    from pytorch_lightning import Trainer, seed_everything
    
    from project.lit_classifier_main import LitClassifier
    
    
    def test_lit_classifier():
        seed_everything(1234)
    
        model = LitClassifier()
        train, val, test = mnist()
        trainer = Trainer(limit_train_batches=50, limit_val_batches=20, max_epochs=2)
        trainer.fit(model, train, val)
    
        results = trainer.test(test_dataloaders=test)
        assert results[0]['test_acc'] > 0.7