diff --git a/src/apollon/som/utilities.py b/src/apollon/som/utilities.py index 6eb7e915002ebbee2a2680b5f3bbfc79f012b139..6db07069656474f5f9cc84dda77085aa184f5708 100644 --- a/src/apollon/som/utilities.py +++ b/src/apollon/som/utilities.py @@ -155,8 +155,8 @@ def sample_rnd(data: Array, shape: Shape) -> Array: """ n_units = np.prod(shape) data_limits = np.column_stack((data.max(axis=0), data.min(axis=0))) - return np.column_stack((np.random.uniform(*data_limits[0], n_units), - np.random.uniform(*data_limits[1], n_units))) + weights = [np.random.uniform(*lim, n_units) for lim in data_limits] + return np.column_stack(weights) def sample_stm(data: Array, shape: Shape): @@ -190,8 +190,7 @@ def sample_stm(data: Array, shape: Shape): n_rows = int(n_rows) n_units = np.prod(shape) - alpha = np.full((n_rows, n_rows), 500) - np.fill_diagonal(alpha, 1000) + alpha = np.random.randint(1, 10, (n_rows, n_rows)) st_matrix = np.hstack([_stats.dirichlet.rvs(alpha=a, size=n_units) for a in alpha]) return st_matrix