diff --git a/research_seed/mnist/mnist.py b/research_seed/mnist/mnist.py index c48ea849cfefb6e143e9bba6a769ec5ad27a9db9..c9ed3976228ca99fa937bd73d0d85de6813f6e90 100644 --- a/research_seed/mnist/mnist.py +++ b/research_seed/mnist/mnist.py @@ -23,13 +23,13 @@ class CoolSystem(pl.LightningModule): def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) - def training_step(self, batch, batch_nb): + def training_step(self, batch, batch_idx): # REQUIRED x, y = batch y_hat = self.forward(x) return {'loss': F.cross_entropy(y_hat, y)} - def validation_step(self, batch, batch_nb): + def validation_step(self, batch, batch_idx): # OPTIONAL x, y = batch y_hat = self.forward(x) @@ -46,7 +46,7 @@ class CoolSystem(pl.LightningModule): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) @pl.data_loader - def tng_dataloader(self): + def train_dataloader(self): # REQUIRED return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)