Skip to content
Snippets Groups Projects
Commit b5015c27 authored by Julius Richter's avatar Julius Richter
Browse files

add --device argument for running on cpu

parent a18cdb7b
No related branches found
No related tags found
No related merge requests found
import glob import glob
import torch import torch
from tqdm import tqdm
from os import makedirs from os import makedirs
from os.path import join, dirname
from argparse import ArgumentParser
from soundfile import write from soundfile import write
from torchaudio import load from torchaudio import load
from tqdm import tqdm from os.path import join, dirname
from argparse import ArgumentParser
# Set CUDA architecture list # Set CUDA architecture list
from sgmse.util.other import set_torch_cuda_arch_list from sgmse.util.other import set_torch_cuda_arch_list
...@@ -19,17 +19,17 @@ if __name__ == '__main__': ...@@ -19,17 +19,17 @@ if __name__ == '__main__':
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--test_dir", type=str, required=True, help='Directory containing the test data') parser.add_argument("--test_dir", type=str, required=True, help='Directory containing the test data')
parser.add_argument("--enhanced_dir", type=str, required=True, help='Directory containing the enhanced data') parser.add_argument("--enhanced_dir", type=str, required=True, help='Directory containing the enhanced data')
parser.add_argument("--ckpt", type=str, help='Path to model checkpoint.') parser.add_argument("--ckpt", type=str, help='Path to model checkpoint')
parser.add_argument("--corrector", type=str, choices=("ald", "langevin", "none"), default="ald", help="Corrector class for the PC sampler.") parser.add_argument("--corrector", type=str, choices=("ald", "langevin", "none"), default="ald", help="Corrector class for the PC sampler.")
parser.add_argument("--corrector_steps", type=int, default=1, help="Number of corrector steps") parser.add_argument("--corrector_steps", type=int, default=1, help="Number of corrector steps")
parser.add_argument("--snr", type=float, default=0.5, help="SNR value for (annealed) Langevin dynmaics.") parser.add_argument("--snr", type=float, default=0.5, help="SNR value for (annealed) Langevin dynmaics")
parser.add_argument("--N", type=int, default=30, help="Number of reverse steps") parser.add_argument("--N", type=int, default=30, help="Number of reverse steps")
parser.add_argument("--device", type=str, default="cuda", help="Device to use for inference")
args = parser.parse_args() args = parser.parse_args()
# Load score model # Load score model
model = ScoreModel.load_from_checkpoint(args.ckpt, base_dir='', batch_size=16, num_workers=0, kwargs=dict(gpu=False)) model = ScoreModel.load_from_checkpoint(args.ckpt, map_location=args.device)
model.eval(no_ema=False) model.eval()
model.cuda()
# Get list of noisy files # Get list of noisy files
noisy_files = [] noisy_files = []
...@@ -58,12 +58,12 @@ if __name__ == '__main__': ...@@ -58,12 +58,12 @@ if __name__ == '__main__':
y = y / norm_factor y = y / norm_factor
# Prepare DNN input # Prepare DNN input
Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0) Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args.device))), 0)
Y = pad_spec(Y, mode=pad_mode) Y = pad_spec(Y, mode=pad_mode)
# Reverse sampling # Reverse sampling
sampler = model.get_pc_sampler( sampler = model.get_pc_sampler(
'reverse_diffusion', args.corrector, Y.cuda(), N=args.N, 'reverse_diffusion', args.corrector, Y.to(args.device), N=args.N,
corrector_steps=args.corrector_steps, snr=args.snr) corrector_steps=args.corrector_steps, snr=args.snr)
sample, _ = sampler() sample, _ = sampler()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment