Skip to content
Snippets Groups Projects
Commit cef3d528 authored by Blaß, Michael's avatar Blaß, Michael :speech_balloon:
Browse files

Merge branch 'fix-hmm' into develop

Fixed HMM to fit new interfaces.
parents 71164a60 92c2ae09
Branches
Tags
No related merge requests found
Pipeline #6098 passed
# Licensed under the terms of the BSD-3-Clause license.
# Copyright (C) 2019 Michael Blaß
# mblass@posteo.net
from . poisson.poisson_hmm import PoissonHmm
......@@ -13,8 +13,6 @@ Classes:
PoissonHMM HMM with univariat Poisson-distributed states.
"""
#import json as _json
#import pathlib as _pathlib
import typing as _typing
import warnings as _warnings
......@@ -24,10 +22,10 @@ import chainsaddiction as _ca
import apollon
from apollon import types as _at
from apollon import io as _io
import apollon.io.io as aio
from apollon.types import Array as _Array
from apollon import tools as _tools
from apollon.hmm import hmm_utilities as _utils
import apollon.hmm.utilities as ahu
class PoissonHmm:
......@@ -150,7 +148,7 @@ class _HyperParams:
self.gamma_dp = _tools.assert_and_pass(self._assert_dirichlet_param, gamma_dp)
self.delta_dp = _tools.assert_and_pass(self._assert_dirichlet_param, delta_dp)
self.fill_diag = _tools.assert_and_pass(_utils.assert_st_val, fill_diag)
self.fill_diag = _tools.assert_and_pass(ahu.assert_st_val, fill_diag)
self.init_lambda_meth = self._assert_lambda(init_lambda)
self.init_gamma_meth = self._assert_gamma(init_gamma, gamma_dp, fill_diag)
......@@ -176,7 +174,7 @@ class _HyperParams:
TypeError
"""
if isinstance(_lambda, str):
if _lambda not in _utils.StateDependentMeansInitializer.methods:
if _lambda not in ahu.StateDependentMeansInitializer.methods:
raise ValueError('Unrecognized initialization method `{}`'.format(_lambda))
elif isinstance(_lambda, _np.ndarray):
......@@ -206,7 +204,7 @@ class _HyperParams:
"""
if isinstance(_gamma, str):
if _gamma not in _utils.TpmInitializer.methods:
if _gamma not in ahu.TpmInitializer.methods:
raise ValueError('Unrecognized initialization method `{}`'.format(_gamma))
if _gamma == 'dirichlet' and gamma_dp is None:
......@@ -218,7 +216,7 @@ class _HyperParams:
'`uniform` for parameter `gamma`.'))
elif isinstance(_gamma, _np.ndarray):
_utils.assert_st_matrix(_gamma)
ahu.assert_st_matrix(_gamma)
else:
raise TypeError(('Unrecognized type of argument `init_gamma`. Expected `str` or '
'`numpy.ndarray`, got {}.\n').format(type(_gamma)))
......@@ -241,7 +239,7 @@ class _HyperParams:
"""
if isinstance(_delta, str):
if _delta not in _utils.StartDistributionInitializer.methods:
if _delta not in ahu.StartDistributionInitializer.methods:
raise ValueError('Unrecognized initialization method `{}`'.format(_delta))
if _delta == 'dirichlet' and delta_dp is None:
......@@ -249,7 +247,7 @@ class _HyperParams:
'`dirichlet` for parameter `delta`.'))
elif isinstance(_delta, _np.ndarray):
_utils.assert_st_vector(_delta)
ahu.assert_st_vector(_delta)
else:
raise TypeError(('Unrecognized type of argument `init_delta`. Expected `str` or '
......@@ -308,16 +306,16 @@ class _InitParams:
return hy_params.init_lambda_meth.copy()
if hy_params.init_lambda_meth == 'hist':
return _utils.StateDependentMeansInitializer.hist(X, hy_params.m_states)
return ahu.StateDependentMeansInitializer.hist(X, hy_params.m_states)
if hy_params.init_lambda_meth == 'linear':
return _utils.StateDependentMeansInitializer.linear(X, hy_params.m_states)
return ahu.StateDependentMeansInitializer.linear(X, hy_params.m_states)
if hy_params.init_lambda_meth == 'quantile':
return _utils.StateDependentMeansInitializer.quantile(X, hy_params.m_states)
return ahu.StateDependentMeansInitializer.quantile(X, hy_params.m_states)
if hy_params.init_lambda_meth == 'random':
return _utils.StateDependentMeansInitializer.random(X, hy_params.m_states)
return ahu.StateDependentMeansInitializer.random(X, hy_params.m_states)
raise ValueError("Unknown init method or init_lambda_meth is not an array.")
......@@ -329,13 +327,13 @@ class _InitParams:
return hy_params.init_gamma_meth.copy()
if hy_params.init_gamma_meth == 'dirichlet':
return _utils.TpmInitializer.dirichlet(hy_params.m_states, hy_params.gamma_dp)
return ahu.TpmInitializer.dirichlet(hy_params.m_states, hy_params.gamma_dp)
if hy_params.init_gamma_meth == 'softmax':
return _utils.TpmInitializer.softmax(hy_params.m_states)
return ahu.TpmInitializer.softmax(hy_params.m_states)
if hy_params.init_gamma_meth == 'uniform':
return _utils.TpmInitializer.uniform(hy_params.m_states, hy_params.fill_diag)
return ahu.TpmInitializer.uniform(hy_params.m_states, hy_params.fill_diag)
raise ValueError("Unknown init method or init_gamma_meth is not an array.")
......@@ -345,22 +343,22 @@ class _InitParams:
return hy_params.init_delta_meth.copy()
if hy_params.init_delta_meth == 'dirichlet':
return _utils.StartDistributionInitializer.dirichlet(hy_params.m_states,
return ahu.StartDistributionInitializer.dirichlet(hy_params.m_states,
hy_params.delta_dp)
if hy_params.init_delta_meth == 'softmax':
return _utils.StartDistributionInitializer.softmax(hy_params.m_states)
return ahu.StartDistributionInitializer.softmax(hy_params.m_states)
if hy_params.init_delta_meth == 'stationary':
return _utils.StartDistributionInitializer.stationary(self.gamma_)
return ahu.StartDistributionInitializer.stationary(self.gamma_)
if hy_params.init_delta_meth == 'uniform':
return _utils.StartDistributionInitializer.uniform(hy_params.m_states)
return ahu.StartDistributionInitializer.uniform(hy_params.m_states)
raise ValueError("Unknown init method or init_delta_meth is not an array.")
def __str__(self):
with _io.array_print_opt(precision=4, suppress=True):
with aio.array_print_opt(precision=4, suppress=True):
out = 'Initial Lambda:\n{}\n\nInitial Gamma:\n{}\n\nInitial Delta:\n{}\n'
out = out.format(*self.__dict__.values())
return out
......@@ -394,7 +392,7 @@ class Params:
self.delta_ = delta_
def __str__(self):
with _io.array_print_opt(precision=4, suppress=True):
with aio.array_print_opt(precision=4, suppress=True):
out = 'Lambda:\n{}\n\nGamma:\n{}\n\nDelta:\n{}\n'
out = out.format(*self.__dict__.values())
return out
......
File moved
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""hmm_test.py
(c) Michael Blaß 2016
Unit test for HMM implementation."""
import numpy as np
from scipy.stats import poisson
import unittest
import numpy as _np
from apollon.hmm.hmm_base import is_tpm
from apollon.hmm.poisson import PoissonHmm
class TestHMM_utilities(unittest.TestCase):
def setUp(self):
# Arbitrary transition probability matrix
self.A = _np.array([[1., 0, 0], [.2, .3, .5], [.1, .3, .6]])
self.A = np.array([[1., 0, 0], [.2, .3, .5], [.1, .3, .6]])
# Wrong number of dimensions
self.B1 = _np.array([1., 0, 0, 0])
self.B2 = _np.array([[[1., 0, 0], [.2, .3, .5], [.1, .3, .6]]])
self.B1 = np.array([1., 0, 0, 0])
self.B2 = np.array([[[1., 0, 0], [.2, .3, .5], [.1, .3, .6]]])
# Not quadratic
self.C1 = _np.array([[1., 0, 0], [.2, .3, .5]])
self.C2 = _np.array([[1.0], [.5, .5], [.2, .8]])
self.C1 = np.array([[1., 0, 0], [.2, .3, .5]])
self.C2 = np.array([[1.0], [.5, .5], [.2, .8]])
# Rows do not sum up to one
self.D = _np.array([[.2, .3, .5], [.5, .4, .2], [1., 0, 0]])
self.D = np.array([[.2, .3, .5], [.5, .4, .2], [1., 0, 0]])
def test_success(self):
mus = [20, 40, 80, 120, 40]
m = len(mus)
data = np.concatenate([poisson(mu).rvs(30) for mu in mus])
hmm = PoissonHmm(data, m)
hmm.fit(data)
self.assertTrue(hmm.success)
def test_true_tpm(self):
self.assertTrue(is_tpm(self.A), True)
if __name__ == '__main__':
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment