Skip to content
Snippets Groups Projects
Commit 9ea06796 authored by jrichter's avatar jrichter
Browse files

add file naming format for WSJ0_REVERB

parent c76f7ad0
No related branches found
No related tags found
No related merge requests found
......@@ -28,6 +28,9 @@ class Specs(Dataset):
if format == "default":
self.clean_files = sorted(glob(join(data_dir, subset) + '/clean/*.wav'))
self.noisy_files = sorted(glob(join(data_dir, subset) + '/noisy/*.wav'))
elif format == "reverb":
self.clean_files = sorted(glob(join(data_dir, subset) + '/anechoic/*.wav'))
self.noisy_files = sorted(glob(join(data_dir, subset) + '/reverb/*.wav'))
else:
# Feel free to add your own directory format
raise NotImplementedError(f"Directory format {format} unknown!")
......@@ -93,7 +96,7 @@ class SpecsDataModule(pl.LightningDataModule):
@staticmethod
def add_argparse_args(parser):
parser.add_argument("--base_dir", type=str, required=True, help="The base directory of the dataset. Should contain `train`, `valid` and `test` subdirectories, each of which contain `clean` and `noisy` subdirectories.")
parser.add_argument("--format", type=str, choices=("default", "dns"), default="default", help="Read file paths according to file naming format.")
parser.add_argument("--format", type=str, choices=("default", "reverb"), default="default", help="Read file paths according to file naming format.")
parser.add_argument("--batch_size", type=int, default=8, help="The batch size. 8 by default.")
parser.add_argument("--n_fft", type=int, default=510, help="Number of FFT bins. 510 by default.") # to assure 256 freq bins
parser.add_argument("--hop_length", type=int, default=128, help="Window hop length. 128 by default.")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment