Skip to content
Snippets Groups Projects
Commit 292fd4c6 authored by Blaß, Michael's avatar Blaß, Michael :speech_balloon:
Browse files

Updated utilities.sample_stm.

parent 080d7bf3
No related branches found
No related tags found
No related merge requests found
......@@ -159,39 +159,39 @@ def sample_rnd(data: Array, shape: Shape) -> Array:
np.random.uniform(*data_limits[1], n_units)))
def sample_stm(n_features, n_units):
"""Initialize the weights with stochastic matrices.
def sample_stm(data: Array, shape: Shape):
"""Compute initial SOM weights by sampling stochastic matrices from
Dirichlet distribution.
The rows of each n by n stochastic matrix are sampes drawn from the
Dirichlet distribution, where n is the number of rows and cols of the
matrix. The diagonal elemets of the matrices are set to twice the
probability of the remaining elements.
The square root n of the weight vectors' size must be element of the
natural numbers, so that the weight vector is reshapeable to a square
matrix.
The square root of the weight vectors' size must be a real integer.
Args:
n_features: Number of features in each vector.
n_units: Number of units on the SOM.
data: Input data set.
shape: Shape of SOM.
Returns:
Two-dimensional array of shape (n_units, n_features), in which each
row is a flattened random stochastic matrix.
Array of SOM weights.
Notes:
Each row of the output array is to be considered a flattened
stochastic matrix, such that each ``N = sqrt(data.shape[1])`` values
are a discrete probability distribution forming the ``N``th row of
the matrix.
"""
# check for square matrix
n_rows = np.sqrt(n_features)
n_rows = np.sqrt(data.shape[1])
if bool(n_rows - int(n_rows)):
msg = (f'Weight vector (len={n_features}) is not '
msg = (f'Weight vector with {n_rows} elements is not '
'reshapeable to square matrix.')
raise ValueError(msg)
else:
n_rows = int(n_rows)
# set alpha
n_rows = int(n_rows)
n_units = np.prod(shape)
alpha = np.full((n_rows, n_rows), 500)
np.fill_diagonal(alpha, 1000)
# sample from dirichlet distributions
st_matrix = np.hstack([_stats.dirichlet.rvs(alpha=a, size=n_units)
for a in alpha])
return st_matrix
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment