Skip to content
Snippets Groups Projects
Unverified Commit b80536f4 authored by William Falcon's avatar William Falcon Committed by GitHub
Browse files

update with new syntax

parent 77c1232d
Branches
No related tags found
No related merge requests found
...@@ -23,13 +23,13 @@ class CoolSystem(pl.LightningModule): ...@@ -23,13 +23,13 @@ class CoolSystem(pl.LightningModule):
def forward(self, x): def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1))) return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb): def training_step(self, batch, batch_idx):
# REQUIRED # REQUIRED
x, y = batch x, y = batch
y_hat = self.forward(x) y_hat = self.forward(x)
return {'loss': F.cross_entropy(y_hat, y)} return {'loss': F.cross_entropy(y_hat, y)}
def validation_step(self, batch, batch_nb): def validation_step(self, batch, batch_idx):
# OPTIONAL # OPTIONAL
x, y = batch x, y = batch
y_hat = self.forward(x) y_hat = self.forward(x)
...@@ -46,7 +46,7 @@ class CoolSystem(pl.LightningModule): ...@@ -46,7 +46,7 @@ class CoolSystem(pl.LightningModule):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@pl.data_loader @pl.data_loader
def tng_dataloader(self): def train_dataloader(self):
# REQUIRED # REQUIRED
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size) return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment