Skip to content
Snippets Groups Projects
Commit 6fc7e700 authored by Maike Rösch's avatar Maike Rösch
Browse files

final code jupyter notebook added

parent 62967873
Branches
No related tags found
No related merge requests found
# Copyright 2014-2018 The ODL contributors
#
# This file is part of ODL.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
from __future__ import division
from packaging.version import parse as parse_version
import warnings
import numpy as np
import torch
if parse_version(torch.__version__) < parse_version('0.4'):
warnings.warn("This interface is designed to work with Pytorch >= 0.4",
RuntimeWarning)
__all__ = ('OperatorAsAutogradFunction', 'OperatorAsModule')
# TODO: ProductSpaceOperator as implementation of channels_in and channels_out?
class OperatorAsAutogradFunction(torch.autograd.Function):
#konstruktor weg, static. operator speichern statt im operator in ctx in der forwardmethode, operator wird an fwd übergeben, self wegmachen
@staticmethod
def forward(ctx, input, operator):
ctx.operator = operator
# TODO: batched evaluation
if not operator.is_linear:
# Only needed for nonlinear operators
ctx.save_for_backward(input)
# TODO: use GPU memory directly if possible
input_arr = input.cpu().detach().numpy()
if any(s == 0 for s in input_arr.strides):
# TODO: remove when Numpy issue #9165 is fixed
# https://github.com/numpy/numpy/pull/9177
input_arr = input_arr.copy()
op_result = operator(input_arr)
if np.isscalar(op_result):
# For functionals, the result is funnelled through `float`,
# so we wrap it into a Numpy array with the same dtype as
# `operator.domain`
op_result = np.array(op_result, ndmin=1,
dtype=operator.domain.dtype)
tensor = torch.from_numpy(np.array(op_result, copy=False, ndmin=1))
#if input.is_cuda:
# Push back to GPU
tensor = tensor.to(input.device)
return tensor
#static, self.op -> ctx.op, self -> ctx
@staticmethod
def backward(ctx, grad_output):
# TODO: implement directly for GPU data
if not ctx.operator.is_linear:
input_arr = ctx.saved_variables[0].data.cpu().numpy()
if any(s == 0 for s in input_arr.strides):
# TODO: remove when Numpy issue #9165 is fixed
# https://github.com/numpy/numpy/pull/9177
input_arr = input_arr.copy()
grad = None
# ODL weights spaces, pytorch doesn't, so we need to handle this
try:
dom_weight = ctx.operator.domain.weighting.const
except AttributeError:
dom_weight = 1.0
try:
ran_weight = ctx.operator.range.weighting.const
except AttributeError:
ran_weight = 1.0
scaling = dom_weight / ran_weight
if ctx.needs_input_grad[0]:
grad_output_arr = grad_output.cpu().numpy()
if any(s == 0 for s in grad_output_arr.strides):
# TODO: remove when Numpy issue #9165 is fixed
# https://github.com/numpy/numpy/pull/9177
grad_output_arr = grad_output_arr.copy()
if ctx.operator.is_linear:
adjoint = ctx.operator.adjoint
else:
adjoint = ctx.operator.derivative(input_arr).adjoint
grad_odl = adjoint(grad_output_arr)
if scaling != 1.0:
grad_odl *= scaling
grad = torch.from_numpy(np.array(grad_odl, copy=False, ndmin=1))
if grad_output.is_cuda:
# Push back to GPU
grad = grad.cuda()
return grad, None
def __repr__(self):
"""Return ``repr(self)``."""
return '{}(\n {!r} \n)'.format(self.__class__.__name__,
self.operator)
class OperatorAsModule(torch.nn.Module):
#self.opfc.op -> self.op, opfc gibts nicht mehr, in 136: opfc(x_flat_extra[i]) -> FKT.apply(operator)
def __init__(self, operator):
super(OperatorAsModule, self).__init__()
self.operator = operator
def forward(self, x):
in_shape = x.data.shape
op_in_shape = self.operator.domain.shape
op_out_shape = self.operator.range.shape
extra_shape = in_shape[:-len(op_in_shape)]
if in_shape[-len(op_in_shape):] != op_in_shape or not extra_shape:
shp_str = str(op_in_shape).strip('()')
raise ValueError('expected input of shape (N, *, {}), got input '
'with shape {}'.format(shp_str, in_shape))
# Flatten extra axes, then do one entry at a time
newshape = (int(np.prod(extra_shape)),) + op_in_shape
x_flat_xtra = x.reshape(*newshape)
results = []
for i in range(x_flat_xtra.data.shape[0]):
results.append(OperatorAsAutogradFunction.apply(x_flat_xtra[i], self.operator))
# Reshape the resulting stack to the expected output shape
stack_flat_xtra = torch.stack(results)
return stack_flat_xtra.view(extra_shape + op_out_shape)
def __repr__(self):
"""Return ``repr(self)``."""
op_name = self.operator.__class__.__name__
op_dom_shape = self.operator.domain.shape
if len(op_dom_shape) == 1:
op_dom_shape = op_dom_shape[0]
op_ran_shape = self.operator.range.shape
if len(op_ran_shape) == 1:
op_ran_shape = op_ran_shape[0]
return '{}({}) ({} -> {})'.format(self.__class__.__name__,
op_name, op_dom_shape, op_ran_shape)
if __name__ == '__main__':
from odl.util.testutils import run_doctests
import odl
from torch import autograd, nn
run_doctests(extraglobs={'np': np, 'odl': odl, 'torch': torch,
'nn': nn, 'autograd': autograd})
\ No newline at end of file
This diff is collapsed.
import os
import astra
import odl
import numpy as np
import dival
import torch
import torch.nn as nn
import torch.nn.functional as F
from matplotlib import pyplot as plt
import torch.utils.data
import custom_odl_op as op
class Linear_Net(nn.Module):
"""
Defines a NN with 4 hidden layers with 200 nodes each. All layers are fully connected.
Uses ReLu activation function after each layer.
"""
def __init__(self):
super(Linear_Net, self).__init__()
#define each layer:
self.inputlayer = nn.Linear(1000*513, 200, True)
self.layer2 = nn.Linear(200,200, True)
self.layer3 = nn.Linear(200, 200, True)
self.layer4 = nn.Linear(200, 200, True)
self.layer5 = nn.Linear(200, 362*362, True)
def forward(self, inp):
"""
Computes the output of the NN for the input inp.
Applies Layers and the activation function.
"""
inp = inp.reshape(inp.shape[0], 1, 1, inp.shape[2] * inp.shape[3])
x = self.inputlayer(inp)
x = F.relu(x)
x = self.layer2(x)
x = F.relu(x)
x = self.layer3(x)
x = F.relu(x)
x = self.layer4(x)
x = F.relu(x)
x = self.layer5(x)
x = F.relu(x)
x = x.reshape(x.shape[0], 1, 362, 362)
return x
class first_Unet(nn.Module):
def __init__(self, m = 16, n = 32, o = 64, p = 64, q = 128, device=None):
super(first_Unet, self).__init__()
#U-net from https://arxiv.org/pdf/1910.01113v2.pdf
self.DEVICE = device
self.conv1 = nn.Conv2d(1, m, 3)
self.norm1 = torch.nn.BatchNorm2d(m)
self.conv2 = nn.Conv2d(m, n, 5, stride = 2)
self.norm2 = torch.nn.BatchNorm2d(n)
self.conv3 = nn.Conv2d(n, n, 3)
self.norm3 = torch.nn.BatchNorm2d(n)
self.conv4 = nn.Conv2d(n, o, 3, stride = 2)
self.norm4 = torch.nn.BatchNorm2d(o)
self.conv5 = nn.Conv2d(o, o, 3)
self.norm5 = torch.nn.BatchNorm2d(o)
self.conv6 = nn.Conv2d(o, p, 3, stride = 2)
self.norm6 = torch.nn.BatchNorm2d(p)
self.conv7 = nn.Conv2d(p, p, 3)
self.norm7 = torch.nn.BatchNorm2d(p)
self.conv8 = nn.Conv2d(p, q, 3, stride = 2)
self.norm8 = torch.nn.BatchNorm2d(q)
self.conv9 = nn.Conv2d(q, q, 3)
self.norm9 = torch.nn.BatchNorm2d(q)
self.up4 = nn.Upsample(scale_factor=2) #nn.Upsample([74, 74])
self.conv10 = nn.Conv2d(q, p, 3)
self.norm10 = torch.nn.BatchNorm2d(p)
self.conv11 = nn.Conv2d(p+4, p, 3)
self.norm11 = torch.nn.BatchNorm2d(p)
self.up3 = nn.Upsample(scale_factor=2)
self.conv12 = nn.Conv2d(p, o, 3)
self.norm12 = torch.nn.BatchNorm2d(o)
self.conv13 = nn.Conv2d(o+4, o, 3)
self.norm13 = torch.nn.BatchNorm2d(o)
self.up2 = nn.Upsample(scale_factor=2)
self.conv14 = nn.Conv2d(o, n, 3)
self.norm14 = torch.nn.BatchNorm2d(n)
self.conv15 = nn.Conv2d(n+4, n, 3)
self.norm15 = torch.nn.BatchNorm2d(n)
self.up2 = nn.Upsample(scale_factor=2)
self.conv16 = nn.Conv2d(n, m, 3)
self.norm16 = torch.nn.BatchNorm2d(m)
self.conv17 = nn.Conv2d(m+4, 1, 1)
self.skip1 = nn.Conv2d(m, 4, 1)
self.skip2 = nn.Conv2d(n, 4, 1)
self.skip3 = nn.Conv2d(o, 4, 1)
self.skip4 = nn.Conv2d(p, 4, 1)
def forward(self, inp):
with torch.cuda.device(self.DEVICE):
a = F.leaky_relu(self.norm1(self.conv1(inp)), negative_slope=0.2)#torch.Size([1, 16, 360, 360])
b = F.leaky_relu(self.norm2(self.conv2(a)), negative_slope=0.2)
b = F.leaky_relu(self.norm3(self.conv3(b)), negative_slope=0.2)#torch.Size([1, 32, 176, 176])
c = F.leaky_relu(self.norm4(self.conv4(b)), negative_slope=0.2)
c = F.leaky_relu(self.norm5(self.conv5(c)), negative_slope=0.2) #torch.Size([1, 64, 85, 85])
d = F.leaky_relu(self.norm6(self.conv6(c)), negative_slope=0.2)
d = F.leaky_relu(self.norm7(self.conv7(d)), negative_slope=0.2) #torch.Size([1, 64, 40, 40])
e = F.leaky_relu(self.norm8(self.conv8(d)), negative_slope=0.2)
e = F.leaky_relu(self.norm9(self.conv9(e)), negative_slope=0.2)
e = F.leaky_relu(self.norm10(self.conv10(self.up4(e))), negative_slope=0.2) #torch.Size([1, 64, 32, 32])
d = self.skip4(d[:,:,4:-4, 4:-4])
d = F.leaky_relu(self.norm11(self.conv11(torch.cat((d, e), 1))), negative_slope=0.2)
d = F.leaky_relu(self.norm12(self.conv12(self.up3(d))), negative_slope=0.2)#torch.Size([1, 64, 58, 58])
c = self.skip3(c[:,:,13:-14,13:-14])
c = F.leaky_relu(self.norm13(self.conv13(torch.cat((c, d), 1))), negative_slope=0.2)
c = F.leaky_relu(self.norm14(self.conv14(self.up2(c))), negative_slope=0.2)#torch.Size([1, 32, 110, 110])
b = self.skip2(b[:,:,33:-33,33:-33])
b = F.leaky_relu(self.norm15(self.conv15(torch.cat((b, c), 1))), negative_slope=0.2)
b = F.leaky_relu(self.norm16(self.conv16(self.up2(b))), negative_slope=0.2)#torch.Size([1, 16, 214, 214])
a = self.skip1(a[:,:,73:-73,73:-73])
a = torch.sigmoid(self.conv17(torch.cat((a,b), 1)))
# out = a
out = F.interpolate(a, [362,362])
return out
def get_paper_unet_model(in_ch=1, out_ch=1, scales=5, skip=4,
channels=(32, 32, 64, 64, 128, 128), use_sigmoid=True,
use_norm=True):
assert (1 <= scales <= 6)
skip_channels = [skip] * (scales)
return paper_UNet(in_ch=in_ch, out_ch=out_ch, channels=channels[:scales],
skip_channels=skip_channels, use_sigmoid=use_sigmoid,
use_norm=use_norm)
class paper_UNet(nn.Module):
def __init__(self, in_ch, out_ch, channels, skip_channels,
use_sigmoid=True, use_norm=True):
super(paper_UNet, self).__init__()
assert (len(channels) == len(skip_channels))
self.scales = len(channels)
self.use_sigmoid = use_sigmoid
self.down = nn.ModuleList()
self.up = nn.ModuleList()
self.inc = InBlock(in_ch, channels[0], use_norm=use_norm)
for i in range(1, self.scales):
self.down.append(DownBlock(in_ch=channels[i - 1],
out_ch=channels[i],
use_norm=use_norm))
for i in range(1, self.scales):
self.up.append(UpBlock(in_ch=channels[-i],
out_ch=channels[-i - 1],
skip_ch=skip_channels[-i],
use_norm=use_norm))
self.outc = OutBlock(in_ch=channels[0],
out_ch=out_ch)
def forward(self, x0):
xs = [self.inc(x0), ]
for i in range(self.scales - 1):
xs.append(self.down[i](xs[-1]))
x = xs[-1]
for i in range(self.scales - 1):
x = self.up[i](x, xs[-2 - i])
return torch.sigmoid(self.outc(x)) if self.use_sigmoid else self.outc(x)
def get_mod_unet_model(in_ch=1, out_ch=1, scales=5, skip=4,
channels=(32, 32, 64, 64, 128, 128), use_sigmoid=True,
use_norm=True):
assert (1 <= scales <= 6)
skip_channels = [skip] * (scales)
return mod_UNet(in_ch=in_ch, out_ch=out_ch, channels=channels[:scales],
skip_channels=skip_channels, use_sigmoid=use_sigmoid,
use_norm=use_norm)
class mod_UNet(nn.Module):
def __init__(self, in_ch, out_ch, channels, skip_channels,
use_sigmoid=True, use_norm=True):
super(mod_UNet, self).__init__()
assert (len(channels) == len(skip_channels))
self.scales = len(channels)
self.use_sigmoid = use_sigmoid
self.down = nn.ModuleList()
self.up = nn.ModuleList()
self.inc = InBlock(in_ch, channels[0], use_norm=use_norm)
for i in range(1, self.scales):
self.down.append(DownBlock(in_ch=channels[i - 1],
out_ch=channels[i],
use_norm=use_norm))
for i in range(1, self.scales):
self.up.append(UpBlock(in_ch=channels[-i],
out_ch=channels[-i - 1],
skip_ch=skip_channels[-i],
use_norm=use_norm))
self.outc = OutBlock(in_ch=channels[0],
out_ch=out_ch)
def forward(self, x0):
xs = [self.inc(x0), ]
for i in range(self.scales - 1):
xs.append(self.down[i](xs[-1]))
x = xs[-1]
for i in range(self.scales - 1):
x = self.up[i](x, xs[-2 - i])
return torch.sigmoid(self.outc(x))+x0 if self.use_sigmoid else self.outc(x)+x0
class DownBlock(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size=3, use_norm=True):
super(DownBlock, self).__init__()
to_pad = int((kernel_size - 1) / 2)
if use_norm:
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size,
stride=2, padding=to_pad),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size,
stride=1, padding=to_pad),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(0.2, inplace=True))
else:
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size,
stride=2, padding=to_pad),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size,
stride=1, padding=to_pad),
nn.LeakyReLU(0.2, inplace=True))
def forward(self, x):
x = self.conv(x)
return x
class InBlock(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size=3, use_norm=True):
super(InBlock, self).__init__()
to_pad = int((kernel_size - 1) / 2)
if use_norm:
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size,
stride=1, padding=to_pad),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(0.2, inplace=True))
else:
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size,
stride=1, padding=to_pad),
nn.LeakyReLU(0.2, inplace=True))
def forward(self, x):
x = self.conv(x)
return x
class UpBlock(nn.Module):
def __init__(self, in_ch, out_ch, skip_ch=4, kernel_size=3, use_norm=True):
super(UpBlock, self).__init__()
to_pad = int((kernel_size - 1) / 2)
self.skip = skip_ch > 0
if skip_ch == 0:
skip_ch = 1
if use_norm:
self.conv = nn.Sequential(
nn.BatchNorm2d(in_ch + skip_ch),
nn.Conv2d(in_ch + skip_ch, out_ch, kernel_size, stride=1,
padding=to_pad),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size,
stride=1, padding=to_pad),
nn.BatchNorm2d(out_ch),
nn.LeakyReLU(0.2, inplace=True))
else:
self.conv = nn.Sequential(
nn.Conv2d(in_ch + skip_ch, out_ch, kernel_size, stride=1,
padding=to_pad),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size,
stride=1, padding=to_pad),
nn.LeakyReLU(0.2, inplace=True))
if use_norm:
self.skip_conv = nn.Sequential(
nn.Conv2d(out_ch, skip_ch, kernel_size=1, stride=1),
nn.BatchNorm2d(skip_ch),
nn.LeakyReLU(0.2, inplace=True))
else:
self.skip_conv = nn.Sequential(
nn.Conv2d(out_ch, skip_ch, kernel_size=1, stride=1),
nn.LeakyReLU(0.2, inplace=True))
self.up = nn.Upsample(scale_factor=2, mode='bilinear',
align_corners=True)
self.concat = Concat()
def forward(self, x1, x2):
x1 = self.up(x1)
x2 = self.skip_conv(x2)
if not self.skip:
x2 = x2 * 0
x = self.concat(x1, x2)
x = self.conv(x)
return x
class Concat(nn.Module):
def __init__(self):
super(Concat, self).__init__()
def forward(self, *inputs):
inputs_shapes2 = [x.shape[2] for x in inputs]
inputs_shapes3 = [x.shape[3] for x in inputs]
if (np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and
np.all(np.array(inputs_shapes3) == min(inputs_shapes3))):
inputs_ = inputs
else:
target_shape2 = min(inputs_shapes2)
target_shape3 = min(inputs_shapes3)
inputs_ = []
for inp in inputs:
diff2 = (inp.size(2) - target_shape2) // 2
diff3 = (inp.size(3) - target_shape3) // 2
inputs_.append(inp[:, :, diff2: diff2 + target_shape2,
diff3:diff3 + target_shape3])
return torch.cat(inputs_, dim=1)
class OutBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super(OutBlock, self).__init__()
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=1)
def forward(self, x):
x = self.conv(x)
return x
def __len__(self):
return len(self._modules)
class DualNet(nn.Module):
def __init__(self, N_dual):
super(DualNet, self).__init__()
self.d_modules = nn.Sequential(
nn.Conv2d(2+N_dual,32, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(32 ,32, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(32 ,5, 3, padding = 1)
)
def forward(self, h, Op_f, g):
out = self.d_modules(torch.cat((h, Op_f, g), dim=1))
return out + h
class PrimalNet(nn.Module):
def __init__(self, N_primal):
super(PrimalNet, self).__init__()
self.p_modules = nn.Sequential(
nn.Conv2d(1+N_primal,32, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(32 ,32, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(32 ,5, 3, padding = 1)
)
def forward(self, f, OpAdj_h):
out = self.p_modules(torch.cat((f, OpAdj_h), 1))
return out + f
class first_LearnedPrimalDual(nn.Module):
def __init__(self, dataset, device, I = 10, N_primal = 5, N_dual = 5):
super(first_LearnedPrimalDual, self).__init__()
self.DEVICE = device
self.I = I
self.N_primal = N_primal
self.N_dual = N_dual
self.Primal_nets = nn.ModuleList([PrimalNet(N_primal) for i in range(I)])
self.Dual_nets = nn.ModuleList([DualNet(N_dual) for i in range(I)])
self.T = op.OperatorAsModule(dataset.get_ray_trafo())
self.Tstar = op.OperatorAsModule(dataset.get_ray_trafo().adjoint)
#self.Dual_nets.to(DEVICE)
#self.Primal_nets.to(DEVICE)
def forward(self, g):
with torch.cuda.device(self.DEVICE):
h = torch.zeros(g.shape[0], self.N_dual, 1000, 513).to(self.DEVICE)
f = torch.zeros(g.shape[0], self.N_primal, 362, 362).to(self.DEVICE)
for i in range(self.I):
f_2 = f[:,1:2]
Op_f = self.T(f_2)
h = self.Dual_nets[i](h, Op_f, g)
h_1 = h[:,0:1]
OpAdj_h = self.Tstar(h_1)
f = self.Primal_nets[i](f, OpAdj_h)
return f[:,0:1]
class DualNet2(nn.Module):
def __init__(self, N_dual):
super(DualNet2, self).__init__()
self.d_modules = nn.Sequential(
nn.Conv2d(2+N_dual,32, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(32 ,32, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(32 ,64, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(64 ,64, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(64 ,3, 3, padding = 1)
)
def forward(self, h, Op_f, g):
out = self.d_modules(torch.cat((h, Op_f, g), dim=1))
return out + h
class PrimalNet2(nn.Module):
def __init__(self, N_primal):
super(PrimalNet2, self).__init__()
self.p_modules = nn.Sequential(
nn.Conv2d(1+N_primal,32, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(32 ,32, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(32 ,64, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(64 ,64, 3, padding = 1),
nn.PReLU(),
nn.Conv2d(64 ,3, 3, padding = 1)
)
def forward(self, f, OpAdj_h):
out = self.p_modules(torch.cat((f, OpAdj_h), 1))
return out + f
class second_LearnedPrimalDual(nn.Module):
def __init__(self, dataset, device, I = 5, N_primal = 3, N_dual = 3):
super(second_LearnedPrimalDual, self).__init__()
self.DEVICE = device
self.I = I
self.N_primal = N_primal
self.N_dual = N_dual
self.Primal_nets = nn.ModuleList([PrimalNet2(N_primal) for i in range(I)])
self.Dual_nets = nn.ModuleList([DualNet2(N_dual) for i in range(I)])
self.T = op.OperatorAsModule(dataset.get_ray_trafo())
self.Tstar = op.OperatorAsModule(dataset.get_ray_trafo().adjoint)
#self.Dual_nets.to(DEVICE)
#self.Primal_nets.to(DEVICE)
def forward(self, g):
with torch.cuda.device(self.DEVICE):
h = torch.zeros(g.shape[0], self.N_dual, 1000, 513).to(self.DEVICE)
f = torch.zeros(g.shape[0], self.N_primal, 362, 362).to(self.DEVICE)
for i in range(self.I):
f_2 = f[:,1:2]
Op_f = self.T(f_2)
h = self.Dual_nets[i](h, Op_f, g)
h_1 = h[:,0:1]
OpAdj_h = self.Tstar(h_1)
f = self.Primal_nets[i](f, OpAdj_h)
return f[:,0:1]
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment