from iminuit import Minuit
from iminuit.util import describe
from scipy.interpolate import interp1d
from astropy.stats import bootstrap
import numpy as np
import matplotlib.pyplot as plt


#HELPER FUNCTIONS

class FakeFuncCode:
    def __init__(self, f, prmt=None, dock=0, append=None):
        #f can either be tuple or function object
        self.co_varnames = describe(f)
        self.co_argcount = len(self.co_varnames)
        self.co_argcount -= dock
        self.co_varnames = self.co_varnames[dock:]

        if prmt is not None:  #rename parameters from the front
            for i, p in enumerate(prmt):
                self.co_varnames[i] = p

        if isinstance(append, str): append = [append]

        if append is not None:
            old_count = self.co_argcount
            self.co_argcount += len(append)
            self.co_varnames = tuple(
                list(self.co_varnames[:old_count]) +
                append +
                list(self.co_varnames[old_count:]))



            
def LatexFormat(value, scirange=[0.01,1000]):
    if(np.abs(value)>scirange[0] and np.abs(value)<scirange[1]):
        float_str = r"${:3.3f}$".format(value)
    else:
        try:
            float_str = "{:3.3E}".format(value)
            base, exponent = float_str.split("E")
            float_str = r"${0} \times 10^{{{1}}}$".format(base, int(exponent))
        except:
            float_str=str(value)
    return float_str


def FormatExponent(ax, axis='y'):

    # Change the ticklabel format to scientific format
    ax.ticklabel_format(axis=axis, style='sci', scilimits=(-2, 2))

    # Get the appropriate axis
    if axis == 'y':
        ax_axis = ax.yaxis
        x_pos = 0.0
        y_pos = 1.0
        horizontalalignment='left'
        verticalalignment='bottom'
    else:
        ax_axis = ax.xaxis
        x_pos = 1.0
        y_pos = -0.05
        horizontalalignment='right'
        verticalalignment='top'

    # Run plt.tight_layout() because otherwise the offset text doesn't update
    plt.tight_layout()
    ##### THIS IS A BUG 
    ##### Well, at least it's sub-optimal because you might not
    ##### want to use tight_layout(). If anyone has a better way of 
    ##### ensuring the offset text is updated appropriately
    ##### please comment!

    # Get the offset value
    offset = ax_axis.get_offset_text().get_text()

    if len(offset) > 0:
        # Get that exponent value and change it into latex format
        minus_sign = u'\u2212'
        expo = np.float(offset.replace(minus_sign, '-').split('e')[-1])
        offset_text = r'x$\mathregular{10^{%d}}$' %expo

        # Turn off the offset text that's calculated automatically
        ax_axis.offsetText.set_visible(False)

        # Add in a text box at the top of the y axis
        ax.text(x_pos, y_pos, offset_text, transform=ax.transAxes,
               horizontalalignment=horizontalalignment,
               verticalalignment=verticalalignment)
    return ax


def Bootstrap(data, statistic, n_bootstrap, alpha=0.95):
    if not (0 < alpha < 1):
        raise ValueError("confidence level must be in (0, 1)")


    if len(data) < 1:
        raise ValueError("data must contain at least one measurement.")


    boot_stat = bootstrap(data, n_bootstrap, bootfunc=statistic)

    stat_data = statistic(data)
    mean_stat = np.mean(boot_stat)
    est_stat = 2*stat_data - mean_stat
    std_err = np.std(boot_stat)
    z_score = np.sqrt(2.0)*sc.erfinv(alpha)
    conf_interval = est_stat + z_score*np.array((-std_err, std_err))


    return est_stat, std_err, conf_interval  


def Linear(x, m, c):
    return m*x + c

def HistMoment(counts, bin_centres, i, mu):
    return np.sum(counts*(bin_centres-mu)**i)

def HistMean(counts, bin_centres):
    N = np.sum(counts)
    mu = HistMoment(counts, bin_centres, 1, 0)/N
    return mu
    
def HistStd(counts, bin_centres):
    N = np.sum(counts)
    mu = HistMean(counts, bin_centres)  
    var = HistMoment(counts, bin_centres, 2, mu)/(N-1)
    sigma = np.sqrt(var)
    return sigma

def HistSkew(counts, bin_centres):
    N = np.sum(counts)
    mu = HistMean(counts, bin_centres)
    gamma = (N*np.sqrt((N-1))/(N-2))*HistMoment(counts, bin_centres, 3, mu)/(HistMoment(counts, bin_centres, 2, mu)**(1.5))
    return gamma




def GP_lbda(mu, sigma, gamma):
    lbda = 0.5*(((mu*gamma)/(sigma))- 1)
    return lbda

def GP_gain(mu, sigma, gamma):
    lbda = GP_lbda(mu, sigma, gamma)
    gain = (sigma**2/mu)*((1 - lbda)**2)
    return gain

def GP_mu(mu, sigma, gamma):
    lbda = GP_lbda(mu, sigma, gamma)
    mu_gp = (1/(1-lbda))*(mu**2/sigma**2)
    return mu_gp