diff --git a/enhancement.py b/enhancement.py index c479fd18b2523989335e1986ce31870b3351223f..b6aaefa3fdc5c6f1f0634e343212a389fbd25424 100644 --- a/enhancement.py +++ b/enhancement.py @@ -1,11 +1,11 @@ import glob import torch +from tqdm import tqdm from os import makedirs -from os.path import join, dirname -from argparse import ArgumentParser from soundfile import write from torchaudio import load -from tqdm import tqdm +from os.path import join, dirname +from argparse import ArgumentParser # Set CUDA architecture list from sgmse.util.other import set_torch_cuda_arch_list @@ -19,17 +19,17 @@ if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("--test_dir", type=str, required=True, help='Directory containing the test data') parser.add_argument("--enhanced_dir", type=str, required=True, help='Directory containing the enhanced data') - parser.add_argument("--ckpt", type=str, help='Path to model checkpoint.') + parser.add_argument("--ckpt", type=str, help='Path to model checkpoint') parser.add_argument("--corrector", type=str, choices=("ald", "langevin", "none"), default="ald", help="Corrector class for the PC sampler.") parser.add_argument("--corrector_steps", type=int, default=1, help="Number of corrector steps") - parser.add_argument("--snr", type=float, default=0.5, help="SNR value for (annealed) Langevin dynmaics.") + parser.add_argument("--snr", type=float, default=0.5, help="SNR value for (annealed) Langevin dynmaics") parser.add_argument("--N", type=int, default=30, help="Number of reverse steps") + parser.add_argument("--device", type=str, default="cuda", help="Device to use for inference") args = parser.parse_args() # Load score model - model = ScoreModel.load_from_checkpoint(args.ckpt, base_dir='', batch_size=16, num_workers=0, kwargs=dict(gpu=False)) - model.eval(no_ema=False) - model.cuda() + model = ScoreModel.load_from_checkpoint(args.ckpt, map_location=args.device) + model.eval() # Get list of noisy files noisy_files = [] @@ -58,12 +58,12 @@ if __name__ == '__main__': y = y / norm_factor # Prepare DNN input - Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0) + Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args.device))), 0) Y = pad_spec(Y, mode=pad_mode) # Reverse sampling sampler = model.get_pc_sampler( - 'reverse_diffusion', args.corrector, Y.cuda(), N=args.N, + 'reverse_diffusion', args.corrector, Y.to(args.device), N=args.N, corrector_steps=args.corrector_steps, snr=args.snr) sample, _ = sampler()