From e06cfec7feee005c081cd38b56bdf8aa14676427 Mon Sep 17 00:00:00 2001 From: William Falcon <waf2107@columbia.edu> Date: Fri, 11 Sep 2020 10:38:07 -0400 Subject: [PATCH] clean up sample project --- README.md | 2 +- {src => project}/__init__.py | 0 .../lit_classifier_main.py | 11 ++--------- setup.py | 2 +- 4 files changed, 4 insertions(+), 11 deletions(-) rename {src => project}/__init__.py (100%) rename src/mnist_classifier.py => project/lit_classifier_main.py (80%) diff --git a/README.md b/README.md index c6cd424..b5ac7ac 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ pip install -r requirements.txt cd src # run module (example: mnist as your main contribution) -python mnist_classifier.py +python lit_classifier_main.py ``` ## Main Contribution diff --git a/src/__init__.py b/project/__init__.py similarity index 100% rename from src/__init__.py rename to project/__init__.py diff --git a/src/mnist_classifier.py b/project/lit_classifier_main.py similarity index 80% rename from src/mnist_classifier.py rename to project/lit_classifier_main.py index 1cd3180..020b72c 100644 --- a/src/mnist_classifier.py +++ b/project/lit_classifier_main.py @@ -1,13 +1,9 @@ -import os from argparse import ArgumentParser import torch import pytorch_lightning as pl from pytorch_lightning.metrics.functional import accuracy from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split -from torchvision import transforms -from torchvision.datasets import MNIST class LitClassifier(pl.LightningModule): @@ -58,6 +54,7 @@ class LitClassifier(pl.LightningModule): def cli_main(): + from project.datasets.mnist import mnist pl.seed_everything(1234) # args @@ -70,11 +67,7 @@ def cli_main(): args = parser.parse_args() # data - dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor()) - mnist_train, mnist_val = random_split(dataset, [55000, 5000], generator=torch.Generator().manual_seed(1234)) - mnist_train = DataLoader(mnist_train, batch_size=32) - mnist_val = DataLoader(mnist_val, batch_size=32) - test_dataset = DataLoader(MNIST('', train=False, download=True, transform=transforms.ToTensor()), batch_size=32) + mnist_train, mnist_val, test_dataset = mnist() # model model = LitClassifier(**vars(args)) diff --git a/setup.py b/setup.py index 0aea5a3..ff891d7 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( - name='src', + name='project', version='0.0.0', description='Describe Your Cool Project', author='', -- GitLab