Skip to content
Snippets Groups Projects
Commit bcdd2104 authored by jrichter's avatar jrichter
Browse files

Revert "update README.md"

This reverts commit a1e20b3c.
parent a1e20b3c
No related branches found
No related tags found
No related merge requests found
......@@ -54,7 +54,7 @@ To see all available training options, run `python train.py --help`. Note that t
**Note:**
- Our journal preprint [2] uses `--backbone ncsnpp`.
- For the 48 kHz model [3], use `--backbone ncsnpp_48k --n_fft 1534 --hop_length 384 --spec_factor 0.065 --spec_abs_exponent 0.667 --sigma-min 0.1 --sigma-max 1.0 --theta 2.0`
- For the 48 kHz model [3], use `--backbone ncsnpp_48k --spec_factor 0.065 --spec_abs_exponent 0.667 --sigma-min 0.1 --sigma-max 1.0 --theta 2.0`
- Our Interspeech paper [1] uses `--backbone dcunet`. You need to pass `--n_fft 512` to make it work.
- Also note that the default parameters for the spectrogram transformation in this repository are slightly different from the ones listed in the first (Interspeech) paper (`--spec_factor 0.15` rather than `--spec_factor 0.333`), but we've found the value in this repository to generally perform better for both models [1] and [2].
......
......@@ -26,19 +26,11 @@ class Specs(Dataset):
# Read file paths according to file naming format.
if format == "default":
self.clean_files = []
self.clean_files += sorted(glob(join(data_dir, subset, "clean", "*.wav")))
self.clean_files += sorted(glob(join(data_dir, subset, "clean", "**", "*.wav")))
self.noisy_files = []
self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "*.wav")))
self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "**", "*.wav")))
self.clean_files = sorted(glob(join(data_dir, subset) + '/clean/*.wav'))
self.noisy_files = sorted(glob(join(data_dir, subset) + '/noisy/*.wav'))
elif format == "reverb":
self.clean_files = []
self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "*.wav")))
self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "**", "*.wav")))
self.noisy_files = []
self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "*.wav")))
self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "**", "*.wav")))
self.clean_files = sorted(glob(join(data_dir, subset) + '/anechoic/*.wav'))
self.noisy_files = sorted(glob(join(data_dir, subset) + '/reverb/*.wav'))
else:
# Feel free to add your own directory format
raise NotImplementedError(f"Directory format {format} unknown!")
......
......@@ -18,13 +18,13 @@ class ScoreModel(pl.LightningModule):
def add_argparse_args(parser):
parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)")
parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)")
parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum time (3e-2 by default)")
parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
parser.add_argument("--loss_type", type=str, default="mse", choices=("mse", "mae"), help="The type of loss function to use.")
return parser
def __init__(
self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03,
self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=3e-2,
num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs
):
"""
......
import wandb
import argparse
import pytorch_lightning as pl
from argparse import ArgumentParser
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from os.path import join
# Set CUDA architecture list
from sgmse.util.other import set_torch_cuda_arch_list
set_torch_cuda_arch_list()
import wandb
from sgmse.backbones.shared import BackboneRegistry
from sgmse.data_module import SpecsDataModule
......@@ -35,7 +31,6 @@ if __name__ == '__main__':
parser_.add_argument("--nolog", action='store_true', help="Turn off logging.")
parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.")
parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.")
parser_.add_argument("--log_dir", type=str, default="logs", help="Directory to save logs.")
temp_args, _ = base_parser.parse_known_args()
......@@ -81,11 +76,11 @@ if __name__ == '__main__':
# Set up callbacks for logger
if logger != None:
callbacks = [ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_last=True, filename='{epoch}-last')]
callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')]
if args.num_eval_files:
checkpoint_callback_pesq = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)),
checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}",
save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}')
checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)),
checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}",
save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}')
callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr]
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment