Skip to content
Snippets Groups Projects
Select Git revision
  • 602e550e91c8847d4d27e7a60a9d3591124c29a8
  • main default protected
  • sumlab
  • dev/test_tobias
  • jack.rolph-main-patch-16563
  • jack.rolph-main-patch-96201
  • jack.rolph-main-patch-18340
  • jack.rolph-main-patch-15793
  • jack.rolph-main-patch-74592
  • 1.0.0
10 results

HelperFunctions.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    HelperFunctions.py 4.52 KiB
    from iminuit import Minuit
    from iminuit.util import describe
    from scipy.interpolate import interp1d
    from astropy.stats import bootstrap
    import numpy as np
    import matplotlib.pyplot as plt
    
    
    #HELPER FUNCTIONS
    
    class FakeFuncCode:
        def __init__(self, f, prmt=None, dock=0, append=None):
            #f can either be tuple or function object
            self.co_varnames = describe(f)
            self.co_argcount = len(self.co_varnames)
            self.co_argcount -= dock
            self.co_varnames = self.co_varnames[dock:]
    
            if prmt is not None:  #rename parameters from the front
                for i, p in enumerate(prmt):
                    self.co_varnames[i] = p
    
            if isinstance(append, str): append = [append]
    
            if append is not None:
                old_count = self.co_argcount
                self.co_argcount += len(append)
                self.co_varnames = tuple(
                    list(self.co_varnames[:old_count]) +
                    append +
                    list(self.co_varnames[old_count:]))
    
    
    
                
    def LatexFormat(value, scirange=[0.01,1000]):
        if(np.abs(value)>scirange[0] and np.abs(value)<scirange[1]):
            float_str = r"${:3.3f}$".format(value)
        else:
            try:
                float_str = "{:3.3E}".format(value)
                base, exponent = float_str.split("E")
                float_str = r"${0} \times 10^{{{1}}}$".format(base, int(exponent))
            except:
                float_str=str(value)
        return float_str
    
    
    def FormatExponent(ax, axis='y'):
    
        # Change the ticklabel format to scientific format
        ax.ticklabel_format(axis=axis, style='sci', scilimits=(-2, 2))
    
        # Get the appropriate axis
        if axis == 'y':
            ax_axis = ax.yaxis
            x_pos = 0.0
            y_pos = 1.0
            horizontalalignment='left'
            verticalalignment='bottom'
        else:
            ax_axis = ax.xaxis
            x_pos = 1.0
            y_pos = -0.05
            horizontalalignment='right'
            verticalalignment='top'
    
        # Run plt.tight_layout() because otherwise the offset text doesn't update
        plt.tight_layout()
        ##### THIS IS A BUG 
        ##### Well, at least it's sub-optimal because you might not
        ##### want to use tight_layout(). If anyone has a better way of 
        ##### ensuring the offset text is updated appropriately
        ##### please comment!
    
        # Get the offset value
        offset = ax_axis.get_offset_text().get_text()
    
        if len(offset) > 0:
            # Get that exponent value and change it into latex format
            minus_sign = u'\u2212'
            expo = np.float(offset.replace(minus_sign, '-').split('e')[-1])
            offset_text = r'x$\mathregular{10^{%d}}$' %expo
    
            # Turn off the offset text that's calculated automatically
            ax_axis.offsetText.set_visible(False)
    
            # Add in a text box at the top of the y axis
            ax.text(x_pos, y_pos, offset_text, transform=ax.transAxes,
                   horizontalalignment=horizontalalignment,
                   verticalalignment=verticalalignment)
        return ax
    
    
    def Bootstrap(data, statistic, n_bootstrap, alpha=0.95):
        if not (0 < alpha < 1):
            raise ValueError("confidence level must be in (0, 1)")
    
    
        if len(data) < 1:
            raise ValueError("data must contain at least one measurement.")
    
    
        boot_stat = bootstrap(data, n_bootstrap, bootfunc=statistic)
    
        stat_data = statistic(data)
        mean_stat = np.mean(boot_stat)
        est_stat = 2*stat_data - mean_stat
        std_err = np.std(boot_stat)
        z_score = np.sqrt(2.0)*sc.erfinv(alpha)
        conf_interval = est_stat + z_score*np.array((-std_err, std_err))
    
    
        return est_stat, std_err, conf_interval  
    
    
    def Linear(x, m, c):
        return m*x + c
    
    def HistMoment(counts, bin_centres, i, mu):
        return np.sum(counts*(bin_centres-mu)**i)
    
    def HistMean(counts, bin_centres):
        N = np.sum(counts)
        mu = HistMoment(counts, bin_centres, 1, 0)/N
        return mu
        
    def HistStd(counts, bin_centres):
        N = np.sum(counts)
        mu = HistMean(counts, bin_centres)  
        var = HistMoment(counts, bin_centres, 2, mu)/(N-1)
        sigma = np.sqrt(var)
        return sigma
    
    def HistSkew(counts, bin_centres):
        N = np.sum(counts)
        mu = HistMean(counts, bin_centres)
        gamma = (N*np.sqrt((N-1))/(N-2))*HistMoment(counts, bin_centres, 3, mu)/(HistMoment(counts, bin_centres, 2, mu)**(1.5))
        return gamma
    
    
    
    
    def GP_lbda(mu, sigma, gamma):
        lbda = 0.5*(((mu*gamma)/(sigma))- 1)
        return lbda
    
    def GP_gain(mu, sigma, gamma):
        lbda = GP_lbda(mu, sigma, gamma)
        gain = (sigma**2/mu)*((1 - lbda)**2)
        return gain
    
    def GP_mu(mu, sigma, gamma):
        lbda = GP_lbda(mu, sigma, gamma)
        mu_gp = (1/(1-lbda))*(mu**2/sigma**2)
        return mu_gp