diff --git a/src/apollon/som/som.py b/src/apollon/som/som.py
index bf9188c588f9b29f79d44c4a9ed886af81913e98..e4cc4121d9840be8a185a4e6fd4fb3f1582d457c 100644
--- a/src/apollon/som/som.py
+++ b/src/apollon/som/som.py
@@ -1,7 +1,7 @@
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# mblass@posteo.net
-from typing import Dict, List, Optional, Tuple
+from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from scipy.spatial import cKDTree
@@ -13,10 +13,13 @@ from . import neighbors as _neighbors
from . import utilities as asu
from .. types import Array, Shape, Coord
+WeightInit = Union[Callable[[Array, Shape], Array], str]
+Metric = Union[Callable[[Array, Array], float], str]
+SomDims = Tuple[int, int, int]
+
class SomGrid:
def __init__(self, shape: Shape) -> None:
-
if not all(isinstance(val, int) and val >= 1 for val in shape):
raise ValueError('Dimensions must be integer > 0.')
self.shape = shape
@@ -62,9 +65,9 @@ class SomGrid:
class SomBase:
- def __init__(self, dims: Tuple[int, int, int], n_iter: int, eta: float,
- nhr: float, nh_shape: str, init_distr: str, metric: str,
- seed: Optional[float] = None):
+ def __init__(self, dims: SomDims, n_iter: int, eta: float,
+ nhr: float, nh_shape: str, init_weights: WeightInit,
+ metric: Metric, seed: Optional[float] = None):
self._grid = SomGrid(dims[:2])
self.n_features = dims[2]
@@ -73,6 +76,7 @@ class SomBase:
self.metric = metric
self._qrr = np.zeros(n_iter)
self._trr = np.zeros(n_iter)
+ self._weights: Optional[Array] = None
try:
self._neighbourhood = getattr(_neighbors, nh_shape)
@@ -98,16 +102,13 @@ class SomBase:
if seed is not None:
np.random.seed(seed)
- if init_distr == 'uniform':
- self._weights = np.random.uniform(0, 1,
- size=(self.n_units, self.dw))
- elif init_distr == 'simplex':
- self._weights = asu.init_simplex(self.dw, self.n_units)
- elif init_distr == 'pca':
- raise NotImplementedError
+ if isinstance(init_weights, str):
+ self.init_weights = asu.weight_initializer[init_weights]
+ elif callable(init_weights):
+ self.init_weights = init_weights
else:
- raise ValueError(f'Unknown initializer "{init_distr}". Use'
- '"uniform", "simplex", or "pca".')
+ msg = f'Initializer must be string or callable.'
+ raise ValueError(msg)
self._dists: Optional[Array] = None
@@ -302,23 +303,24 @@ class SomBase:
class BatchMap(SomBase):
- def __init__(self, dims: tuple, n_iter: int, eta: float, nhr: float,
- nh_shape: str = 'gaussian', init_distr: str = 'uniform',
- metric: str = 'euclidean', seed: int = None):
+ def __init__(self, dims: SomDims, n_iter: int, eta: float, nhr: float,
+ nh_shape: str = 'gaussian', init_weights: WeightInit = 'rnd',
+ metric: Metric = 'euclidean', seed: int = None):
- super().__init__(dims, n_iter, eta, nhr, nh_shape, init_distr, metric,
+ super().__init__(dims, n_iter, eta, nhr, nh_shape, init_weights, metric,
seed=seed)
class IncrementalMap(SomBase):
- def __init__(self, dims: tuple, n_iter: int, eta: float, nhr: float,
- nh_shape: str = 'gaussian', init_distr: str = 'uniform',
- metric: str = 'euclidean', seed: int = None):
+ def __init__(self, dims: SomDims, n_iter: int, eta: float, nhr: float,
+ nh_shape: str = 'gaussian', init_weights: WeightInit = 'rnd',
+ metric: Metric = 'euclidean', seed: int = None):
- super().__init__(dims, n_iter, eta, nhr, nh_shape, init_distr, metric,
+ super().__init__(dims, n_iter, eta, nhr, nh_shape, init_weights, metric,
seed=seed)
def fit(self, train_data, verbose=False, output_weights=False):
+ self._weights = self.init_weights(train_data, self.shape)
eta_ = asu.decrease_linear(self.init_eta, self.n_iter, _defaults.final_eta)
nhr_ = asu.decrease_expo(self.init_nhr, self.n_iter, _defaults.final_nhr)
@@ -353,6 +355,7 @@ class IncrementalKDTReeMap(SomBase):
def fit(self, train_data, verbose=False):
"""Fit SOM to input data."""
+ self._weights = self.init_weights(train_data, self.shape)
eta_ = asu.decrease_linear(self.init_eta, self.n_iter, _defaults.final_eta)
nhr_ = asu.decrease_expo(self.init_nhr, self.n_iter, _defaults.final_nhr)
iter_ = range(self.n_iter)
diff --git a/src/apollon/som/utilities.py b/src/apollon/som/utilities.py
index 6db07069656474f5f9cc84dda77085aa184f5708..64ef5a7885c9d29bd79df52f01ad451122d94237 100644
--- a/src/apollon/som/utilities.py
+++ b/src/apollon/som/utilities.py
@@ -218,3 +218,9 @@ def distribute(bmu_idx: Iterable[int], n_units: int
for data_idx, bmu in enumerate(bmu_idx):
unit_matches[bmu].append(data_idx)
return unit_matches
+
+
+weight_initializer = {
+ 'rnd': sample_rnd,
+ 'stm': sample_stm,
+ 'pca': sample_pca,}
diff --git a/tests/som/test_som.py b/tests/som/test_som.py
index b4cd0a0ec53f66c28b4c0e14a97d54245e16b38f..5d6d92a39245b9ac1c5c01052588733cc7be7f5d 100644
--- a/tests/som/test_som.py
+++ b/tests/som/test_som.py
@@ -14,68 +14,72 @@ som_dims = hst.tuples(dimension, dimension, dimension)
class TestSomBase(unittest.TestCase):
+
@given(som_dims)
def test_dims(self, dims: SomDim) -> None:
- som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+ som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertEqual(som.dims, dims)
@given(som_dims)
def test_dx(self, dims: SomDim) -> None:
- som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+ som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertEqual(som.dx, dims[0])
@given(som_dims)
def test_dy(self, dims: SomDim) -> None:
- som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+ som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertEqual(som.dy, dims[1])
@given(som_dims)
def test_dw(self, dims: SomDim) -> None:
- som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+ som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertEqual(som.dw, dims[2])
@given(som_dims)
def test_n_units(self, dims: SomDim) -> None:
- som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+ som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertEqual(som.n_units, dims[0]*dims[1])
@given(som_dims)
def test_shape(self, dims: SomDim) -> None:
- som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+ som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertEqual(som.shape, (dims[0], dims[1]))
@given(som_dims)
def test_grid(self, dims: SomDim) -> None:
- som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+ som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertIsInstance(som.grid, SomGrid)
"""
@given(som_dims)
def test_dists(self, dims: SomDim) -> None:
- som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+ som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertIsInstance(som.dists, np.ndarray)
"""
@given(som_dims)
def test_weights(self, dims: SomDim) -> None:
- som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
- self.assertIsInstance(som.weights, np.ndarray)
+ som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
+ self.assertIsNone(som.weights)
@given(som_dims)
def test_match(self, dims: SomDim) -> None:
- som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
data = np.random.rand(100, dims[2])
+ som = SomBase(dims, 10, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
+ som._weights = som.init_weights(data, som.shape)
self.assertIsInstance(som.match(data), np.ndarray)
@given(som_dims)
def test_umatrix_has_map_shape(self, dims: SomDim) -> None:
- som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+ data = np.random.rand(100, dims[2])
+ som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
+ som._weights = som.init_weights(data, som.shape)
um = som.umatrix()
self.assertEqual(um.shape, som.shape)
@given(som_dims)
def test_umatrix_scale(self, dims: SomDim) -> None:
- som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+ som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
som._weights = np.tile(np.arange(som.n_features), (som.n_units, 1))
som._weights[:, -1] = np.arange(som.n_units)
um = som.umatrix(scale=True, norm=False)
@@ -84,7 +88,9 @@ class TestSomBase(unittest.TestCase):
@given(som_dims)
def test_umatrix_norm(self, dims: SomDim) -> None:
- som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+ data = np.random.rand(100, dims[2])
+ som = SomBase(dims, 10, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
+ som._weights = som.init_weights(data, som.shape)
um = som.umatrix(norm=True)
self.assertEqual(um.max(), 1.0)