Skip to content
Snippets Groups Projects
Commit 4cceb363 authored by Blaß, Michael's avatar Blaß, Michael :speech_balloon:
Browse files

Updated plot utilities.

parent 88744b85
No related branches found
No related tags found
No related merge requests found
...@@ -15,26 +15,30 @@ from apollon import aplot ...@@ -15,26 +15,30 @@ from apollon import aplot
from apollon.types import Array, Axis, Shape from apollon.types import Array, Axis, Shape
def umatrix(ax: Axis, umx: Array, outline: bool = False, **kwargs) -> None: def umatrix(ax: Axis, umx: Array, outline: bool = False,
pad_mode: str = 'constant', **kwargs) -> None:
"""Plot the U-matrix. """Plot the U-matrix.
Args: Args:
ax: Axis subplot. ax: Axis subplot.
umx: U-matrix data. umx: U-matrix data.
Returns: Note:
Image. Figure aspect is set to 'eqaul'.
""" """
defaults = { defaults = {
'cmap': 'terrain', 'cmap': 'terrain',
'levels': 20} 'levels': 20}
defaults.update(kwargs)
sdx, sdy = umx.shape
umx_padded = np.pad(umx, 1, mode=pad_mode)
for k, v in kwargs.items(): _ = ax.contourf(umx, **defaults, extent=(-0.5, sdy-0.5, -0.5, sdx-0.5))
_ = kwargs.setdefault(k, v) _ = ax.set_xticks(range(sdy))
ax.contourf(umx, **kwargs) _ = ax.set_yticks(range(sdx))
if outline: if outline:
ax.contour(umx, cmap='Greys_r', alpha=.7) ax.contour(umx, cmap='Greys_r', alpha=.7)
return ax ax.set_aspect('equal')
def plot_calibration(self, lables=None, ax=None, cmap='plasma', **kwargs): def plot_calibration(self, lables=None, ax=None, cmap='plasma', **kwargs):
...@@ -94,7 +98,7 @@ def plot_datamap(self, data, targets, interp='None', marker=False, ...@@ -94,7 +98,7 @@ def plot_datamap(self, data, targets, interp='None', marker=False,
return (ax, udm, (x, y)) return (ax, udm, (x, y))
def plot_qerror(self, ax=None, **kwargs): def plot_qerror(ax=None, **kwargs):
"""Plot quantization error.""" """Plot quantization error."""
if ax is None: if ax is None:
fig, ax = _new_axis(**kwargs) fig, ax = _new_axis(**kwargs)
...@@ -109,7 +113,7 @@ def plot_qerror(self, ax=None, **kwargs): ...@@ -109,7 +113,7 @@ def plot_qerror(self, ax=None, **kwargs):
def plot_umatrix3d(self, w=1, cmap='viridis', **kwargs): def plot_umatrix3d(w=1, cmap='viridis', **kwargs):
"""Plot the umatrix in 3d. The color on each unit (x, y) represents its """Plot the umatrix in 3d. The color on each unit (x, y) represents its
mean distance to all direct neighbours. mean distance to all direct neighbours.
......
...@@ -215,9 +215,11 @@ class SomBase: ...@@ -215,9 +215,11 @@ class SomBase:
return bmu return bmu
def match(self, data: Array) -> Array: def match(self, data: Array) -> Array:
"""Return the multi_index of the best matching unit for each vector in """Return the multi index of the best matching unit for each vector in
``data``. ``data``.
Caution: This function returns the multi index into the array.
Args: Args:
data: Input data set. data: Input data set.
...@@ -225,8 +227,7 @@ class SomBase: ...@@ -225,8 +227,7 @@ class SomBase:
Array of SOM unit indices. Array of SOM unit indices.
""" """
bmu = self.match_flat(data) bmu = self.match_flat(data)
pos_y, pos_x = np.unravel_index(bmu, self.shape) return np.column_stack(np.unravel_index(bmu, self.shape))
return np.column_stack((pos_x, pos_y))
def predict(self, data: Array) -> Array: def predict(self, data: Array) -> Array:
"""Predict the SOM index of the best matching unit """Predict the SOM index of the best matching unit
......
...@@ -138,9 +138,11 @@ def sample_pca(dims: SomDims, data: Optional[Array] = None, **kwargs) -> Array: ...@@ -138,9 +138,11 @@ def sample_pca(dims: SomDims, data: Optional[Array] = None, **kwargs) -> Array:
data_limits = np.column_stack((trans_data.min(axis=0), data_limits = np.column_stack((trans_data.min(axis=0),
trans_data.max(axis=0))) trans_data.max(axis=0)))
if 'adapt' in kwargs and kwargs['adapt'] is True: if 'adapt' in kwargs and kwargs['adapt'] is True:
shape = sorted(shape, reverse=True) shape = sorted((n_rows, n_cols), reverse=True)
dim_x = np.linspace(*data_limits[0], n_rows) else:
dim_y = np.linspace(*data_limits[1], n_cols) shape = (n_rows, n_cols)
dim_x = np.linspace(*data_limits[0], shape[0])
dim_y = np.linspace(*data_limits[1], shape[1])
grid_x, grid_y = np.meshgrid(dim_x, dim_y) grid_x, grid_y = np.meshgrid(dim_x, dim_y)
points = np.vstack((grid_x.ravel(), grid_y.ravel())) points = np.vstack((grid_x.ravel(), grid_y.ravel()))
weights = points.T @ vects + data.mean(axis=0) weights = points.T @ vects + data.mean(axis=0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment