From e06cfec7feee005c081cd38b56bdf8aa14676427 Mon Sep 17 00:00:00 2001
From: William Falcon <waf2107@columbia.edu>
Date: Fri, 11 Sep 2020 10:38:07 -0400
Subject: [PATCH] clean up sample project

---
 README.md                                             |  2 +-
 {src => project}/__init__.py                          |  0
 .../lit_classifier_main.py                            | 11 ++---------
 setup.py                                              |  2 +-
 4 files changed, 4 insertions(+), 11 deletions(-)
 rename {src => project}/__init__.py (100%)
 rename src/mnist_classifier.py => project/lit_classifier_main.py (80%)

diff --git a/README.md b/README.md
index c6cd424..b5ac7ac 100644
--- a/README.md
+++ b/README.md
@@ -55,7 +55,7 @@ pip install -r requirements.txt
 cd src
 
 # run module (example: mnist as your main contribution)   
-python mnist_classifier.py    
+python lit_classifier_main.py    
 ```
 
 ## Main Contribution      
diff --git a/src/__init__.py b/project/__init__.py
similarity index 100%
rename from src/__init__.py
rename to project/__init__.py
diff --git a/src/mnist_classifier.py b/project/lit_classifier_main.py
similarity index 80%
rename from src/mnist_classifier.py
rename to project/lit_classifier_main.py
index 1cd3180..020b72c 100644
--- a/src/mnist_classifier.py
+++ b/project/lit_classifier_main.py
@@ -1,13 +1,9 @@
-import os
 from argparse import ArgumentParser
 
 import torch
 import pytorch_lightning as pl
 from pytorch_lightning.metrics.functional import accuracy
 from torch.nn import functional as F
-from torch.utils.data import DataLoader, random_split
-from torchvision import transforms
-from torchvision.datasets import MNIST
 
 
 class LitClassifier(pl.LightningModule):
@@ -58,6 +54,7 @@ class LitClassifier(pl.LightningModule):
 
 
 def cli_main():
+    from project.datasets.mnist import mnist
     pl.seed_everything(1234)
 
     # args
@@ -70,11 +67,7 @@ def cli_main():
     args = parser.parse_args()
 
     # data
-    dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
-    mnist_train, mnist_val = random_split(dataset, [55000, 5000], generator=torch.Generator().manual_seed(1234))
-    mnist_train = DataLoader(mnist_train, batch_size=32)
-    mnist_val = DataLoader(mnist_val, batch_size=32)
-    test_dataset = DataLoader(MNIST('', train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
+    mnist_train, mnist_val, test_dataset = mnist()
 
     # model
     model = LitClassifier(**vars(args))
diff --git a/setup.py b/setup.py
index 0aea5a3..ff891d7 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@
 from setuptools import setup, find_packages
 
 setup(
-    name='src',
+    name='project',
     version='0.0.0',
     description='Describe Your Cool Project',
     author='',
-- 
GitLab