Skip to content
Snippets Groups Projects
Commit a22c5a70 authored by Blaß, Michael's avatar Blaß, Michael :speech_balloon:
Browse files

Updated all plot functions.

parent 4cceb363
No related branches found
No related tags found
No related merge requests found
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# mblass@posteo.net
"""apollon/som/plot.py """apollon/som/plot.py
Plotting functions for SOMs.
Licensed under the terms of the BSD-3-Clause license.
Copyright (C) 2019 Michael Blaß
mblass@posteo.net
""" """
from typing import Optional, Tuple, Union
import matplotlib as mpl __all__ = ['cluster_by', 'component', 'hit_counts', 'qerror', 'label_target',
import matplotlib.pyplot as plt 'umatrix', 'wire']
from typing import Callable, Optional, Union
import numpy as np import numpy as np
from apollon import tools from apollon import tools
from apollon import aplot from apollon import aplot
from apollon.types import Array, Axis, Shape from apollon.types import Array, Axis
def umatrix(ax: Axis, umx: Array, outline: bool = False, def umatrix(ax: Axis, som, outline: bool = False, **kwargs) -> None:
pad_mode: str = 'constant', **kwargs) -> None:
"""Plot the U-matrix. """Plot the U-matrix.
Args: Args:
ax: Axis subplot. ax: Axis subplot.
umx: U-matrix data. som: SOM instance.
Note: Note:
Figure aspect is set to 'eqaul'. Figure aspect is set to 'eqaul'.
""" """
defaults = { props = {
'cmap': 'terrain', 'cmap': 'terrain',
'levels': 20} 'levels': 20}
defaults.update(kwargs) props.update(kwargs)
sdx, sdy = umx.shape _generic_contour(ax, som.umatrix(), outline, **props)
umx_padded = np.pad(umx, 1, mode=pad_mode)
_ = ax.contourf(umx, **defaults, extent=(-0.5, sdy-0.5, -0.5, sdx-0.5))
_ = ax.set_xticks(range(sdy))
_ = ax.set_yticks(range(sdx))
if outline:
ax.contour(umx, cmap='Greys_r', alpha=.7)
ax.set_aspect('equal')
def plot_calibration(self, lables=None, ax=None, cmap='plasma', **kwargs):
"""Plot calibrated map.
Args:
labels:
ax
cmap:
Returns:
"""
if not self.isCalibrated:
raise ValueError('Map not calibrated.')
else:
if ax is None:
fig, ax = _new_axis()
ax.set_title('Calibration')
ax.set_xlabel('# units')
ax.set_ylabel('# units')
ax.imshow(self._cmap.reshape(self.dx, self.dy), origin='lower',
cmap=cmap)
#return ax
def plot_datamap(self, data, targets, interp='None', marker=False,
cmap='viridis', **kwargs):
"""Represent the input data on the map by retrieving the best
matching unit for every element in `data`. Mark each map unit
with the corresponding target value.
Args:
data: Input data set.
targets: Class labels or values.
interp: matplotlib interpolation method name.
marker: Plot markers in bmu position if True.
Returns:
axis, umatrix, bmu_xy
"""
ax, udm = self.plot_umatrix(interp=interp, cmap=cmap, **kwargs)
#
# TODO: Use .transform() instead
#
bmu, err = self.get_winners(data)
x, y = _np.unravel_index(bmu, (self.dx, self.dy)) def umatrix3d(ax: Axis, som, **kwargs) -> None:
fd = {'color':'#cccccc'} """Plot the U-matrix in three dimensions.
if marker:
ax.scatter(y, x, s=40, marker='x', color='r')
for i, j, t in zip(x, y, targets):
ax.text(j, i, t, fontdict=fd,
horizontalalignment='center',
verticalalignment='center')
return (ax, udm, (x, y))
def plot_qerror(ax=None, **kwargs):
"""Plot quantization error."""
if ax is None:
fig, ax = _new_axis(**kwargs)
ax.set_title('Quantization Errors per iteration')
ax.set_xlabel('# interation')
ax.set_ylabel('Error')
ax.plot(self.quantization_error, lw=3, alpha=.8,
label='Quantizationerror')
def plot_umatrix3d(w=1, cmap='viridis', **kwargs):
"""Plot the umatrix in 3d. The color on each unit (x, y) represents its
mean distance to all direct neighbours.
Args: Args:
w: Neighbourhood width. ax: Axis subplot.
som: SOM instance.
Returns: Note:
axis, umatrix Figure aspect is set to 'eqaul'.
""" """
fig, ax = _new_axis_3d(**kwargs) props = {
udm = _som_utils.umatrix(self.weights, self.shape, metric=self.metric) 'cmap': 'terrain',
X, Y = _np.mgrid[:self.dx, :self.dy] }
ax.plot_surface(X, Y, udm, cmap=cmap) props.update(kwargs)
return ax, udm ax.plot_surface(*np.mgrid[:som.dx, :som.dy], som.umatrix(), **props)
def plot_features(self, figsize=(8, 8)):
"""Values of each feature of the weight matrix per map unit.
This works currently ony for feature vectors of len dw**2. def component(ax: Axis, som, comp: int, outline: bool = False,
**kwargs) -> None:
"""Plot a component plane.
Args: Args:
Size of figure. ax: Axis subplot.
som: SOM instance.
comp: Component number.
""" """
d = _np.sqrt(self.dw).astype(int) props = {
rweigths = self.weights.reshape(self.dims) 'cmap': 'magma',
'levels': 20,}
props.update(kwargs)
_generic_contour(ax, som.weights[:, comp].reshape(som.shape), outline,
**props)
fig, _ = _plt.subplots(d, d, figsize=figsize, sharex=True, sharey=True)
for i, ax in enumerate(fig.axes):
ax.axison=False
ax.imshow(rweigths[..., i], origin='lower')
def label_target(ax: Axis, som, data: Array, target: Array, **kwargs) -> None:
def plot_whist(self, interp='None', ax=None, **kwargs): """Add target labels for each bmu.
"""Plot the winner histogram.
The darker the color on position (x, y) the more often neuron (x, y)
was choosen as winner. The number of winners at edge neuros is
magnitudes of order higher than on the rest of the map. Thus, the
histogram is shown in log-mode.
Args: Args:
interp: matplotlib interpolation method name. ax: Axis subplot.
ax: Provide custom axis object. som: SOM instance.
data: Input data.
Returns: target: Target labels.
The axis.
""" """
if ax is None: props = {
fig, ax = _new_axis(**kwargs) 'fontsize': 9,
ax.imshow(_np.log1p(self.whist.reshape(self.dx, self.dy)), 'ha': 'left',
vmin=0, cmap='Greys', interpolation=interp, origin='lower') 'va': 'bottom',
return ax }
props.update(kwargs)
def inspect(self):
fig = _plt.figure(figsize=(12, 5))
ax1 = _new_axis(sp_pos=(1, 3, 1), fig=fig)
ax2 = _new_axis(sp_pos=(1, 3, 2), fig=fig)
ax3 = _new_axis(sp_pos=(1, 3, 3), fig=fig)
_, _ = self.plot_umatrix(ax=ax1) bmu = som.match(data)
bmu_xy = np.fliplr(np.atleast_2d(bmu)).T
for x, y, t in zip(*bmu_xy, target):
ax.text(x, y, t, fontdict=props)
if self.isCalibrated:
_ = self.plot_calibration(ax=ax2)
else:
_ = self.plot_whist(ax=ax2)
self.plot_qerror(ax=ax3) def qerror(ax: Axis, som, **kwargs) -> None:
"""Plot quantization error."""
props = {
'lw': 3,
'alpha': .8,
}
props.update(kwargs)
ax.plot(som.quantization_error, **props)
def weights(weights: Array, dims: Tuple, cmap: str = 'tab20', def cluster_by(ax: Axis, som, data: Array, target: Array,
figsize: Tuple = (15, 15), stand: bool =False) -> Tuple: **kwargs) -> None:
"""Plot a bar chart of the weights of each map unit. """Plot bmu colored by ``traget``.
Args: Args:
weights: Two-dimensional array of weights. ax: Axis subplot.
dims: SOM dimensions (dx, dy, dw). som: SOM instance.
cmap: Matplotlib color map name. data: Input data.
figsize: Figure size. target: Target labels.
stand: Standardize the weights if ``True``.
Returns:
Figure and axes.
""" """
dx, dy, dw = dims props = {
if stand: 's': 50,
weights = tools.standardize(weights) 'c': target,
lower = np.floor(weights.min()) 'marker': 'o',
upper = np.ceil(weights.max()) }
yticks = np.linspace(lower, upper, 5) props.update(kwargs)
xr = range(dw) bmu = som.match(data)
bar_colors = getattr(plt.cm, cmap)(xr) bmu_xy = np.fliplr(np.atleast_2d(bmu)).T
ax.scatter(*bmu_xy, **props)
fig, axs = plt.subplots(dx, dy, figsize=figsize, sharex=True, sharey=True,
subplot_kw={'xticks': [], 'yticks': yticks})
axs = np.flipud(axs).flatten()
for ax, wv in zip(axs, weights):
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_position('zero')
ax.bar(xr, wv, color=bar_colors)
return fig, axs
def hit_counts(ax: Axis, som, transform: Optional[Callable] = None,
**kwargs) -> None:
"""Plot the winner histogram.
def weights_line(weights: Array, dims: Tuple, color: str = 'r', Each unit is colored according to the number of times it was bmu.
figsize: Tuple = (15, 15), stand: bool =False) -> Tuple:
"""Plot a line chart of the weights of each map unit.
Args: Args:
weights: Two-dimensional array of weights. ax: Axis subplot.
dims: SOM dimensions (dx, dy, dw). som: SOM instance.
cmap: Matplotlib color map name. mode: Choose either 'linear', or 'log'.
figsize: Figure size.
stand: Standardize the weights if ``True``.
Returns:
Figure and axes.
""" """
dx, dy, dw = dims props = {
if stand: 'interpolation': None,
weights = tools.standardize(weights) 'origin': 'lower',
lower = np.floor(weights.min()) 'cmap': 'Greys',
upper = np.ceil(weights.max()) }
props.update(kwargs)
fig, axs = plt.subplots(dx, dy, figsize=figsize, sharex=True, sharey=True, data = som.hit_counts.reshape(som.shape)
subplot_kw={'xticks': [], 'yticks': [], 'frame_on': False}) if transform is not None:
axs = np.flipud(axs).flatten() data = transform(data)
ax.imshow(data, **props)
for ax, wv in zip(axs, weights):
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_position('zero')
ax.plot(wv, color=color)
return fig, axs
def wire(ax: Axis, weights: Array, shape: Shape, *, def wire(ax: Axis, som,
unit_size: Union[int, float, Array] = 100.0, unit_size: Union[int, float, Array] = 100.0,
line_width: Union[int, float] = 1.0, line_width: Union[int, float] = 1.0,
highlight: Optional[Array] = None, labels: bool = False, **kwargs): highlight: Optional[Array] = None, labels: bool = False,
**kwargs) -> None:
"""Plot the weight vectors of a SOM with two-dimensional feature space. """Plot the weight vectors of a SOM with two-dimensional feature space.
Neighbourhood relations are indicate by connecting lines. Neighbourhood relations are indicate by connecting lines.
Args: Args:
ax: The axis subplot. ax: The axis subplot.
weights: SOM weigth matrix. som: SOM instance.
shape: SOM shape.
unit_size: Size for each unit. unit_size: Size for each unit.
line_width: Width of the wire lines. line_width: Width of the wire lines.
highlight: Index of units to be marked in different color. highlight: Index of units to be marked in different color.
...@@ -275,38 +166,57 @@ def wire(ax: Axis, weights: Array, shape: Shape, *, ...@@ -275,38 +166,57 @@ def wire(ax: Axis, weights: Array, shape: Shape, *,
Returns: Returns:
vlines, hlines, bgmarker, umarker vlines, hlines, bgmarker, umarker
""" """
unit_color = 'k'
bg_color = 'w'
hl_color = 'r'
alpha = .7
if isinstance(unit_size, np.ndarray): if isinstance(unit_size, np.ndarray):
marker_size = tools.scale(unit_size, 10, 110) marker_size = tools.scale(unit_size, 10, 110)
elif isinstance(unit_size, float) or isinstance(unit_size, int): elif isinstance(unit_size, float) or isinstance(unit_size, int):
marker_size = np.repeat(unit_size, weights.shape[0]) marker_size = np.repeat(unit_size, som.n_units)
else: else:
msg = (f'Argument of parameter ``unit_size`` must be real scalar ' msg = (f'Argument of parameter ``unit_size`` must be real scalar '
'or one-dimensional numpy array.') 'or one-dimensional numpy array.')
raise ValueError(msg) raise ValueError(msg)
marker_size_bg = marker_size + marker_size / 100 * 30 marker_size_bg = marker_size + marker_size / 100 * 30
bg_color: str = 'w'
hl_color: str = 'r'
line_props = {
'color': 'k',
'alpha': 0.7,
'lw': 1.0,
'zorder': 9,
}
line_props.update(kwargs)
marker_bg_props = {
's': marker_size_bg,
'c': bg_color,
'edgecolors': None,
'zorder': 11,
}
marker_hl_props = {
's': marker_size,
'c': unit_color,
'alpha': line_props['alpha'],
}
if highlight is not None: if highlight is not None:
bg_color = np.where(highlight, hl_color, bg_color) bg_color = np.where(highlight, hl_color, bg_color)
rsw = weights.reshape(*shape, 2) rsw = som.weights.reshape(som.shape, 2)
vx, vy = rsw.T v_wx, v_wy = rsw.T
hx, hy = np.rollaxis(rsw, 1).T h_wx, h_wy = np.rollaxis(rsw, 1).T
ax.set_aspect('equal') vlines = ax.plot(v_wx, v_wy, **line_props)
vlines = ax.plot(vx, vy, unit_color, alpha=alpha, lw=line_width, zorder=9) hlines = ax.plot(h_wx, h_wy, **line_props)
hlines = ax.plot(hx, hy, unit_color, alpha=alpha, lw=line_width, zorder=9) bgmarker = ax.scatter(v_wx, v_wy, s=marker_size_bg, c=bg_color,
bgmarker = ax.scatter(vx, vy, s=marker_size_bg, c=bg_color,
edgecolors='None', zorder=11) edgecolors='None', zorder=11)
umarker = ax.scatter(vx, vy, s=marker_size, c=unit_color, alpha=alpha, umarker = ax.scatter(v_wx, v_wy, s=marker_size, c=unit_color, alpha=alpha,
edgecolors='None', zorder=12) edgecolors='None', zorder=12)
font = {'fontsize': 4, font = {'fontsize': 4,
'va': 'bottom', 'va': 'bottom',
'ha': 'center'} 'ha': 'center',
}
bbox = {'alpha': 0.7, bbox = {'alpha': 0.7,
'boxstyle': 'round', 'boxstyle': 'round',
...@@ -316,14 +226,15 @@ def wire(ax: Axis, weights: Array, shape: Shape, *, ...@@ -316,14 +226,15 @@ def wire(ax: Axis, weights: Array, shape: Shape, *,
} }
if labels is True: if labels is True:
for (px, py), (ix, iy) in zip(weights, np.ndindex(shape)): for (px, py), (ix, iy) in zip(som.weights, np.ndindex(shape)):
ax.text(px+1.3, py, f'({ix}, {iy})', font, bbox=bbox, zorder=13) ax.text(px+1.3, py, f'({ix}, {iy})', font, bbox=bbox, zorder=13)
ax.set_aspect('equal')
return None
return vlines, hlines, bgmarker, umarker
def data_2d(ax: Axis, data: Array, colors: Array, def data_2d(ax: Axis, data: Array, colors: Array,
**kwargs) -> mpl.collections.PathCollection: **kwargs) -> None:
"""Scatter plot a data set with two-dimensional feature space. """Scatter plot a data set with two-dimensional feature space.
This just the usual scatter command with some reasonable defaults. This just the usual scatter command with some reasonable defaults.
...@@ -336,13 +247,33 @@ def data_2d(ax: Axis, data: Array, colors: Array, ...@@ -336,13 +247,33 @@ def data_2d(ax: Axis, data: Array, colors: Array,
Returns: Returns:
PathCollection. PathCollection.
""" """
defaults = { props = {
'alpha': 0.2, 'alpha': 0.2,
'c': colors, 'c': colors,
'cmap': 'plasma', 'cmap': 'plasma',
'edgecolors': 'None', 'edgecolors': 'None',
's': 10} 's': 10}
for k, v in defaults.items(): props.update(kwargs)
_ = kwargs.setdefault(k, v)
aplot.outward_spines(ax) aplot.outward_spines(ax)
return ax.scatter(*data.T, **kwargs) _ = ax.scatter(*data.T, **props)
def _generic_contour(ax: Axis, data: Array, outline: bool = False,
**kwargs) -> None:
"""Contour plot.
Args:
ax: Axis subplot.
data: Two-dimensional array.
"""
sdx, sdy = data.shape
overwrites = {
'extent': (-0.5, sdy-0.5, -0.5, sdx-0.5),
}
kwargs.update(overwrites)
_ = ax.contourf(data, **kwargs)
_ = ax.set_xticks(range(sdy))
_ = ax.set_yticks(range(sdx))
if outline:
ax.contour(data, cmap='Greys_r', alpha=.7)
ax.set_aspect('equal')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment