From 4b5970c299b43df85c9b955856e47924fd306dd9 Mon Sep 17 00:00:00 2001 From: jrichter <jrichter@exchange.informatik.uni-hamburg.de> Date: Sun, 16 Jun 2024 14:21:52 +0200 Subject: [PATCH] add set_torch_cuda_arch_list() --- sgmse/util/other.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/sgmse/util/other.py b/sgmse/util/other.py index fbe5052..83ef70e 100644 --- a/sgmse/util/other.py +++ b/sgmse/util/other.py @@ -1,19 +1,14 @@ import os - +import torch import numpy as np - import scipy.stats from scipy.signal import butter, sosfilt -import torch - from pesq import pesq from pystoi import stoi def si_sdr_components(s_hat, s, n): - """ - """ # s_target alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2 s_target = alpha_s * s @@ -28,8 +23,6 @@ def si_sdr_components(s_hat, s, n): return s_target, e_noise, e_art def energy_ratios(s_hat, s, n): - """ - """ s_target, e_noise, e_art = si_sdr_components(s_hat, s, n) si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2) @@ -129,3 +122,20 @@ def print_mean_std(data, decimal=2): elif decimal == 1: string = f'{mean:.1f} ± {std:.1f}' return string + +def set_torch_cuda_arch_list(): + if not torch.cuda.is_available(): + print("CUDA is not available. No GPUs found.") + return + + num_gpus = torch.cuda.device_count() + compute_capabilities = [] + + for i in range(num_gpus): + cc_major, cc_minor = torch.cuda.get_device_capability(i) + cc = f"{cc_major}.{cc_minor}" + compute_capabilities.append(cc) + + cc_string = ";".join(compute_capabilities) + os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string + print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}") \ No newline at end of file -- GitLab