diff --git a/enhancement.py b/enhancement.py index b6aaefa3fdc5c6f1f0634e343212a389fbd25424..4da78b809a9e8e094e7e6041c7649b7ea6617247 100644 --- a/enhancement.py +++ b/enhancement.py @@ -6,6 +6,7 @@ from soundfile import write from torchaudio import load from os.path import join, dirname from argparse import ArgumentParser +from librosa import resample # Set CUDA architecture list from sgmse.util.other import set_torch_cuda_arch_list @@ -38,10 +39,10 @@ if __name__ == '__main__': # Check if the model is trained on 48 kHz data if model.backbone == 'ncsnpp_48k': - sr = 48000 + target_sr = 48000 pad_mode = "reflection" else: - sr = 16000 + target_sr = 16000 pad_mode = "zero_pad" # Enhance files @@ -50,7 +51,12 @@ if __name__ == '__main__': filename = noisy_file.replace(args.test_dir, "")[1:] # Remove the first character which is a slash # Load wav - y, _ = load(noisy_file) + y, sr = load(noisy_file) + + # Resample if necessary + if sr != target_sr: + y = resample(y, orig_sr=sr, target_sr=target_sr) + T_orig = y.size(1) # Normalize @@ -75,4 +81,4 @@ if __name__ == '__main__': # Write enhanced wav file makedirs(dirname(join(args.enhanced_dir, filename)), exist_ok=True) - write(join(args.enhanced_dir, filename), x_hat.cpu().numpy(), sr) + write(join(args.enhanced_dir, filename), x_hat.cpu().numpy(), target_sr)