Skip to content
Snippets Groups Projects
Unverified Commit c76f7ad0 authored by Julius Richter's avatar Julius Richter Committed by GitHub
Browse files

Merge pull request #29 from taltalim/add_backbone_args

Expose NCSN++ architecture args to CLI
parents 2573ecad 5790c9b7
No related branches found
No related tags found
No related merge requests found
...@@ -39,7 +39,12 @@ class NCSNpp(nn.Module): ...@@ -39,7 +39,12 @@ class NCSNpp(nn.Module):
@staticmethod @staticmethod
def add_argparse_args(parser): def add_argparse_args(parser):
# TODO: add additional arguments of constructor, if you wish to modify them. parser.add_argument("--ch_mult",type=int, nargs='+', default=[1,1,2,2,2,2,2])
parser.add_argument("--num_res_blocks", type=int, default=2)
parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[16])
parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
parser.set_defaults(centered=True)
return parser return parser
def __init__(self, def __init__(self,
...@@ -247,12 +252,6 @@ class NCSNpp(nn.Module): ...@@ -247,12 +252,6 @@ class NCSNpp(nn.Module):
self.all_modules = nn.ModuleList(modules) self.all_modules = nn.ModuleList(modules)
@staticmethod
def add_argparse_args(parser):
parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]")
parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]")
parser.set_defaults(centered=True)
return parser
def forward(self, x, time_cond): def forward(self, x, time_cond):
# timestep/noise_level embedding; only for continuous training # timestep/noise_level embedding; only for continuous training
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment