From 18f8a2ceac4f11a2e18e89fc23dea325e91b1e3f Mon Sep 17 00:00:00 2001 From: Patrick Horn <baw4310@mathgpu1.physnet.uni-hamburg.de> Date: Sun, 20 Jun 2021 17:38:17 +0200 Subject: [PATCH] add file --- master_file.ipynb | 633 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 633 insertions(+) create mode 100644 master_file.ipynb diff --git a/master_file.ipynb b/master_file.ipynb new file mode 100644 index 0000000..3f776da --- /dev/null +++ b/master_file.ipynb @@ -0,0 +1,633 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Implementation of\n", + " - Unet from https://arxiv.org/pdf/1910.01113v2.pdf\n", + " - LPD Net from https://arxiv.org/pdf/1707.06474.pdf" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import astra\n", + "import odl\n", + "import numpy as np\n", + "import dival\n", + "#from dival import get_standard_dataset\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from matplotlib import pyplot as plt\n", + "import torch.utils.data\n", + "import custom_odl_op8 as op\n", + "import time \n", + "from torch.optim.lr_scheduler import StepLR\n", + "from skimage.metrics import peak_signal_noise_ratio as PSNR\n", + "from skimage.metrics import structural_similarity as SSIM\n", + "import dival.datasets.lodopab_dataset as lodopab\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'import numpy as np\\nimport torch\\nimport torch.nn as nn\\nimport torch.nn.functional as F\\nimport astra\\nimport dival\\nfrom matplotlib import pyplot as plt\\nimport torch.utils.data\\nimport odl\\nimport custom_odl_op as op\\nimport custom_odl_op8 as op8\\nfrom skimage.metrics import peak_signal_noise_ratio as PSNR\\nfrom skimage.metrics import structural_similarity as SSIM\\n\\nimport dival.datasets.lodopab_dataset as lodopab\\nimport time'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\"\"\"import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import astra\n", + "import dival\n", + "from matplotlib import pyplot as plt\n", + "import torch.utils.data\n", + "import odl\n", + "import custom_odl_op as op\n", + "import custom_odl_op8 as op8\n", + "from skimage.metrics import peak_signal_noise_ratio as PSNR\n", + "from skimage.metrics import structural_similarity as SSIM\n", + "\n", + "import dival.datasets.lodopab_dataset as lodopab\n", + "import time\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set chosen parameters\n", + "\n", + " BATCH_SIZE: number of images to process before the optimizer step\n", + " EPOCHS: number of epochs\n", + " LEARNING_RATE initial learning rate\n", + " IMG_TO_TRAIN total number of images per epoch\n", + " DEVICE device to use for computing\n", + " PRINT_AFTER number of batches to process before printing an update\n", + " LR_UPDATE_AFTER number of epochs before a learning rate update\n", + " LR_UPDATE_FACTOR facter of the learning rate update" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "constants from the dataset:\n", + "input_size: 1000*513\n", + "output_size: 362*362\n", + "\n", + "train_len = 35820\n", + "validation_len = 3522\n", + "test_len = 3553\n", + "\"\"\"\n", + "\n", + "\n", + "BATCH_SIZE = 64\n", + "EPOCHS = 10\n", + "LEARNING_RATE = 0.01\n", + "IMG_TO_TRAIN = 35820\n", + "IMG_TO_TEST = 50\n", + "IMG_TO_VAL = 100\n", + "DEVICE = \"cuda:3\"\n", + "PRINT_AFTER = 15\n", + "LR_UPDATE_AFTER = 5\n", + "LR_UPDATE_FACTOR = 0.1\n", + "SAVE_AFTER = 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Definition of the dataset and dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = lodopab.LoDoPaBDataset(impl= 'astra_cpu')\n", + "\n", + "trainset = torch.utils.data.Subset(\n", + " dataset.create_torch_dataset('train'),\n", + " list(range(IMG_TO_TRAIN))) \n", + "trainloader = torch.utils.data.DataLoader(\n", + " trainset, \n", + " batch_size=BATCH_SIZE, \n", + " shuffle=True)\n", + "\n", + "testset = torch.utils.data.Subset(\n", + " dataset.create_torch_dataset('test'),\n", + " list(range(IMG_TO_TEST))) \n", + "testloader = torch.utils.data.DataLoader(\n", + " testset, \n", + " batch_size=BATCH_SIZE, \n", + " shuffle=True)\n", + "\n", + "evalset = torch.utils.data.Subset(\n", + " dataset.create_torch_dataset('test'),\n", + " list(range(IMG_TO_VAL)))\n", + "evalloader = torch.utils.data.DataLoader(\n", + " evalset, \n", + " batch_size=BATCH_SIZE, \n", + " shuffle=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Definition the helper functions\n", + "\n", + "plot\n", + " sinogram\n", + " FBP + PSNR + SSIM\n", + " model output + PSNR / SSIM\n", + " ground truth" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'model' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m<ipython-input-5-7be6f06ed16e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mpsnr\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mnb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mssim\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mnb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0mplot_Unet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0mplot_LPD\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevalloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'model' is not defined" + ] + } + ], + "source": [ + "fbp_op = odl.tomo.fbp_op(dataset.ray_trafo, filter_type='Ram-Lak', frequency_scaling=1.0)\n", + "#fbp=odl.tomo.analytic.filtered_back_projection.fbp_op(dataset.get_ray_trafo())\n", + "\n", + "FBP = op.OperatorAsModule(fbp_op)\n", + "\n", + "def plot_Unet(n, model, part = \"test\"):\n", + " sinogram, ground_truth = dataset.get_samples(slice(n,n+1,1), part=part) #numpy, 1 x N x M\n", + " sinogram = torch.from_numpy(sinogram[0])\n", + " ground_truth = torch.from_numpy(ground_truth[0])\n", + " fbp_out = FBP(sinogram.unsqueeze(0).unsqueeze(0)).squeeze()\n", + " out = model(sinogram.unsqueeze(0).unsqueeze(0).to(DEVICE)).squeeze().cpu().detach() #torch N x M\n", + " min = torch.min(torch.tensor([torch.min(sinogram), torch.min(fbp_out), torch.min(out), torch.min(ground_truth)]))\n", + " max = torch.max(torch.tensor([torch.max(sinogram), torch.max(fbp_out), torch.max(out), torch.max(ground_truth)]))\n", + " \n", + " psnr_fbp = PSNR(ground_truth.squeeze().numpy(), fbp_out.numpy())\n", + " psnr_model = PSNR(ground_truth.squeeze().numpy(), out.numpy())\n", + " ssim_fbp = SSIM(ground_truth.squeeze().numpy(), fbp_out.numpy())\n", + " ssim_model = SSIM(ground_truth.squeeze().numpy(), out.numpy())\n", + "\n", + " plt.figure().set_dpi(150)\n", + "\n", + " plt.subplot(1, 4, 1)\n", + " plt.imshow(sinogram.numpy(), vmin=min, vmax=max)\n", + " plt.title(\"sinogram\")\n", + " \n", + " plt.subplot(1, 4, 2)\n", + " plt.imshow(fbp_out.numpy(), vmin=min, vmax=max)\n", + " plt.title(\"FBP\")\n", + " plt.figtext(0.2, 0.1, \"PSNR:\" + str(psnr_fbp) + \"\\nSSIM:\"+ str(ssim_fbp))\n", + " \n", + " plt.subplot(1, 4, 3)\n", + " plt.imshow(out.numpy(), vmin=min, vmax=max)\n", + " plt.title(\"model output\")\n", + " plt.figtext(0.6, 0.1, \"PSNR:\"+ str(psnr_model) + \"\\nSSIM:\" + str(ssim_model))\n", + " \n", + " plt.subplot(1, 4, 4)\n", + " plt.imshow(ground_truth.numpy(), vmin=min, vmax=max)\n", + " plt.title(\"ground truth\")\n", + " \n", + " plt.suptitle(\"Unet Plots\")\n", + "\n", + " \n", + "def plot_LPD(n, model, part = \"test\"):\n", + " sinogram, ground_truth = dataset.get_samples(slice(n,n+1,1), part=part) #numpy, 1 x N x M\n", + " sinogram = torch.from_numpy(sinogram[0])\n", + " ground_truth = torch.from_numpy(ground_truth[0])\n", + " out = model(sinogram.unsqueeze(0).unsqueeze(0).to(DEVICE)).squeeze().cpu().detach() #torch N x M\n", + " min = torch.min(torch.tensor([torch.min(sinogram), torch.min(out), torch.min(ground_truth)]))\n", + " max = torch.max(torch.tensor([torch.max(sinogram), torch.max(out), torch.max(ground_truth)]))\n", + " \n", + " psnr_model = PSNR(ground_truth.squeeze().numpy(), out.numpy())\n", + " ssim_model = SSIM(ground_truth.squeeze().numpy(), out.numpy())\n", + "\n", + " plt.figure().set_dpi(150)\n", + "\n", + " plt.subplot(1, 3, 1)\n", + " plt.imshow(sinogram.numpy(), vmin=min, vmax=max)\n", + " plt.title(\"sinogram\")\n", + " \n", + " plt.subplot(1, 3, 2)\n", + " plt.imshow(out.numpy(), vmin=min, vmax=max)\n", + " plt.title(\"model output\")\n", + " plt.figtext(0.4, 0.1, \"PSNR:\"+ str(psnr_model) + \"\\nSSIM:\" + str(ssim_model))\n", + " \n", + " plt.subplot(1, 3, 3)\n", + " plt.imshow(ground_truth.numpy(), vmin=min, vmax=max)\n", + " plt.title(\"ground truth\")\n", + " \n", + " plt.suptitle(\"LPD Plots\")\n", + " \n", + " \n", + " \n", + "def evaluate(model, loader):\n", + " psnr = 0\n", + " ssim = 0\n", + " nb = 0\n", + " for batch_id, samples in enumerate(loader):\n", + " sinogram, ground_truth = samples\n", + " out = model(sinogram.unsqueeze(1).to(DEVICE)).squeeze(dim=1).cpu().detach()\n", + " \n", + " for n in range(sinogram.shape[0]):\n", + "\n", + " psnr += PSNR(ground_truth.numpy()[n], out.numpy()[n])\n", + " ssim += SSIM(ground_truth.numpy()[n], out.numpy()[n])\n", + " nb += sinogram.shape[0]\n", + " return psnr/nb, ssim/nb\n", + "\n", + "plot_Unet(10, model)\n", + "plot_LPD(10, model)\n", + "evaluate(model, evalloader)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Definition of the Nets" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class UNet_5Layer(nn.Module):\n", + " def __init__(self, m = 16, n = 32, o = 64, p = 64, q = 128):\n", + " super(UNet_5Layer, self).__init__()\n", + " #U-net from https://arxiv.org/pdf/1910.01113v2.pdf\n", + "\n", + " self.conv1 = nn.Conv2d(1, m, 3)\n", + " self.norm1 = torch.nn.BatchNorm2d(m)\n", + " \n", + " self.conv2 = nn.Conv2d(m, n, 5, stride = 2)\n", + " self.norm2 = torch.nn.BatchNorm2d(n)\n", + " self.conv3 = nn.Conv2d(n, n, 3)\n", + " self.norm3 = torch.nn.BatchNorm2d(n)\n", + " \n", + " self.conv4 = nn.Conv2d(n, o, 3, stride = 2)\n", + " self.norm4 = torch.nn.BatchNorm2d(o)\n", + " self.conv5 = nn.Conv2d(o, o, 3)\n", + " self.norm5 = torch.nn.BatchNorm2d(o)\n", + " \n", + " self.conv6 = nn.Conv2d(o, p, 3, stride = 2)\n", + " self.norm6 = torch.nn.BatchNorm2d(p)\n", + " self.conv7 = nn.Conv2d(p, p, 3)\n", + " self.norm7 = torch.nn.BatchNorm2d(p)\n", + "\n", + " self.conv8 = nn.Conv2d(p, q, 3, stride = 2)\n", + " self.norm8 = torch.nn.BatchNorm2d(q)\n", + " self.conv9 = nn.Conv2d(q, q, 3)\n", + " self.norm9 = torch.nn.BatchNorm2d(q)\n", + " self.up4 = nn.Upsample(scale_factor=2) #nn.Upsample([74, 74])\n", + " self.conv10 = nn.Conv2d(q, p, 3)\n", + " self.norm10 = torch.nn.BatchNorm2d(p)\n", + " \n", + " self.conv11 = nn.Conv2d(p+4, p, 3)\n", + " self.norm11 = torch.nn.BatchNorm2d(p)\n", + " self.up3 = nn.Upsample(scale_factor=2)\n", + " self.conv12 = nn.Conv2d(p, o, 3)\n", + " self.norm12 = torch.nn.BatchNorm2d(o)\n", + " \n", + " self.conv13 = nn.Conv2d(o+4, o, 3)\n", + " self.norm13 = torch.nn.BatchNorm2d(o)\n", + " self.up2 = nn.Upsample(scale_factor=2)\n", + " self.conv14 = nn.Conv2d(o, n, 3)\n", + " self.norm14 = torch.nn.BatchNorm2d(n)\n", + " \n", + " self.conv15 = nn.Conv2d(n+4, n, 3)\n", + " self.norm15 = torch.nn.BatchNorm2d(n)\n", + " self.up2 = nn.Upsample(scale_factor=2)\n", + " self.conv16 = nn.Conv2d(n, m, 3)\n", + " self.norm16 = torch.nn.BatchNorm2d(m)\n", + " \n", + " self.conv17 = nn.Conv2d(m+4, 1, 1)\n", + " \n", + " self.skip1 = nn.Conv2d(m, 4, 1)\n", + " self.skip2 = nn.Conv2d(n, 4, 1)\n", + " self.skip3 = nn.Conv2d(o, 4, 1)\n", + " self.skip4 = nn.Conv2d(p, 4, 1)\n", + "\n", + " def forward(self, inp):\n", + " with torch.cuda.device(DEVICE):\n", + " a = F.leaky_relu(self.norm1(self.conv1(inp)), negative_slope=0.2)#torch.Size([1, 16, 360, 360])\n", + "\n", + "\n", + " b = F.leaky_relu(self.norm2(self.conv2(a)), negative_slope=0.2)\n", + " b = F.leaky_relu(self.norm3(self.conv3(b)), negative_slope=0.2)#torch.Size([1, 32, 176, 176])\n", + "\n", + " c = F.leaky_relu(self.norm4(self.conv4(b)), negative_slope=0.2)\n", + " c = F.leaky_relu(self.norm5(self.conv5(c)), negative_slope=0.2) #torch.Size([1, 64, 85, 85])\n", + "\n", + " d = F.leaky_relu(self.norm6(self.conv6(c)), negative_slope=0.2)\n", + " d = F.leaky_relu(self.norm7(self.conv7(d)), negative_slope=0.2) #torch.Size([1, 64, 40, 40])\n", + "\n", + " e = F.leaky_relu(self.norm8(self.conv8(d)), negative_slope=0.2)\n", + " e = F.leaky_relu(self.norm9(self.conv9(e)), negative_slope=0.2)\n", + " e = F.leaky_relu(self.norm10(self.conv10(self.up4(e))), negative_slope=0.2) #torch.Size([1, 64, 32, 32])\n", + "\n", + " d = self.skip4(d[:,:,4:-4, 4:-4])\n", + "\n", + " d = F.leaky_relu(self.norm11(self.conv11(torch.cat((d, e), 1))), negative_slope=0.2)\n", + " d = F.leaky_relu(self.norm12(self.conv12(self.up3(d))), negative_slope=0.2)#torch.Size([1, 64, 58, 58])\n", + "\n", + " c = self.skip3(c[:,:,13:-14,13:-14])\n", + "\n", + " c = F.leaky_relu(self.norm13(self.conv13(torch.cat((c, d), 1))), negative_slope=0.2)\n", + " c = F.leaky_relu(self.norm14(self.conv14(self.up2(c))), negative_slope=0.2)#torch.Size([1, 32, 110, 110])\n", + "\n", + " b = self.skip2(b[:,:,33:-33,33:-33])\n", + "\n", + " b = F.leaky_relu(self.norm15(self.conv15(torch.cat((b, c), 1))), negative_slope=0.2)\n", + " b = F.leaky_relu(self.norm16(self.conv16(self.up2(b))), negative_slope=0.2)#torch.Size([1, 16, 214, 214])\n", + "\n", + " a = self.skip1(a[:,:,73:-73,73:-73])\n", + " a = torch.sigmoid(self.conv17(torch.cat((a,b), 1)))\n", + " # out = a\n", + " out = F.interpolate(a, [362,362])\n", + "\n", + " return out\n", + " \n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "class dual_iterate(nn.Module):\n", + " def __init__(self, N_dual):\n", + " super(dual_iterate, self).__init__()\n", + " self.d_modules = nn.Sequential(\n", + " nn.Conv2d(2+N_dual,32, 3, padding = 1),\n", + " nn.PReLU(),\n", + " nn.Conv2d(32 ,32, 3, padding = 1),\n", + " nn.PReLU(),\n", + " nn.Conv2d(32 ,5, 3, padding = 1)\n", + " )\n", + " \n", + " def forward(self, h, f, g):\n", + " out = self.d_modules(torch.cat((h, f[:, 1:2, :, :], g), dim=1))\n", + " return out + h\n", + "\n", + "class primal_iterate(nn.Module):\n", + " def __init__(self, N_primal):\n", + " super(primal_iterate, self).__init__()\n", + " self.p_modules = nn.Sequential(\n", + " nn.Conv2d(1+N_primal,32, 3, padding = 1),\n", + " nn.PReLU(),\n", + " nn.Conv2d(32 ,32, 3, padding = 1),\n", + " nn.PReLU(),\n", + " nn.Conv2d(32 ,5, 3, padding = 1)\n", + " )\n", + " \n", + " def forward(self, inp, f):\n", + " out = self.p_modules(torch.cat((inp[:, 0:1, :, :], f), 1))\n", + " return out + f\n", + " \n", + "class LPD_Net(nn.Module):\n", + " def __init__(self, I = 10, N_primal = 5, N_dual = 5):\n", + " super(LPD_Net, self).__init__()\n", + " self.I = I\n", + " self.N_primal = N_primal\n", + " self.N_dual = N_dual\n", + " self.primal_iterates = nn.ModuleList([primal_iterate(N_primal) for i in range(I)])\n", + " self.dual_iterates = nn.ModuleList([dual_iterate(N_dual) for i in range(I)])\n", + " self.T = op.OperatorAsModule(dataset.get_ray_trafo())\n", + " self.Tstar = op.OperatorAsModule(dataset.get_ray_trafo().adjoint)\n", + " self.dual_iterates.to(DEVICE)\n", + " self.primal_iterates.to(DEVICE)\n", + "\n", + " def forward(self, g):\n", + " with torch.cuda.device(DEVICE):\n", + "\n", + " #g = g.reshape(1, 1, 1000, 513)\n", + " h = torch.zeros(g.shape[0], self.N_dual, 1000, 513).to(DEVICE)\n", + " f = torch.zeros(g.shape[0], self.N_primal, 362, 362).to(DEVICE)\n", + " T_f = self.T(f)\n", + " \"\"\"print(g.shape)\n", + " print(f.shape)\n", + " print(h.shape)\n", + " print(T_f.shape)\n", + " print(T_f[:, 1:2, :, :].shape)\"\"\"\n", + " for i in range(self.I):\n", + " \"\"\"print(h.shape)\n", + " print(T_f.shape)\n", + " print(g.shape)\"\"\"\n", + " h = self.dual_iterates[i](h, T_f, g)\n", + " Tstar_h = self.Tstar(h)\n", + " f = self.primal_iterates[i](Tstar_h, f)\n", + " T_f = self.T(f)\n", + " return f[:,0:1,:,:]\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "get Net and Object\n", + " net is just the net\n", + " model is a function of the net which takes the sinogram as input" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "net = UNet_5Layer().to(DEVICE)\n", + "model = lambda x : net(FBP(x))\n", + "#net = LPD_Net(I=10).to(DEVICE)\n", + "#model = lambda x : net(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Optimizer\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1, 15/560] loss: 0.052\n", + "[1, 30/560] loss: 0.012\n" + ] + } + ], + "source": [ + "optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)\n", + "start_time = time.time()\n", + "\n", + "for epoch in range(EPOCHS):\n", + " running_loss = 0.0\n", + " for batch_id, sample in enumerate(trainloader):\n", + " \n", + " sinograms, ground_truths = sample\n", + " sinograms = torch.unsqueeze(sinograms, 1).to(DEVICE) \n", + " ground_truths = torch.unsqueeze(ground_truths, 1).to(DEVICE)\n", + " \n", + " optimizer.zero_grad()\n", + " outputs = model(sinograms)\n", + "\n", + " loss = nn.functional.mse_loss(outputs, ground_truths)\n", + " # print(\"loss: \", loss)\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " running_loss += float(loss)\n", + " if batch_id % PRINT_AFTER == PRINT_AFTER-1: # print every PRINT_AFTER mini-batches\n", + " print('[%d, %s] loss: %.3f' %\n", + " (epoch + 1, str(batch_id + 1)+\"/\"+str(len(trainloader)), running_loss/ PRINT_AFTER))\n", + " running_loss = 0.0\n", + " \n", + " if epoch % LR_UPDATE_AFTER == LR_UPDATE_AFTER-1:\n", + " LEARNING_RATE *= LR_UPDATE_FACTOR\n", + " optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)\n", + " \n", + " if epoch % SAVE_AFTER == SAVE_AFTER-1:\n", + " torch.save(net.state_dict(), '../../../../../scratch/s21gpu1project/saves/U-Net_200621.pth')\n", + " \n", + "end_time = time.time()\n", + "run_time = (end_time - start_time)/60\n", + "print(\"run time in minutes: \", run_time)\n", + "print('Finished Training')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "save net manually" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "torch.save(net.state_dict(), '../../../../../scratch/s21gpu1project/saves/LPD_net200621.pth')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "load Unet" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<All keys matched successfully>" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#loaded_model = UNet_5Layer().to(DEVICE)\n", + "net.load_state_dict(torch.load('../../../../../scratch/s21gpu1project/saves/LPD_net200621.pth'))" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "forastra2", + "language": "python", + "name": "forastra2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} -- GitLab