Skip to content
Snippets Groups Projects
Commit 3ecdbd1a authored by Dawit Hailu's avatar Dawit Hailu
Browse files

our model definition, including for Primal Dual

parent b1e8ddfa
No related branches found
No related tags found
No related merge requests found
models.py 0 → 100644
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys as sys
from skimage.transform import resize
class Linear_Net(nn.Module):
"""
Defines a NN with 4 hidden layers with 200 nodes each. All layers are fully connected.
Uses ReLu activation function after each layer.
"""
def __init__(self):
super(Linear_Net, self).__init__()
#define each layer:
self.inputlayer = nn.Linear(1000*513, 200, True)
self.layer2 = nn.Linear(200,200, True)
self.layer3 = nn.Linear(200, 200, True)
self.layer4 = nn.Linear(200, 200, True)
self.layer5 = nn.Linear(200, 362*362, True)
def forward(self, inp):
"""
Computes the output of the NN for the input inp.
Applies Layers and the activation function.
"""
x = self.inputlayer(inp)
x = F.relu(x)
x = self.layer2(x)
x = F.relu(x)
x = self.layer3(x)
x = F.relu(x)
x = self.layer4(x)
x = F.relu(x)
x = self.layer5(x)
x = F.relu(x)
return x
class Conv_Net(nn.Module):
def __init__(self):
super(Conv_Net, self).__init__()
self.CL1 = nn.Conv2d(1, 64, 3)
self.PL1 = nn.MaxPool2d(2) #output ~ 250 x 500 x X
self.CL2 = nn.Conv2d(64, 64, 3)
self.PL2 = nn.MaxPool2d(2) #output ~ 120 x 250 x X
self.CL3 = nn.Conv2d(64, 128, 3)
self.PL3 = nn.MaxPool2d(2) #output ~ 60 x 120 x X
self.CL4 = nn.Conv2d(128, 128, 3)
self.PL4 = nn.MaxPool2d(2) #output torch.Size([1, X, 61, 31])
self.CL5 = nn.Conv2d(128, 256, 3)
self.PL5 = nn.MaxPool2d(2) #torch.Size([1, X, 30, 15])
self.CL6 = nn.Conv2d(256, 256, 3)
self.PL6 = nn.MaxPool2d(2) #torch.Size([1, 16, 14, 7])
#self.CL7 = nn.Conv2d(256, 256, 2)
#self.PL7 = nn.MaxPool2d(2) #torch.Size([1, X, 6, 3])
self.Layer8 = nn.Linear(256*13*6, 362*362)
def forward(self, inp): #inp is a vector of inputsize
x = inp.reshape(1,1,1000,513)
x = F.relu(self.CL1(x))
x = self.PL1(x)
x = F.relu(self.CL2(x))
x = self.PL2(x)
x = F.relu(self.CL3(x))
x = self.PL3(x)
x = F.relu(self.CL4(x))
x = self.PL4(x)
#torch.Size([1, 128, 61, 31])
x = F.relu(self.CL5(x))
x = self.PL5(x)
x = F.relu(self.CL6(x))
x = self.PL6(x)
#x = F.relu(self.CL7(x))
#x = self.PL7(x)
x = x.reshape(256*13*6)
x = self.Layer8(x)
x = nn.Sigmoid()(x)
return x
class UNet_4Layer_without_normalizing(nn.Module):
def __init__(self):
super(UNet_4Layer_without_normalizing, self).__init__()
#U-net from https://arxiv.org/pdf/1505.04597v1.pdf, 1 down less
self.conv1 = nn.Conv2d(1, 64, 3)
self.conv2 = nn.Conv2d(64, 64, 3)
self.down1 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(64, 128, 3)
self.conv4 = nn.Conv2d(128, 128, 3)
self.down2 = nn.MaxPool2d(2)
self.conv5 = nn.Conv2d(128, 256, 3)
self.conv6 = nn.Conv2d(256, 256, 3)
self.down3 = nn.MaxPool2d(2)
self.conv7 = nn.Conv2d(256, 512, 3)
self.conv8 = nn.Conv2d(512, 512, 3)
self.up1 = nn.ConvTranspose2d(512, 256, 3, stride=2)
self.conv9 = nn.Conv2d(512, 256, 3)
self.conv10 = nn.Conv2d(256, 256, 3)
self.up2 = nn.ConvTranspose2d(256, 128, 3, stride=2)
self.conv11 = nn.Conv2d(256, 128, 3)
self.conv12 = nn.Conv2d(128, 128, 3)
self.up3 = nn.ConvTranspose2d(128, 64, 3, stride=2)
self.conv13 = nn.Conv2d(128, 64, 3)
self.conv14 = nn.Conv2d(64, 64, 3)
self.conv15 = nn.Conv2d(64, 1, 1)
self.down4 = nn.MaxPool2d(2)
#self.lin = nn.Linear(457*213,362**2)
self.skip1 = nn.Conv2d(256,256,1)
self.skip2 = nn.Conv2d(128,128,1)
self.skip3 = nn.Conv2d(64,64,1)
def forward(self, inp):
x = inp.reshape(1,1,362,362).float()
x = F.relu(self.conv1(x)) # 1 x 64 x 998 x 511
x = F.relu(self.conv2(x)) # 1 x 64 x 996 x 509
y = self.down1(x) # 1 x 64 x 498 x 254
y = F.relu(self.conv3(y)) # 1 x 128 x 496 x 252
y = F.relu(self.conv4(y)) # 1 x 128 x 494 x 250
z = self.down1(y) # 1 x 128 x 247 x 125
z = F.relu(self.conv5(z)) # 1 x 256 x 245 x 123
z = F.relu(self.conv6(z)) # 1 x 256 x 243 x 121
a = self.down1(z) # 1 x 256 x 121 x 60
a = F.relu(self.conv7(a)) # 1 x 512 x 119 x 58
a = F.relu(self.conv8(a)) # 1 x 512 x 117 x 56
a = self.up1(a) # 1 x 256 x 235 x 113
z = torch.cat( (a, self.skip1(z[:,:,4:-4,4:-4])) , 1) # 1 x 512 x 235 x113
z = F.relu(self.conv9(z)) # 1 x 256 x 233 x 111
z = F.relu(self.conv10(z)) # 1 x 256 x 231 x 109
z = self.up2(z) # 1 x 128 x 463 x 219
y = torch.cat( (z, self.skip2(y[:,:,15:-16,15:-16])) , 1) # !!!!!!
y = F.relu(self.conv11(y)) # 1 x 128 x 461 x 217
y = F.relu(self.conv12(y)) # 1 x 128 x 459 x 215
y = self.up3(y) # 1 x 64 x 919 x 431
x = torch.cat( (y, self.skip3(x[:,:,38:-39,39:-39])), 1)
x = F.relu(self.conv13(x)) # 1 x 64 x 917 x 429
x = F.relu(self.conv14(x)) # 1 x 64 x 915 x 427
x = F.relu(self.conv15(x)) # 1 x 1 x 915 x 427
#x = self.down4(x) #1 x 1x 457 x 213
#out = torch.sigmoid(self.lin(x.reshape(457*213)))
out = F.interpolate(x, [362,362])
return out.reshape(362,362)
class UNet_4Layer(nn.Module):
def __init__(self, m=128, n=256, o=512, p=512):
super(UNet_4Layer, self).__init__()
#U-net from https://arxiv.org/pdf/1505.04597v1.pdf, 1 down less
self.conv1 = nn.Conv2d(1, m, 3)
self.norm1 = torch.nn.BatchNorm2d(m)
self.conv2 = nn.Conv2d(m, m, 3)
self.norm2 = torch.nn.BatchNorm2d(m)
self.down1 = nn.MaxPool2d(2)
self.norm3 = torch.nn.BatchNorm2d(m)
self.conv3 = nn.Conv2d(m, n, 3)
self.norm4 = torch.nn.BatchNorm2d(n)
self.conv4 = nn.Conv2d(n, n, 3)
self.norm5 = torch.nn.BatchNorm2d(n)
self.down2 = nn.MaxPool2d(2)
self.norm6 = torch.nn.BatchNorm2d(n)
self.conv5 = nn.Conv2d(n, o, 3)
self.norm7 = torch.nn.BatchNorm2d(o)
self.conv6 = nn.Conv2d(o, o, 3)
self.norm8 = torch.nn.BatchNorm2d(o)
self.down3 = nn.MaxPool2d(2)
self.norm9 = torch.nn.BatchNorm2d(o)
self.conv7 = nn.Conv2d(o, p, 3)
self.norm10 = torch.nn.BatchNorm2d(p)
self.conv8 = nn.Conv2d(p, p, 3)
self.norm11 = torch.nn.BatchNorm2d(p)
self.up1 = nn.Upsample(scale_factor=2) #nn.Upsample([74, 74])
self.norm12 = torch.nn.BatchNorm2d(o+p)
self.conv9 = nn.Conv2d(o+p, o, 3)
self.norm13 = torch.nn.BatchNorm2d(o)
self.conv10 = nn.Conv2d(o, o, 3)
self.norm14 = torch.nn.BatchNorm2d(o)
self.up2 = nn.Upsample(scale_factor=2)
self.norm15 = torch.nn.BatchNorm2d(o+n)
self.conv11 = nn.Conv2d(o+n, n, 3)
self.norm16 = torch.nn.BatchNorm2d(n)
self.conv12 = nn.Conv2d(n, n, 3)
self.norm17 = torch.nn.BatchNorm2d(n)
self.up3 = nn.Upsample(scale_factor=2)
self.norm18 = torch.nn.BatchNorm2d(n+m)
self.conv13 = nn.Conv2d(n+m, m, 3)
self.norm19 = torch.nn.BatchNorm2d(m)
self.conv14 = nn.Conv2d(m, m, 3)
self.norm20 = torch.nn.BatchNorm2d(m)
self.conv15 = nn.Conv2d(m, 1, 1)
#self.down4 = nn.MaxPool2d(2)
#self.lin = nn.Linear(457*213,362**2)
self.skip1 = nn.Conv2d(o, o,1)
self.skip2 = nn.Conv2d(n,n,1)
self.skip3 = nn.Conv2d(m,m,1)
def forward(self, inp):
x = inp.reshape(1,1,362,362).float()
x = F.leaky_relu(self.norm1(self.conv1(x))) # 1 x 64 x 998 x 511
x = F.leaky_relu(self.norm2(self.conv2(x))) # 1 x 64 x 996 x 509
y = self.norm3(self.down1(x)) # 1 x 64 x 498 x 254
y = F.leaky_relu(self.norm4(self.conv3(y))) # 1 x 128 x 496 x 252
y = F.leaky_relu(self.norm5(self.conv4(y)))# 1 x 128 x 494 x 250
z = self.norm6(self.down1(y)) # 1 x 128 x 247 x 125
z = F.leaky_relu(self.norm7(self.conv5(z))) # 1 x 256 x 245 x 123
z = F.leaky_relu(self.norm8(self.conv6(z))) # 1 x 256 x 243 x 121
a = self.norm9(self.down1(z)) # 1 x 256 x 121 x 60
a = F.leaky_relu(self.norm10(self.conv7(a))) # 1 x 512 x 119 x 58
a = F.leaky_relu(self.norm11(self.conv8(a))) # 1 x 512 x 117 x 56
a = self.up1(a) # 1 x 256 x 235 x 113
z = self.norm12(torch.cat( (a, self.skip1(z[:,:,4:-5,4:-5])) , 1)) # 1 x 512 x 235 x113
z = F.leaky_relu(self.norm13(self.conv9(z))) # 1 x 256 x 233 x 111
z = F.leaky_relu(self.norm14(self.conv10(z))) # 1 x 256 x 231 x 109
z = self.up2(z) # 1 x 128 x 463 x 219
y = self.norm15(torch.cat( (z, self.skip2(y[:,:,17:-18,17:-18])) , 1)) # !!!!!!
y = F.leaky_relu(self.norm16(self.conv11(y))) # 1 x 128 x 461 x 217
y = F.leaky_relu(self.norm17(self.conv12(y))) # 1 x 128 x 459 x 215
y = self.up3(y) # 1 x 64 x 919 x 431
x = self.norm18(torch.cat( (y, self.skip3(x[:,:,43:-43,43:-43])), 1))
x = F.leaky_relu(self.norm19(self.conv13(x))) # 1 x 64 x 917 x 429
x = F.leaky_relu(self.norm20(self.conv14(x))) # 1 x 64 x 915 x 427
x = torch.sigmoid(self.conv15(x)) # 1 x 1 x 915 x 427
#x = self.down4(x) #1 x 1x 457 x 213
#out = torch.sigmoid(self.lin(x.reshape(457*213)))
out = F.interpolate(x, [362,362])
return out.reshape(362,362)
class UNet_5x5conv(nn.Module): #same as above, 5x5 conv and 1 padding
def __init__(self):
super(UNet_5x5conv, self).__init__()
#U-net from https://arxiv.org/pdf/1505.04597v1.pdf, 1 down less
m = 128
n = 256
o = 512
p = 512
self.conv1 = nn.Conv2d(1, m, 5, padding=1)
self.norm1 = torch.nn.BatchNorm2d(m)
self.conv2 = nn.Conv2d(m, m, 5, padding=1)
self.norm2 = torch.nn.BatchNorm2d(m)
self.down1 = nn.MaxPool2d(2)
self.norm3 = torch.nn.BatchNorm2d(m)
self.conv3 = nn.Conv2d(m, n, 5, padding=1)
self.norm4 = torch.nn.BatchNorm2d(n)
self.conv4 = nn.Conv2d(n, n, 5, padding=1)
self.norm5 = torch.nn.BatchNorm2d(n)
self.down2 = nn.MaxPool2d(2)
self.norm6 = torch.nn.BatchNorm2d(n)
self.conv5 = nn.Conv2d(n, o, 5, padding=1)
self.norm7 = torch.nn.BatchNorm2d(o)
self.conv6 = nn.Conv2d(o, o, 5, padding=1)
self.norm8 = torch.nn.BatchNorm2d(o)
self.down3 = nn.MaxPool2d(2)
self.norm9 = torch.nn.BatchNorm2d(o)
self.conv7 = nn.Conv2d(o, p, 5, padding=1)
self.norm10 = torch.nn.BatchNorm2d(p)
self.conv8 = nn.Conv2d(p, p, 5, padding=1)
self.norm11 = torch.nn.BatchNorm2d(p)
self.up1 = nn.Upsample(scale_factor=2) #nn.Upsample([74, 74])
self.norm12 = torch.nn.BatchNorm2d(o+p)
self.conv9 = nn.Conv2d(o+p, o, 5, padding=1)
self.norm13 = torch.nn.BatchNorm2d(o)
self.conv10 = nn.Conv2d(o, o, 5, padding=1)
self.norm14 = torch.nn.BatchNorm2d(o)
self.up2 = nn.Upsample(scale_factor=2)
self.norm15 = torch.nn.BatchNorm2d(o+n)
self.conv11 = nn.Conv2d(o+n, n, 5, padding=1)
self.norm16 = torch.nn.BatchNorm2d(n)
self.conv12 = nn.Conv2d(n, n, 5, padding=1)
self.norm17 = torch.nn.BatchNorm2d(n)
self.up3 = nn.Upsample(scale_factor=2)
self.norm18 = torch.nn.BatchNorm2d(n+m)
self.conv13 = nn.Conv2d(n+m, m, 5, padding=1)
self.norm19 = torch.nn.BatchNorm2d(m)
self.conv14 = nn.Conv2d(m, m, 5, padding=1)
self.norm20 = torch.nn.BatchNorm2d(m)
self.conv15 = nn.Conv2d(m, 1, 1)
#self.down4 = nn.MaxPool2d(2)
#self.lin = nn.Linear(457*213,362**2)
self.skip1 = nn.Conv2d(o, o,1)
self.skip2 = nn.Conv2d(n,n,1)
self.skip3 = nn.Conv2d(m,m,1)
def forward(self, inp):
x = inp.reshape(1,1,362,362).float()
x = F.leaky_relu(self.norm1(self.conv1(x))) # 1 x 64 x 998 x 511
x = F.leaky_relu(self.norm2(self.conv2(x))) # 1 x 64 x 996 x 509
y = self.norm3(self.down1(x)) # 1 x 64 x 498 x 254
y = F.leaky_relu(self.norm4(self.conv3(y))) # 1 x 128 x 496 x 252
y = F.leaky_relu(self.norm5(self.conv4(y)))# 1 x 128 x 494 x 250
z = self.norm6(self.down1(y)) # 1 x 128 x 247 x 125
z = F.leaky_relu(self.norm7(self.conv5(z))) # 1 x 256 x 245 x 123
z = F.leaky_relu(self.norm8(self.conv6(z))) # 1 x 256 x 243 x 121
a = self.norm9(self.down1(z)) # 1 x 256 x 121 x 60
a = F.leaky_relu(self.norm10(self.conv7(a))) # 1 x 512 x 119 x 58
a = F.leaky_relu(self.norm11(self.conv8(a))) # 1 x 512 x 117 x 56
a = self.up1(a) # 1 x 256 x 235 x 113
z = self.norm12(torch.cat( (a, self.skip1(z[:,:,4:-5,4:-5])) , 1)) # 1 x 512 x 235 x113
z = F.leaky_relu(self.norm13(self.conv9(z))) # 1 x 256 x 233 x 111
z = F.leaky_relu(self.norm14(self.conv10(z))) # 1 x 256 x 231 x 109
z = self.up2(z) # 1 x 128 x 463 x 219
y = self.norm15(torch.cat( (z, self.skip2(y[:,:,17:-18,17:-18])) , 1)) # !!!!!!
y = F.leaky_relu(self.norm16(self.conv11(y))) # 1 x 128 x 461 x 217
y = F.leaky_relu(self.norm17(self.conv12(y))) # 1 x 128 x 459 x 215
y = self.up3(y) # 1 x 64 x 919 x 431
x = self.norm18(torch.cat( (y, self.skip3(x[:,:,43:-43,43:-43])), 1))
x = F.leaky_relu(self.norm19(self.conv13(x))) # 1 x 64 x 917 x 429
x = F.leaky_relu(self.norm20(self.conv14(x))) # 1 x 64 x 915 x 427
x = torch.sigmoid(self.conv15(x)) # 1 x 1 x 915 x 427
#x = self.down4(x) #1 x 1x 457 x 213
#out = torch.sigmoid(self.lin(x.reshape(457*213)))
out = F.interpolate(x, [362,362])
return out.reshape(362,362)
class UNet_5Layer(nn.Module): #exactly the same as in the paper
def __init__(self, m = 16, n = 32, o = 64, p = 64, q = 128):
super(UNet_5Layer, self).__init__()
#U-net from https://arxiv.org/pdf/1910.01113v2.pdf
self.conv1 = nn.Conv2d(1, m, 3)
self.norm1 = torch.nn.BatchNorm2d(m)
self.conv2 = nn.Conv2d(m, n, 5, stride = 2)
self.norm2 = torch.nn.BatchNorm2d(n)
self.conv3 = nn.Conv2d(n, n, 3)
self.norm3 = torch.nn.BatchNorm2d(n)
self.conv4 = nn.Conv2d(n, o, 3, stride = 2)
self.norm4 = torch.nn.BatchNorm2d(o)
self.conv5 = nn.Conv2d(o, o, 3)
self.norm5 = torch.nn.BatchNorm2d(o)
self.conv6 = nn.Conv2d(o, p, 3, stride = 2)
self.norm6 = torch.nn.BatchNorm2d(p)
self.conv7 = nn.Conv2d(p, p, 3)
self.norm7 = torch.nn.BatchNorm2d(p)
self.conv8 = nn.Conv2d(p, q, 3, stride = 2)
self.norm8 = torch.nn.BatchNorm2d(q)
self.conv9 = nn.Conv2d(q, q, 3)
self.norm9 = torch.nn.BatchNorm2d(q)
self.up4 = nn.Upsample(scale_factor=2) #nn.Upsample([74, 74])
self.conv10 = nn.Conv2d(q, p, 3)
self.norm10 = torch.nn.BatchNorm2d(p)
self.conv11 = nn.Conv2d(p+4, p, 3)
self.norm11 = torch.nn.BatchNorm2d(p)
self.up3 = nn.Upsample(scale_factor=2)
self.conv12 = nn.Conv2d(p, o, 3)
self.norm12 = torch.nn.BatchNorm2d(o)
self.conv13 = nn.Conv2d(o+4, o, 3)
self.norm13 = torch.nn.BatchNorm2d(o)
self.up2 = nn.Upsample(scale_factor=2)
self.conv14 = nn.Conv2d(o, n, 3)
self.norm14 = torch.nn.BatchNorm2d(n)
self.conv15 = nn.Conv2d(n+4, n, 3)
self.norm15 = torch.nn.BatchNorm2d(n)
self.up2 = nn.Upsample(scale_factor=2)
self.conv16 = nn.Conv2d(n, m, 3)
self.norm16 = torch.nn.BatchNorm2d(m)
self.conv17 = nn.Conv2d(m+4, 1, 1)
self.skip1 = nn.Conv2d(m, 4, 1)
self.skip2 = nn.Conv2d(n, 4, 1)
self.skip3 = nn.Conv2d(o, 4, 1)
self.skip4 = nn.Conv2d(p, 4, 1)
def forward(self, inp):
a = inp.reshape(1,1,362,362).float()
a = F.leaky_relu(self.norm1(self.conv1(a)), negative_slope=0.2)#torch.Size([1, 16, 360, 360])
b = F.leaky_relu(self.norm2(self.conv2(a)), negative_slope=0.2)
b = F.leaky_relu(self.norm3(self.conv3(b)), negative_slope=0.2)#torch.Size([1, 32, 176, 176])
c = F.leaky_relu(self.norm4(self.conv4(b)), negative_slope=0.2)
c = F.leaky_relu(self.norm5(self.conv5(c)), negative_slope=0.2) #torch.Size([1, 64, 85, 85])
d = F.leaky_relu(self.norm6(self.conv6(c)), negative_slope=0.2)
d = F.leaky_relu(self.norm7(self.conv7(d)), negative_slope=0.2) #torch.Size([1, 64, 40, 40])
e = F.leaky_relu(self.norm8(self.conv8(d)), negative_slope=0.2)
e = F.leaky_relu(self.norm9(self.conv9(e)), negative_slope=0.2)
e = F.leaky_relu(self.norm10(self.conv10(self.up4(e))), negative_slope=0.2) #torch.Size([1, 64, 32, 32])
d = self.skip4(d[:,:,4:-4, 4:-4])
d = F.leaky_relu(self.norm11(self.conv11(torch.cat((d, e), 1))), negative_slope=0.2)
d = F.leaky_relu(self.norm12(self.conv12(self.up3(d))), negative_slope=0.2)#torch.Size([1, 64, 58, 58])
c = self.skip3(c[:,:,13:-14,13:-14])
c = F.leaky_relu(self.norm13(self.conv13(torch.cat((c, d), 1))), negative_slope=0.2)
c = F.leaky_relu(self.norm14(self.conv14(self.up2(c))), negative_slope=0.2)#torch.Size([1, 32, 110, 110])
b = self.skip2(b[:,:,33:-33,33:-33])
b = F.leaky_relu(self.norm15(self.conv15(torch.cat((b, c), 1))), negative_slope=0.2)
b = F.leaky_relu(self.norm16(self.conv16(self.up2(b))), negative_slope=0.2)#torch.Size([1, 16, 214, 214])
a = self.skip1(a[:,:,73:-73,73:-73])
a = torch.sigmoid(self.conv17(torch.cat((a,b), 1)))
out = F.interpolate(a, [362,362])
return out.reshape(362,362)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment