From b033d3ea5e3ae01bf4619193b55597484f291e1c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Michael=20Bla=C3=9F?= <michael.blass@uni-hamburg.de>
Date: Sun, 2 Aug 2020 14:49:35 +0200
Subject: [PATCH] Updated ``sample_stm``.

---
 src/apollon/som/utilities.py | 7 +++----
 1 file changed, 3 insertions(+), 4 deletions(-)

diff --git a/src/apollon/som/utilities.py b/src/apollon/som/utilities.py
index 6eb7e91..6db0706 100644
--- a/src/apollon/som/utilities.py
+++ b/src/apollon/som/utilities.py
@@ -155,8 +155,8 @@ def sample_rnd(data: Array, shape: Shape) -> Array:
     """
     n_units = np.prod(shape)
     data_limits = np.column_stack((data.max(axis=0), data.min(axis=0)))
-    return np.column_stack((np.random.uniform(*data_limits[0], n_units),
-                            np.random.uniform(*data_limits[1], n_units)))
+    weights = [np.random.uniform(*lim, n_units) for lim in data_limits]
+    return np.column_stack(weights)
 
 
 def sample_stm(data: Array, shape: Shape):
@@ -190,8 +190,7 @@ def sample_stm(data: Array, shape: Shape):
 
     n_rows = int(n_rows)
     n_units = np.prod(shape)
-    alpha = np.full((n_rows, n_rows), 500)
-    np.fill_diagonal(alpha, 1000)
+    alpha = np.random.randint(1, 10, (n_rows, n_rows))
     st_matrix = np.hstack([_stats.dirichlet.rvs(alpha=a, size=n_units)
                            for a in alpha])
     return st_matrix
-- 
GitLab