diff --git a/sgmse/data_module.py b/sgmse/data_module.py index 0eb4535ed1b7ebf62e2dd543099791207facb07b..474f5b55271bf0b3a4ae3a18bc0fe55bac905b22 100644 --- a/sgmse/data_module.py +++ b/sgmse/data_module.py @@ -26,11 +26,19 @@ class Specs(Dataset): # Read file paths according to file naming format. if format == "default": - self.clean_files = sorted(glob(join(data_dir, subset) + '/clean/*.wav')) - self.noisy_files = sorted(glob(join(data_dir, subset) + '/noisy/*.wav')) + self.clean_files = [] + self.clean_files += sorted(glob(join(data_dir, subset, "clean", "*.wav"))) + self.clean_files += sorted(glob(join(data_dir, subset, "clean", "**", "*.wav"))) + self.noisy_files = [] + self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "*.wav"))) + self.noisy_files += sorted(glob(join(data_dir, subset, "noisy", "**", "*.wav"))) elif format == "reverb": - self.clean_files = sorted(glob(join(data_dir, subset) + '/anechoic/*.wav')) - self.noisy_files = sorted(glob(join(data_dir, subset) + '/reverb/*.wav')) + self.clean_files = [] + self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "*.wav"))) + self.clean_files += sorted(glob(join(data_dir, subset, "anechoic", "**", "*.wav"))) + self.noisy_files = [] + self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "*.wav"))) + self.noisy_files += sorted(glob(join(data_dir, subset, "reverb", "**", "*.wav"))) else: # Feel free to add your own directory format raise NotImplementedError(f"Directory format {format} unknown!")