Skip to content
Snippets Groups Projects
Commit 292fd4c6 authored by Blaß, Michael's avatar Blaß, Michael :speech_balloon:
Browse files

Updated utilities.sample_stm.

parent 080d7bf3
No related branches found
No related tags found
No related merge requests found
...@@ -159,39 +159,39 @@ def sample_rnd(data: Array, shape: Shape) -> Array: ...@@ -159,39 +159,39 @@ def sample_rnd(data: Array, shape: Shape) -> Array:
np.random.uniform(*data_limits[1], n_units))) np.random.uniform(*data_limits[1], n_units)))
def sample_stm(n_features, n_units): def sample_stm(data: Array, shape: Shape):
"""Initialize the weights with stochastic matrices. """Compute initial SOM weights by sampling stochastic matrices from
Dirichlet distribution.
The rows of each n by n stochastic matrix are sampes drawn from the The rows of each n by n stochastic matrix are sampes drawn from the
Dirichlet distribution, where n is the number of rows and cols of the Dirichlet distribution, where n is the number of rows and cols of the
matrix. The diagonal elemets of the matrices are set to twice the matrix. The diagonal elemets of the matrices are set to twice the
probability of the remaining elements. probability of the remaining elements.
The square root n of the weight vectors' size must be element of the The square root of the weight vectors' size must be a real integer.
natural numbers, so that the weight vector is reshapeable to a square
matrix.
Args: Args:
n_features: Number of features in each vector. data: Input data set.
n_units: Number of units on the SOM. shape: Shape of SOM.
Returns: Returns:
Two-dimensional array of shape (n_units, n_features), in which each Array of SOM weights.
row is a flattened random stochastic matrix.
Notes:
Each row of the output array is to be considered a flattened
stochastic matrix, such that each ``N = sqrt(data.shape[1])`` values
are a discrete probability distribution forming the ``N``th row of
the matrix.
""" """
# check for square matrix n_rows = np.sqrt(data.shape[1])
n_rows = np.sqrt(n_features)
if bool(n_rows - int(n_rows)): if bool(n_rows - int(n_rows)):
msg = (f'Weight vector (len={n_features}) is not ' msg = (f'Weight vector with {n_rows} elements is not '
'reshapeable to square matrix.') 'reshapeable to square matrix.')
raise ValueError(msg) raise ValueError(msg)
else:
n_rows = int(n_rows)
# set alpha n_rows = int(n_rows)
n_units = np.prod(shape)
alpha = np.full((n_rows, n_rows), 500) alpha = np.full((n_rows, n_rows), 500)
np.fill_diagonal(alpha, 1000) np.fill_diagonal(alpha, 1000)
# sample from dirichlet distributions
st_matrix = np.hstack([_stats.dirichlet.rvs(alpha=a, size=n_units) st_matrix = np.hstack([_stats.dirichlet.rvs(alpha=a, size=n_units)
for a in alpha]) for a in alpha])
return st_matrix return st_matrix
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment