Skip to content
Snippets Groups Projects
Commit e06cfec7 authored by William Falcon's avatar William Falcon
Browse files

clean up sample project

parent b527088b
No related branches found
No related tags found
No related merge requests found
......@@ -55,7 +55,7 @@ pip install -r requirements.txt
cd src
# run module (example: mnist as your main contribution)
python mnist_classifier.py
python lit_classifier_main.py
```
## Main Contribution
......
File moved
import os
from argparse import ArgumentParser
import torch
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
class LitClassifier(pl.LightningModule):
......@@ -58,6 +54,7 @@ class LitClassifier(pl.LightningModule):
def cli_main():
from project.datasets.mnist import mnist
pl.seed_everything(1234)
# args
......@@ -70,11 +67,7 @@ def cli_main():
args = parser.parse_args()
# data
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000], generator=torch.Generator().manual_seed(1234))
mnist_train = DataLoader(mnist_train, batch_size=32)
mnist_val = DataLoader(mnist_val, batch_size=32)
test_dataset = DataLoader(MNIST('', train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
mnist_train, mnist_val, test_dataset = mnist()
# model
model = LitClassifier(**vars(args))
......
......@@ -3,7 +3,7 @@
from setuptools import setup, find_packages
setup(
name='src',
name='project',
version='0.0.0',
description='Describe Your Cool Project',
author='',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment