Skip to content
Snippets Groups Projects
Unverified Commit d2de2224 authored by William Falcon's avatar William Falcon Committed by GitHub
Browse files

Update mnist_baseline.py

parent b6d32d95
No related branches found
No related tags found
No related merge requests found
......@@ -45,17 +45,14 @@ class CoolSystem(pl.LightningModule):
# can return multiple optimizers and learning_rate schedulers
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@pl.data_loader
def train_dataloader(self):
# REQUIRED
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
@pl.data_loader
def val_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
@pl.data_loader
def test_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment