diff --git a/sgmse/model.py b/sgmse/model.py index a97a56680b44a397827ad21afecfdf57a6ad4b0c..3c4bf8b494f7cbaf3f5f0a61d9e27e4a4066098f 100644 --- a/sgmse/model.py +++ b/sgmse/model.py @@ -18,13 +18,13 @@ class ScoreModel(pl.LightningModule): def add_argparse_args(parser): parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)") parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)") - parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum time (3e-2 by default)") + parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)") parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).") parser.add_argument("--loss_type", type=str, default="mse", choices=("mse", "mae"), help="The type of loss function to use.") return parser def __init__( - self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=3e-2, + self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03, num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs ): """