From 636a70e9591bc8b995018d39652c83efc9b8551b Mon Sep 17 00:00:00 2001 From: jrichter <jrichter@exchange.informatik.uni-hamburg.de> Date: Wed, 10 Jul 2024 10:28:59 +0200 Subject: [PATCH] resampling if sr is not target sr --- enhancement.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/enhancement.py b/enhancement.py index b6aaefa..4da78b8 100644 --- a/enhancement.py +++ b/enhancement.py @@ -6,6 +6,7 @@ from soundfile import write from torchaudio import load from os.path import join, dirname from argparse import ArgumentParser +from librosa import resample # Set CUDA architecture list from sgmse.util.other import set_torch_cuda_arch_list @@ -38,10 +39,10 @@ if __name__ == '__main__': # Check if the model is trained on 48 kHz data if model.backbone == 'ncsnpp_48k': - sr = 48000 + target_sr = 48000 pad_mode = "reflection" else: - sr = 16000 + target_sr = 16000 pad_mode = "zero_pad" # Enhance files @@ -50,7 +51,12 @@ if __name__ == '__main__': filename = noisy_file.replace(args.test_dir, "")[1:] # Remove the first character which is a slash # Load wav - y, _ = load(noisy_file) + y, sr = load(noisy_file) + + # Resample if necessary + if sr != target_sr: + y = resample(y, orig_sr=sr, target_sr=target_sr) + T_orig = y.size(1) # Normalize @@ -75,4 +81,4 @@ if __name__ == '__main__': # Write enhanced wav file makedirs(dirname(join(args.enhanced_dir, filename)), exist_ok=True) - write(join(args.enhanced_dir, filename), x_hat.cpu().numpy(), sr) + write(join(args.enhanced_dir, filename), x_hat.cpu().numpy(), target_sr) -- GitLab