diff --git a/src/apollon/som/som.py b/src/apollon/som/som.py
index 0552a17fc569a96667b971ee7cb50018218a531b..bf9188c588f9b29f79d44c4a9ed886af81913e98 100644
--- a/src/apollon/som/som.py
+++ b/src/apollon/som/som.py
@@ -289,11 +289,15 @@ class SomBase:
         for i, nhd_idx in enumerate(nhd_per_unit):
             cwv = self._weights[[i]]
             nhd = self._weights[nhd_idx]
-            u_height[i] = distance.cdist(cwv, nhd).sum()
+            u_height[i] = distance.cdist(cwv, nhd, self.metric).sum()
             if scale:
                 u_height[i] /= len(nhd_idx)
         if norm:
-            u_height /= u_height.max()
+            umax = u_height.max()
+            if umax == 0:
+                u_height = np.zeros_like(u_height)
+            else:
+                u_height /= u_height.max()
         return u_height.reshape(self.shape)
 
 
diff --git a/tests/som/test_som.py b/tests/som/test_som.py
index 9e269b16e3213bd97e35ff68ef4a35ae10fb7f1d..b4cd0a0ec53f66c28b4c0e14a97d54245e16b38f 100644
--- a/tests/som/test_som.py
+++ b/tests/som/test_som.py
@@ -9,7 +9,7 @@ import scipy as sp
 from apollon.som.som import SomBase, SomGrid
 
 SomDim = Tuple[int, int, int]
-dimension = hst.integers(min_value=1, max_value=100)
+dimension = hst.integers(min_value=2, max_value=50)
 som_dims = hst.tuples(dimension, dimension, dimension)
 
 
@@ -67,23 +67,27 @@ class TestSomBase(unittest.TestCase):
         data = np.random.rand(100, dims[2])
         self.assertIsInstance(som.match(data), np.ndarray)
 
+    @given(som_dims)
+    def test_umatrix_has_map_shape(self, dims: SomDim) -> None:
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        um = som.umatrix()
+        self.assertEqual(um.shape, som.shape)
 
-"""
-class TestSelfOrganizingMap(unittest.TestCase):
-    def setUp(self):
-        N = 100
-
-        m1 = (0, 0)
-        m2 = (10, 15)
-        c1 = ((10, 0), (0, 10))
-        c2 = ((2, 0), (0, 2))
+    @given(som_dims)
+    def test_umatrix_scale(self, dims: SomDim) -> None:
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        som._weights = np.tile(np.arange(som.n_features), (som.n_units, 1))
+        som._weights[:, -1] = np.arange(som.n_units)
+        um = som.umatrix(scale=True, norm=False)
+        self.assertEqual(um[0, 0], um[-1, -1])
+        self.assertEqual(um[0, -1], um[-1, 0])
 
-        seg1 = np.random.multivariate_normal(m1, c1, N)
-        seg2 = np.random.multivariate_normal(m2, c2, N)
+    @given(som_dims)
+    def test_umatrix_norm(self, dims: SomDim) -> None:
+        som = SomBase(dims, 100, 0.1, 10, 'gaussian', 'uniform', 'euclidean')
+        um = som.umatrix(norm=True)
+        self.assertEqual(um.max(), 1.0)
 
-        self.data = np.vstack((seg1, seg2))
-        self.dims = (10, 10, 2)
-"""
 
 if __name__ == '__main__':
     unittest.main()