From feef5375f28aaaca22b84df232a7bb73ea093f66 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Michael=20Bla=C3=9F?= <michael.blass@uni-hamburg.de>
Date: Tue, 4 Aug 2020 08:59:56 +0200
Subject: [PATCH] Moved SOM shape attribute to SomGrid

---
 src/apollon/som/som.py | 26 ++++++++++++--------------
 1 file changed, 12 insertions(+), 14 deletions(-)

diff --git a/src/apollon/som/som.py b/src/apollon/som/som.py
index 4dec4a2..0552a17 100644
--- a/src/apollon/som/som.py
+++ b/src/apollon/som/som.py
@@ -1,7 +1,6 @@
 # Licensed under the terms of the BSD-3-Clause license.
 # Copyright (C) 2019 Michael Blaß
 # mblass@posteo.net
-import collections
 from typing import Dict, List, Optional, Tuple
 
 import numpy as np
@@ -16,7 +15,11 @@ from .. types import Array, Shape, Coord
 
 
 class SomGrid:
-    def __init__(self, shape: Shape):
+    def __init__(self, shape: Shape) -> None:
+
+        if not all(isinstance(val, int) and val >= 1 for val in shape):
+            raise ValueError('Dimensions must be integer > 0.')
+        self.shape = shape
         self.pos = np.asarray(list(np.ndindex(shape)), dtype=int)
         self.tree = cKDTree(self.pos)
         self.rows, self.cols = np.indices(shape)
@@ -63,12 +66,8 @@ class SomBase:
                  nhr: float, nh_shape: str, init_distr: str, metric: str,
                  seed: Optional[float] = None):
 
-        # check dimensions
-        for d in dims:
-            if not isinstance(d, int) or not d >= 1:
-                raise ValueError('Dimensions must be integer > 0.')
-
-        self._dims = dims
+        self._grid = SomGrid(dims[:2])
+        self.n_features = dims[2]
         self._hit_counts = np.zeros(self.n_units)
         self.n_iter = n_iter
         self.metric = metric
@@ -110,28 +109,27 @@ class SomBase:
             raise ValueError(f'Unknown initializer "{init_distr}". Use'
                              '"uniform", "simplex", or "pca".')
 
-        self._grid = SomGrid(self.shape)
         self._dists: Optional[Array] = None
 
     @property
     def dims(self) -> Tuple[int, int, int]:
         """Return the SOM dimensions."""
-        return self._dims
+        return (*self._grid.shape, self.n_features)
 
     @property
     def dx(self) -> int:
         """Return the number of units along the first dimension."""
-        return self.dims[0]
+        return self._grid.shape[0]
 
     @property
     def dy(self) -> int:
         """Return the number of units along the second dimension."""
-        return self.dims[1]
+        return self._grid.shape[1]
 
     @property
     def dw(self) -> int:
         """Return the dimension of the weight vectors."""
-        return self.dims[2]
+        return self.n_features
 
     @property
     def n_units(self) -> int:
@@ -141,7 +139,7 @@ class SomBase:
     @property
     def shape(self) -> Shape:
         """Return the map shape."""
-        return (self.dx, self.dy)
+        return self._grid.shape
 
     @property
     def grid(self) -> Array:
-- 
GitLab