diff --git a/sgmse/util/other.py b/sgmse/util/other.py index fbe5052f348def7eb05ed4682cfed3700f8e2ab0..83ef70e994a90cbb7dc2444bc060a37f84b77651 100644 --- a/sgmse/util/other.py +++ b/sgmse/util/other.py @@ -1,19 +1,14 @@ import os - +import torch import numpy as np - import scipy.stats from scipy.signal import butter, sosfilt -import torch - from pesq import pesq from pystoi import stoi def si_sdr_components(s_hat, s, n): - """ - """ # s_target alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2 s_target = alpha_s * s @@ -28,8 +23,6 @@ def si_sdr_components(s_hat, s, n): return s_target, e_noise, e_art def energy_ratios(s_hat, s, n): - """ - """ s_target, e_noise, e_art = si_sdr_components(s_hat, s, n) si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2) @@ -129,3 +122,20 @@ def print_mean_std(data, decimal=2): elif decimal == 1: string = f'{mean:.1f} ± {std:.1f}' return string + +def set_torch_cuda_arch_list(): + if not torch.cuda.is_available(): + print("CUDA is not available. No GPUs found.") + return + + num_gpus = torch.cuda.device_count() + compute_capabilities = [] + + for i in range(num_gpus): + cc_major, cc_minor = torch.cuda.get_device_capability(i) + cc = f"{cc_major}.{cc_minor}" + compute_capabilities.append(cc) + + cc_string = ";".join(compute_capabilities) + os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string + print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}") \ No newline at end of file