Select Git revision
test_classifier.py
-
William Falcon authoredWilliam Falcon authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
test_classifier.py 473 B
from project.datasets.mnist import mnist
from project.lit_classifier_main import LitClassifier
from pytorch_lightning import Trainer, seed_everything
def test_lit_classifier():
seed_everything(1234)
model = LitClassifier()
train, val, test = mnist()
trainer = Trainer(limit_train_batches=50, limit_val_batches=20, max_epochs=2)
trainer.fit(model, train, val)
results = trainer.test(test_dataloaders=test)
assert results[0]['test_acc'] > 0.7