From b80536f44337e3fe008d1fb3a0f1e23fd95b697b Mon Sep 17 00:00:00 2001 From: William Falcon <waf2107@columbia.edu> Date: Thu, 26 Sep 2019 10:50:33 -0400 Subject: [PATCH] update with new syntax --- research_seed/mnist/mnist.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/research_seed/mnist/mnist.py b/research_seed/mnist/mnist.py index c48ea84..c9ed397 100644 --- a/research_seed/mnist/mnist.py +++ b/research_seed/mnist/mnist.py @@ -23,13 +23,13 @@ class CoolSystem(pl.LightningModule): def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) - def training_step(self, batch, batch_nb): + def training_step(self, batch, batch_idx): # REQUIRED x, y = batch y_hat = self.forward(x) return {'loss': F.cross_entropy(y_hat, y)} - def validation_step(self, batch, batch_nb): + def validation_step(self, batch, batch_idx): # OPTIONAL x, y = batch y_hat = self.forward(x) @@ -46,7 +46,7 @@ class CoolSystem(pl.LightningModule): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) @pl.data_loader - def tng_dataloader(self): + def train_dataloader(self): # REQUIRED return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size) -- GitLab