diff --git a/research_seed/mnist/mnist.py b/research_seed/mnist/mnist.py index c9ed3976228ca99fa937bd73d0d85de6813f6e90..e1d54bda640669e01b026046b3b8f00fbf12a0d9 100644 --- a/research_seed/mnist/mnist.py +++ b/research_seed/mnist/mnist.py @@ -45,17 +45,14 @@ class CoolSystem(pl.LightningModule): # can return multiple optimizers and learning_rate schedulers return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) - @pl.data_loader def train_dataloader(self): # REQUIRED return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size) - @pl.data_loader def val_dataloader(self): # OPTIONAL return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size) - @pl.data_loader def test_dataloader(self): # OPTIONAL return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)