diff --git a/src/apollon/som/som.py b/src/apollon/som/som.py index 4dec4a2cd0c497f047da922dcd6df5da29119065..0552a17fc569a96667b971ee7cb50018218a531b 100644 --- a/src/apollon/som/som.py +++ b/src/apollon/som/som.py @@ -1,7 +1,6 @@ # Licensed under the terms of the BSD-3-Clause license. # Copyright (C) 2019 Michael Blaß # mblass@posteo.net -import collections from typing import Dict, List, Optional, Tuple import numpy as np @@ -16,7 +15,11 @@ from .. types import Array, Shape, Coord class SomGrid: - def __init__(self, shape: Shape): + def __init__(self, shape: Shape) -> None: + + if not all(isinstance(val, int) and val >= 1 for val in shape): + raise ValueError('Dimensions must be integer > 0.') + self.shape = shape self.pos = np.asarray(list(np.ndindex(shape)), dtype=int) self.tree = cKDTree(self.pos) self.rows, self.cols = np.indices(shape) @@ -63,12 +66,8 @@ class SomBase: nhr: float, nh_shape: str, init_distr: str, metric: str, seed: Optional[float] = None): - # check dimensions - for d in dims: - if not isinstance(d, int) or not d >= 1: - raise ValueError('Dimensions must be integer > 0.') - - self._dims = dims + self._grid = SomGrid(dims[:2]) + self.n_features = dims[2] self._hit_counts = np.zeros(self.n_units) self.n_iter = n_iter self.metric = metric @@ -110,28 +109,27 @@ class SomBase: raise ValueError(f'Unknown initializer "{init_distr}". Use' '"uniform", "simplex", or "pca".') - self._grid = SomGrid(self.shape) self._dists: Optional[Array] = None @property def dims(self) -> Tuple[int, int, int]: """Return the SOM dimensions.""" - return self._dims + return (*self._grid.shape, self.n_features) @property def dx(self) -> int: """Return the number of units along the first dimension.""" - return self.dims[0] + return self._grid.shape[0] @property def dy(self) -> int: """Return the number of units along the second dimension.""" - return self.dims[1] + return self._grid.shape[1] @property def dw(self) -> int: """Return the dimension of the weight vectors.""" - return self.dims[2] + return self.n_features @property def n_units(self) -> int: @@ -141,7 +139,7 @@ class SomBase: @property def shape(self) -> Shape: """Return the map shape.""" - return (self.dx, self.dy) + return self._grid.shape @property def grid(self) -> Array: