From f8c67bdcf7819f62187e1b3d168796997c856aff Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Michael=20Bla=C3=9F?= <michael.blass@uni-hamburg.de>
Date: Sat, 15 Aug 2020 10:26:33 +0200
Subject: [PATCH] metric and init_weights now accept functions and strings.

---
 src/apollon/som/som.py       | 47 +++++++++++++++++++-----------------
 src/apollon/som/utilities.py |  6 +++++
 tests/som/test_som.py        | 34 +++++++++++++++-----------
 3 files changed, 51 insertions(+), 36 deletions(-)

diff --git a/src/apollon/som/som.py b/src/apollon/som/som.py
index bf9188c..e4cc412 100644
--- a/src/apollon/som/som.py
+++ b/src/apollon/som/som.py
@@ -1,7 +1,7 @@
 # Licensed under the terms of the BSD-3-Clause license.
 # Copyright (C) 2019 Michael Blaß
 # mblass@posteo.net
-from typing import Dict, List, Optional, Tuple
+from typing import Callable, Dict, List, Optional, Tuple, Union
 
 import numpy as np
 from scipy.spatial import cKDTree
@@ -13,10 +13,13 @@ from . import neighbors as _neighbors
 from . import utilities as asu
 from .. types import Array, Shape, Coord
 
+WeightInit = Union[Callable[[Array, Shape], Array], str]
+Metric = Union[Callable[[Array, Array], float], str]
+SomDims = Tuple[int, int, int]
+
 
 class SomGrid:
     def __init__(self, shape: Shape) -> None:
-
         if not all(isinstance(val, int) and val >= 1 for val in shape):
             raise ValueError('Dimensions must be integer > 0.')
         self.shape = shape
@@ -62,9 +65,9 @@ class SomGrid:
 
 
 class SomBase:
-    def __init__(self, dims: Tuple[int, int, int], n_iter: int, eta: float,
-                 nhr: float, nh_shape: str, init_distr: str, metric: str,
-                 seed: Optional[float] = None):
+    def __init__(self, dims: SomDims, n_iter: int, eta: float,
+                 nhr: float, nh_shape: str, init_weights: WeightInit,
+                 metric: Metric, seed: Optional[float] = None):
 
         self._grid = SomGrid(dims[:2])
         self.n_features = dims[2]
@@ -73,6 +76,7 @@ class SomBase:
         self.metric = metric
         self._qrr = np.zeros(n_iter)
         self._trr = np.zeros(n_iter)
+        self._weights: Optional[Array] = None
 
         try:
             self._neighbourhood = getattr(_neighbors, nh_shape)
@@ -98,16 +102,13 @@ class SomBase:
         if seed is not None:
             np.random.seed(seed)
 
-        if init_distr == 'uniform':
-            self._weights = np.random.uniform(0, 1,
-                                              size=(self.n_units, self.dw))
-        elif init_distr == 'simplex':
-            self._weights = asu.init_simplex(self.dw, self.n_units)
-        elif init_distr == 'pca':
-            raise NotImplementedError
+        if isinstance(init_weights, str):
+            self.init_weights = asu.weight_initializer[init_weights]
+        elif callable(init_weights):
+            self.init_weights = init_weights
         else:
-            raise ValueError(f'Unknown initializer "{init_distr}". Use'
-                             '"uniform", "simplex", or "pca".')
+            msg = f'Initializer must be string or callable.'
+            raise ValueError(msg)
 
         self._dists: Optional[Array] = None
 
@@ -302,23 +303,24 @@ class SomBase:
 
 
 class BatchMap(SomBase):
-    def __init__(self, dims: tuple, n_iter: int, eta: float, nhr: float,
-                 nh_shape: str = 'gaussian', init_distr: str = 'uniform',
-                 metric: str = 'euclidean', seed: int = None):
+    def __init__(self, dims: SomDims, n_iter: int, eta: float, nhr: float,
+                 nh_shape: str = 'gaussian', init_weights: WeightInit  = 'rnd',
+                 metric: Metric = 'euclidean', seed: int = None):
 
-        super().__init__(dims, n_iter, eta, nhr, nh_shape, init_distr, metric,
+        super().__init__(dims, n_iter, eta, nhr, nh_shape, init_weights, metric,
                          seed=seed)
 
 
 class IncrementalMap(SomBase):
-    def __init__(self, dims: tuple, n_iter: int, eta: float, nhr: float,
-                 nh_shape: str = 'gaussian', init_distr: str = 'uniform',
-                 metric: str = 'euclidean', seed: int = None):
+    def __init__(self, dims: SomDims, n_iter: int, eta: float, nhr: float,
+                 nh_shape: str = 'gaussian', init_weights: WeightInit = 'rnd',
+                 metric: Metric = 'euclidean', seed: int = None):
 
-        super().__init__(dims, n_iter, eta, nhr, nh_shape, init_distr, metric,
+        super().__init__(dims, n_iter, eta, nhr, nh_shape, init_weights, metric,
                          seed=seed)
 
     def fit(self, train_data, verbose=False, output_weights=False):
+        self._weights = self.init_weights(train_data, self.shape)
         eta_ = asu.decrease_linear(self.init_eta, self.n_iter, _defaults.final_eta)
         nhr_ = asu.decrease_expo(self.init_nhr, self.n_iter, _defaults.final_nhr)
 
@@ -353,6 +355,7 @@ class IncrementalKDTReeMap(SomBase):
 
     def fit(self, train_data, verbose=False):
         """Fit SOM to input data."""
+        self._weights = self.init_weights(train_data, self.shape)
         eta_ = asu.decrease_linear(self.init_eta, self.n_iter, _defaults.final_eta)
         nhr_ = asu.decrease_expo(self.init_nhr, self.n_iter, _defaults.final_nhr)
         iter_ = range(self.n_iter)
diff --git a/src/apollon/som/utilities.py b/src/apollon/som/utilities.py
index 6db0706..64ef5a7 100644
--- a/src/apollon/som/utilities.py
+++ b/src/apollon/som/utilities.py
@@ -218,3 +218,9 @@ def distribute(bmu_idx: Iterable[int], n_units: int
     for data_idx, bmu in enumerate(bmu_idx):
         unit_matches[bmu].append(data_idx)
     return unit_matches
+
+
+weight_initializer = {
+    'rnd': sample_rnd,
+    'stm': sample_stm,
+    'pca': sample_pca,}
diff --git a/tests/som/test_som.py b/tests/som/test_som.py
index b4cd0a0..5d6d92a 100644
--- a/tests/som/test_som.py
+++ b/tests/som/test_som.py
@@ -14,68 +14,72 @@ som_dims = hst.tuples(dimension, dimension, dimension)
 
 
 class TestSomBase(unittest.TestCase):
+
     @given(som_dims)
     def test_dims(self, dims: SomDim) -> None:
-        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
         self.assertEqual(som.dims, dims)
 
     @given(som_dims)
     def test_dx(self, dims: SomDim) -> None:
-        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
         self.assertEqual(som.dx, dims[0])
 
     @given(som_dims)
     def test_dy(self, dims: SomDim) -> None:
-        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
         self.assertEqual(som.dy, dims[1])
 
     @given(som_dims)
     def test_dw(self, dims: SomDim) -> None:
-        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
         self.assertEqual(som.dw, dims[2])
 
     @given(som_dims)
     def test_n_units(self, dims: SomDim) -> None:
-        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
         self.assertEqual(som.n_units, dims[0]*dims[1])
 
     @given(som_dims)
     def test_shape(self, dims: SomDim) -> None:
-        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
         self.assertEqual(som.shape, (dims[0], dims[1]))
 
     @given(som_dims)
     def test_grid(self, dims: SomDim) -> None:
-        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
         self.assertIsInstance(som.grid, SomGrid)
 
     """
     @given(som_dims)
     def test_dists(self, dims: SomDim) -> None:
-        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
         self.assertIsInstance(som.dists, np.ndarray)
     """
 
     @given(som_dims)
     def test_weights(self, dims: SomDim) -> None:
-        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
-        self.assertIsInstance(som.weights, np.ndarray)
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
+        self.assertIsNone(som.weights)
 
     @given(som_dims)
     def test_match(self, dims: SomDim) -> None:
-        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
         data = np.random.rand(100, dims[2])
+        som = SomBase(dims, 10, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
+        som._weights = som.init_weights(data, som.shape)
         self.assertIsInstance(som.match(data), np.ndarray)
 
     @given(som_dims)
     def test_umatrix_has_map_shape(self, dims: SomDim) -> None:
-        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        data = np.random.rand(100, dims[2])
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
+        som._weights = som.init_weights(data, som.shape)
         um = som.umatrix()
         self.assertEqual(um.shape, som.shape)
 
     @given(som_dims)
     def test_umatrix_scale(self, dims: SomDim) -> None:
-        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
         som._weights = np.tile(np.arange(som.n_features), (som.n_units, 1))
         som._weights[:, -1] = np.arange(som.n_units)
         um = som.umatrix(scale=True, norm=False)
@@ -84,7 +88,9 @@ class TestSomBase(unittest.TestCase):
 
     @given(som_dims)
     def test_umatrix_norm(self, dims: SomDim) -> None:
-        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        data = np.random.rand(100, dims[2])
+        som = SomBase(dims, 10, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
+        som._weights = som.init_weights(data, som.shape)
         um = som.umatrix(norm=True)
         self.assertEqual(um.max(), 1.0)
 
-- 
GitLab