diff --git a/src/apollon/som/som.py b/src/apollon/som/som.py
index 4dec4a2cd0c497f047da922dcd6df5da29119065..0552a17fc569a96667b971ee7cb50018218a531b 100644
--- a/src/apollon/som/som.py
+++ b/src/apollon/som/som.py
@@ -1,7 +1,6 @@
 # Licensed under the terms of the BSD-3-Clause license.
 # Copyright (C) 2019 Michael Blaß
 # mblass@posteo.net
-import collections
 from typing import Dict, List, Optional, Tuple
 
 import numpy as np
@@ -16,7 +15,11 @@ from .. types import Array, Shape, Coord
 
 
 class SomGrid:
-    def __init__(self, shape: Shape):
+    def __init__(self, shape: Shape) -> None:
+
+        if not all(isinstance(val, int) and val >= 1 for val in shape):
+            raise ValueError('Dimensions must be integer > 0.')
+        self.shape = shape
         self.pos = np.asarray(list(np.ndindex(shape)), dtype=int)
         self.tree = cKDTree(self.pos)
         self.rows, self.cols = np.indices(shape)
@@ -63,12 +66,8 @@ class SomBase:
                  nhr: float, nh_shape: str, init_distr: str, metric: str,
                  seed: Optional[float] = None):
 
-        # check dimensions
-        for d in dims:
-            if not isinstance(d, int) or not d >= 1:
-                raise ValueError('Dimensions must be integer > 0.')
-
-        self._dims = dims
+        self._grid = SomGrid(dims[:2])
+        self.n_features = dims[2]
         self._hit_counts = np.zeros(self.n_units)
         self.n_iter = n_iter
         self.metric = metric
@@ -110,28 +109,27 @@ class SomBase:
             raise ValueError(f'Unknown initializer "{init_distr}". Use'
                              '"uniform", "simplex", or "pca".')
 
-        self._grid = SomGrid(self.shape)
         self._dists: Optional[Array] = None
 
     @property
     def dims(self) -> Tuple[int, int, int]:
         """Return the SOM dimensions."""
-        return self._dims
+        return (*self._grid.shape, self.n_features)
 
     @property
     def dx(self) -> int:
         """Return the number of units along the first dimension."""
-        return self.dims[0]
+        return self._grid.shape[0]
 
     @property
     def dy(self) -> int:
         """Return the number of units along the second dimension."""
-        return self.dims[1]
+        return self._grid.shape[1]
 
     @property
     def dw(self) -> int:
         """Return the dimension of the weight vectors."""
-        return self.dims[2]
+        return self.n_features
 
     @property
     def n_units(self) -> int:
@@ -141,7 +139,7 @@ class SomBase:
     @property
     def shape(self) -> Shape:
         """Return the map shape."""
-        return (self.dx, self.dy)
+        return self._grid.shape
 
     @property
     def grid(self) -> Array: