diff --git a/sgmse/data_module.py b/sgmse/data_module.py index e1c1c6c483f41066a4786c9c3e313e957bd33f19..0eb4535ed1b7ebf62e2dd543099791207facb07b 100644 --- a/sgmse/data_module.py +++ b/sgmse/data_module.py @@ -28,6 +28,9 @@ class Specs(Dataset): if format == "default": self.clean_files = sorted(glob(join(data_dir, subset) + '/clean/*.wav')) self.noisy_files = sorted(glob(join(data_dir, subset) + '/noisy/*.wav')) + elif format == "reverb": + self.clean_files = sorted(glob(join(data_dir, subset) + '/anechoic/*.wav')) + self.noisy_files = sorted(glob(join(data_dir, subset) + '/reverb/*.wav')) else: # Feel free to add your own directory format raise NotImplementedError(f"Directory format {format} unknown!") @@ -93,7 +96,7 @@ class SpecsDataModule(pl.LightningDataModule): @staticmethod def add_argparse_args(parser): parser.add_argument("--base_dir", type=str, required=True, help="The base directory of the dataset. Should contain `train`, `valid` and `test` subdirectories, each of which contain `clean` and `noisy` subdirectories.") - parser.add_argument("--format", type=str, choices=("default", "dns"), default="default", help="Read file paths according to file naming format.") + parser.add_argument("--format", type=str, choices=("default", "reverb"), default="default", help="Read file paths according to file naming format.") parser.add_argument("--batch_size", type=int, default=8, help="The batch size. 8 by default.") parser.add_argument("--n_fft", type=int, default=510, help="Number of FFT bins. 510 by default.") # to assure 256 freq bins parser.add_argument("--hop_length", type=int, default=128, help="Window hop length. 128 by default.")