From a1e20b3c9b18287b31d47dc728b6fee1c84796fd Mon Sep 17 00:00:00 2001
From: jrichter <jrichter@exchange.informatik.uni-hamburg.de>
Date: Tue, 18 Jun 2024 08:15:53 +0200
Subject: [PATCH] update README.md

---
 README.md            |  2 +-
 sgmse/data_module.py | 16 ++++++++++++----
 sgmse/model.py       |  4 ++--
 train.py             | 17 +++++++++++------
 4 files changed, 26 insertions(+), 13 deletions(-)

diff --git a/README.md b/README.md
index 8f36068..893947e 100644
--- a/README.md
+++ b/README.md
@@ -54,7 +54,7 @@ To see all available training options, run `python train.py --help`. Note that t
 
 **Note:**
 - Our journal preprint [2] uses `--backbone ncsnpp`.
-- For the 48 kHz model [3], use `--backbone ncsnpp_48k --spec_factor 0.065 --spec_abs_exponent 0.667 --sigma-min 0.1 --sigma-max 1.0 --theta 2.0`
+- For the 48 kHz model [3], use `--backbone ncsnpp_48k --n_fft 1534 --hop_length 384 --spec_factor 0.065 --spec_abs_exponent 0.667 --sigma-min 0.1 --sigma-max 1.0 --theta 2.0`
 - Our Interspeech paper [1] uses `--backbone dcunet`. You need to pass `--n_fft 512` to make it work.
     - Also note that the default parameters for the spectrogram transformation in this repository are slightly different from the ones listed in the first (Interspeech) paper (`--spec_factor 0.15` rather than `--spec_factor 0.333`), but we've found the value in this repository to generally perform better for both models [1] and [2].
 
diff --git a/sgmse/data_module.py b/sgmse/data_module.py
index 0eb4535..474f5b5 100644
--- a/sgmse/data_module.py
+++ b/sgmse/data_module.py
@@ -26,11 +26,19 @@ class Specs(Dataset):
 
         # Read file paths according to file naming format.
         if format == "default":
-            self.clean_files = sorted(glob(join(data_dir, subset) + '/clean/*.wav'))
-            self.noisy_files = sorted(glob(join(data_dir, subset) + '/noisy/*.wav'))
+            self.clean_files = []
+            self.clean_files += sorted(glob(join(data_dir, subset, "clean", "*.wav")))
+            self.clean_files += sorted(glob(join(data_dir, subset, "clean", "**", "*.wav")))
+            self.noisy_files = []
+            self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "*.wav")))
+            self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "**", "*.wav")))
         elif format == "reverb":
-            self.clean_files = sorted(glob(join(data_dir, subset) + '/anechoic/*.wav'))
-            self.noisy_files = sorted(glob(join(data_dir, subset) + '/reverb/*.wav'))
+            self.clean_files = []
+            self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "*.wav")))
+            self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "**", "*.wav")))
+            self.noisy_files = []
+            self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "*.wav")))
+            self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "**", "*.wav")))
         else:
             # Feel free to add your own directory format
             raise NotImplementedError(f"Directory format {format} unknown!")
diff --git a/sgmse/model.py b/sgmse/model.py
index a97a566..3c4bf8b 100644
--- a/sgmse/model.py
+++ b/sgmse/model.py
@@ -18,13 +18,13 @@ class ScoreModel(pl.LightningModule):
     def add_argparse_args(parser):
         parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)")
         parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
-        parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum time (3e-2 by default)")
+        parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)")
         parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
         parser.add_argument("--loss_type", type=str, default="mse", choices=("mse", "mae"), help="The type of loss function to use.")
         return parser
 
     def __init__(
-        self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=3e-2,
+        self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03,
         num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs
     ):
         """
diff --git a/train.py b/train.py
index 56266de..cd545bb 100644
--- a/train.py
+++ b/train.py
@@ -1,11 +1,15 @@
+import wandb
 import argparse
-from argparse import ArgumentParser
-
 import pytorch_lightning as pl
+
+from argparse import ArgumentParser
 from pytorch_lightning.loggers import WandbLogger
 from pytorch_lightning.callbacks import ModelCheckpoint
+from os.path import join
 
-import wandb
+# Set CUDA architecture list
+from sgmse.util.other import set_torch_cuda_arch_list
+set_torch_cuda_arch_list()
 
 from sgmse.backbones.shared import BackboneRegistry
 from sgmse.data_module import SpecsDataModule
@@ -31,6 +35,7 @@ if __name__ == '__main__':
           parser_.add_argument("--nolog", action='store_true', help="Turn off logging.")
           parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.")
           parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.")
+          parser_.add_argument("--log_dir", type=str, default="logs", help="Directory to save logs.")
           
      temp_args, _ = base_parser.parse_known_args()
 
@@ -76,11 +81,11 @@ if __name__ == '__main__':
 
      # Set up callbacks for logger
      if logger != None:
-          callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')]
+          callbacks = [ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_last=True, filename='{epoch}-last')]
           if args.num_eval_files:
-               checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}", 
+               checkpoint_callback_pesq = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), 
                     save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}')
-               checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}", 
+               checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), 
                     save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}')
                callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr]
      else:
-- 
GitLab