import numpy as np
from scipy.signal import fftconvolve as convolve
from scipy.stats import norm, binom, poisson,expon
from scipy.interpolate import interp1d
from AdditionalPDFs import borel, gpoisson
import matplotlib.pyplot as plt
    
    
def SigmaK(k, sigma_0, sigma_1):
    return np.sqrt(sigma_0**2 + k*sigma_1**2)



def dPApdPH(x, x_0, G, mu, lbda, tauAp,  pAp, tau, t_gate, k_low, k_hi):
     
    x_lin = np.arange(0, len(x))
    x_pad_lin = np.arange(-100, len(x)+100)
    f_lin = interp1d(x, x_lin, fill_value='extrapolate')
    f_lin_inv = interp1d(x_lin, x, fill_value='extrapolate')
    dx = abs(f_lin_inv(1) - f_lin_inv(0))
    G_lin = f_lin(x_0+G) - f_lin(x_0)
    x_0_lin = f_lin(x_0)
    y_ap = np.zeros_like(x_pad_lin).astype(float)

    idx_orig = np.where(np.in1d(x_pad_lin,
                                    np.intersect1d(
                                    x_pad_lin,
                                    x_lin
                                )
                        ), True, False)

    idx_x_0_lin =  np.argmin(abs(x_pad_lin - x_0_lin))
    

    exp_tgate_2tau = np.exp(t_gate/(2*tau))
    exp_mtgate_2tau = np.exp(-t_gate/(2*tau))
    PH_norm = (x_pad_lin - x_0_lin)/G_lin
    PH_max = (1 - exp_mtgate_2tau)**2
    norm_factor = PH_max/G_lin 
                       
                             
    for _n_pg in range(1, int(k_hi)+1):
        _p_pg = gpoisson.pmf(_n_pg, mu, lbda)
        for _n_ap in range(1, min(5,_n_pg+1)):
            
            PH = (PH_norm - _n_pg)/_n_ap
            cond_PH_max = (PH>0.95*PH_max)
           
            
            PH[cond_PH_max] = 0.95*PH_max
            

            t_b0 = 0.5*t_gate - tau*np.arccosh(0.5*((1 - PH)*exp_tgate_2tau + exp_mtgate_2tau))
            
            cond_gate = (t_b0<0) | (t_b0>t_gate/2)
            
            dpApdt_b0 = _p_pg*binom.pmf(_n_ap, _n_pg, pAp*(1-np.exp(-t_b0/tau)))*expon.pdf(t_b0, loc=0, scale=tauAp)
            dpApdt_b0[cond_gate] = 0

            dPHdt_b0 = abs(-2*np.sinh((t_b0-0.5*t_gate)/tau)*exp_mtgate_2tau/tau)
            
            cond_dPHdt = (dPHdt_b0<1e-3)
            dPHdt_b0[cond_dPHdt] = 1e-3
            
            dpApdPH_b0 = dpApdt_b0/dPHdt_b0
                            
            dpApdPH_b0[cond_PH_max] = 0
            dpApdPH_b0[cond_gate] = 0
            


            y_ap+=dpApdPH_b0*(norm_factor/_n_ap)
    
    y_ap = y_ap[idx_orig]/dx
    
    return y_ap


    



def DRM_basic(x, x_0, G, mu, lbda, sigma_0, sigma_1, DCR, tau, t_gate, t_0, tauAp, pAp,  k_low, k_hi):

        
        n_pg = np.arange(1, k_hi)

        f0 = gpoisson.pmf(0, mu, lbda)*norm.pdf(x, x_0, sigma_0)
        f1s = np.asarray([
                gpoisson.pmf(_k, mu, lbda)*norm.pdf(x, x_0 + _k*G,  SigmaK(_k, sigma_0, sigma_1))
                for _k in n_pg
        ])

      
        f_light = np.sum(np.vstack([f0, f1s]), axis=0)
        f_ap = dPApdPH(x, x_0, G, mu, lbda, tauAp,  pAp, tau, t_gate, k_low, k_hi)
        

        return f_light + f_ap




def DC_PH_range(t, t_0, r_fast, tau, t_gate):
    
    if((t>t_0) and (t<0)):
        PH_min = (1-r_fast)*np.exp(t_0/tau)*(1 - np.exp(-t_gate/tau))
        PH_max = (1-r_fast)*(1 - np.exp(-t_gate/tau))  
    elif((t>0) and (t<t_gate)):
        PH_min = r_fast
        PH_max = (1 - (1-r_fast)*np.exp(-t_gate/tau))
    else:
        PH_min = 0
        PH_max = 0

    return PH_min, PH_max


def dpDCRdPH(x, x_0, G, tau, t_gate, t_0):

    PH_bfg_low, PH_bfg_hi = DC_PH_range(-abs(t_0)/2, -abs(t_0), 0, tau, t_gate)
    PH_dg_low, PH_dg_hi = DC_PH_range(t_gate/2, -abs(t_0), 0, tau, t_gate)

    x_norm = (x-x_0)/(G)

    PHs = np.zeros_like(x_norm)

    cond_bfg = (x_norm>PH_bfg_low) & (x_norm<=PH_bfg_hi)
    cond_dg = (x_norm>PH_dg_low) & (x_norm<=PH_dg_hi)

    PHs[cond_bfg] += 1/x_norm[cond_bfg]
    PHs[cond_dg] += 1/(1 - x_norm[cond_dg])  


    return PHs




def ApplyDCR(x, y_light, x_0, G, DCR, t_gate, t_0, tau, lbda, k_dcr_low, k_dcr_hi):

    mu_dcr = DCR*(abs(t_0) + t_gate)

    x_lin = np.arange(0, len(x))
    x_pad_lin = np.arange(-100, len(x)+100)
    f_lin = interp1d(x, x_lin, fill_value='extrapolate')
    f_lin_inv = interp1d(x_lin, x, fill_value='extrapolate')
    dx = abs(f_lin_inv(1) - f_lin_inv(0))
    G_lin = f_lin(x_0+G) - f_lin(x_0)
    x_0_lin = f_lin(x_0)

    idx_orig = np.where(np.in1d(x_pad_lin,
                                    np.intersect1d(
                                    x_pad_lin,
                                    x_lin
                                )
                        ), True, False)

    idx_x_0_lin =  np.argmin(abs(x_pad_lin - x_0_lin))


    fs = []
    pfs = []
    hs = []
    phs = []



    for _n_dcr in range(0, max(2, int(k_dcr_hi))):

        if(_n_dcr==0):

            f = np.zeros(len(x_pad_lin))

            f[idx_x_0_lin] = 1

            pf = poisson.pmf(0, mu_dcr)

            h = np.zeros(len(x_pad_lin))
            
            ph = 0
            
            
        else:
            if(_n_dcr==1):
                f = dpDCRdPH(x_pad_lin, x_0_lin, G_lin, tau, t_gate, t_0)
            else:
                f = convolve(fs[1], fs[-1])[idx_x_0_lin:len(x_pad_lin)+idx_x_0_lin]

            f = f/ np.trapz(f, dx = 1)

            pf = poisson.pmf(_n_dcr, mu_dcr)*(borel.pmf(0, lbda)**_n_dcr)


            
            if(_n_dcr==1):
                h = np.zeros(len(x_pad_lin))
                ph=0
            else:
                h = dpDCRdPH(x_pad_lin, x_0_lin, G_lin*((_n_dcr-1)+1), tau, t_gate, t_0)
                h = h/ np.trapz(h, dx = 1)
                ph = poisson.pmf(1, mu_dcr)*borel.pmf((_n_dcr-1), lbda)/((_n_dcr-1)+1)

        fs.append(f)
        pfs.append(pf)
        hs.append(h)
        phs.append(ph)





    fs = np.asarray(fs)
    hs = np.asarray(hs)
    pfs = np.expand_dims(np.asarray(pfs), -1)
    phs = np.expand_dims(np.asarray(phs), -1)
    y_dark = np.sum(fs*pfs, axis=0) + np.sum(hs*phs, axis=0)
    #y_dark = y_dark/np.trapz(h, dx = 1)
    y_dark = y_dark/np.trapz(y_dark, dx = 1)


    y_model = convolve(y_dark,
                          np.pad(y_light,
                                 (100, 100),
                                 "constant",
                                 constant_values=(0.0,0.0)),
                        )[idx_x_0_lin:len(x_pad_lin)+idx_x_0_lin]

    y_model = y_model/np.trapz(y_model, dx = 1)/dx

    y_model  = y_model[idx_orig]

    return y_model


def DRM(
        x,
        mu,
        x_0,
        G,
        sigma_0,
        sigma_1,
        tauAp,
        pAp,
        lbda,
        k_low,
        k_hi,
        k_dcr_low,
        k_dcr_hi,
        DCR,
        tau,
        t_gate,
        t_0
       ):


    y_light = DRM_basic(x, x_0, G, mu, lbda, sigma_0, sigma_1, DCR, tau, t_gate, t_0, tauAp, pAp,  k_low, k_hi)

    y_model = ApplyDCR(x, y_light, x_0, G, DCR, t_gate, t_0, tau, lbda, k_dcr_low, k_dcr_hi)



    return y_model