diff --git a/sgmse/backbones/ncsnpp.py b/sgmse/backbones/ncsnpp.py index c5be09e539b6e8bee3da474c397b643b11cbefdb..f5c810e7ec1a20ccbc20a61600d867ab8b1e7b7f 100644 --- a/sgmse/backbones/ncsnpp.py +++ b/sgmse/backbones/ncsnpp.py @@ -39,7 +39,12 @@ class NCSNpp(nn.Module): @staticmethod def add_argparse_args(parser): - # TODO: add additional arguments of constructor, if you wish to modify them. + parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2]) + parser.add_argument("--num_res_blocks", type=int, default=2) + parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[16]) + parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]") + parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]") + parser.set_defaults(centered=True) return parser def __init__(self, @@ -246,13 +251,7 @@ class NCSNpp(nn.Module): modules.append(conv3x3(in_ch, channels, init_scale=init_scale)) self.all_modules = nn.ModuleList(modules) - - @staticmethod - def add_argparse_args(parser): - parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]") - parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]") - parser.set_defaults(centered=True) - return parser + def forward(self, x, time_cond): # timestep/noise_level embedding; only for continuous training