diff --git a/sgmse/backbones/ncsnpp_48k.py b/sgmse/backbones/ncsnpp_48k.py index 1737baf0c8d518f59bc354635b415dbe48bffb20..a3beeb818e295d3efa4b86da47d4f36fc91af487 100644 --- a/sgmse/backbones/ncsnpp_48k.py +++ b/sgmse/backbones/ncsnpp_48k.py @@ -43,7 +43,6 @@ class NCSNpp_48k(nn.Module): parser.add_argument("--num_res_blocks", type=int, default=2) parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[]) parser.add_argument("--nf", type=int, default=128, help="Number of channels to use in the model") - parser.add_argument("--attn_resolutions", type=int, nargs='+', default=[]) parser.add_argument("--no-centered", dest="centered", action="store_false", help="The data is not centered [-1, 1]") parser.add_argument("--centered", dest="centered", action="store_true", help="The data is centered [-1, 1]") parser.add_argument("--progressive", type=str, default='none', help="Progressive downsampling method") @@ -81,7 +80,7 @@ class NCSNpp_48k(nn.Module): self.nf = nf = nf ch_mult = ch_mult self.num_res_blocks = num_res_blocks = num_res_blocks - self.attn_resolutions = attn_resolutions = attn_resolutions + self.attn_resolutions = attn_resolutions dropout = dropout resamp_with_conv = resamp_with_conv self.num_resolutions = num_resolutions = len(ch_mult)