import numpy as np
from scipy.stats import rv_discrete, rv_continuous, uniform
import scipy.special as sc
import matplotlib.pyplot as plt


from scipy.stats._distn_infrastructure import (
        rv_discrete, _ncx2_pdf, _ncx2_cdf, get_distribution_names)


class gpd_gen(rv_discrete):
    def _argcheck(self, mu, lbda):
        return mu >= 0.0 and lbda >= 0.0 and lbda <= 1.0

    def _rvs(self, mu, lbda):
        population = np.asarray(
            self._random_state.poisson(mu, self._size)
        )
        if population.shape == ():
            population = population.reshape(-1)
        offspring = population.copy()
        while np.any(offspring > 0):
            # probability dists are NOT ufuncs
            # print("offspring", offspring)
            offspring[:] = [
                self._random_state.poisson(m)
                for m in lbda*offspring
            ]
            population += offspring
        return population

    def _pmf(self, k, mu, lbda):
        return np.exp(self._logpmf(k, mu, lbda))

    def _logpmf(self, k, mu, lbda):
        mu_pls_klmb = mu + lbda*k
        return np.log(mu) + sc.xlogy(k-1, mu_pls_klmb) - mu_pls_klmb - sc.gammaln(k+1)

    def _munp(self, n, mu, lbda):
        if n == 1:
            return mu/(1-lbda)
        elif n == 2:
            return (mu/(1-lbda))**2+mu/(1-lbda)**3


gpoisson = gpd_gen(name='gpoisson')

    
class borel_gen(rv_discrete):
    def _argcheck(self, mu):
        return ((mu > 0) & (mu<1))

    def _logpmf(self, k, mu):
        n = k+1
        Pk = sc.xlogy(n-1, mu*n) - sc.gammaln(n + 1) - mu*n
        return Pk

    def _pmf(self, k, mu):
        return np.exp(self._logpmf(k, mu))

#     def _rvs(self, mu, size=None, random_state=None):
#         u = uniform.rvs(loc=0, scale = 1, size=size)
#         cum = np.cumsum([self._pmf(_k, mu) for _k in range(0, 100)])
#         print(cum)
#         rnd = [ np.argmax( cum>=_u ) for _u in u ]
#         return rnd
    
    def _rvs(self, mu, size=None, random_state=None, epsilon=1e-4):
        _u = uniform.rvs(loc=0, scale = 1-epsilon, size=size)
        _sum = 0
        _k=0
        _elem = []
        _max_u = np.max(_u)
        
        while(_sum<_max_u):
            _pmf = self._pmf(_k, mu)
            _elem.append(_pmf)
            _sum+=_pmf
            _k+=1
            
        _cum = np.cumsum(_elem)
        _rnd = [ np.argmax( _cum>=__u ) for __u in _u ]
        
        return _rnd


    def _stats(self, mu):
        _mu = 1/(1-mu)
        _var = mu/(1-mu)**3
        tmp = np.asarray(mu)
        mu_nonzero = ((tmp > 0) & (tmp<1))
        #g1 and g2: Lagrangian Probability Distributions, 978-0-8176-4365-2, page 159
        g1 = scipy._lib._util._lazywhere(mu_nonzero, (tmp,), lambda x: (1+2*x)/scipy.sqrt(x*(1-x)), np.inf)
        g2 = scipy._lib._util._lazywhere(mu_nonzero, (tmp,), lambda x: 3 + (1 + 8*x+6*x**2)/(x*(1-x)), np.inf)
        return _mu, _var, g1, g2


borel= borel_gen(name='borel')
    
  
    
class erlang_gen(rv_discrete):
    
    
    
    def _pdf(self, x, a):
        # gamma.pdf(x, a) = x**(a-1) * exp(-x) / gamma(a)
        return np.exp(self._logpdf(x, a))

    def _logpdf(self, k, mu, nu):
        return sc.xlogy(a-1.0, x) - x - sc.gammaln(a)

   
    
    
  

#     def _rvs(self, mu, nu, size=None, random_state=None):
#         u = scipy.stats.uniform.rvs(loc=0, scale = 1, size=size)
#         cum = np.cumsum([self._pmf(_k, mu, nu) for _k in range(0, 100)])
#         rnd = [ np.argmax( cum>=_u ) for _u in u ]
#         return rnd

pairs = list(globals().items())
_distn_names, _distn_gen_names = get_distribution_names(pairs, rv_discrete)

__all__ = _distn_names + _distn_gen_names