From 4b5970c299b43df85c9b955856e47924fd306dd9 Mon Sep 17 00:00:00 2001
From: jrichter <jrichter@exchange.informatik.uni-hamburg.de>
Date: Sun, 16 Jun 2024 14:21:52 +0200
Subject: [PATCH] add set_torch_cuda_arch_list()

---
 sgmse/util/other.py | 26 ++++++++++++++++++--------
 1 file changed, 18 insertions(+), 8 deletions(-)

diff --git a/sgmse/util/other.py b/sgmse/util/other.py
index fbe5052..83ef70e 100644
--- a/sgmse/util/other.py
+++ b/sgmse/util/other.py
@@ -1,19 +1,14 @@
 import os
-
+import torch
 import numpy as np
-
 import scipy.stats
 from scipy.signal import butter, sosfilt
 
-import torch
-
 from pesq import pesq
 from pystoi import stoi
 
 
 def si_sdr_components(s_hat, s, n):
-    """
-    """
     # s_target
     alpha_s = np.dot(s_hat, s) / np.linalg.norm(s)**2
     s_target = alpha_s * s
@@ -28,8 +23,6 @@ def si_sdr_components(s_hat, s, n):
     return s_target, e_noise, e_art
 
 def energy_ratios(s_hat, s, n):
-    """
-    """
     s_target, e_noise, e_art = si_sdr_components(s_hat, s, n)
 
     si_sdr = 10*np.log10(np.linalg.norm(s_target)**2 / np.linalg.norm(e_noise + e_art)**2)
@@ -129,3 +122,20 @@ def print_mean_std(data, decimal=2):
     elif decimal == 1:
         string = f'{mean:.1f} ± {std:.1f}'
     return string
+
+def set_torch_cuda_arch_list():
+    if not torch.cuda.is_available():
+        print("CUDA is not available. No GPUs found.")
+        return
+    
+    num_gpus = torch.cuda.device_count()
+    compute_capabilities = []
+
+    for i in range(num_gpus):
+        cc_major, cc_minor = torch.cuda.get_device_capability(i)
+        cc = f"{cc_major}.{cc_minor}"
+        compute_capabilities.append(cc)
+    
+    cc_string = ";".join(compute_capabilities)
+    os.environ['TORCH_CUDA_ARCH_LIST'] = cc_string
+    print(f"Set TORCH_CUDA_ARCH_LIST to: {cc_string}")
\ No newline at end of file
-- 
GitLab