From bcdd2104bdf8ab138fa860a92d87b68387b050ec Mon Sep 17 00:00:00 2001
From: jrichter <jrichter@exchange.informatik.uni-hamburg.de>
Date: Tue, 18 Jun 2024 08:19:03 +0200
Subject: [PATCH] Revert "update README.md"

This reverts commit a1e20b3c9b18287b31d47dc728b6fee1c84796fd.
---
 README.md            |  2 +-
 sgmse/data_module.py | 16 ++++------------
 sgmse/model.py       |  4 ++--
 train.py             | 17 ++++++-----------
 4 files changed, 13 insertions(+), 26 deletions(-)

diff --git a/README.md b/README.md
index 893947e..8f36068 100644
--- a/README.md
+++ b/README.md
@@ -54,7 +54,7 @@ To see all available training options, run `python train.py --help`. Note that t
 
 **Note:**
 - Our journal preprint [2] uses `--backbone ncsnpp`.
-- For the 48 kHz model [3], use `--backbone ncsnpp_48k --n_fft 1534 --hop_length 384 --spec_factor 0.065 --spec_abs_exponent 0.667 --sigma-min 0.1 --sigma-max 1.0 --theta 2.0`
+- For the 48 kHz model [3], use `--backbone ncsnpp_48k --spec_factor 0.065 --spec_abs_exponent 0.667 --sigma-min 0.1 --sigma-max 1.0 --theta 2.0`
 - Our Interspeech paper [1] uses `--backbone dcunet`. You need to pass `--n_fft 512` to make it work.
     - Also note that the default parameters for the spectrogram transformation in this repository are slightly different from the ones listed in the first (Interspeech) paper (`--spec_factor 0.15` rather than `--spec_factor 0.333`), but we've found the value in this repository to generally perform better for both models [1] and [2].
 
diff --git a/sgmse/data_module.py b/sgmse/data_module.py
index 474f5b5..0eb4535 100644
--- a/sgmse/data_module.py
+++ b/sgmse/data_module.py
@@ -26,19 +26,11 @@ class Specs(Dataset):
 
         # Read file paths according to file naming format.
         if format == "default":
-            self.clean_files = []
-            self.clean_files += sorted(glob(join(data_dir, subset, "clean", "*.wav")))
-            self.clean_files += sorted(glob(join(data_dir, subset, "clean", "**", "*.wav")))
-            self.noisy_files = []
-            self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "*.wav")))
-            self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "**", "*.wav")))
+            self.clean_files = sorted(glob(join(data_dir, subset) + '/clean/*.wav'))
+            self.noisy_files = sorted(glob(join(data_dir, subset) + '/noisy/*.wav'))
         elif format == "reverb":
-            self.clean_files = []
-            self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "*.wav")))
-            self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "**", "*.wav")))
-            self.noisy_files = []
-            self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "*.wav")))
-            self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "**", "*.wav")))
+            self.clean_files = sorted(glob(join(data_dir, subset) + '/anechoic/*.wav'))
+            self.noisy_files = sorted(glob(join(data_dir, subset) + '/reverb/*.wav'))
         else:
             # Feel free to add your own directory format
             raise NotImplementedError(f"Directory format {format} unknown!")
diff --git a/sgmse/model.py b/sgmse/model.py
index 3c4bf8b..a97a566 100644
--- a/sgmse/model.py
+++ b/sgmse/model.py
@@ -18,13 +18,13 @@ class ScoreModel(pl.LightningModule):
     def add_argparse_args(parser):
         parser.add_argument("--lr", type=float, default=1e-4, help="The learning rate (1e-4 by default)")
         parser.add_argument("--ema_decay", type=float, default=0.999, help="The parameter EMA decay constant (0.999 by default)")
-        parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum process time (0.03 by default)")
+        parser.add_argument("--t_eps", type=float, default=0.03, help="The minimum time (3e-2 by default)")
         parser.add_argument("--num_eval_files", type=int, default=20, help="Number of files for speech enhancement performance evaluation during training. Pass 0 to turn off (no checkpoints based on evaluation metrics will be generated).")
         parser.add_argument("--loss_type", type=str, default="mse", choices=("mse", "mae"), help="The type of loss function to use.")
         return parser
 
     def __init__(
-        self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=0.03,
+        self, backbone, sde, lr=1e-4, ema_decay=0.999, t_eps=3e-2,
         num_eval_files=20, loss_type='mse', data_module_cls=None, **kwargs
     ):
         """
diff --git a/train.py b/train.py
index cd545bb..56266de 100644
--- a/train.py
+++ b/train.py
@@ -1,15 +1,11 @@
-import wandb
 import argparse
-import pytorch_lightning as pl
-
 from argparse import ArgumentParser
+
+import pytorch_lightning as pl
 from pytorch_lightning.loggers import WandbLogger
 from pytorch_lightning.callbacks import ModelCheckpoint
-from os.path import join
 
-# Set CUDA architecture list
-from sgmse.util.other import set_torch_cuda_arch_list
-set_torch_cuda_arch_list()
+import wandb
 
 from sgmse.backbones.shared import BackboneRegistry
 from sgmse.data_module import SpecsDataModule
@@ -35,7 +31,6 @@ if __name__ == '__main__':
           parser_.add_argument("--nolog", action='store_true', help="Turn off logging.")
           parser_.add_argument("--wandb_name", type=str, default=None, help="Name for wandb logger. If not set, a random name is generated.")
           parser_.add_argument("--ckpt", type=str, default=None, help="Resume training from checkpoint.")
-          parser_.add_argument("--log_dir", type=str, default="logs", help="Directory to save logs.")
           
      temp_args, _ = base_parser.parse_known_args()
 
@@ -81,11 +76,11 @@ if __name__ == '__main__':
 
      # Set up callbacks for logger
      if logger != None:
-          callbacks = [ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), save_last=True, filename='{epoch}-last')]
+          callbacks = [ModelCheckpoint(dirpath=f"logs/{logger.version}", save_last=True, filename='{epoch}-last')]
           if args.num_eval_files:
-               checkpoint_callback_pesq = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), 
+               checkpoint_callback_pesq = ModelCheckpoint(dirpath=f"logs/{logger.version}", 
                     save_top_k=2, monitor="pesq", mode="max", filename='{epoch}-{pesq:.2f}')
-               checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=join(args.log_dir, str(logger.version)), 
+               checkpoint_callback_si_sdr = ModelCheckpoint(dirpath=f"logs/{logger.version}", 
                     save_top_k=2, monitor="si_sdr", mode="max", filename='{epoch}-{si_sdr:.2f}')
                callbacks += [checkpoint_callback_pesq, checkpoint_callback_si_sdr]
      else:
-- 
GitLab