From bcdd2104bdf8ab138fa860a92d87b68387b050ec Mon Sep 17 00:00:00 2001 From: jrichter <jrichter@exchange.informatik.uni-hamburg.de> Date: Tue, 18 Jun 2024 08:19:03 +0200 Subject: [PATCH] Revert "update README.md" This reverts commit a1e20b3c9b18287b31d47dc728b6fee1c84796fd. --- README.md | 2 +- sgmse/data_module.py | 16 ++++------------ sgmse/model.py | 4 ++-- train.py | 17 ++++++----------- 4 files changed, 13 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 893947e..8f36068 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ To see all available training options, run `python train.py --help`. Note that t **Note:** - Our journal preprint [2] uses `--backbone ncsnpp`. -- For the 48 kHz model [3], use `--backbone ncsnpp_48k --n_fft 1534 --hop_length 384 --spec_factor 0.065 --spec_abs_exponent 0.667 --sigma-min 0.1 --sigma-max 1.0 --theta 2.0` +- For the 48 kHz model [3], use `--backbone ncsnpp_48k --spec_factor 0.065 --spec_abs_exponent 0.667 --sigma-min 0.1 --sigma-max 1.0 --theta 2.0` - Our Interspeech paper [1] uses `--backbone dcunet`. You need to pass `--n_fft 512` to make it work. - Also note that the default parameters for the spectrogram transformation in this repository are slightly different from the ones listed in the first (Interspeech) paper (`--spec_factor 0.15` rather than `--spec_factor 0.333`), but we've found the value in this repository to generally perform better for both models [1] and [2]. diff --git a/sgmse/data_module.py b/sgmse/data_module.py index 474f5b5..0eb4535 100644 --- a/sgmse/data_module.py +++ b/sgmse/data_module.py @@ -26,19 +26,11 @@ class Specs(Dataset): # Read file paths according to file naming format. if format == "default": - self.clean_files = [] - self.clean_files += sorted(glob(join(data_dir, subset, "clean", "*.wav"))) - self.clean_files += sorted(glob(join(data_dir, subset, "clean", "**", "*.wav"))) - self.noisy_files = [] - self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "*.wav"))) - self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "**", "*.wav"))) + self.clean_files = sorted(glob(join(data_dir, subset) + '/clean/*.wav')) + self.noisy_files = sorted(glob(join(data_dir, subset) + '/noisy/*.wav')) elif format == "reverb": - self.clean_files = [] - self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "*.wav"))) - self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "**", "*.wav"))) - self.noisy_files = [] - self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "*.wav"))) - self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "**", "*.wav"))) + self.clean_files = sorted(glob(join(data_dir, subset) + '/anechoic/*.wav')) + self.noisy_files = sorted(glob(join(data_dir, subset) + '/reverb/*.wav')) else: # Feel free to add your own directory format raise NotImplementedError(f"Directory format {format} unknown!") diff --git a/sgmse/model.py b/sgmse/model.py index 3c4bf8b..a97a566 100644 --- a/sgmse/model.py +++ b/sgmse/model.py @@ -18,13 +18,13 @@ class ScoreModel(pl.LightningModule): def add_argparse_args(parser): parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)") parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)") - parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)") + parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum time (3e-2 by default)") parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).") parser.add_argument("--loss_type", type=str, default="mse", choices=("mse", "mae"), help="The type of loss function to use.") return parser def __init__( - self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03, + self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=3e-2, num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs ): """ diff --git a/train.py b/train.py index cd545bb..56266de 100644 --- a/train.py +++ b/train.py @@ -1,15 +1,11 @@ -import wandb import argparse -import pytorch_lightning as pl - from argparse import ArgumentParser + +import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.callbacks import ModelCheckpoint -from os.path import join -# Set CUDA architecture list -from sgmse.util.other import set_torch_cuda_arch_list -set_torch_cuda_arch_list() +import wandb from sgmse.backbones.shared import BackboneRegistry from sgmse.data_module import SpecsDataModule @@ -35,7 +31,6 @@ if __name__ == '__main__': parser_.add_argument("--nolog", action='store_true', help="Turn off logging.") parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.") parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.") - parser_.add_argument("--log_dir", type=str, default="logs", help="Directory to save logs.") temp_args, _ = base_parser.parse_known_args() @@ -81,11 +76,11 @@ if __name__ == '__main__': # Set up callbacks for logger if logger != None: - callbacks = [ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_last=True, filename='{epoch}-last')] + callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')] if args.num_eval_files: - checkpoint_callback_pesq = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), + checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}", save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}') - checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), + checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}", save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}') callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr] else: -- GitLab