diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_classifier.py b/tests/test_classifier.py new file mode 100644 index 0000000000000000000000000000000000000000..a11a314e075ffe6e61e21a46f198e355f1e71f40 --- /dev/null +++ b/tests/test_classifier.py @@ -0,0 +1,15 @@ +from project.datasets.mnist import mnist +from project.lit_classifier_main import LitClassifier +from pytorch_lightning import Trainer, seed_everything + + +def test_lit_classifier(): + seed_everything(1234) + + model = LitClassifier() + train, val, test = mnist() + trainer = Trainer(limit_train_batches=50, limit_val_batches=20, max_epochs=2) + trainer.fit(model, train, val) + + results = trainer.test(test_dataloaders=test) + assert results[0]['test_acc'] > 0.7