diff --git a/README.md b/README.md index 8f360680a2068b40a898982aaacb6d375800d71b..893947e76d3a49ddfb7c98bbaab184376672ae4f 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ To see all available training options, run `python train.py --help`. Note that t **Note:** - Our journal preprint [2] uses `--backbone ncsnpp`. -- For the 48 kHz model [3], use `--backbone ncsnpp_48k --spec_factor 0.065 --spec_abs_exponent 0.667 --sigma-min 0.1 --sigma-max 1.0 --theta 2.0` +- For the 48 kHz model [3], use `--backbone ncsnpp_48k --n_fft 1534 --hop_length 384 --spec_factor 0.065 --spec_abs_exponent 0.667 --sigma-min 0.1 --sigma-max 1.0 --theta 2.0` - Our Interspeech paper [1] uses `--backbone dcunet`. You need to pass `--n_fft 512` to make it work. - Also note that the default parameters for the spectrogram transformation in this repository are slightly different from the ones listed in the first (Interspeech) paper (`--spec_factor 0.15` rather than `--spec_factor 0.333`), but we've found the value in this repository to generally perform better for both models [1] and [2]. diff --git a/sgmse/data_module.py b/sgmse/data_module.py index 0eb4535ed1b7ebf62e2dd543099791207facb07b..474f5b55271bf0b3a4ae3a18bc0fe55bac905b22 100644 --- a/sgmse/data_module.py +++ b/sgmse/data_module.py @@ -26,11 +26,19 @@ class Specs(Dataset): # Read file paths according to file naming format. if format == "default": - self.clean_files = sorted(glob(join(data_dir, subset) + '/clean/*.wav')) - self.noisy_files = sorted(glob(join(data_dir, subset) + '/noisy/*.wav')) + self.clean_files = [] + self.clean_files += sorted(glob(join(data_dir, subset, "clean", "*.wav"))) + self.clean_files += sorted(glob(join(data_dir, subset, "clean", "**", "*.wav"))) + self.noisy_files = [] + self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "*.wav"))) + self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "**", "*.wav"))) elif format == "reverb": - self.clean_files = sorted(glob(join(data_dir, subset) + '/anechoic/*.wav')) - self.noisy_files = sorted(glob(join(data_dir, subset) + '/reverb/*.wav')) + self.clean_files = [] + self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "*.wav"))) + self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "**", "*.wav"))) + self.noisy_files = [] + self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "*.wav"))) + self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "**", "*.wav"))) else: # Feel free to add your own directory format raise NotImplementedError(f"Directory format {format} unknown!") diff --git a/sgmse/model.py b/sgmse/model.py index a97a56680b44a397827ad21afecfdf57a6ad4b0c..3c4bf8b494f7cbaf3f5f0a61d9e27e4a4066098f 100644 --- a/sgmse/model.py +++ b/sgmse/model.py @@ -18,13 +18,13 @@ class ScoreModel(pl.LightningModule): def add_argparse_args(parser): parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)") parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)") - parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum time (3e-2 by default)") + parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)") parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).") parser.add_argument("--loss_type", type=str, default="mse", choices=("mse", "mae"), help="The type of loss function to use.") return parser def __init__( - self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=3e-2, + self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03, num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs ): """ diff --git a/train.py b/train.py index 56266debe9a3362a7d50e0d3e362e492a6acbb86..cd545bb0fabf983a020e4b62ec9fba7f8ece808f 100644 --- a/train.py +++ b/train.py @@ -1,11 +1,15 @@ +import wandb import argparse -from argparse import ArgumentParser - import pytorch_lightning as pl + +from argparse import ArgumentParser from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.callbacks import ModelCheckpoint +from os.path import join -import wandb +# Set CUDA architecture list +from sgmse.util.other import set_torch_cuda_arch_list +set_torch_cuda_arch_list() from sgmse.backbones.shared import BackboneRegistry from sgmse.data_module import SpecsDataModule @@ -31,6 +35,7 @@ if __name__ == '__main__': parser_.add_argument("--nolog", action='store_true', help="Turn off logging.") parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.") parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.") + parser_.add_argument("--log_dir", type=str, default="logs", help="Directory to save logs.") temp_args, _ = base_parser.parse_known_args() @@ -76,11 +81,11 @@ if __name__ == '__main__': # Set up callbacks for logger if logger != None: - callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')] + callbacks = [ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_last=True, filename='{epoch}-last')] if args.num_eval_files: - checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}", + checkpoint_callback_pesq = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}') - checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}", + checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}') callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr] else: