From b033d3ea5e3ae01bf4619193b55597484f291e1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20Bla=C3=9F?= <michael.blass@uni-hamburg.de> Date: Sun, 2 Aug 2020 14:49:35 +0200 Subject: [PATCH] Updated ``sample_stm``. --- src/apollon/som/utilities.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/apollon/som/utilities.py b/src/apollon/som/utilities.py index 6eb7e91..6db0706 100644 --- a/src/apollon/som/utilities.py +++ b/src/apollon/som/utilities.py @@ -155,8 +155,8 @@ def sample_rnd(data: Array, shape: Shape) -> Array: """ n_units = np.prod(shape) data_limits = np.column_stack((data.max(axis=0), data.min(axis=0))) - return np.column_stack((np.random.uniform(*data_limits[0], n_units), - np.random.uniform(*data_limits[1], n_units))) + weights = [np.random.uniform(*lim, n_units) for lim in data_limits] + return np.column_stack(weights) def sample_stm(data: Array, shape: Shape): @@ -190,8 +190,7 @@ def sample_stm(data: Array, shape: Shape): n_rows = int(n_rows) n_units = np.prod(shape) - alpha = np.full((n_rows, n_rows), 500) - np.fill_diagonal(alpha, 1000) + alpha = np.random.randint(1, 10, (n_rows, n_rows)) st_matrix = np.hstack([_stats.dirichlet.rvs(alpha=a, size=n_units) for a in alpha]) return st_matrix -- GitLab