diff --git a/README.md b/README.md
index 665d714328f54d3c6a7d649c8b368cf0cabb0250..c6cd424fe24667a5201d285111264febd8e365a5 100644
--- a/README.md
+++ b/README.md
@@ -52,10 +52,10 @@ pip install -r requirements.txt
  Next, navigate to [Your Main Contribution (MNIST here)] and run it.   
  ```bash
 # module folder
-cd research_mnist/    
+cd src
 
 # run module (example: mnist as your main contribution)   
-python simplest_mnist.py    
+python mnist_classifier.py    
 ```
 
 ## Main Contribution      
diff --git a/research_mnist/README.md b/research_mnist/README.md
deleted file mode 100644
index 029c0970b8c66b4ca7120485c8722a6256bc7e4c..0000000000000000000000000000000000000000
--- a/research_mnist/README.md
+++ /dev/null
@@ -1,35 +0,0 @@
-## Research Seed Folder   
-Create a folder for each contribution (ie: MNIST, BERT, etc...).   
-Each folder will have:
-
-##### contribution_name_trainer.py    
-Runs your LightningModule. Abstracts training loop, distributed training, etc...   
-
-##### contribution_name.py  
-Holds your main contribution   
-
-## Example  
-The folder here gives an example for mnist.   
-
-### MNIST    
-In this readme, give instructions on how to run your code.   
-
-#### CPU   
-```bash   
-python mnist_trainer.py     
-```
-
-#### Multiple-GPUs   
-```bash   
-python mnist_trainer.py --gpus 4
-```   
-
-or specific GPUs
-```bash   
-python mnist_trainer.py --gpus '0,3'
-```   
-
-#### On multiple nodes   
-```bash  
-python mnist_trainer.py --gpus 4 --nodes 4  --precision 16
-```   
diff --git a/research_mnist/mnist_data_module.py b/research_mnist/mnist_data_module.py
deleted file mode 100644
index a55b8f9512a70cefbee95e71b2929227e2b9fe64..0000000000000000000000000000000000000000
--- a/research_mnist/mnist_data_module.py
+++ /dev/null
@@ -1,88 +0,0 @@
-from argparse import ArgumentParser
-
-import pytorch_lightning as pl
-from torch.utils.data import random_split, DataLoader
-from torchvision import transforms
-# Note - you must have torchvision installed for this example
-from torchvision.datasets import MNIST
-
-
-class MNISTDataModule(pl.LightningDataModule):
-
-    def __init__(self, hparams):
-
-        super().__init__()
-
-        self.hparams = hparams
-
-        self.data_dir = self.hparams.data_dir
-        self.batch_size = self.hparams.batch_size
-
-        # We hardcode dataset specific stuff here.
-        self.num_classes = 10
-        self.dims = (1, 28, 28)
-        self.transform = transforms.Compose([transforms.ToTensor(), ])
-
-        # Basic test that parameters passed are sensible.
-        assert (
-            self.hparams.train_size + self.hparams.valid_size == 60_000
-        ), "Invalid Train and Valid Split, make sure they add up to 60,000"
-
-    def prepare_data(self):
-        # download the dataset
-        MNIST(self.data_dir, train=True, download=True)
-        MNIST(self.data_dir, train=False, download=True)
-
-    def setup(self, stage=None):
-
-        # Assign train/val datasets for use in dataloaders
-        if stage == "fit" or stage is None:
-            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
-            self.mnist_train, self.mnist_val = random_split(
-                mnist_full, [self.hparams.train_size, self.hparams.valid_size]
-            )
-
-        # Assign test dataset for use in dataloader(s)
-        if stage == "test" or stage is None:
-            self.mnist_test = MNIST(
-                self.data_dir, train=False, transform=self.transform
-            )
-
-    def train_dataloader(self):
-        # REQUIRED
-        return DataLoader(
-            self.mnist_train,
-            batch_size=self.batch_size,
-            num_workers=self.hparams.workers,
-        )
-
-    def val_dataloader(self):
-        # OPTIONAL
-        return DataLoader(
-            self.mnist_val, batch_size=self.batch_size, num_workers=self.hparams.workers
-        )
-
-    def test_dataloader(self):
-        # OPTIONAL
-        return DataLoader(
-            self.mnist_test,
-            batch_size=self.batch_size,
-            num_workers=self.hparams.workers,
-        )
-
-    @staticmethod
-    def add_data_specific_args(parent_parser):
-        """
-        Specify the hyperparams for this LightningModule
-        """
-        # Dataset specific
-        parser = ArgumentParser(parents=[parent_parser], add_help=False)
-        parser.add_argument("--batch_size", default=32, type=int)
-        parser.add_argument("--data_dir", default="./", type=str)
-
-        # training specific
-        parser.add_argument("--train_size", default=55_000, type=int)
-        parser.add_argument("--valid_size", default=5_000, type=int)
-        parser.add_argument("--workers", default=8, type=int)
-
-        return parser
diff --git a/research_mnist/mnist_trainer.py b/research_mnist/mnist_trainer.py
deleted file mode 100644
index 23d257576fa5a948ebd663c68746c7f2285eb74a..0000000000000000000000000000000000000000
--- a/research_mnist/mnist_trainer.py
+++ /dev/null
@@ -1,46 +0,0 @@
-"""
-This file runs the main training/val loop, etc. using Lightning Trainer.
-"""
-from argparse import ArgumentParser
-
-from pytorch_lightning import Trainer, seed_everything
-
-from research_mnist.simplest_mnist import CoolSystem
-from research_mnist.mnist_data_module import MNISTDataModule
-
-
-def main(args):
-    # init modules
-    dm = MNISTDataModule(hparams=args)
-    model = CoolSystem(hparams=args)
-
-    # most basic trainer, uses good defaults
-    trainer = Trainer.from_argparse_args(args)
-    trainer.fit(model, dm)
-
-    trainer.test()
-
-
-def main_cli():
-    # sets seeds for numpy, torch, etc...
-    # must do for DDP to work well
-    seed_everything(123)
-    parser = ArgumentParser(add_help=False)
-
-    # add args from trainer
-    parser = Trainer.add_argparse_args(parser)
-
-    # give the module a chance to add own params
-    # good practice to define LightningModule speficic params in the module
-    parser = CoolSystem.add_model_specific_args(parser)
-    # same goes for data modules
-    parser = MNISTDataModule.add_data_specific_args(parser)
-
-    # parse params
-    args = parser.parse_args()
-
-    main(args)
-
-
-if __name__ == '__main__':
-    main_cli()
diff --git a/research_mnist/simplest_mnist.py b/research_mnist/simplest_mnist.py
deleted file mode 100644
index 6f81a84bbad1d26080faaa13488f7f5f7cbe84e0..0000000000000000000000000000000000000000
--- a/research_mnist/simplest_mnist.py
+++ /dev/null
@@ -1,88 +0,0 @@
-"""
-This file defines the core research contribution.
-"""
-from argparse import ArgumentParser
-
-import pytorch_lightning as pl
-import torch
-from torch.nn import functional as F
-
-
-class CoolSystem(pl.LightningModule):
-
-    def __init__(self, hparams):
-        super(CoolSystem, self).__init__()
-        # not the best model...
-        self.hparams = hparams
-        self.l1 = torch.nn.Linear(28 * 28, 10)
-
-    def forward(self, x):
-        return torch.relu(self.l1(x.view(x.size(0), -1)))
-
-    def training_step(self, batch, batch_idx):
-        # REQUIRED
-        x, y = batch
-        y_hat = self.forward(x)
-        loss = F.cross_entropy(y_hat, y)
-
-        result = pl.TrainResult(minimize=loss)
-        result.log('train_loss', loss, prog_bar=True)
-
-        return result
-
-    def validation_step(self, batch, batch_idx):
-        # OPTIONAL
-        x, y = batch
-        y_hat = self.forward(x)
-        loss = F.cross_entropy(y_hat, y)
-        result = pl.EvalResult()
-        result.valid_batch_loss = loss
-        result.log('valid_loss', loss, on_epoch=True, prog_bar=True)
-
-        return result
-
-    def validation_epoch_end(self, outputs):
-        # OPTIONAL
-        avg_loss = outputs.valid_batch_loss.mean()
-        result = pl.EvalResult(checkpoint_on=avg_loss)
-        result.log('valid_loss', avg_loss, on_epoch=True, prog_bar=True)
-
-        return result
-
-    def test_step(self, batch, batch_idx):
-        # OPTIONAL
-        x, y = batch
-        y_hat = self.forward(x)
-        loss = F.cross_entropy(y_hat, y)
-        result = pl.EvalResult()
-        result.test_batch_loss = loss
-        result.log('test_loss', loss, on_epoch=True)
-
-        return result
-
-    def test_epoch_end(self, outputs):
-        # OPTIONAL
-        avg_loss = outputs.test_batch_loss.mean()
-
-        result = pl.EvalResult()
-        result.log('test_loss', avg_loss, on_epoch=True)
-        return result
-
-    def configure_optimizers(self):
-        # REQUIRED
-        # can return multiple optimizers and learning_rate schedulers
-        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
-
-    @staticmethod
-    def add_model_specific_args(parent_parser):
-        """
-        Specify the hyperparams for this LightningModule
-        """
-        # MODEL specific
-        parser = ArgumentParser(parents=[parent_parser], add_help=False)
-        parser.add_argument('--learning_rate', default=0.02, type=float)
-
-        # training specific (for this model)
-        parser.add_argument('--max_nb_epochs', default=2, type=int)
-
-        return parser
diff --git a/setup.py b/setup.py
index 70be95a18d08c2da6aef0f8172d9a5d1178919a7..0aea5a3a9724c894a7257db8a834b9afb0509d8f 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@
 from setuptools import setup, find_packages
 
 setup(
-    name='research_mnist',
+    name='src',
     version='0.0.0',
     description='Describe Your Cool Project',
     author='',
diff --git a/research_mnist/__init__.py b/src/__init__.py
similarity index 100%
rename from research_mnist/__init__.py
rename to src/__init__.py
diff --git a/src/mnist_classifier.py b/src/mnist_classifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cd318056c237bb1e23dfd8ab7ecb886c648f3f8
--- /dev/null
+++ b/src/mnist_classifier.py
@@ -0,0 +1,90 @@
+import os
+from argparse import ArgumentParser
+
+import torch
+import pytorch_lightning as pl
+from pytorch_lightning.metrics.functional import accuracy
+from torch.nn import functional as F
+from torch.utils.data import DataLoader, random_split
+from torchvision import transforms
+from torchvision.datasets import MNIST
+
+
+class LitClassifier(pl.LightningModule):
+    def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, **kwargs):
+        super().__init__()
+        self.save_hyperparameters()
+
+        self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
+        self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)
+
+        self.mnist_train = None
+        self.mnist_val = None
+
+    def forward(self, x):
+        x = x.view(x.size(0), -1)
+        x = torch.relu(self.l1(x))
+        x = torch.relu(self.l2(x))
+        return x
+
+    def training_step(self, batch, batch_idx):
+        x, y = batch
+        y_hat = self(x)
+        loss = F.cross_entropy(y_hat, y)
+        result = pl.TrainResult(minimize=loss)
+        result.log('train_loss', loss)
+        return result
+
+    def validation_step(self, batch, batch_idx):
+        x, y = batch
+        y_hat = self(x)
+        loss = F.cross_entropy(y_hat, y)
+        result = pl.EvalResult(checkpoint_on=loss)
+        result.log('val_loss', loss)
+        result.log('val_acc', accuracy(y_hat, y))
+        return result
+
+    def test_step(self, batch, batch_idx):
+        x, y = batch
+        y_hat = self(x)
+        loss = F.cross_entropy(y_hat, y)
+        result = pl.EvalResult(checkpoint_on=loss)
+        result.log('test_loss', loss)
+        result.log('test_acc', accuracy(y_hat, y))
+        return result
+
+    def configure_optimizers(self):
+        return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
+
+
+def cli_main():
+    pl.seed_everything(1234)
+
+    # args
+    parser = ArgumentParser()
+    parser.add_argument('--gpus', default=0, type=int)
+
+    # optional... automatically add all the params
+    # parser = pl.Trainer.add_argparse_args(parser)
+    # parser = MNISTDataModule.add_argparse_args(parser)
+    args = parser.parse_args()
+
+    # data
+    dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
+    mnist_train, mnist_val = random_split(dataset, [55000, 5000], generator=torch.Generator().manual_seed(1234))
+    mnist_train = DataLoader(mnist_train, batch_size=32)
+    mnist_val = DataLoader(mnist_val, batch_size=32)
+    test_dataset = DataLoader(MNIST('', train=False, download=True, transform=transforms.ToTensor()), batch_size=32)
+
+    # model
+    model = LitClassifier(**vars(args))
+
+    # training
+    trainer = pl.Trainer(gpus=args.gpus, max_epochs=2, limit_train_batches=200)
+    trainer.fit(model, mnist_train, mnist_val)
+
+    trainer.test(test_dataloaders=test_dataset)
+
+
+if __name__ == '__main__':  # pragma: no cover
+    cli_main()