Skip to content
Snippets Groups Projects
Commit 18f8a2ce authored by Patrick Horn's avatar Patrick Horn
Browse files

add file

parent a10b4cd8
Branches
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
Implementation of
- Unet from https://arxiv.org/pdf/1910.01113v2.pdf
- LPD Net from https://arxiv.org/pdf/1707.06474.pdf
%% Cell type:code id: tags:
``` python
import os
import astra
import odl
import numpy as np
import dival
#from dival import get_standard_dataset
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
import torch.utils.data
import custom_odl_op8 as op
import time
from torch.optim.lr_scheduler import StepLR
from skimage.metrics import peak_signal_noise_ratio as PSNR
from skimage.metrics import structural_similarity as SSIM
import dival.datasets.lodopab_dataset as lodopab
```
%% Cell type:code id: tags:
``` python
"""import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import astra
import dival
from matplotlib import pyplot as plt
import torch.utils.data
import odl
import custom_odl_op as op
import custom_odl_op8 as op8
from skimage.metrics import peak_signal_noise_ratio as PSNR
from skimage.metrics import structural_similarity as SSIM
import dival.datasets.lodopab_dataset as lodopab
import time"""
```
%% Output
'import numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport astra\nimport dival\nfrom matplotlib import pyplot as plt\nimport torch.utils.data\nimport odl\nimport custom_odl_op as op\nimport custom_odl_op8 as op8\nfrom skimage.metrics import peak_signal_noise_ratio as PSNR\nfrom skimage.metrics import structural_similarity as SSIM\n\nimport dival.datasets.lodopab_dataset as lodopab\nimport time'
%% Cell type:markdown id: tags:
Set chosen parameters
BATCH_SIZE: number of images to process before the optimizer step
EPOCHS: number of epochs
LEARNING_RATE initial learning rate
IMG_TO_TRAIN total number of images per epoch
DEVICE device to use for computing
PRINT_AFTER number of batches to process before printing an update
LR_UPDATE_AFTER number of epochs before a learning rate update
LR_UPDATE_FACTOR facter of the learning rate update
%% Cell type:code id: tags:
``` python
"""
constants from the dataset:
input_size: 1000*513
output_size: 362*362
train_len = 35820
validation_len = 3522
test_len = 3553
"""
BATCH_SIZE = 64
EPOCHS = 10
LEARNING_RATE = 0.01
IMG_TO_TRAIN = 35820
IMG_TO_TEST = 50
IMG_TO_VAL = 100
DEVICE = "cuda:3"
PRINT_AFTER = 15
LR_UPDATE_AFTER = 5
LR_UPDATE_FACTOR = 0.1
SAVE_AFTER = 2
```
%% Cell type:markdown id: tags:
Definition of the dataset and dataloader
%% Cell type:code id: tags:
``` python
dataset = lodopab.LoDoPaBDataset(impl= 'astra_cpu')
trainset = torch.utils.data.Subset(
dataset.create_torch_dataset('train'),
list(range(IMG_TO_TRAIN)))
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=BATCH_SIZE,
shuffle=True)
testset = torch.utils.data.Subset(
dataset.create_torch_dataset('test'),
list(range(IMG_TO_TEST)))
testloader = torch.utils.data.DataLoader(
testset,
batch_size=BATCH_SIZE,
shuffle=True)
evalset = torch.utils.data.Subset(
dataset.create_torch_dataset('test'),
list(range(IMG_TO_VAL)))
evalloader = torch.utils.data.DataLoader(
evalset,
batch_size=BATCH_SIZE,
shuffle=True)
```
%% Cell type:markdown id: tags:
Definition the helper functions
plot
sinogram
FBP + PSNR + SSIM
model output + PSNR / SSIM
ground truth
%% Cell type:code id: tags:
``` python
fbp_op = odl.tomo.fbp_op(dataset.ray_trafo, filter_type='Ram-Lak', frequency_scaling=1.0)
#fbp=odl.tomo.analytic.filtered_back_projection.fbp_op(dataset.get_ray_trafo())
FBP = op.OperatorAsModule(fbp_op)
def plot_Unet(n, model, part = "test"):
sinogram, ground_truth = dataset.get_samples(slice(n,n+1,1), part=part) #numpy, 1 x N x M
sinogram = torch.from_numpy(sinogram[0])
ground_truth = torch.from_numpy(ground_truth[0])
fbp_out = FBP(sinogram.unsqueeze(0).unsqueeze(0)).squeeze()
out = model(sinogram.unsqueeze(0).unsqueeze(0).to(DEVICE)).squeeze().cpu().detach() #torch N x M
min = torch.min(torch.tensor([torch.min(sinogram), torch.min(fbp_out), torch.min(out), torch.min(ground_truth)]))
max = torch.max(torch.tensor([torch.max(sinogram), torch.max(fbp_out), torch.max(out), torch.max(ground_truth)]))
psnr_fbp = PSNR(ground_truth.squeeze().numpy(), fbp_out.numpy())
psnr_model = PSNR(ground_truth.squeeze().numpy(), out.numpy())
ssim_fbp = SSIM(ground_truth.squeeze().numpy(), fbp_out.numpy())
ssim_model = SSIM(ground_truth.squeeze().numpy(), out.numpy())
plt.figure().set_dpi(150)
plt.subplot(1, 4, 1)
plt.imshow(sinogram.numpy(), vmin=min, vmax=max)
plt.title("sinogram")
plt.subplot(1, 4, 2)
plt.imshow(fbp_out.numpy(), vmin=min, vmax=max)
plt.title("FBP")
plt.figtext(0.2, 0.1, "PSNR:" + str(psnr_fbp) + "\nSSIM:"+ str(ssim_fbp))
plt.subplot(1, 4, 3)
plt.imshow(out.numpy(), vmin=min, vmax=max)
plt.title("model output")
plt.figtext(0.6, 0.1, "PSNR:"+ str(psnr_model) + "\nSSIM:" + str(ssim_model))
plt.subplot(1, 4, 4)
plt.imshow(ground_truth.numpy(), vmin=min, vmax=max)
plt.title("ground truth")
plt.suptitle("Unet Plots")
def plot_LPD(n, model, part = "test"):
sinogram, ground_truth = dataset.get_samples(slice(n,n+1,1), part=part) #numpy, 1 x N x M
sinogram = torch.from_numpy(sinogram[0])
ground_truth = torch.from_numpy(ground_truth[0])
out = model(sinogram.unsqueeze(0).unsqueeze(0).to(DEVICE)).squeeze().cpu().detach() #torch N x M
min = torch.min(torch.tensor([torch.min(sinogram), torch.min(out), torch.min(ground_truth)]))
max = torch.max(torch.tensor([torch.max(sinogram), torch.max(out), torch.max(ground_truth)]))
psnr_model = PSNR(ground_truth.squeeze().numpy(), out.numpy())
ssim_model = SSIM(ground_truth.squeeze().numpy(), out.numpy())
plt.figure().set_dpi(150)
plt.subplot(1, 3, 1)
plt.imshow(sinogram.numpy(), vmin=min, vmax=max)
plt.title("sinogram")
plt.subplot(1, 3, 2)
plt.imshow(out.numpy(), vmin=min, vmax=max)
plt.title("model output")
plt.figtext(0.4, 0.1, "PSNR:"+ str(psnr_model) + "\nSSIM:" + str(ssim_model))
plt.subplot(1, 3, 3)
plt.imshow(ground_truth.numpy(), vmin=min, vmax=max)
plt.title("ground truth")
plt.suptitle("LPD Plots")
def evaluate(model, loader):
psnr = 0
ssim = 0
nb = 0
for batch_id, samples in enumerate(loader):
sinogram, ground_truth = samples
out = model(sinogram.unsqueeze(1).to(DEVICE)).squeeze(dim=1).cpu().detach()
for n in range(sinogram.shape[0]):
psnr += PSNR(ground_truth.numpy()[n], out.numpy()[n])
ssim += SSIM(ground_truth.numpy()[n], out.numpy()[n])
nb += sinogram.shape[0]
return psnr/nb, ssim/nb
plot_Unet(10, model)
plot_LPD(10, model)
evaluate(model, evalloader)
```
%% Output
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-5-7be6f06ed16e> in <module>
86 return psnr/nb, ssim/nb
87
---> 88 plot_Unet(10, model)
89 plot_LPD(10, model)
90 evaluate(model, evalloader)
NameError: name 'model' is not defined
%% Cell type:markdown id: tags:
Definition of the Nets
%% Cell type:code id: tags:
``` python
class UNet_5Layer(nn.Module):
def __init__(self, m = 16, n = 32, o = 64, p = 64, q = 128):
super(UNet_5Layer, self).__init__()
#U-net from https://arxiv.org/pdf/1910.01113v2.pdf
self.conv1 = nn.Conv2d(1, m, 3)
self.norm1 = torch.nn.BatchNorm2d(m)
self.conv2 = nn.Conv2d(m, n, 5, stride = 2)
self.norm2 = torch.nn.BatchNorm2d(n)
self.conv3 = nn.Conv2d(n, n, 3)
self.norm3 = torch.nn.BatchNorm2d(n)
self.conv4 = nn.Conv2d(n, o, 3, stride = 2)
self.norm4 = torch.nn.BatchNorm2d(o)
self.conv5 = nn.Conv2d(o, o, 3)
self.norm5 = torch.nn.BatchNorm2d(o)
self.conv6 = nn.Conv2d(o, p, 3, stride = 2)
self.norm6 = torch.nn.BatchNorm2d(p)
self.conv7 = nn.Conv2d(p, p, 3)
self.norm7 = torch.nn.BatchNorm2d(p)
self.conv8 = nn.Conv2d(p, q, 3, stride = 2)
self.norm8 = torch.nn.BatchNorm2d(q)
self.conv9 = nn.Conv2d(q, q, 3)
self.norm9 = torch.nn.BatchNorm2d(q)
self.up4 = nn.Upsample(scale_factor=2) #nn.Upsample([74, 74])
self.conv10 = nn.Conv2d(q, p, 3)
self.norm10 = torch.nn.BatchNorm2d(p)
self.conv11 = nn.Conv2d(p+4, p, 3)
self.norm11 = torch.nn.BatchNorm2d(p)
self.up3 = nn.Upsample(scale_factor=2)
self.conv12 = nn.Conv2d(p, o, 3)
self.norm12 = torch.nn.BatchNorm2d(o)
self.conv13 = nn.Conv2d(o+4, o, 3)
self.norm13 = torch.nn.BatchNorm2d(o)
self.up2 = nn.Upsample(scale_factor=2)
self.conv14 = nn.Conv2d(o, n, 3)
self.norm14 = torch.nn.BatchNorm2d(n)
self.conv15 = nn.Conv2d(n+4, n, 3)
self.norm15 = torch.nn.BatchNorm2d(n)
self.up2 = nn.Upsample(scale_factor=2)
self.conv16 = nn.Conv2d(n, m, 3)
self.norm16 = torch.nn.BatchNorm2d(m)
self.conv17 = nn.Conv2d(m+4, 1, 1)
self.skip1 = nn.Conv2d(m, 4, 1)
self.skip2 = nn.Conv2d(n, 4, 1)
self.skip3 = nn.Conv2d(o, 4, 1)
self.skip4 = nn.Conv2d(p, 4, 1)
def forward(self, inp):
with torch.cuda.device(DEVICE):
a = F.leaky_relu(self.norm1(self.conv1(inp)), negative_slope=0.2)#torch.Size([1, 16, 360, 360])
b = F.leaky_relu(self.norm2(self.conv2(a)), negative_slope=0.2)
b = F.leaky_relu(self.norm3(self.conv3(b)), negative_slope=0.2)#torch.Size([1, 32, 176, 176])
c = F.leaky_relu(self.norm4(self.conv4(b)), negative_slope=0.2)
c = F.leaky_relu(self.norm5(self.conv5(c)), negative_slope=0.2) #torch.Size([1, 64, 85, 85])
d = F.leaky_relu(self.norm6(self.conv6(c)), negative_slope=0.2)
d = F.leaky_relu(self.norm7(self.conv7(d)), negative_slope=0.2) #torch.Size([1, 64, 40, 40])
e = F.leaky_relu(self.norm8(self.conv8(d)), negative_slope=0.2)
e = F.leaky_relu(self.norm9(self.conv9(e)), negative_slope=0.2)
e = F.leaky_relu(self.norm10(self.conv10(self.up4(e))), negative_slope=0.2) #torch.Size([1, 64, 32, 32])
d = self.skip4(d[:,:,4:-4, 4:-4])
d = F.leaky_relu(self.norm11(self.conv11(torch.cat((d, e), 1))), negative_slope=0.2)
d = F.leaky_relu(self.norm12(self.conv12(self.up3(d))), negative_slope=0.2)#torch.Size([1, 64, 58, 58])
c = self.skip3(c[:,:,13:-14,13:-14])
c = F.leaky_relu(self.norm13(self.conv13(torch.cat((c, d), 1))), negative_slope=0.2)
c = F.leaky_relu(self.norm14(self.conv14(self.up2(c))), negative_slope=0.2)#torch.Size([1, 32, 110, 110])
b = self.skip2(b[:,:,33:-33,33:-33])
b = F.leaky_relu(self.norm15(self.conv15(torch.cat((b, c), 1))), negative_slope=0.2)
b = F.leaky_relu(self.norm16(self.conv16(self.up2(b))), negative_slope=0.2)#torch.Size([1, 16, 214, 214])
a = self.skip1(a[:,:,73:-73,73:-73])
a = torch.sigmoid(self.conv17(torch.cat((a,b), 1)))
# out = a
out = F.interpolate(a, [362,362])
return out
```
%% Cell type:code id: tags:
``` python
class dual_iterate(nn.Module):
def __init__(self, N_dual):
super(dual_iterate, self).__init__()
self.d_modules = nn.Sequential(
nn.Conv2d(2+N_dual,32, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(32 ,32, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(32 ,5, 3, padding = 1)
)
def forward(self, h, f, g):
out = self.d_modules(torch.cat((h, f[:, 1:2, :, :], g), dim=1))
return out + h
class primal_iterate(nn.Module):
def __init__(self, N_primal):
super(primal_iterate, self).__init__()
self.p_modules = nn.Sequential(
nn.Conv2d(1+N_primal,32, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(32 ,32, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(32 ,5, 3, padding = 1)
)
def forward(self, inp, f):
out = self.p_modules(torch.cat((inp[:, 0:1, :, :], f), 1))
return out + f
class LPD_Net(nn.Module):
def __init__(self, I = 10, N_primal = 5, N_dual = 5):
super(LPD_Net, self).__init__()
self.I = I
self.N_primal = N_primal
self.N_dual = N_dual
self.primal_iterates = nn.ModuleList([primal_iterate(N_primal) for i in range(I)])
self.dual_iterates = nn.ModuleList([dual_iterate(N_dual) for i in range(I)])
self.T = op.OperatorAsModule(dataset.get_ray_trafo())
self.Tstar = op.OperatorAsModule(dataset.get_ray_trafo().adjoint)
self.dual_iterates.to(DEVICE)
self.primal_iterates.to(DEVICE)
def forward(self, g):
with torch.cuda.device(DEVICE):
#g = g.reshape(1, 1, 1000, 513)
h = torch.zeros(g.shape[0], self.N_dual, 1000, 513).to(DEVICE)
f = torch.zeros(g.shape[0], self.N_primal, 362, 362).to(DEVICE)
T_f = self.T(f)
"""print(g.shape)
print(f.shape)
print(h.shape)
print(T_f.shape)
print(T_f[:, 1:2, :, :].shape)"""
for i in range(self.I):
"""print(h.shape)
print(T_f.shape)
print(g.shape)"""
h = self.dual_iterates[i](h, T_f, g)
Tstar_h = self.Tstar(h)
f = self.primal_iterates[i](Tstar_h, f)
T_f = self.T(f)
return f[:,0:1,:,:]
```
%% Cell type:markdown id: tags:
get Net and Object
net is just the net
model is a function of the net which takes the sinogram as input
%% Cell type:code id: tags:
``` python
net = UNet_5Layer().to(DEVICE)
model = lambda x : net(FBP(x))
#net = LPD_Net(I=10).to(DEVICE)
#model = lambda x : net(x)
```
%% Cell type:markdown id: tags:
Optimizer
%% Cell type:code id: tags:
``` python
optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)
start_time = time.time()
for epoch in range(EPOCHS):
running_loss = 0.0
for batch_id, sample in enumerate(trainloader):
sinograms, ground_truths = sample
sinograms = torch.unsqueeze(sinograms, 1).to(DEVICE)
ground_truths = torch.unsqueeze(ground_truths, 1).to(DEVICE)
optimizer.zero_grad()
outputs = model(sinograms)
loss = nn.functional.mse_loss(outputs, ground_truths)
# print("loss: ", loss)
loss.backward()
optimizer.step()
running_loss += float(loss)
if batch_id % PRINT_AFTER == PRINT_AFTER-1: # print every PRINT_AFTER mini-batches
print('[%d, %s] loss: %.3f' %
(epoch + 1, str(batch_id + 1)+"/"+str(len(trainloader)), running_loss/ PRINT_AFTER))
running_loss = 0.0
if epoch % LR_UPDATE_AFTER == LR_UPDATE_AFTER-1:
LEARNING_RATE *= LR_UPDATE_FACTOR
optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)
if epoch % SAVE_AFTER == SAVE_AFTER-1:
torch.save(net.state_dict(), '../../../../../scratch/s21gpu1project/saves/U-Net_200621.pth')
end_time = time.time()
run_time = (end_time - start_time)/60
print("run time in minutes: ", run_time)
print('Finished Training')
```
%% Output
[1, 15/560] loss: 0.052
[1, 30/560] loss: 0.012
%% Cell type:markdown id: tags:
save net manually
%% Cell type:code id: tags:
``` python
torch.save(net.state_dict(), '../../../../../scratch/s21gpu1project/saves/LPD_net200621.pth')
```
%% Cell type:markdown id: tags:
load Unet
%% Cell type:code id: tags:
``` python
#loaded_model = UNet_5Layer().to(DEVICE)
net.load_state_dict(torch.load('../../../../../scratch/s21gpu1project/saves/LPD_net200621.pth'))
```
%% Output
<All keys matched successfully>
%% Cell type:code id: tags:
``` python
torch.cuda.empty_cache()
```
%% Cell type:code id: tags:
``` python
```
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment