Skip to content
Snippets Groups Projects
Unverified Commit 9506b577 authored by William Falcon's avatar William Falcon Committed by GitHub
Browse files

update with new syntax

parent b80536f4
No related branches found
No related tags found
No related merge requests found
......@@ -23,13 +23,13 @@ class CoolSystem(pl.LightningModule):
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
def training_step(self, batch, batch_idx):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)}
def validation_step(self, batch, batch_nb):
def validation_step(self, batch, batch_idx):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
......@@ -46,7 +46,7 @@ class CoolSystem(pl.LightningModule):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@pl.data_loader
def tng_dataloader(self):
def train_dataloader(self):
# REQUIRED
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment