diff --git a/README.md b/README.md index c6cd424fe24667a5201d285111264febd8e365a5..b5ac7acf18f0f1ed9646aa3275c856999727ca67 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ pip install -r requirements.txt cd src # run module (example: mnist as your main contribution) -python mnist_classifier.py +python lit_classifier_main.py ``` ## Main Contribution diff --git a/src/__init__.py b/project/__init__.py similarity index 100% rename from src/__init__.py rename to project/__init__.py diff --git a/src/mnist_classifier.py b/project/lit_classifier_main.py similarity index 80% rename from src/mnist_classifier.py rename to project/lit_classifier_main.py index 1cd318056c237bb1e23dfd8ab7ecb886c648f3f8..020b72c6c0e843b1166d4a95824b196a31eec45f 100644 --- a/src/mnist_classifier.py +++ b/project/lit_classifier_main.py @@ -1,13 +1,9 @@ -import os from argparse import ArgumentParser import torch import pytorch_lightning as pl from pytorch_lightning.metrics.functional import accuracy from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split -from torchvision import transforms -from torchvision.datasets import MNIST class LitClassifier(pl.LightningModule): @@ -58,6 +54,7 @@ class LitClassifier(pl.LightningModule): def cli_main(): + from project.datasets.mnist import mnist pl.seed_everything(1234) # args @@ -70,11 +67,7 @@ def cli_main(): args = parser.parse_args() # data - dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor()) - mnist_train, mnist_val = random_split(dataset, [55000, 5000], generator=torch.Generator().manual_seed(1234)) - mnist_train = DataLoader(mnist_train, batch_size=32) - mnist_val = DataLoader(mnist_val, batch_size=32) - test_dataset = DataLoader(MNIST('', train=False, download=True, transform=transforms.ToTensor()), batch_size=32) + mnist_train, mnist_val, test_dataset = mnist() # model model = LitClassifier(**vars(args)) diff --git a/setup.py b/setup.py index 0aea5a3a9724c894a7257db8a834b9afb0509d8f..ff891d70bb3ca1b3c94415a7df2c95af3346e0b9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( - name='src', + name='project', version='0.0.0', description='Describe Your Cool Project', author='',