diff --git a/master_file.ipynb b/master_file.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..3f776daecd93706ec542f7813b7bdd084b5bff74
--- /dev/null
+++ b/master_file.ipynb
@@ -0,0 +1,633 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Implementation of\n",
+    "    - Unet from https://arxiv.org/pdf/1910.01113v2.pdf\n",
+    "    - LPD Net from https://arxiv.org/pdf/1707.06474.pdf"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import os\n",
+    "import astra\n",
+    "import odl\n",
+    "import numpy as np\n",
+    "import dival\n",
+    "#from dival import get_standard_dataset\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "from matplotlib import pyplot as plt\n",
+    "import torch.utils.data\n",
+    "import custom_odl_op8 as op\n",
+    "import time \n",
+    "from torch.optim.lr_scheduler import StepLR\n",
+    "from skimage.metrics import peak_signal_noise_ratio as PSNR\n",
+    "from skimage.metrics import structural_similarity as SSIM\n",
+    "import dival.datasets.lodopab_dataset as lodopab\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "'import numpy as np\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport astra\\nimport dival\\nfrom matplotlib import pyplot as plt\\nimport torch.utils.data\\nimport odl\\nimport custom_odl_op as op\\nimport custom_odl_op8 as op8\\nfrom skimage.metrics import peak_signal_noise_ratio as PSNR\\nfrom skimage.metrics import structural_similarity as SSIM\\n\\nimport dival.datasets.lodopab_dataset as lodopab\\nimport time'"
+      ]
+     },
+     "execution_count": 2,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "\"\"\"import numpy as np\n",
+    "import torch\n",
+    "import torch.nn as nn\n",
+    "import torch.nn.functional as F\n",
+    "import astra\n",
+    "import dival\n",
+    "from matplotlib import pyplot as plt\n",
+    "import torch.utils.data\n",
+    "import odl\n",
+    "import custom_odl_op as op\n",
+    "import custom_odl_op8 as op8\n",
+    "from skimage.metrics import peak_signal_noise_ratio as PSNR\n",
+    "from skimage.metrics import structural_similarity as SSIM\n",
+    "\n",
+    "import dival.datasets.lodopab_dataset as lodopab\n",
+    "import time\"\"\""
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Set chosen parameters\n",
+    "\n",
+    "    BATCH_SIZE:           number of images to process before the optimizer step\n",
+    "    EPOCHS:               number of epochs\n",
+    "    LEARNING_RATE         initial learning rate\n",
+    "    IMG_TO_TRAIN          total number of images per epoch\n",
+    "    DEVICE                device to use for computing\n",
+    "    PRINT_AFTER           number of batches to process before printing an update\n",
+    "    LR_UPDATE_AFTER       number of epochs before a learning rate update\n",
+    "    LR_UPDATE_FACTOR      facter of the learning rate update"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\"\"\"\n",
+    "constants from the dataset:\n",
+    "input_size: 1000*513\n",
+    "output_size: 362*362\n",
+    "\n",
+    "train_len = 35820\n",
+    "validation_len = 3522\n",
+    "test_len = 3553\n",
+    "\"\"\"\n",
+    "\n",
+    "\n",
+    "BATCH_SIZE = 64\n",
+    "EPOCHS = 10\n",
+    "LEARNING_RATE = 0.01\n",
+    "IMG_TO_TRAIN = 35820\n",
+    "IMG_TO_TEST = 50\n",
+    "IMG_TO_VAL = 100\n",
+    "DEVICE = \"cuda:3\"\n",
+    "PRINT_AFTER = 15\n",
+    "LR_UPDATE_AFTER = 5\n",
+    "LR_UPDATE_FACTOR = 0.1\n",
+    "SAVE_AFTER = 2"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Definition of the dataset and dataloader"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "dataset = lodopab.LoDoPaBDataset(impl= 'astra_cpu')\n",
+    "\n",
+    "trainset = torch.utils.data.Subset(\n",
+    "    dataset.create_torch_dataset('train'),\n",
+    "    list(range(IMG_TO_TRAIN))) \n",
+    "trainloader = torch.utils.data.DataLoader(\n",
+    "    trainset, \n",
+    "    batch_size=BATCH_SIZE, \n",
+    "    shuffle=True)\n",
+    "\n",
+    "testset = torch.utils.data.Subset(\n",
+    "    dataset.create_torch_dataset('test'),\n",
+    "    list(range(IMG_TO_TEST))) \n",
+    "testloader = torch.utils.data.DataLoader(\n",
+    "    testset, \n",
+    "    batch_size=BATCH_SIZE, \n",
+    "    shuffle=True)\n",
+    "\n",
+    "evalset = torch.utils.data.Subset(\n",
+    "    dataset.create_torch_dataset('test'),\n",
+    "    list(range(IMG_TO_VAL)))\n",
+    "evalloader = torch.utils.data.DataLoader(\n",
+    "    evalset, \n",
+    "    batch_size=BATCH_SIZE, \n",
+    "    shuffle=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Definition the helper functions\n",
+    "\n",
+    "plot\n",
+    "    sinogram\n",
+    "    FBP + PSNR + SSIM\n",
+    "    model output + PSNR / SSIM\n",
+    "    ground truth"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [
+    {
+     "ename": "NameError",
+     "evalue": "name 'model' is not defined",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
+      "\u001b[0;32m<ipython-input-5-7be6f06ed16e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     86\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mpsnr\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mnb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mssim\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mnb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     87\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0mplot_Unet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     89\u001b[0m \u001b[0mplot_LPD\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     90\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevalloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+      "\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
+     ]
+    }
+   ],
+   "source": [
+    "fbp_op = odl.tomo.fbp_op(dataset.ray_trafo, filter_type='Ram-Lak', frequency_scaling=1.0)\n",
+    "#fbp=odl.tomo.analytic.filtered_back_projection.fbp_op(dataset.get_ray_trafo())\n",
+    "\n",
+    "FBP = op.OperatorAsModule(fbp_op)\n",
+    "\n",
+    "def plot_Unet(n, model, part = \"test\"):\n",
+    "        sinogram, ground_truth = dataset.get_samples(slice(n,n+1,1), part=part) #numpy, 1 x N x M\n",
+    "        sinogram = torch.from_numpy(sinogram[0])\n",
+    "        ground_truth = torch.from_numpy(ground_truth[0])\n",
+    "        fbp_out = FBP(sinogram.unsqueeze(0).unsqueeze(0)).squeeze()\n",
+    "        out = model(sinogram.unsqueeze(0).unsqueeze(0).to(DEVICE)).squeeze().cpu().detach() #torch N x M\n",
+    "        min = torch.min(torch.tensor([torch.min(sinogram), torch.min(fbp_out), torch.min(out), torch.min(ground_truth)]))\n",
+    "        max = torch.max(torch.tensor([torch.max(sinogram), torch.max(fbp_out), torch.max(out), torch.max(ground_truth)]))\n",
+    "                \n",
+    "        psnr_fbp = PSNR(ground_truth.squeeze().numpy(), fbp_out.numpy())\n",
+    "        psnr_model = PSNR(ground_truth.squeeze().numpy(), out.numpy())\n",
+    "        ssim_fbp = SSIM(ground_truth.squeeze().numpy(), fbp_out.numpy())\n",
+    "        ssim_model = SSIM(ground_truth.squeeze().numpy(), out.numpy())\n",
+    "\n",
+    "        plt.figure().set_dpi(150)\n",
+    "\n",
+    "        plt.subplot(1, 4, 1)\n",
+    "        plt.imshow(sinogram.numpy(), vmin=min, vmax=max)\n",
+    "        plt.title(\"sinogram\")\n",
+    "        \n",
+    "        plt.subplot(1, 4, 2)\n",
+    "        plt.imshow(fbp_out.numpy(), vmin=min, vmax=max)\n",
+    "        plt.title(\"FBP\")\n",
+    "        plt.figtext(0.2, 0.1, \"PSNR:\" + str(psnr_fbp) + \"\\nSSIM:\"+ str(ssim_fbp))\n",
+    "        \n",
+    "        plt.subplot(1, 4, 3)\n",
+    "        plt.imshow(out.numpy(), vmin=min, vmax=max)\n",
+    "        plt.title(\"model output\")\n",
+    "        plt.figtext(0.6, 0.1, \"PSNR:\"+ str(psnr_model) + \"\\nSSIM:\" + str(ssim_model))\n",
+    "        \n",
+    "        plt.subplot(1, 4, 4)\n",
+    "        plt.imshow(ground_truth.numpy(), vmin=min, vmax=max)\n",
+    "        plt.title(\"ground truth\")\n",
+    "        \n",
+    "        plt.suptitle(\"Unet Plots\")\n",
+    "\n",
+    "        \n",
+    "def plot_LPD(n, model, part = \"test\"):\n",
+    "        sinogram, ground_truth = dataset.get_samples(slice(n,n+1,1), part=part) #numpy, 1 x N x M\n",
+    "        sinogram = torch.from_numpy(sinogram[0])\n",
+    "        ground_truth = torch.from_numpy(ground_truth[0])\n",
+    "        out = model(sinogram.unsqueeze(0).unsqueeze(0).to(DEVICE)).squeeze().cpu().detach() #torch N x M\n",
+    "        min = torch.min(torch.tensor([torch.min(sinogram), torch.min(out), torch.min(ground_truth)]))\n",
+    "        max = torch.max(torch.tensor([torch.max(sinogram), torch.max(out), torch.max(ground_truth)]))\n",
+    "                \n",
+    "        psnr_model = PSNR(ground_truth.squeeze().numpy(), out.numpy())\n",
+    "        ssim_model = SSIM(ground_truth.squeeze().numpy(), out.numpy())\n",
+    "\n",
+    "        plt.figure().set_dpi(150)\n",
+    "\n",
+    "        plt.subplot(1, 3, 1)\n",
+    "        plt.imshow(sinogram.numpy(), vmin=min, vmax=max)\n",
+    "        plt.title(\"sinogram\")\n",
+    "        \n",
+    "        plt.subplot(1, 3, 2)\n",
+    "        plt.imshow(out.numpy(), vmin=min, vmax=max)\n",
+    "        plt.title(\"model output\")\n",
+    "        plt.figtext(0.4, 0.1, \"PSNR:\"+ str(psnr_model) + \"\\nSSIM:\" + str(ssim_model))\n",
+    "        \n",
+    "        plt.subplot(1, 3, 3)\n",
+    "        plt.imshow(ground_truth.numpy(), vmin=min, vmax=max)\n",
+    "        plt.title(\"ground truth\")\n",
+    "        \n",
+    "        plt.suptitle(\"LPD Plots\")\n",
+    "        \n",
+    "        \n",
+    "        \n",
+    "def evaluate(model, loader):\n",
+    "    psnr = 0\n",
+    "    ssim = 0\n",
+    "    nb = 0\n",
+    "    for batch_id, samples in enumerate(loader):\n",
+    "        sinogram, ground_truth = samples\n",
+    "        out = model(sinogram.unsqueeze(1).to(DEVICE)).squeeze(dim=1).cpu().detach()\n",
+    "        \n",
+    "        for n in range(sinogram.shape[0]):\n",
+    "\n",
+    "            psnr += PSNR(ground_truth.numpy()[n], out.numpy()[n])\n",
+    "            ssim += SSIM(ground_truth.numpy()[n], out.numpy()[n])\n",
+    "        nb += sinogram.shape[0]\n",
+    "    return psnr/nb, ssim/nb\n",
+    "\n",
+    "plot_Unet(10, model)\n",
+    "plot_LPD(10, model)\n",
+    "evaluate(model, evalloader)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Definition of the Nets"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class UNet_5Layer(nn.Module):\n",
+    "    def __init__(self, m = 16, n = 32, o = 64, p = 64, q = 128):\n",
+    "        super(UNet_5Layer, self).__init__()\n",
+    "        #U-net from https://arxiv.org/pdf/1910.01113v2.pdf\n",
+    "\n",
+    "        self.conv1 = nn.Conv2d(1, m, 3)\n",
+    "        self.norm1 = torch.nn.BatchNorm2d(m)\n",
+    "        \n",
+    "        self.conv2 = nn.Conv2d(m, n, 5, stride = 2)\n",
+    "        self.norm2 = torch.nn.BatchNorm2d(n)\n",
+    "        self.conv3 = nn.Conv2d(n, n, 3)\n",
+    "        self.norm3 = torch.nn.BatchNorm2d(n)\n",
+    "        \n",
+    "        self.conv4 = nn.Conv2d(n, o, 3, stride = 2)\n",
+    "        self.norm4 = torch.nn.BatchNorm2d(o)\n",
+    "        self.conv5 = nn.Conv2d(o, o, 3)\n",
+    "        self.norm5 = torch.nn.BatchNorm2d(o)\n",
+    "        \n",
+    "        self.conv6 = nn.Conv2d(o, p, 3, stride = 2)\n",
+    "        self.norm6 = torch.nn.BatchNorm2d(p)\n",
+    "        self.conv7 = nn.Conv2d(p, p, 3)\n",
+    "        self.norm7 = torch.nn.BatchNorm2d(p)\n",
+    "\n",
+    "        self.conv8 = nn.Conv2d(p, q, 3, stride = 2)\n",
+    "        self.norm8 = torch.nn.BatchNorm2d(q)\n",
+    "        self.conv9 = nn.Conv2d(q, q, 3)\n",
+    "        self.norm9 = torch.nn.BatchNorm2d(q)\n",
+    "        self.up4 = nn.Upsample(scale_factor=2) #nn.Upsample([74, 74])\n",
+    "        self.conv10 = nn.Conv2d(q, p, 3)\n",
+    "        self.norm10 = torch.nn.BatchNorm2d(p)\n",
+    "        \n",
+    "        self.conv11 = nn.Conv2d(p+4, p, 3)\n",
+    "        self.norm11 = torch.nn.BatchNorm2d(p)\n",
+    "        self.up3 = nn.Upsample(scale_factor=2)\n",
+    "        self.conv12 = nn.Conv2d(p, o, 3)\n",
+    "        self.norm12 = torch.nn.BatchNorm2d(o)\n",
+    "        \n",
+    "        self.conv13 = nn.Conv2d(o+4, o, 3)\n",
+    "        self.norm13 = torch.nn.BatchNorm2d(o)\n",
+    "        self.up2 = nn.Upsample(scale_factor=2)\n",
+    "        self.conv14 = nn.Conv2d(o, n, 3)\n",
+    "        self.norm14 = torch.nn.BatchNorm2d(n)\n",
+    "        \n",
+    "        self.conv15 = nn.Conv2d(n+4, n, 3)\n",
+    "        self.norm15 = torch.nn.BatchNorm2d(n)\n",
+    "        self.up2 = nn.Upsample(scale_factor=2)\n",
+    "        self.conv16 = nn.Conv2d(n, m, 3)\n",
+    "        self.norm16 = torch.nn.BatchNorm2d(m)\n",
+    "        \n",
+    "        self.conv17 = nn.Conv2d(m+4, 1, 1)\n",
+    "        \n",
+    "        self.skip1 = nn.Conv2d(m, 4, 1)\n",
+    "        self.skip2 = nn.Conv2d(n, 4, 1)\n",
+    "        self.skip3 = nn.Conv2d(o, 4, 1)\n",
+    "        self.skip4 = nn.Conv2d(p, 4, 1)\n",
+    "\n",
+    "    def forward(self, inp):\n",
+    "        with torch.cuda.device(DEVICE):\n",
+    "            a = F.leaky_relu(self.norm1(self.conv1(inp)), negative_slope=0.2)#torch.Size([1, 16, 360, 360])\n",
+    "\n",
+    "\n",
+    "            b = F.leaky_relu(self.norm2(self.conv2(a)), negative_slope=0.2)\n",
+    "            b = F.leaky_relu(self.norm3(self.conv3(b)), negative_slope=0.2)#torch.Size([1, 32, 176, 176])\n",
+    "\n",
+    "            c = F.leaky_relu(self.norm4(self.conv4(b)), negative_slope=0.2)\n",
+    "            c = F.leaky_relu(self.norm5(self.conv5(c)), negative_slope=0.2) #torch.Size([1, 64, 85, 85])\n",
+    "\n",
+    "            d = F.leaky_relu(self.norm6(self.conv6(c)), negative_slope=0.2)\n",
+    "            d = F.leaky_relu(self.norm7(self.conv7(d)), negative_slope=0.2) #torch.Size([1, 64, 40, 40])\n",
+    "\n",
+    "            e = F.leaky_relu(self.norm8(self.conv8(d)), negative_slope=0.2)\n",
+    "            e = F.leaky_relu(self.norm9(self.conv9(e)), negative_slope=0.2)\n",
+    "            e = F.leaky_relu(self.norm10(self.conv10(self.up4(e))), negative_slope=0.2) #torch.Size([1, 64, 32, 32])\n",
+    "\n",
+    "            d = self.skip4(d[:,:,4:-4, 4:-4])\n",
+    "\n",
+    "            d = F.leaky_relu(self.norm11(self.conv11(torch.cat((d, e), 1))), negative_slope=0.2)\n",
+    "            d = F.leaky_relu(self.norm12(self.conv12(self.up3(d))), negative_slope=0.2)#torch.Size([1, 64, 58, 58])\n",
+    "\n",
+    "            c = self.skip3(c[:,:,13:-14,13:-14])\n",
+    "\n",
+    "            c = F.leaky_relu(self.norm13(self.conv13(torch.cat((c, d), 1))), negative_slope=0.2)\n",
+    "            c = F.leaky_relu(self.norm14(self.conv14(self.up2(c))), negative_slope=0.2)#torch.Size([1, 32, 110, 110])\n",
+    "\n",
+    "            b = self.skip2(b[:,:,33:-33,33:-33])\n",
+    "\n",
+    "            b = F.leaky_relu(self.norm15(self.conv15(torch.cat((b, c), 1))), negative_slope=0.2)\n",
+    "            b = F.leaky_relu(self.norm16(self.conv16(self.up2(b))), negative_slope=0.2)#torch.Size([1, 16, 214, 214])\n",
+    "\n",
+    "            a = self.skip1(a[:,:,73:-73,73:-73])\n",
+    "            a = torch.sigmoid(self.conv17(torch.cat((a,b), 1)))\n",
+    "    #         out = a\n",
+    "            out = F.interpolate(a, [362,362])\n",
+    "\n",
+    "        return out\n",
+    "        \n",
+    "        "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 7,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "class dual_iterate(nn.Module):\n",
+    "    def __init__(self, N_dual):\n",
+    "        super(dual_iterate, self).__init__()\n",
+    "        self.d_modules = nn.Sequential(\n",
+    "            nn.Conv2d(2+N_dual,32, 3, padding = 1),\n",
+    "            nn.PReLU(),\n",
+    "            nn.Conv2d(32 ,32, 3, padding = 1),\n",
+    "            nn.PReLU(),\n",
+    "            nn.Conv2d(32 ,5, 3, padding = 1)\n",
+    "        )\n",
+    "        \n",
+    "    def forward(self, h, f, g):\n",
+    "        out = self.d_modules(torch.cat((h, f[:, 1:2, :, :], g), dim=1))\n",
+    "        return out + h\n",
+    "\n",
+    "class primal_iterate(nn.Module):\n",
+    "    def __init__(self, N_primal):\n",
+    "        super(primal_iterate, self).__init__()\n",
+    "        self.p_modules = nn.Sequential(\n",
+    "            nn.Conv2d(1+N_primal,32, 3, padding = 1),\n",
+    "            nn.PReLU(),\n",
+    "            nn.Conv2d(32 ,32, 3, padding = 1),\n",
+    "            nn.PReLU(),\n",
+    "            nn.Conv2d(32 ,5, 3, padding = 1)\n",
+    "        )\n",
+    "        \n",
+    "    def forward(self, inp, f):\n",
+    "        out = self.p_modules(torch.cat((inp[:, 0:1, :, :], f), 1))\n",
+    "        return out + f\n",
+    "    \n",
+    "class LPD_Net(nn.Module):\n",
+    "    def __init__(self, I = 10,  N_primal = 5, N_dual = 5):\n",
+    "        super(LPD_Net, self).__init__()\n",
+    "        self.I = I\n",
+    "        self.N_primal = N_primal\n",
+    "        self.N_dual = N_dual\n",
+    "        self.primal_iterates = nn.ModuleList([primal_iterate(N_primal) for i in range(I)])\n",
+    "        self.dual_iterates = nn.ModuleList([dual_iterate(N_dual) for i in range(I)])\n",
+    "        self.T = op.OperatorAsModule(dataset.get_ray_trafo())\n",
+    "        self.Tstar = op.OperatorAsModule(dataset.get_ray_trafo().adjoint)\n",
+    "        self.dual_iterates.to(DEVICE)\n",
+    "        self.primal_iterates.to(DEVICE)\n",
+    "\n",
+    "    def forward(self, g):\n",
+    "        with torch.cuda.device(DEVICE):\n",
+    "\n",
+    "            #g = g.reshape(1, 1, 1000, 513)\n",
+    "            h = torch.zeros(g.shape[0], self.N_dual, 1000, 513).to(DEVICE)\n",
+    "            f = torch.zeros(g.shape[0], self.N_primal, 362, 362).to(DEVICE)\n",
+    "            T_f = self.T(f)\n",
+    "            \"\"\"print(g.shape)\n",
+    "            print(f.shape)\n",
+    "            print(h.shape)\n",
+    "            print(T_f.shape)\n",
+    "            print(T_f[:, 1:2, :, :].shape)\"\"\"\n",
+    "            for i in range(self.I):\n",
+    "                \"\"\"print(h.shape)\n",
+    "                print(T_f.shape)\n",
+    "                print(g.shape)\"\"\"\n",
+    "                h = self.dual_iterates[i](h, T_f, g)\n",
+    "                Tstar_h = self.Tstar(h)\n",
+    "                f = self.primal_iterates[i](Tstar_h, f)\n",
+    "                T_f = self.T(f)\n",
+    "        return f[:,0:1,:,:]\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "get Net and Object\n",
+    "    net is just the net\n",
+    "    model is a function of the net which takes the sinogram as input"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "net = UNet_5Layer().to(DEVICE)\n",
+    "model = lambda x : net(FBP(x))\n",
+    "#net = LPD_Net(I=10).to(DEVICE)\n",
+    "#model = lambda x : net(x)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Optimizer\n",
+    "    "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[1, 15/560] loss: 0.052\n",
+      "[1, 30/560] loss: 0.012\n"
+     ]
+    }
+   ],
+   "source": [
+    "optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)\n",
+    "start_time = time.time()\n",
+    "\n",
+    "for epoch in range(EPOCHS):\n",
+    "    running_loss = 0.0\n",
+    "    for batch_id, sample in enumerate(trainloader):\n",
+    "        \n",
+    "        sinograms, ground_truths = sample\n",
+    "        sinograms = torch.unsqueeze(sinograms, 1).to(DEVICE) \n",
+    "        ground_truths = torch.unsqueeze(ground_truths, 1).to(DEVICE)\n",
+    "                \n",
+    "        optimizer.zero_grad()\n",
+    "        outputs = model(sinograms)\n",
+    "\n",
+    "        loss = nn.functional.mse_loss(outputs, ground_truths)\n",
+    "        # print(\"loss: \", loss)\n",
+    "        loss.backward()\n",
+    "        optimizer.step()\n",
+    "        \n",
+    "        running_loss += float(loss)\n",
+    "        if batch_id % PRINT_AFTER == PRINT_AFTER-1:    # print every PRINT_AFTER mini-batches\n",
+    "            print('[%d, %s] loss: %.3f' %\n",
+    "                  (epoch + 1, str(batch_id + 1)+\"/\"+str(len(trainloader)), running_loss/ PRINT_AFTER))\n",
+    "            running_loss = 0.0\n",
+    "        \n",
+    "        if epoch % LR_UPDATE_AFTER == LR_UPDATE_AFTER-1:\n",
+    "            LEARNING_RATE *= LR_UPDATE_FACTOR\n",
+    "            optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)\n",
+    "        \n",
+    "        if epoch % SAVE_AFTER == SAVE_AFTER-1:\n",
+    "            torch.save(net.state_dict(), '../../../../../scratch/s21gpu1project/saves/U-Net_200621.pth')\n",
+    "        \n",
+    "end_time = time.time()\n",
+    "run_time = (end_time - start_time)/60\n",
+    "print(\"run time in minutes: \", run_time)\n",
+    "print('Finished Training')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "save net manually"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 22,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "torch.save(net.state_dict(), '../../../../../scratch/s21gpu1project/saves/LPD_net200621.pth')"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "load Unet"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "<All keys matched successfully>"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "#loaded_model =  UNet_5Layer().to(DEVICE)\n",
+    "net.load_state_dict(torch.load('../../../../../scratch/s21gpu1project/saves/LPD_net200621.pth'))"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 25,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "torch.cuda.empty_cache()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "forastra2",
+   "language": "python",
+   "name": "forastra2"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}