Skip to content
Snippets Groups Projects
Commit 636a70e9 authored by jrichter's avatar jrichter
Browse files

resampling if sr is not target sr

parent 215325c4
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,7 @@ from soundfile import write ...@@ -6,6 +6,7 @@ from soundfile import write
from torchaudio import load from torchaudio import load
from os.path import join, dirname from os.path import join, dirname
from argparse import ArgumentParser from argparse import ArgumentParser
from librosa import resample
# Set CUDA architecture list # Set CUDA architecture list
from sgmse.util.other import set_torch_cuda_arch_list from sgmse.util.other import set_torch_cuda_arch_list
...@@ -38,10 +39,10 @@ if __name__ == '__main__': ...@@ -38,10 +39,10 @@ if __name__ == '__main__':
# Check if the model is trained on 48 kHz data # Check if the model is trained on 48 kHz data
if model.backbone == 'ncsnpp_48k': if model.backbone == 'ncsnpp_48k':
sr = 48000 target_sr = 48000
pad_mode = "reflection" pad_mode = "reflection"
else: else:
sr = 16000 target_sr = 16000
pad_mode = "zero_pad" pad_mode = "zero_pad"
# Enhance files # Enhance files
...@@ -50,7 +51,12 @@ if __name__ == '__main__': ...@@ -50,7 +51,12 @@ if __name__ == '__main__':
filename = noisy_file.replace(args.test_dir, "")[1:] # Remove the first character which is a slash filename = noisy_file.replace(args.test_dir, "")[1:] # Remove the first character which is a slash
# Load wav # Load wav
y, _ = load(noisy_file) y, sr = load(noisy_file)
# Resample if necessary
if sr != target_sr:
y = resample(y, orig_sr=sr, target_sr=target_sr)
T_orig = y.size(1) T_orig = y.size(1)
# Normalize # Normalize
...@@ -75,4 +81,4 @@ if __name__ == '__main__': ...@@ -75,4 +81,4 @@ if __name__ == '__main__':
# Write enhanced wav file # Write enhanced wav file
makedirs(dirname(join(args.enhanced_dir, filename)), exist_ok=True) makedirs(dirname(join(args.enhanced_dir, filename)), exist_ok=True)
write(join(args.enhanced_dir, filename), x_hat.cpu().numpy(), sr) write(join(args.enhanced_dir, filename), x_hat.cpu().numpy(), target_sr)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment