diff --git a/enhancement.py b/enhancement.py
index c479fd18b2523989335e1986ce31870b3351223f..b6aaefa3fdc5c6f1f0634e343212a389fbd25424 100644
--- a/enhancement.py
+++ b/enhancement.py
@@ -1,11 +1,11 @@
 import glob
 import torch
+from tqdm import tqdm
 from os import makedirs
-from os.path import join, dirname
-from argparse import ArgumentParser
 from soundfile import write
 from torchaudio import load
-from tqdm import tqdm
+from os.path import join, dirname
+from argparse import ArgumentParser
 
 # Set CUDA architecture list
 from sgmse.util.other import set_torch_cuda_arch_list
@@ -19,17 +19,17 @@ if __name__ == '__main__':
     parser = ArgumentParser()
     parser.add_argument("--test_dir", type=str, required=True, help='Directory containing the test data')
     parser.add_argument("--enhanced_dir", type=str, required=True, help='Directory containing the enhanced data')
-    parser.add_argument("--ckpt", type=str,  help='Path to model checkpoint.')
+    parser.add_argument("--ckpt", type=str,  help='Path to model checkpoint')
     parser.add_argument("--corrector", type=str, choices=("ald", "langevin", "none"), default="ald", help="Corrector class for the PC sampler.")
     parser.add_argument("--corrector_steps", type=int, default=1, help="Number of corrector steps")
-    parser.add_argument("--snr", type=float, default=0.5, help="SNR value for (annealed) Langevin dynmaics.")
+    parser.add_argument("--snr", type=float, default=0.5, help="SNR value for (annealed) Langevin dynmaics")
     parser.add_argument("--N", type=int, default=30, help="Number of reverse steps")
+    parser.add_argument("--device", type=str, default="cuda", help="Device to use for inference")
     args = parser.parse_args()
 
     # Load score model 
-    model = ScoreModel.load_from_checkpoint(args.ckpt, base_dir='', batch_size=16, num_workers=0, kwargs=dict(gpu=False))
-    model.eval(no_ema=False)
-    model.cuda()
+    model = ScoreModel.load_from_checkpoint(args.ckpt, map_location=args.device)
+    model.eval()
 
     # Get list of noisy files
     noisy_files = []
@@ -58,12 +58,12 @@ if __name__ == '__main__':
         y = y / norm_factor
         
         # Prepare DNN input
-        Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0)
+        Y = torch.unsqueeze(model._forward_transform(model._stft(y.to(args.device))), 0)
         Y = pad_spec(Y, mode=pad_mode)
         
         # Reverse sampling
         sampler = model.get_pc_sampler(
-            'reverse_diffusion', args.corrector, Y.cuda(), N=args.N, 
+            'reverse_diffusion', args.corrector, Y.to(args.device), N=args.N, 
             corrector_steps=args.corrector_steps, snr=args.snr)
         sample, _ = sampler()