Skip to content
Snippets Groups Projects
Commit 4b5970c2 authored by jrichter's avatar jrichter
Browse files

add set_torch_cuda_arch_list()

parent 089a625f
No related branches found
No related tags found
No related merge requests found
import os
import torch
import numpy as np
import scipy.stats
from scipy.signal import butter, sosfilt
import torch
from pesq import pesq
from pystoi import stoi
def si_sdr_components(s_hat, s, n):
"""
"""
# s_target
alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
s_target = alpha_s * s
......@@ -28,8 +23,6 @@ def si_sdr_components(s_hat, s, n):
return s_target, e_noise, e_art
def energy_ratios(s_hat, s, n):
"""
"""
s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
......@@ -129,3 +122,20 @@ def print_mean_std(data, decimal=2):
elif decimal == 1:
string = f'{mean:.1f} ± {std:.1f}'
return string
def set_torch_cuda_arch_list():
if not torch.cuda.is_available():
print("CUDA is not available. No GPUs found.")
return
num_gpus = torch.cuda.device_count()
compute_capabilities = []
for i in range(num_gpus):
cc_major, cc_minor = torch.cuda.get_device_capability(i)
cc = f"{cc_major}.{cc_minor}"
compute_capabilities.append(cc)
cc_string = ";".join(compute_capabilities)
os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string
print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment