Skip to content
Snippets Groups Projects
Commit c7335912 authored by Blaß, Michael's avatar Blaß, Michael :speech_balloon:
Browse files

Merge branch 'develop'

Minor changes for premature upload to pypi.
parents 588a347f 7556bcba
No related branches found
No related tags found
No related merge requests found
Pipeline #8495 passed
Showing
with 35 additions and 1048 deletions
......@@ -31,6 +31,8 @@ wheels/
*.ipynb
.ipynb_checkpoints
.mypy_cache
.apollon-devel/
notebooks/
# PyInstaller
# Usually these files are written by a python script from a template
......@@ -69,7 +71,11 @@ instance/
.scrapy
# Sphinx documentation
<<<<<<< HEAD
docs/_build/
=======
>>>>>>> docs
docs/source/_build/
# PyBuilder
target/
......@@ -90,7 +96,7 @@ celerybeat-schedule
.env
# virtualenv
.venv
.venv*
venv/
ENV/
......
build:
image: python:latest
before_script:
- pip install numpy
- which python
script:
- pip install .
......@@ -414,16 +414,7 @@ function-naming-style=snake_case
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,
j,
k,
ex,
Run,
_,
m,
X,
N,
ax # matplotlib axis
good-names=_, ax, i, j, k, m, n, t, x, y, z
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
......@@ -528,10 +519,10 @@ valid-metaclass-classmethod-first-arg=cls
[DESIGN]
# Maximum number of arguments for function / method.
max-args=5
max-args=8 # original was 5
# Maximum number of attributes for a class (see R0902).
max-attributes=7
max-attributes=10 # original was 8
# Maximum number of boolean expressions in an if statement.
max-bool-expr=5
......
......@@ -2,10 +2,8 @@
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Build documentation in the docs/ directory with Sphinx
sphinx:
builder: html
configuration: docs/source/conf.py
......@@ -16,9 +14,8 @@ formats: all
# Optionally set the version of Python and requirements required to build your docs
python:
version: 3.7
version: 3.8
install:
- requirements: docs/requirements.txt
- method: pip
- method: setuptools
path: .
system_packages: true
\ No newline at end of file
File moved
graft json
include *.md
include LICENSE.txt
recursive-include include *.h
# Apollon
Apollon is a Python framework for audio feature extraction and music similarity
estimation. It includes subpackages for
Apollon is a tool for music modelling. It comprises
* low-level audio feature extraction
* Hidden-Markov Models
* Audio feature extraction
* Hidden Markov Models
* Self-Organizing Map
## 1. Installation
This repository. Navigate the packages root directory
and install apollon using pip.
### 1.1 Install from PyPi
The latest version of apollon is available on PyPi. Just open a terminal an run
the following command to download and install apollon:
```
cd path/to/apollon
pip install .
pip install apollon
```
Note that the period on the end of the last line is necessary.
## 2. Documentation
Full [documentation](https://apollon.readthedocs.io) is available on readthedocs.
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# michael.blass@uni-hamburg.de
from . audio import load_audio
from . io import dump_json, decode_array
from . signal.spectral import stft, Spectrum
from . signal.features import FeatureSpace
from . tools import time_stamp
from . types import PathType
from . types import Array as _Array
from . onsets import FluxOnsetDetector
from . import segment
def rhythm_track(file_path: PathType) -> dict:
"""Perform rhythm track analysis of given audio file.
Args:
file_path: Path to audio file.
Returns:
Rhythm track parameters and data.
"""
snd = load_audio(file_path)
onsets = FluxOnsetDetector(snd.data, snd.fps)
segs = segment.by_onsets(snd.data, 2**11, onsets.index())
spctr = Spectrum(segs, snd.fps, window='hamming')
onsets_features = {
'peaks': onsets.peaks,
'index': onsets.index(),
'times': onsets.times(snd.fps)
}
track_data = {
'meta': {'source': file_path, 'time_stamp': time_stamp()},
'params': {'onsets': onsets.params(), 'spectrum': spctr.params()},
'features': {'onsets': onsets_features,
'spectrum': spctr.extract().as_dict()}
}
return track_data
def timbre_track(file_path: PathType) -> dict:
"""Perform timbre track analysis of given audio file.
Args:
file_path: Path to input file.
Returns:
Timbre track parameters and data.
"""
snd = load_audio(file_path)
spctrgr = stft(snd.data, snd.fps, n_perseg=1024, hop_size=512)
track_data = {
'meta': {'source': file_path, 'time_stamp':time_stamp()},
'params': {'spectrogram': spctrgr.params()},
'features': spctrgr.extract().as_dict()
}
return track_data
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# michael.blass@uni-hamburg.de
"""
apollon/audio.py -- Wrapper classes for audio data.
Classes:
AudioFile Representation of an audio file.
Functions:
load_audio Load .wav file.
"""
import pathlib as _pathlib
import matplotlib.pyplot as _plt
import soundfile as _sf
from . signal import tools as _ast
from . types import PathType
class AudioFile:
"""Representation of an audio file.
Args:
path: Path to file.
norm: If True, signal will be normalized ]-1, 1[.
mono: If True, mixdown all channels.
"""
def __init__(self, path: PathType, norm: bool = False, mono: bool = True) -> None:
"""Load an audio file."""
self.file = _pathlib.Path(path)
self.data, self.fps = _sf.read(self.file, dtype='float')
self.size = self.data.shape[0]
if mono and self.data.ndim > 1:
self.data = self.data.sum(axis=1) / self.data.shape[1]
if norm:
self.data = _ast.normalize(self.data)
def plot(self) -> None:
"""Plot audio as wave form."""
fig = _plt.figure(figsize=(14, 7))
ax1 = fig.add_subplot(1, 1, 1)
ax1.plot(self.data)
def __str__(self):
return "<{}, {} kHz, {:.3} s>" \
.format(self.file.name, self.fps/1000, self.size/self.fps)
def __repr__(self):
return self.__str__()
def __len__(self):
return self.size
def __getitem__(self, item):
return self.data[item]
def load_audio(path: PathType, norm: bool = False, mono: bool = True) -> AudioFile:
"""Load an audio file.
Args:
path: Path to audio file.
norm: True if data should be normalized.
mono: If True, mixdown channels.
Return:
Audio file representation.
"""
return AudioFile(path, norm, mono)
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# michael.blass@uni-hamburg.de
from . import apollon_export
from . import apollon_features
from . import apollon_onsets
from . import apollon_hmm
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# michael.blass@uni-hamburg.de
import argparse
import json
import sys
import typing
from .. io import dump_json, decode_array
from .. signal.spectral import stft
from .. signal.features import FeatureSpace
from .. tools import time_stamp
from .. types import PathType
def _parse_cml(argv):
parser = argparse.ArgumentParser(description='Apollon feature extraction engine')
parser.add_argument('--csv', action='store_true',
help='Export csv')
parser.add_argument('-o', '--outpath', action='store',
help='Output file path')
parser.add_argument('csv_data', type=str, nargs=1)
return parser.parse_args(argv)
def _export_csv(data: typing.Dict[str, typing.Any], path: PathType = None) -> None:
fspace = json.loads(data, object_hook=decode_array)
fspace = FeatureSpace(**fspace)
fspace.to_csv()
def main(argv=None):
if argv is None:
argv = sys.argv
args = _parse_cml(argv)
if args.csv:
_export_csv(args.csv_data[0], args.outpath)
return 0
if __name__ == '__main__':
sys.exit(main)
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# michael.blass@uni-hamburg.de
import argparse
import json
import sys
import typing
from .. import analyses
from .. types import PathType
from .. import io
from .. signal.features import FeatureSpace
def _export_csv(
data: typing.Dict[str, typing.Any],
path: PathType = None) -> None:
""""""
fspace = json.loads(data, object_hook=io.decode_array)
fspace = FeatureSpace(**fspace)
fspace.to_csv()
def main(args: argparse.Namespace) -> int:
if args.export:
if args.export == 'csv':
_export_csv(args.file[0], args.outpath)
return 0
track_data = {}
if args.rhythm:
track_data['rhythm'] = analyses.rhythm_track(args.file[0])
if args.timbre:
track_data['timbre'] = analyses.timbre_track(args.file[0])
if args.pitch:
track_data['pitch'] = analyses.pitch_track(args.file[0])
io.dump_json(track_data, args.outpath)
return 0
if __name__ == '__main__':
sys.exit(main())
#!/usr/bin/env python3
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# michael.blass@uni-hamburg.de
import argparse
import json
import pathlib
import sys
import typing
from .. import io
from .. hmm import PoissonHmm
from .. types import Array as _Array
def _load_track_file(track_file: str) -> dict:
track_file = pathlib.Path(track_file)
with track_file.open('r') as fobj:
track_data = json.load(fobj, object_hook=io.decode_array)
return track_data
def _parse_feature(track_data: dict, feature_path: str) -> _Array:
feature = track_data
for key in feature_path.split('.'):
try:
feature = feature[key]
except KeyError:
print('Error. Invalid node "{}" in feature path.'.format(key))
exit(10)
return feature
def _generate_outpath(in_path, out_path: str, feature_path: str) -> None:
in_path = pathlib.Path(in_path)
default_fname = '{}.hmm'.format(in_path.stem)
if out_path is None:
out_path = pathlib.Path(default_fname)
else:
out_path = pathlib.Path(out_path)
if not out_path.suffix:
out_path = out_path.joinpath(default_fname)
if not out_path.parent.is_dir():
print('Error. Path "{!s}" does not exist.'.format(out_path.parent))
exit(10)
return out_path
def _train_n_hmm(data: _Array, m_states: int, n_trails: int):
"""Trains ``n_trails`` HMMs each initialized with a random tpm.
Args:
data: Possibly unporcessed input data set.
m_states: Number of states.
n_trails: Number of trails.
Returns:
Best model regarding to log-likelihood.
"""
feat = data.round().astype(int)
trails = []
for i in range(n_trails):
hmm = PoissonHmm(feat, m_states, init_gamma='softmax')
hmm.fit(feat)
if hmm.success:
trails.append(hmm)
if len(trails) == 0:
return None
return min(trails, key=lambda hmm: abs(hmm.quality.nll))
def main(argv=None) -> int:
if argv is None:
argv = sys.argv
for trf in argv.track_files:
track_data = _load_track_file(trf)
feature = _parse_feature(track_data, argv.feature_path)
hmm = _train_n_hmm(feature, argv.mstates, 5)
if hmm is None:
print('Error. Could not train HMM on {}'.format(trf))
continue
out_path = _generate_outpath(trf, argv.outpath, argv.feature_path)
io.dump_json(hmm.to_dict(), out_path)
return 0
if __name__ == '__main__':
sys.exit(main())
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# michael.blass@uni-hamburg.de
import argparse
import multiprocessing as mp
import sys
from .. import onsets
def _parse_cml(argv):
parser = argparse.ArgumentParser(description='Apollon onset detection engine')
parser.add_argument('--amplitude', action='store_true',
help='Detect onsets based on local extrema in the time domain signal.')
parser.add_argument('--entropy', action='store_true',
help='Detect onsets based on time domain entropy maxima.')
parser.add_argument('--flux', action='store_true',
help='Detect onsets based on spectral flux.')
parser.add_argument('-o', '--outpath', action='store',
help='Output file path.')
parser.add_argument('filepath', type=str, nargs=1)
return parser.parse_args(argv)
def _amp(a):
print('Amplitude')
return a
def _entropy(a):
print('Entropy')
return a
def _flux(a):
print('Flux')
return a
def main(argv=None):
if argv is None:
argv = sys.argv
args = _parse_cml(argv)
args = _parse_cml(argv)
detectors = {'amplitude': _amp,
'entropy': _entropy,
'flux': _flux}
methods = [func for name, func in detectors.items() if getattr(args, name)]
if len(methods) == 0:
print('At least one detection method required. Aborting.')
return 1
with mp.Pool(processes=3) as pool:
results = [pool.apply_async(meth, (i,)) for i, meth in enumerate(methods)]
out = [res.get() for res in results]
return out
if __name__ == '__main__':
sys.exit(main())
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# michael.blass@uni-hamburg.de
"""apollon/fractal.py
Tools for estimating fractal dimensions.
Function:
corr_dim Estimate correlation dimension.
embdedding Pseudo-phase space embdedding.
lorenz_attractor Simulate Lorenz system.
pps_entropy Entropy of pps embdedding.
"""
import numpy as _np
from scipy import stats as _stats
from scipy.spatial import distance as _distance
def correlation_dimension(data, tau, m, r, mode='cut', fit_n_points=10):
"""Compute an estimate of the correlation dimension D_2.
TODO:
- Implement algo for linear region detection
- Implement orbital delay parameter \gamma
- Implement multiprocessing
- Find a way to use L_\inf norm with distance.pdist
Args:
data (1d array) Input time series.
tau (int) Reconstruction delay.
m (iterable) of embedding dimensions
r (iterable) of radii
mode (str) See doc of `embedding`.
Returns:
lCrm (array) Logarithm of correlation sums given r_i.
lr (array) Logarithm of radii.
d2 (float) Estimate of correlation dimension.
"""
N = data.size
sd = data.std()
M = len(m)
lr = _np.log(r)
Nr = len(r)
# output arrays
lCrm = _np.zeros((M, Nr)) # Log correlation sum given `r` at dimension `m`
D2m = _np.zeros(M) # Corr-dim estimate at embdedding `m`
# iterate over each dimension dimensions
for i, mi in enumerate(m):
# compute embedding
emb = embedding(data, tau, mi, mode)
# compute distance matrix
# we should use L_\inf norm here
pairwise_distances = _distance.squareform(
_distance.pdist(emb.T, metric='euclidean'))
# compute correlation sums
Cr = _np.array([_np.sum(pairwise_distances < ri) for ri in r],
dtype=float)
Cr *= 1 / (N * (N-1))
# transform sums to log domain
lCrm[i] = _np.log(Cr)
# fit 1d polynominal in the of range of s +- n
cde, inter = _np.polyfit(lr, lCrm[i], 1)
D2m[i] = cde
return lCrm, lr, D2m
def embedding(inp_sig, tau, m=2, mode='zero'):
"""Generate n-dimensional pseudo-phase space embedding.
Params:
inp_sig (iterable) Input signal.
tau (int) Time shift.
m (int) Embedding dimensions.
mode (str) Either `zero` for zero padding,
`wrap` for wrapping the signal around, or
`cut`, which cuts the signal at the edges.
Note: In cut-mode, each dimension is only
len(sig) - tau * (m - 1) samples long.
Return:
(np.ndarray) of shape
(m, len(inp_sig)) in modes 'wrap' or 'zeros', or
(m, len(sig) - tau * (m - 1)) in cut-mode.
"""
inp_sig = _np.atleast_1d(inp_sig)
N = len(inp_sig)
if mode == 'zero':
# perform zero padding
out = _np.zeros((m, N))
out[0] = inp_sig
for i in range(1, m):
out[i, tau*i:] = inp_sig[:-tau*i]
elif mode == 'wrap':
# wraps the signal around at the bounds
out = _np.empty((m, N))
for i in range(m):
out[i] = _np.roll(inp_sig, i*tau)
elif mode == 'cut':
# cut every index beyond the bounds
Nm = N - tau * (m-1) # number of vectors
if Nm < 1:
raise ValueError('Embedding params to large for input.')
out = _np.empty((m, Nm))
for i in range(m):
off = N - i * tau
out[i] = inp_sig[off-Nm:off]
else:
raise ValueError('Unknown mode `{}`.'.format(pad))
return out
def embedding_entropy(emb, bins, extent=(-1, 1)):
"""Calculate entropy of given embedding unsing log_e.
Args:
emb (ndarray) Embedding.
bins (int) Number of histogram bins per axis.""
extent (tuple) Extent per dimension
Return:
(float) Entropy of pps.
"""
pps, _ = _np.histogramdd(emb.T, bins, range=[extent]*emb.shape[0])
entropy = _stats.entropy(pps.flat) / _np.log(pps.size)
return entropy
def __lorenz_system(x, y, z, s, r, b):
"""Compute the derivatives of the Lorenz system of coupled
differential equations.
Params:
x, y, z (float) Current system state.
s, r, b (float) System parameters.
Return:
xyz_dot (array) Derivatives of current system state.
"""
xyz_dot = _np.array([s * (y - x),
x * (r - z) - y,
x * y - b * z])
return xyz_dot
def lorenz_attractor(n, sigma=10, rho=28, beta=8/3,
init_xyz=(0., 1., 1.05), dt=0.01):
"""Simulate a Lorenz system with given parameters.
Params:
n (int) Number of data points to generate.
sigma (float) System parameter.
rho (rho) System parameter.
beta (beta) System parameter.
init_xyz (tuple) Initial System state.
dt (float) Step size.
Return:
xyz (array) System states.
"""
xyz = _np.empty((n, 3))
xyz[0] = init_xyz
for i in range(n-1):
xyz_prime = __lorenz_system(*xyz[i], sigma, rho, beta)
xyz[i+1] = xyz[i] + xyz_prime * dt
return xyz
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# michael.blass@uni-hamburg.de
from . poisson.poisson_hmm import PoissonHmm
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# michael.blass@uni-hamburg.de
from . grapher import draw_matrix, draw_network, save_hmmfig
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# michael.blass@uni-hamburg.de
"""
grapher.py -- Plot graphs from HMMs.
"""
from matplotlib import cm
from matplotlib.patches import ArrowStyle
from matplotlib.patches import Circle
from matplotlib.patches import ConnectionStyle
from matplotlib.patches import FancyArrowPatch
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
from scipy.spatial import distance
from apollon import tools
def _prepare_fig(pos):
"""Prepare a figure with the correct size.
Params:
pos (dict) with structur {node_name_i: np.array([pos_x, pos_y])}
as return by nx.layout methods.
Return:
(Figure, AxesSubplot)
"""
pos_data = np.array(list(pos.values()))
diameter = distance.pdist(pos_data).max()
dd = diameter / 2 + 1
fig = plt.figure(figsize=(7, 7), frameon=False)
ax = fig.add_subplot(111)
r = 1.5
ax.axis([-(dd+r), (dd+r), -(dd+r), (dd+r)])
ax.set_axis_off()
return fig, ax
def _draw_nodes(G, pos, ax):
"""Draw the nodes of a (small) networkx graph.
Params:
G (nx.classes.*) a networkx graph.
pos (dict) returned by nx.layout methods.
ax (AxesSubplot) mpl axe.
Return:
(dict) of Circle patches.
"""
#degree = np.array([deg for node, deg in G.degree], dtype=float)
#degree /= degree.sum()
flare_kwargs = {'alpha' : 0.2,
'edgecolor': (0, 0, 0, 1),
'facecolor': None}
node_kwargs = {'alpha' : 0.8,
'edgecolor': (0, 0, 0, 1),
'facecolor': None}
nodes = {}
node_params = zip(pos.items())
for i, (label, xy) in enumerate(pos.items()):
size = G.nodes[label]['size']
fsize = G.nodes[label]['fsize']
flare_kwargs['facecolor'] = 'C{}'.format(i)
flare = Circle(xy, size+fsize, **flare_kwargs)
node_kwargs['facecolor'] = 'C{}'.format(i)
node = Circle(xy, size, **node_kwargs)
ax.add_patch(flare)
ax.add_patch(node)
font_style = {'size':15, 'weight':'bold'}
text_kwargs = {'color': (0, 0, 0, .8),
'verticalalignment': 'center',
'horizontalalignment': 'center',
'fontdict': font_style}
ax.text(*xy, i+1, **text_kwargs)
nodes[label] = node
return nodes
def _draw_edges(G, pos, nodes, ax):
"""Draw the edges of a (small) networkx graph.
Params:
G (nx.classes.*) a networkx graph.
pos (dict) returned by nx.layout methods.
nodes (dict) of Circle patches.
ax (AxesSubplot) mpl axe.
Return:
(dict) of Circle patches.
"""
pointer = ArrowStyle.Fancy(head_width=10, head_length=15)
curved_edge = ConnectionStyle('arc3', rad=.2)
arrow_kwargs = {'arrowstyle': pointer,
'antialiased': True,
'connectionstyle': curved_edge,
'edgecolor': None,
'facecolor': None,
'linewidth': None}
edges = {}
for i, (a, b, attr) in enumerate(G.edges.data()):
arrow_kwargs['edgecolor'] = attr['color']
arrow_kwargs['facecolor'] = attr['color']
arrow_kwargs['linewidth'] = 1.0
edge = FancyArrowPatch(pos[a], pos[b],
patchA=nodes[a], patchB=nodes[b],
shrinkA=5, shrinkB=5,
**arrow_kwargs)
ax.add_patch(edge)
edges[(a, b)] = edge
return edges
def _legend(G, nodes, ax):
"""Draw the legend for a (small) nx graph.
Params:
G (nx.classes.*) a networkx graph.
nodes (list) of Circle patches.
ax (AxesSubplot) mpl axe.
Return:
(AxesSubplot)
"""
legend_kwargs = {'fancybox': True,
'fontsize': 14,
'bbox_to_anchor': (1.02, 1.0)}
labels = [r'$f_c = {:>9.3f}$ Hz'.format(k) for k in G.nodes.keys()]
legend = ax.legend(nodes.values(), labels, **legend_kwargs, borderaxespad=0)
return legend
def draw_network(labels, tpm, delta):
"""Draw the graph of a HMM's transition probability matrix.
Params:
lables (iterable) Labels for each state.
tpm (np.ndarray) A two-dimensional (row) stochastic matrix.
delta (iterable)
Return:
(Figure, AxesSubplot)
"""
G = nx.MultiDiGraph()
#scaled_tpm = np.exp(tools.scale(tpm, 0, 1.5))
for i, from_state in enumerate(labels):
G.add_node(from_state, fsize=np.exp(delta[i]))
for j, to_state in enumerate(labels):
if not np.isclose(tpm[i, j], 0.0):
G.add_edge(from_state, to_state,
weight=tpm[i, j],
color='k')
sd = np.sum([np.exp(degree) for node, degree in G.degree()])
for node, degree in G.degree():
G.node[node]['size'] = .5 + np.exp(degree) / sd
#pos = nx.layout.circular_layout(G, center=(0., 0.), scale=4)
pos = nx.layout.spring_layout(G, center=(0.0, 0.0), scale=4)
fig, ax = _prepare_fig(pos)
nodes = _draw_nodes(G, pos, ax)
edges = _draw_edges(G, pos, nodes, ax)
legend = _legend(G, nodes, ax)
return fig, ax, G
def draw_matrix(tpm):
"""Draw a heatmap from a transition probability matrix.
Args:
tpm (np.ndarray) Two-dimensional, row-stochastic square matrix.
Returns:
(fig, ax, img)
"""
img_kwargs = {'origin': 'upper',
'interpolation': 'nearest',
'aspect': 'equal',
'cmap': 'viridis',
'vmin': 0.0, 'vmax': 1.0}
nx, ny = tpm.shape
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(1, 1, 1)
img = ax.imshow(tpm, **img_kwargs)
# colorbar
cbar = ax.figure.colorbar(img, ax=ax)
cbar.ax.set_ylabel('Probability', rotation=-90, va="bottom")
# major ticks
ax.set_xticks(np.arange(nx))
ax.set_yticks(np.arange(ny))
ax.tick_params(which='major', top=True, bottom=False,
labeltop=True, labelbottom=False)
# minor ticks (for grid)
ax.set_xticks(np.arange(nx)-.5, minor=True)
ax.set_yticks(np.arange(ny)-.5, minor=True)
ax.tick_params(which="minor", bottom=False, left=False)
ax.grid(which="minor", color="w", linestyle='-', linewidth=2)
# spines
ax.spines['top'].set_position(('outward', 10))
ax.spines['left'].set_position(('outward', 10))
ax.spines['bottom'].set_visible(False)
ax.spines['right'].set_visible(False)
# cell labels
font_kw = {'fontsize': 14}
text_kw = {'ha': 'center', 'va': 'center', 'fontdict': font_kw}
for i in range(nx):
for j in range(ny):
val = tpm[i, j].astype(float).round(2)
bc = cm.viridis(val)
tc = cm.viridis(1-val)
if np.array_equal(tc, bc):
tc = 'k'
ax.text(j, i, '{:.3f}'.format(val), **text_kw, color=tc)
return fig, ax, img
def save_hmmfig(fig, path, **kwargs):
"""Save the figure to file.
This saves the figure and ensures that the out-of-axes legend
is completely visible in the saved version.
All kwargs are passed on to plt.savefig.
Params:
fig (Figure) Figure of HMM tpm.
path (str) Path to save file.
"""
fig.savefig(fname=path,
bbox_extra_artists=(fig.axes[0].legend_,))
apollon/hmm/graph/img/hubert.png

33.3 KiB

#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""poisson_core.py
Core functionality for Poisson HMM.
"""
import numpy as _np
from scipy import stats as _stats
from scipy.special import logsumexp as _logsumexp
import warnings
warnings.filterwarnings("ignore")
def log_poisson_fwbw(x, m, _lambda, _gamma, _delta):
"""Compute forward and backward probabilities for Poisson HMM.
Note: this alogorithm fails if `_delta` has zeros.
Params:
x (np.ndarray) One-dimensional array of integer values.
theta (tuple) Initial guesses (lambda, gamma, delta).
maxiter (int) Mmaximum number of EM iterations.
tol (float) Convergence criterion.
"""
n = len(x)
lalpha, lbeta = _np.zeros((2, n, m))
# init forward
pprob = _stats.poisson.pmf(x[:, None], _lambda)
a_0 = _delta * pprob[0]
# normalize
sum_a = a_0.sum()
a_t = a_0 / sum_a
# scale factor in log domain
lscale = _np.log(sum_a)
# set first forward prob
lalpha[0] = _np.log(a_t) + lscale
# start recursion
for i in range(1, n):
a_t = a_t @ _gamma * pprob[i]
sum_a = a_t.sum()
a_t /= sum_a
lscale += _np.log(sum_a)
lalpha[i] = _np.log(a_t) + lscale
# init backward
lbeta[-1] = 0
b_t = _np.repeat(1/m, m)
lscale = _np.log(m)
# start backward recursion
for i in range(n-1, 0, -1): # ugly reverse iteration in python
b_t = _gamma @ (pprob[i] * b_t)
lbeta[i-1] = _np.log(b_t) + lscale
sum_b = b_t.sum()
b_t /= sum_b
lscale += _np.log(sum_b)
return lalpha, lbeta, _np.log(pprob)
def poisson_EM(x, m, theta, maxiter=1000, tol=1e-6):
"""Estimate the parameters of an m-state PoissonHMM.
Params:
x (np.ndarray) One-dimensional array of integer values.
theta (tuple) Initial guesses (lambda, gamma, delta).
maxiter (int) Mmaximum number of EM iterations.
tol (float) Convergence criterion.
"""
n = len(x)
this_lambda = theta[0].copy()
this_gamma = theta[1].copy()
this_delta = theta[2].copy()
next_lambda = theta[0].copy()
next_gamma = theta[1].copy()
next_delta = theta[2].copy()
for i in range(maxiter):
lalpha, lbeta, lpprob = log_poisson_fwbw(x, m, this_lambda, this_gamma, this_delta)
c = max(lalpha[-1])
log_likelihood = c + _logsumexp(lalpha[-1] - c)
for j in range(m):
for k in range(m):
next_gamma[j, k] *= _np.sum(_np.exp(lalpha[:n-1, j] +
lbeta[1:n, k] +
lpprob[1:n, k] -
log_likelihood))
next_gamma /= _np.sum(next_gamma, axis=1, keepdims=True)
rab = _np.exp(lalpha + lbeta - log_likelihood)
next_lambda = (rab * x[:, None]).sum(axis=0) / rab.sum(axis=0)
next_delta = rab[0] / rab[0].sum()
crit = (_np.abs(this_lambda - next_lambda).sum() +
_np.abs(this_gamma - next_gamma).sum() +
_np.abs(this_delta - next_delta).sum())
if crit < tol:
theta_ = (next_lambda, next_gamma, next_delta)
return theta_, log_likelihood, True
else:
this_lambda = next_lambda.copy()
this_gamma = next_gamma.copy()
this_delta = next_delta.copy()
theta_ = (next_lambda, next_gamma, next_delta)
return theta_, log_likelihood, False
def poisson_viterbi(mod, x):
"""Calculate the Viterbi path (global decoding) of a PoissonHMM
given some data x.
Params:
x (array-like) observations
mod (HMM-Object)
Return:
(np.ndarray) Most probable sequence of hidden states given x.
"""
n = len(x)
# Make sure that x is an array
x = _np.atleast_1d(x)
# calculate the probability mass for each x_i and for each mean
pmf_x = _stats.poisson.pmf(x[:, None], mod.lambda_)
# allocate forward pass array
xi = _np.zeros((n, mod.m))
# Probabilities of oberseving x_0 give each state
probs = mod.delta_ * pmf_x[0]
xi[0] = probs / probs.sum()
# Interate over the remaining observations
for i in range(1, n):
foo = _np.max(xi[i-1] * mod.gamma_, axis=1) * pmf_x[i]
xi[i] = foo / foo.sum()
# allocate backward pass array
phi = _np.zeros(n, dtype=int)
# calculate most probable state on last time step
phi[-1] = _np.argmax(xi[-1])
# backtrack to first time step
for i in range(n-2, -1, -1):
phi[i] = _np.argmax(mod.gamma_[phi[i+1]] * xi[i])
return phi
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment