Skip to content
Snippets Groups Projects
Commit f8c67bdc authored by Blaß, Michael's avatar Blaß, Michael :speech_balloon:
Browse files

metric and init_weights now accept functions and strings.

parent 16d1a162
No related branches found
No related tags found
No related merge requests found
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# mblass@posteo.net
from typing import Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from scipy.spatial import cKDTree
......@@ -13,10 +13,13 @@ from . import neighbors as _neighbors
from . import utilities as asu
from .. types import Array, Shape, Coord
WeightInit = Union[Callable[[Array, Shape], Array], str]
Metric = Union[Callable[[Array, Array], float], str]
SomDims = Tuple[int, int, int]
class SomGrid:
def __init__(self, shape: Shape) -> None:
if not all(isinstance(val, int) and val >= 1 for val in shape):
raise ValueError('Dimensions must be integer > 0.')
self.shape = shape
......@@ -62,9 +65,9 @@ class SomGrid:
class SomBase:
def __init__(self, dims: Tuple[int, int, int], n_iter: int, eta: float,
nhr: float, nh_shape: str, init_distr: str, metric: str,
seed: Optional[float] = None):
def __init__(self, dims: SomDims, n_iter: int, eta: float,
nhr: float, nh_shape: str, init_weights: WeightInit,
metric: Metric, seed: Optional[float] = None):
self._grid = SomGrid(dims[:2])
self.n_features = dims[2]
......@@ -73,6 +76,7 @@ class SomBase:
self.metric = metric
self._qrr = np.zeros(n_iter)
self._trr = np.zeros(n_iter)
self._weights: Optional[Array] = None
try:
self._neighbourhood = getattr(_neighbors, nh_shape)
......@@ -98,16 +102,13 @@ class SomBase:
if seed is not None:
np.random.seed(seed)
if init_distr == 'uniform':
self._weights = np.random.uniform(0, 1,
size=(self.n_units, self.dw))
elif init_distr == 'simplex':
self._weights = asu.init_simplex(self.dw, self.n_units)
elif init_distr == 'pca':
raise NotImplementedError
if isinstance(init_weights, str):
self.init_weights = asu.weight_initializer[init_weights]
elif callable(init_weights):
self.init_weights = init_weights
else:
raise ValueError(f'Unknown initializer "{init_distr}". Use'
'"uniform", "simplex", or "pca".')
msg = f'Initializer must be string or callable.'
raise ValueError(msg)
self._dists: Optional[Array] = None
......@@ -302,23 +303,24 @@ class SomBase:
class BatchMap(SomBase):
def __init__(self, dims: tuple, n_iter: int, eta: float, nhr: float,
nh_shape: str = 'gaussian', init_distr: str = 'uniform',
metric: str = 'euclidean', seed: int = None):
def __init__(self, dims: SomDims, n_iter: int, eta: float, nhr: float,
nh_shape: str = 'gaussian', init_weights: WeightInit = 'rnd',
metric: Metric = 'euclidean', seed: int = None):
super().__init__(dims, n_iter, eta, nhr, nh_shape, init_distr, metric,
super().__init__(dims, n_iter, eta, nhr, nh_shape, init_weights, metric,
seed=seed)
class IncrementalMap(SomBase):
def __init__(self, dims: tuple, n_iter: int, eta: float, nhr: float,
nh_shape: str = 'gaussian', init_distr: str = 'uniform',
metric: str = 'euclidean', seed: int = None):
def __init__(self, dims: SomDims, n_iter: int, eta: float, nhr: float,
nh_shape: str = 'gaussian', init_weights: WeightInit = 'rnd',
metric: Metric = 'euclidean', seed: int = None):
super().__init__(dims, n_iter, eta, nhr, nh_shape, init_distr, metric,
super().__init__(dims, n_iter, eta, nhr, nh_shape, init_weights, metric,
seed=seed)
def fit(self, train_data, verbose=False, output_weights=False):
self._weights = self.init_weights(train_data, self.shape)
eta_ = asu.decrease_linear(self.init_eta, self.n_iter, _defaults.final_eta)
nhr_ = asu.decrease_expo(self.init_nhr, self.n_iter, _defaults.final_nhr)
......@@ -353,6 +355,7 @@ class IncrementalKDTReeMap(SomBase):
def fit(self, train_data, verbose=False):
"""Fit SOM to input data."""
self._weights = self.init_weights(train_data, self.shape)
eta_ = asu.decrease_linear(self.init_eta, self.n_iter, _defaults.final_eta)
nhr_ = asu.decrease_expo(self.init_nhr, self.n_iter, _defaults.final_nhr)
iter_ = range(self.n_iter)
......
......@@ -218,3 +218,9 @@ def distribute(bmu_idx: Iterable[int], n_units: int
for data_idx, bmu in enumerate(bmu_idx):
unit_matches[bmu].append(data_idx)
return unit_matches
weight_initializer = {
'rnd': sample_rnd,
'stm': sample_stm,
'pca': sample_pca,}
......@@ -14,68 +14,72 @@ som_dims = hst.tuples(dimension, dimension, dimension)
class TestSomBase(unittest.TestCase):
@given(som_dims)
def test_dims(self, dims: SomDim) -> None:
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertEqual(som.dims, dims)
@given(som_dims)
def test_dx(self, dims: SomDim) -> None:
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertEqual(som.dx, dims[0])
@given(som_dims)
def test_dy(self, dims: SomDim) -> None:
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertEqual(som.dy, dims[1])
@given(som_dims)
def test_dw(self, dims: SomDim) -> None:
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertEqual(som.dw, dims[2])
@given(som_dims)
def test_n_units(self, dims: SomDim) -> None:
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertEqual(som.n_units, dims[0]*dims[1])
@given(som_dims)
def test_shape(self, dims: SomDim) -> None:
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertEqual(som.shape, (dims[0], dims[1]))
@given(som_dims)
def test_grid(self, dims: SomDim) -> None:
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertIsInstance(som.grid, SomGrid)
"""
@given(som_dims)
def test_dists(self, dims: SomDim) -> None:
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertIsInstance(som.dists, np.ndarray)
"""
@given(som_dims)
def test_weights(self, dims: SomDim) -> None:
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
self.assertIsInstance(som.weights, np.ndarray)
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
self.assertIsNone(som.weights)
@given(som_dims)
def test_match(self, dims: SomDim) -> None:
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
data = np.random.rand(100, dims[2])
som = SomBase(dims, 10, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
som._weights = som.init_weights(data, som.shape)
self.assertIsInstance(som.match(data), np.ndarray)
@given(som_dims)
def test_umatrix_has_map_shape(self, dims: SomDim) -> None:
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
data = np.random.rand(100, dims[2])
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
som._weights = som.init_weights(data, som.shape)
um = som.umatrix()
self.assertEqual(um.shape, som.shape)
@given(som_dims)
def test_umatrix_scale(self, dims: SomDim) -> None:
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
som._weights = np.tile(np.arange(som.n_features), (som.n_units, 1))
som._weights[:, -1] = np.arange(som.n_units)
um = som.umatrix(scale=True, norm=False)
......@@ -84,7 +88,9 @@ class TestSomBase(unittest.TestCase):
@given(som_dims)
def test_umatrix_norm(self, dims: SomDim) -> None:
som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
data = np.random.rand(100, dims[2])
som = SomBase(dims, 10, 0.1, 10, 'gaussian', 'rnd', 'euclidean')
som._weights = som.init_weights(data, som.shape)
um = som.umatrix(norm=True)
self.assertEqual(um.max(), 1.0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment