diff --git a/.gitignore b/.gitignore
index 73c615f495e411a8b0f83d29460d4279dce974d0..1dc6a3997d333ca448b9226e8e5cf2d9ef63557a 100644
--- a/.gitignore
+++ b/.gitignore
@@ -120,6 +120,7 @@ venv.bak/
 
 # IDEs
 .idea
+.vscode
 
 # seed project
 lightning_logs/
diff --git a/requirements.txt b/requirements.txt
index 299248718de256390ef419166ee3f51be197d3c0..260eeda4423b12fd41ed697dc29276f2eef18c50 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1 @@
-pytorch-lightning >= 0.7.5
\ No newline at end of file
+pytorch-lightning >= 0.9.0
\ No newline at end of file
diff --git a/src/research_mnist/mnist.py b/src/research_mnist/mnist.py
index d441101d3af99b6910c5965ae2f8abd1faadba9c..dffe47f0dd3a702986b2a215e5fc0118cad95751 100644
--- a/src/research_mnist/mnist.py
+++ b/src/research_mnist/mnist.py
@@ -29,52 +29,54 @@ class CoolSystem(pl.LightningModule):
         y_hat = self.forward(x)
         loss = F.cross_entropy(y_hat, y)
 
-        tensorboard_logs = {'train_loss': loss}
+        result = pl.TrainResult(minimize=loss)
+        result.log('train_loss', loss, prog_bar=True)
 
-        return {'loss': loss, 'log': tensorboard_logs}
+        return result
 
     def validation_step(self, batch, batch_idx):
         # OPTIONAL
         x, y = batch
         y_hat = self.forward(x)
-        return {'val_loss': F.cross_entropy(y_hat, y)}
+        loss = F.cross_entropy(y_hat, y)
+        result = pl.EvalResult()
+        result.valid_batch_loss = loss
+        result.log('valid_loss', loss, on_epoch=True, prog_bar=True)
+
+        return result
 
     def validation_epoch_end(self, outputs):
         # OPTIONAL
-        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
+        avg_loss = outputs.valid_batch_loss.mean()
+        result = pl.EvalResult(checkpoint_on=avg_loss)
+        result.log('valid_loss', avg_loss, on_epoch=True, prog_bar=True)
 
-        tensorboard_logs = {'avg_val_loss': avg_loss}
-        return {'val_loss': avg_loss, 'log': tensorboard_logs}
+        return result
 
     def test_step(self, batch, batch_idx):
         # OPTIONAL
         x, y = batch
         y_hat = self.forward(x)
-        return {'test_loss': F.cross_entropy(y_hat, y)}
+        loss = F.cross_entropy(y_hat, y)
+        result = pl.EvalResult()
+        result.test_batch_loss = loss
+        result.log('test_loss', loss, on_epoch=True)
+
+        return result
 
     def test_epoch_end(self, outputs):
         # OPTIONAL
-        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
+        avg_loss = outputs.test_batch_loss.mean()
 
-        tensorboard_logs = {'test_val_loss': avg_loss}
-        return {'test_loss': avg_loss, 'log': tensorboard_logs}
+        result = pl.EvalResult()
+        result.log('test_loss', avg_loss, on_epoch=True)
+        return result
 
     def configure_optimizers(self):
         # REQUIRED
         # can return multiple optimizers and learning_rate schedulers
         return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
 
-    def train_dataloader(self):
-        # REQUIRED
-        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
-
-    def val_dataloader(self):
-        # OPTIONAL
-        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
-
-    def test_dataloader(self):
-        # OPTIONAL
-        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=self.hparams.batch_size)
 
     @staticmethod
     def add_model_specific_args(parent_parser):
@@ -84,7 +86,6 @@ class CoolSystem(pl.LightningModule):
         # MODEL specific
         parser = ArgumentParser(parents=[parent_parser], add_help=False)
         parser.add_argument('--learning_rate', default=0.02, type=float)
-        parser.add_argument('--batch_size', default=32, type=int)
 
         # training specific (for this model)
         parser.add_argument('--max_nb_epochs', default=2, type=int)
diff --git a/src/research_mnist/mnist_data_module.py b/src/research_mnist/mnist_data_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4f7c256383bd287bdd8732b8da548e8617b02af
--- /dev/null
+++ b/src/research_mnist/mnist_data_module.py
@@ -0,0 +1,92 @@
+from argparse import ArgumentParser
+
+import pytorch_lightning as pl
+from pytorch_lightning.metrics.functional import accuracy
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.utils.data import random_split, DataLoader
+
+# Note - you must have torchvision installed for this example
+from torchvision.datasets import MNIST, CIFAR10
+from torchvision import transforms
+
+
+class MNISTDataModule(pl.LightningDataModule):
+    def __init__(self, hparams):
+
+        super().__init__()
+
+        self.hparams = hparams
+
+        self.data_dir = self.hparams.data_dir
+        self.batch_size = self.hparams.batch_size
+
+        # We hardcode dataset specific stuff here.
+        self.num_classes = 10
+        self.dims = (1, 28, 28)
+        self.transform = transforms.Compose([transforms.ToTensor(),])
+
+        # Basic test that parameters passed are sensible.
+        assert (
+            self.hparams.train_size + self.hparams.valid_size == 60_000
+        ), "Invalid Train and Valid Split, make sure they add up to 60,000"
+
+    def prepare_data(self):
+        # download the dataset
+        MNIST(self.data_dir, train=True, download=True)
+        MNIST(self.data_dir, train=False, download=True)
+
+    def setup(self, stage=None):
+
+        # Assign train/val datasets for use in dataloaders
+        if stage == "fit" or stage is None:
+            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
+            self.mnist_train, self.mnist_val = random_split(
+                mnist_full, [self.hparams.train_size, self.hparams.valid_size]
+            )
+
+        # Assign test dataset for use in dataloader(s)
+        if stage == "test" or stage is None:
+            self.mnist_test = MNIST(
+                self.data_dir, train=False, transform=self.transform
+            )
+
+    def train_dataloader(self):
+        # REQUIRED
+        return DataLoader(
+            self.mnist_train,
+            batch_size=self.batch_size,
+            num_workers=self.hparams.workers,
+        )
+
+    def val_dataloader(self):
+        # OPTIONAL
+        return DataLoader(
+            self.mnist_val, batch_size=self.batch_size, num_workers=self.hparams.workers
+        )
+
+    def test_dataloader(self):
+        # OPTIONAL
+        return DataLoader(
+            self.mnist_test,
+            batch_size=self.batch_size,
+            num_workers=self.hparams.workers,
+        )
+
+    @staticmethod
+    def add_data_specific_args(parent_parser):
+        """
+        Specify the hyperparams for this LightningModule
+        """
+        # Dataset specific
+        parser = ArgumentParser(parents=[parent_parser], add_help=False)
+        parser.add_argument("--batch_size", default=32, type=int)
+        parser.add_argument("--data_dir", default="./", type=str)
+
+        # training specific
+        parser.add_argument("--train_size", default=55_000, type=int)
+        parser.add_argument("--valid_size", default=5_000, type=int)
+        parser.add_argument("--workers", default=8, type=int)
+
+        return parser
diff --git a/src/research_mnist/mnist_trainer.py b/src/research_mnist/mnist_trainer.py
index 74396b93af5ba4ca8f3d0c0b2ccc14c6420ddaa8..6e05f059cad4305de31998bfb39f5316379b9ad3 100644
--- a/src/research_mnist/mnist_trainer.py
+++ b/src/research_mnist/mnist_trainer.py
@@ -4,18 +4,20 @@ This file runs the main training/val loop, etc... using Lightning Trainer
 from pytorch_lightning import Trainer, seed_everything
 from argparse import ArgumentParser
 from src.research_mnist.mnist import CoolSystem
+from src.research_mnist.mnist_data_module import MNISTDataModule
 
 # sets seeds for numpy, torch, etc...
 # must do for DDP to work well
 seed_everything(123)
 
 def main(args):
-    # init module
+    # init modules
+    dm = MNISTDataModule(hparams=args)
     model = CoolSystem(hparams=args)
 
     # most basic trainer, uses good defaults
     trainer = Trainer.from_argparse_args(args)
-    trainer.fit(model)
+    trainer.fit(model, dm)
 
     trainer.test()
 
@@ -29,6 +31,8 @@ if __name__ == '__main__':
     # give the module a chance to add own params
     # good practice to define LightningModule speficic params in the module
     parser = CoolSystem.add_model_specific_args(parser)
+    # same goes for data modules
+    parser = MNISTDataModule.add_data_specific_args(parser)
 
     # parse params
     args = parser.parse_args()