Skip to content
Snippets Groups Projects
Commit d4a55d8f authored by Blaß, Michael's avatar Blaß, Michael :speech_balloon:
Browse files

Merge branch 'feat-som-plots' into develop

New plotting functionallity.
parents b3e59b61 a22c5a70
No related branches found
No related tags found
No related merge requests found
Pipeline #6925 passed
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# mblass@posteo.net
"""apollon/som/plot.py
Plotting functions for SOMs.
Licensed under the terms of the BSD-3-Clause license.
Copyright (C) 2019 Michael Blaß
mblass@posteo.net
"""
from typing import Optional, Tuple, Union
import matplotlib as mpl
import matplotlib.pyplot as plt
__all__ = ['cluster_by', 'component', 'hit_counts', 'qerror', 'label_target',
'umatrix', 'wire']
from typing import Callable, Optional, Union
import numpy as np
from apollon import tools
from apollon import aplot
from apollon.types import Array, Axis, Shape
from apollon.types import Array, Axis
def umatrix(ax: Axis, umx: Array, outline: bool = False, **kwargs) -> None:
def umatrix(ax: Axis, som, outline: bool = False, **kwargs) -> None:
"""Plot the U-matrix.
Args:
ax: Axis subplot.
umx: U-matrix data.
som: SOM instance.
Returns:
Image.
Note:
Figure aspect is set to 'eqaul'.
"""
defaults = {
props = {
'cmap': 'terrain',
'levels': 20}
props.update(kwargs)
_generic_contour(ax, som.umatrix(), outline, **props)
for k, v in kwargs.items():
_ = kwargs.setdefault(k, v)
ax.contourf(umx, **kwargs)
if outline:
ax.contour(umx, cmap='Greys_r', alpha=.7)
return ax
def plot_calibration(self, lables=None, ax=None, cmap='plasma', **kwargs):
"""Plot calibrated map.
Args:
labels:
ax
cmap:
Returns:
"""
if not self.isCalibrated:
raise ValueError('Map not calibrated.')
else:
if ax is None:
fig, ax = _new_axis()
ax.set_title('Calibration')
ax.set_xlabel('# units')
ax.set_ylabel('# units')
ax.imshow(self._cmap.reshape(self.dx, self.dy), origin='lower',
cmap=cmap)
#return ax
def plot_datamap(self, data, targets, interp='None', marker=False,
cmap='viridis', **kwargs):
"""Represent the input data on the map by retrieving the best
matching unit for every element in `data`. Mark each map unit
with the corresponding target value.
def umatrix3d(ax: Axis, som, **kwargs) -> None:
"""Plot the U-matrix in three dimensions.
Args:
data: Input data set.
targets: Class labels or values.
interp: matplotlib interpolation method name.
marker: Plot markers in bmu position if True.
Returns:
axis, umatrix, bmu_xy
"""
ax, udm = self.plot_umatrix(interp=interp, cmap=cmap, **kwargs)
#
# TODO: Use .transform() instead
#
bmu, err = self.get_winners(data)
x, y = _np.unravel_index(bmu, (self.dx, self.dy))
fd = {'color':'#cccccc'}
if marker:
ax.scatter(y, x, s=40, marker='x', color='r')
for i, j, t in zip(x, y, targets):
ax.text(j, i, t, fontdict=fd,
horizontalalignment='center',
verticalalignment='center')
return (ax, udm, (x, y))
def plot_qerror(self, ax=None, **kwargs):
"""Plot quantization error."""
if ax is None:
fig, ax = _new_axis(**kwargs)
ax.set_title('Quantization Errors per iteration')
ax.set_xlabel('# interation')
ax.set_ylabel('Error')
ax.plot(self.quantization_error, lw=3, alpha=.8,
label='Quantizationerror')
def plot_umatrix3d(self, w=1, cmap='viridis', **kwargs):
"""Plot the umatrix in 3d. The color on each unit (x, y) represents its
mean distance to all direct neighbours.
Args:
w: Neighbourhood width.
ax: Axis subplot.
som: SOM instance.
Returns:
axis, umatrix
Note:
Figure aspect is set to 'eqaul'.
"""
fig, ax = _new_axis_3d(**kwargs)
udm = _som_utils.umatrix(self.weights, self.shape, metric=self.metric)
X, Y = _np.mgrid[:self.dx, :self.dy]
ax.plot_surface(X, Y, udm, cmap=cmap)
return ax, udm
props = {
'cmap': 'terrain',
}
props.update(kwargs)
ax.plot_surface(*np.mgrid[:som.dx, :som.dy], som.umatrix(), **props)
def plot_features(self, figsize=(8, 8)):
"""Values of each feature of the weight matrix per map unit.
This works currently ony for feature vectors of len dw**2.
def component(ax: Axis, som, comp: int, outline: bool = False,
**kwargs) -> None:
"""Plot a component plane.
Args:
Size of figure.
ax: Axis subplot.
som: SOM instance.
comp: Component number.
"""
d = _np.sqrt(self.dw).astype(int)
rweigths = self.weights.reshape(self.dims)
fig, _ = _plt.subplots(d, d, figsize=figsize, sharex=True, sharey=True)
for i, ax in enumerate(fig.axes):
ax.axison=False
ax.imshow(rweigths[..., i], origin='lower')
props = {
'cmap': 'magma',
'levels': 20,}
props.update(kwargs)
_generic_contour(ax, som.weights[:, comp].reshape(som.shape), outline,
**props)
def plot_whist(self, interp='None', ax=None, **kwargs):
"""Plot the winner histogram.
The darker the color on position (x, y) the more often neuron (x, y)
was choosen as winner. The number of winners at edge neuros is
magnitudes of order higher than on the rest of the map. Thus, the
histogram is shown in log-mode.
def label_target(ax: Axis, som, data: Array, target: Array, **kwargs) -> None:
"""Add target labels for each bmu.
Args:
interp: matplotlib interpolation method name.
ax: Provide custom axis object.
Returns:
The axis.
ax: Axis subplot.
som: SOM instance.
data: Input data.
target: Target labels.
"""
if ax is None:
fig, ax = _new_axis(**kwargs)
ax.imshow(_np.log1p(self.whist.reshape(self.dx, self.dy)),
vmin=0, cmap='Greys', interpolation=interp, origin='lower')
return ax
def inspect(self):
fig = _plt.figure(figsize=(12, 5))
ax1 = _new_axis(sp_pos=(1, 3, 1), fig=fig)
ax2 = _new_axis(sp_pos=(1, 3, 2), fig=fig)
ax3 = _new_axis(sp_pos=(1, 3, 3), fig=fig)
props = {
'fontsize': 9,
'ha': 'left',
'va': 'bottom',
}
props.update(kwargs)
_, _ = self.plot_umatrix(ax=ax1)
bmu = som.match(data)
bmu_xy = np.fliplr(np.atleast_2d(bmu)).T
for x, y, t in zip(*bmu_xy, target):
ax.text(x, y, t, fontdict=props)
if self.isCalibrated:
_ = self.plot_calibration(ax=ax2)
else:
_ = self.plot_whist(ax=ax2)
self.plot_qerror(ax=ax3)
def qerror(ax: Axis, som, **kwargs) -> None:
"""Plot quantization error."""
props = {
'lw': 3,
'alpha': .8,
}
props.update(kwargs)
ax.plot(som.quantization_error, **props)
def weights(weights: Array, dims: Tuple, cmap: str = 'tab20',
figsize: Tuple = (15, 15), stand: bool =False) -> Tuple:
"""Plot a bar chart of the weights of each map unit.
def cluster_by(ax: Axis, som, data: Array, target: Array,
**kwargs) -> None:
"""Plot bmu colored by ``traget``.
Args:
weights: Two-dimensional array of weights.
dims: SOM dimensions (dx, dy, dw).
cmap: Matplotlib color map name.
figsize: Figure size.
stand: Standardize the weights if ``True``.
Returns:
Figure and axes.
ax: Axis subplot.
som: SOM instance.
data: Input data.
target: Target labels.
"""
dx, dy, dw = dims
if stand:
weights = tools.standardize(weights)
lower = np.floor(weights.min())
upper = np.ceil(weights.max())
yticks = np.linspace(lower, upper, 5)
xr = range(dw)
bar_colors = getattr(plt.cm, cmap)(xr)
fig, axs = plt.subplots(dx, dy, figsize=figsize, sharex=True, sharey=True,
subplot_kw={'xticks': [], 'yticks': yticks})
axs = np.flipud(axs).flatten()
for ax, wv in zip(axs, weights):
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_position('zero')
ax.bar(xr, wv, color=bar_colors)
props = {
's': 50,
'c': target,
'marker': 'o',
}
props.update(kwargs)
bmu = som.match(data)
bmu_xy = np.fliplr(np.atleast_2d(bmu)).T
ax.scatter(*bmu_xy, **props)
return fig, axs
def hit_counts(ax: Axis, som, transform: Optional[Callable] = None,
**kwargs) -> None:
"""Plot the winner histogram.
def weights_line(weights: Array, dims: Tuple, color: str = 'r',
figsize: Tuple = (15, 15), stand: bool =False) -> Tuple:
"""Plot a line chart of the weights of each map unit.
Each unit is colored according to the number of times it was bmu.
Args:
weights: Two-dimensional array of weights.
dims: SOM dimensions (dx, dy, dw).
cmap: Matplotlib color map name.
figsize: Figure size.
stand: Standardize the weights if ``True``.
Returns:
Figure and axes.
ax: Axis subplot.
som: SOM instance.
mode: Choose either 'linear', or 'log'.
"""
dx, dy, dw = dims
if stand:
weights = tools.standardize(weights)
lower = np.floor(weights.min())
upper = np.ceil(weights.max())
fig, axs = plt.subplots(dx, dy, figsize=figsize, sharex=True, sharey=True,
subplot_kw={'xticks': [], 'yticks': [], 'frame_on': False})
axs = np.flipud(axs).flatten()
for ax, wv in zip(axs, weights):
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_position('zero')
ax.plot(wv, color=color)
return fig, axs
props = {
'interpolation': None,
'origin': 'lower',
'cmap': 'Greys',
}
props.update(kwargs)
data = som.hit_counts.reshape(som.shape)
if transform is not None:
data = transform(data)
ax.imshow(data, **props)
def wire(ax: Axis, weights: Array, shape: Shape, *,
def wire(ax: Axis, som,
unit_size: Union[int, float, Array] = 100.0,
line_width: Union[int, float] = 1.0,
highlight: Optional[Array] = None, labels: bool = False, **kwargs):
highlight: Optional[Array] = None, labels: bool = False,
**kwargs) -> None:
"""Plot the weight vectors of a SOM with two-dimensional feature space.
Neighbourhood relations are indicate by connecting lines.
Args:
ax: The axis subplot.
weights: SOM weigth matrix.
shape: SOM shape.
som: SOM instance.
unit_size: Size for each unit.
line_width: Width of the wire lines.
highlight: Index of units to be marked in different color.
......@@ -271,38 +166,57 @@ def wire(ax: Axis, weights: Array, shape: Shape, *,
Returns:
vlines, hlines, bgmarker, umarker
"""
unit_color = 'k'
bg_color = 'w'
hl_color = 'r'
alpha = .7
if isinstance(unit_size, np.ndarray):
marker_size = tools.scale(unit_size, 10, 110)
elif isinstance(unit_size, float) or isinstance(unit_size, int):
marker_size = np.repeat(unit_size, weights.shape[0])
marker_size = np.repeat(unit_size, som.n_units)
else:
msg = (f'Argument of parameter ``unit_size`` must be real scalar '
'or one-dimensional numpy array.')
raise ValueError(msg)
marker_size_bg = marker_size + marker_size / 100 * 30
bg_color: str = 'w'
hl_color: str = 'r'
line_props = {
'color': 'k',
'alpha': 0.7,
'lw': 1.0,
'zorder': 9,
}
line_props.update(kwargs)
marker_bg_props = {
's': marker_size_bg,
'c': bg_color,
'edgecolors': None,
'zorder': 11,
}
marker_hl_props = {
's': marker_size,
'c': unit_color,
'alpha': line_props['alpha'],
}
if highlight is not None:
bg_color = np.where(highlight, hl_color, bg_color)
rsw = weights.reshape(*shape, 2)
vx, vy = rsw.T
hx, hy = np.rollaxis(rsw, 1).T
ax.set_aspect('equal')
vlines = ax.plot(vx, vy, unit_color, alpha=alpha, lw=line_width, zorder=9)
hlines = ax.plot(hx, hy, unit_color, alpha=alpha, lw=line_width, zorder=9)
bgmarker = ax.scatter(vx, vy, s=marker_size_bg, c=bg_color,
rsw = som.weights.reshape(som.shape, 2)
v_wx, v_wy = rsw.T
h_wx, h_wy = np.rollaxis(rsw, 1).T
vlines = ax.plot(v_wx, v_wy, **line_props)
hlines = ax.plot(h_wx, h_wy, **line_props)
bgmarker = ax.scatter(v_wx, v_wy, s=marker_size_bg, c=bg_color,
edgecolors='None', zorder=11)
umarker = ax.scatter(vx, vy, s=marker_size, c=unit_color, alpha=alpha,
umarker = ax.scatter(v_wx, v_wy, s=marker_size, c=unit_color, alpha=alpha,
edgecolors='None', zorder=12)
font = {'fontsize': 4,
'va': 'bottom',
'ha': 'center'}
'ha': 'center',
}
bbox = {'alpha': 0.7,
'boxstyle': 'round',
......@@ -312,14 +226,15 @@ def wire(ax: Axis, weights: Array, shape: Shape, *,
}
if labels is True:
for (px, py), (ix, iy) in zip(weights, np.ndindex(shape)):
for (px, py), (ix, iy) in zip(som.weights, np.ndindex(shape)):
ax.text(px+1.3, py, f'({ix}, {iy})', font, bbox=bbox, zorder=13)
ax.set_aspect('equal')
return None
return vlines, hlines, bgmarker, umarker
def data_2d(ax: Axis, data: Array, colors: Array,
**kwargs) -> mpl.collections.PathCollection:
**kwargs) -> None:
"""Scatter plot a data set with two-dimensional feature space.
This just the usual scatter command with some reasonable defaults.
......@@ -332,13 +247,33 @@ def data_2d(ax: Axis, data: Array, colors: Array,
Returns:
PathCollection.
"""
defaults = {
props = {
'alpha': 0.2,
'c': colors,
'cmap': 'plasma',
'edgecolors': 'None',
's': 10}
for k, v in defaults.items():
_ = kwargs.setdefault(k, v)
props.update(kwargs)
aplot.outward_spines(ax)
return ax.scatter(*data.T, **kwargs)
_ = ax.scatter(*data.T, **props)
def _generic_contour(ax: Axis, data: Array, outline: bool = False,
**kwargs) -> None:
"""Contour plot.
Args:
ax: Axis subplot.
data: Two-dimensional array.
"""
sdx, sdy = data.shape
overwrites = {
'extent': (-0.5, sdy-0.5, -0.5, sdx-0.5),
}
kwargs.update(overwrites)
_ = ax.contourf(data, **kwargs)
_ = ax.set_xticks(range(sdy))
_ = ax.set_yticks(range(sdx))
if outline:
ax.contour(data, cmap='Greys_r', alpha=.7)
ax.set_aspect('equal')
......@@ -215,9 +215,11 @@ class SomBase:
return bmu
def match(self, data: Array) -> Array:
"""Return the multi_index of the best matching unit for each vector in
"""Return the multi index of the best matching unit for each vector in
``data``.
Caution: This function returns the multi index into the array.
Args:
data: Input data set.
......@@ -225,8 +227,7 @@ class SomBase:
Array of SOM unit indices.
"""
bmu = self.match_flat(data)
pos_y, pos_x = np.unravel_index(bmu, self.shape)
return np.column_stack((pos_x, pos_y))
return np.column_stack(np.unravel_index(bmu, self.shape))
def predict(self, data: Array) -> Array:
"""Predict the SOM index of the best matching unit
......
......@@ -138,9 +138,11 @@ def sample_pca(dims: SomDims, data: Optional[Array] = None, **kwargs) -> Array:
data_limits = np.column_stack((trans_data.min(axis=0),
trans_data.max(axis=0)))
if 'adapt' in kwargs and kwargs['adapt'] is True:
shape = sorted(shape, reverse=True)
dim_x = np.linspace(*data_limits[0], n_rows)
dim_y = np.linspace(*data_limits[1], n_cols)
shape = sorted((n_rows, n_cols), reverse=True)
else:
shape = (n_rows, n_cols)
dim_x = np.linspace(*data_limits[0], shape[0])
dim_y = np.linspace(*data_limits[1], shape[1])
grid_x, grid_y = np.meshgrid(dim_x, dim_y)
points = np.vstack((grid_x.ravel(), grid_y.ravel()))
weights = points.T @ vects + data.mean(axis=0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment