Skip to content
Snippets Groups Projects
Commit b033d3ea authored by Blaß, Michael's avatar Blaß, Michael :speech_balloon:
Browse files

Updated ``sample_stm``.

parent f5aa4de3
No related branches found
No related tags found
No related merge requests found
...@@ -155,8 +155,8 @@ def sample_rnd(data: Array, shape: Shape) -> Array: ...@@ -155,8 +155,8 @@ def sample_rnd(data: Array, shape: Shape) -> Array:
""" """
n_units = np.prod(shape) n_units = np.prod(shape)
data_limits = np.column_stack((data.max(axis=0), data.min(axis=0))) data_limits = np.column_stack((data.max(axis=0), data.min(axis=0)))
return np.column_stack((np.random.uniform(*data_limits[0], n_units), weights = [np.random.uniform(*lim, n_units) for lim in data_limits]
np.random.uniform(*data_limits[1], n_units))) return np.column_stack(weights)
def sample_stm(data: Array, shape: Shape): def sample_stm(data: Array, shape: Shape):
...@@ -190,8 +190,7 @@ def sample_stm(data: Array, shape: Shape): ...@@ -190,8 +190,7 @@ def sample_stm(data: Array, shape: Shape):
n_rows = int(n_rows) n_rows = int(n_rows)
n_units = np.prod(shape) n_units = np.prod(shape)
alpha = np.full((n_rows, n_rows), 500) alpha = np.random.randint(1, 10, (n_rows, n_rows))
np.fill_diagonal(alpha, 1000)
st_matrix = np.hstack([_stats.dirichlet.rvs(alpha=a, size=n_units) st_matrix = np.hstack([_stats.dirichlet.rvs(alpha=a, size=n_units)
for a in alpha]) for a in alpha])
return st_matrix return st_matrix
... ...
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment