from Bootstrapping import BootstrapKDE, Bootstrap
from HelperFunctions import GP_gain, GP_lbda, GP_muGP, GetStats
from HelperFunctions import LatexFormat, Linear, SelectRangeNumba
from LossFunctions import *
from Model_AP1 import DRM
from iminuit import Minuit
from scipy.signal import find_peaks, peak_widths
from scipy.integrate import cumtrapz, quad
from scipy.interpolate import interp1d
from scipy.stats import poisson
from scipy.stats import skew as sp_skew
from scipy.stats import kurtosis as sp_kurt
from astropy.stats import knuth_bin_width, freedman_bin_width, scott_bin_width
from matplotlib import cm
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.ticker import MaxNLocator, ScalarFormatter
from itertools import chain



import matplotlib.pyplot as plt



class PeakOTron:
    
    def __init__(self,
                 verbose=False,
                 n_call_minuit=20000,
                 n_iterations_minuit = 50,
                 n_bootstrap=1000
                ):
        
        self._default_hist_data={
            "count":None,
            "density":None,
            "bin_numbers":None,
            "bin_centers":None,
            "bw":None,
            "peaks":{
                "peak_position":None,
                "peak_position_lower":None,
                "peak_position_upper":None,
                "peak_height":None,
                "peak_height_lower":None,
                "peak_height_upper":None,
                "peak_mean":None,
                "peak_mean_error":None,
                "peak_variance":None,
                "peak_variance_error":None,
                "peak_std_deviation":None,
                "peak_std_deviation_error":None,
                "peak_skewness":None,
                "peak_skewness_error":None,
                "peak_kurtosis":None,
                "peak_kurtosis_error":None,
            },
            "bc2bn":None,
            "bn2bc":None,
            "bn2kde":None,
            "bn2kde_err":None,
            "bn2bg_sub":None
        }
        
        self._default_fit_dict={
            "x_0":None,
            "G":None,
            "lbda":None,
            "mu":None,
            "sigma_0":None,
            "sigma_1":None,
            "DCR":None,
            "tau":None,
            "t_gate":None,
            "tauAp":None,
            "pAp":None,
            "t_0":None,
            "k_low":None,
            "k_hi":None,
            "k_dcr_low":None,
            "k_dcr_hi":None
 
        }
        
        

      
        
        self._plot_figsize= (10,10)
        self._plot_fontsize= 25
        self._plot_legendfontsize= 18
        self._cmap = cm.get_cmap('viridis')
        #self._eps = np.finfo(np.float64).eps * 10
        self._eps = 1e-5

        self._n_bootstrap=n_bootstrap
        self._len_DCR_pad = int(100)
        
        self._verbose=verbose
        self._n_call_minuit = n_call_minuit
        self._n_iterations_minuit = n_iterations_minuit
        self._is_histogram = False
        self._failed = False
        
  
        self._default_fit_kwargs={
                "tau":None,
                "tau_err":None,
                "t_0":None,
                "t_0_err":None,
                "t_gate":None,
                "t_gate_err":None,
                "bw":None,
                "peak_dist_factor":0.8,
                "peak_width_factor":0.5,
                "ppf_mainfit":1 - 1e-6,
                "bin_method":"knuth",
                "alpha_peaks":0.99,
                "alpha_fit":1-1e-6,
        }
    
        self.Init()
        

    

    ###HELPER FUNCTIONS
        
    def Init(self):
        self._fit_kwargs = self._default_fit_kwargs.copy()
        self._prefit_values = self._default_fit_dict.copy()
        self._prefit_errors = self._default_fit_dict.copy()
        self._fit_values= self._default_fit_dict.copy()
        self._fit_errors = self._default_fit_dict.copy()
        
        self._hist_data = self._default_hist_data.copy()
    

        self._fit = None
        self._failed = False


    
    def SetMaxNPeaks(self, max_n_peaks):
        max_n_peaks = int(max_n_peaks)
        if(max_n_peaks<1):
            raise Exception("Maximum number of peaks must be greater than 2.")
        

        self._max_n_peaks =  max_n_peaks
        if(self._verbose):
            print("Set maximum number of peaks to {:d}.".format(self._max_n_peaks))
                
        
    def SetMaxNDCRPeaks(self, max_n_dcr_peaks):
        max_n = int( max_n_dcr_peaks)
        if(max_n_dcr_peaks<2):
            raise Exception("Maximum number of peaks must be greater than 2.")
        

        self._max_n_dcr_peaks =  max_n_dcr_peaks
        if(self._verbose):
            print("Set maximum number of dark peaks to {:d}.".format(self._max_n_dcr_peaks))


    

    
    def GetBins(self, data, bw, bin_method):
        if(data.ndim == 1):
#             print("Data is assumed 1D. Assuming list of charges.")

            if(bw is None):
                print("Bin width not set. Binning with {:s}".format(bin_method))
                if(bin_method == "knuth"):
                    bw = knuth_bin_width(data)
                elif(bin_method == "freedman"):
                    bw = freedman_bin_width(data)
                elif(bin_method == "scott"):
                    bw = scott_bin_width(data)
                else:
                    raise Exception("Binning method not recognised. Please select bin_method from 'knuth', 'freedman' or 'scott'")

            x_min, x_max = data.min(), data.max()

        elif(data.ndim == 2):
#             print("Data is assumed 2D. Assuming input is a histogram, with columns [bin_centre, counts].")
            _x, _y = data[:,0], data[:,1]
            data = np.array(list(chain.from_iterable([[__x]*int(__y) for __x,__y in zip(_x,_y)])))
            f_bin = interp1d(np.arange(0, len(_x)), _x)
            bw = abs(f_bin(1) - f_bin(0))
            idx_nonzero = np.squeeze(np.argwhere((_y>0)))
            x_min, x_max = _x[np.min(idx_nonzero)]-bw/2,  _x[np.max(idx_nonzero)]+bw/2

        
        nbins = np.ceil((x_max - x_min) / bw)
        nbins = max(1, nbins)
        bins = x_min + bw * np.arange(nbins + 1)
        return data, bw, nbins, bins
    
    
      
 
    def GetEstPedestal(self, data, gain_fft):
        mu, sigma, gamma = GetStats(data)
        
        f_pedestal = lambda x_0: (GP_gain(mu-x_0, sigma, gamma) - gain_fft)**2
        
        m_gain = Minuit(f_pedestal, x_0 = mu)
        m_gain.migrad()
        m_gain.hesse()
        
        return m_gain.values["x_0"]
    
    
    
    def PlotOriginalHistogram(self,
                              ax                              
                             ):
        
        

        counts = self._hist_data["counts"]
        bin_numbers = self._hist_data["bin_numbers"]
        
        ax.plot(bin_numbers, counts, label="Histogram", color="C0", lw=5)
        ax.grid(which="both")
        ax.set_ylabel("$f_{Q}(Q)$ [$Q^{-1}$]", fontsize=self._plot_fontsize)
        ax.set_xlabel("Bin Number", fontsize=self._plot_fontsize)
        ax.tick_params(axis="x", labelsize=self._plot_fontsize, rotation=45)
        ax.tick_params(axis="y", labelsize=self._plot_fontsize)
        ax.legend(fontsize=self._plot_legendfontsize)
        
        
        
        
    def GetChi2(self, prefit=False):
        
        if(any(_ is None for _ in [self._hist_data["bin_centres"],
                                self._hist_data["density_orig"],
                                self._hist_data["density_orig_error"],
                                self._fit_values
                               ])):
            return np.nan, np.nan
        else:
            x = self._hist_data["bin_centres"]
            y = self._hist_data["density_orig"]
            min_error = 1/self._hist_data["bw"]/np.sum(self._hist_data["counts"])
            y_err = np.where(self._hist_data["density_orig_error"]<min_error, 
                             min_error, 
                             self._hist_data["density_orig_error"])
            
            y_hat = self.GetModel(x, prefit)


            chi2 = np.nansum(((y_hat - y)/y_err)**2)
            ndof = len(x) - 9
            return chi2, ndof
        
    
    
    def PlotDensity(self, ax):
        
        density = self._hist_data["density"]
        density_error = self._hist_data["density_error"]
        bin_numbers = self._hist_data["bin_numbers"]
        
        ax.plot(bin_numbers, density, label="KDE", color="purple", lw=5)
        plt.fill_between(bin_numbers,
                         density - 1.96*density_error,  
                         density + 1.96*density_error,
                         label="KDE, 95% Confidence",
                         alpha=0.3,
                         color="purple", lw=0)
        ax.grid(which="both")
        ax.set_ylabel("$f_{Q}(Q)$ [$Q^{-1}$]", fontsize=self._plot_fontsize)
        ax.set_xlabel("Bin Number", fontsize=self._plot_fontsize)
        ax.tick_params(axis="x", labelsize=self._plot_fontsize, rotation=45)
        ax.tick_params(axis="y", labelsize=self._plot_fontsize)
        ax.legend(fontsize=self._plot_legendfontsize)

        
    def PlotFFT(self, ax):
        fft_freq = self._hist_data["fft_freq"]
        fft_amplitude = self._hist_data["fft_amplitude"]
        G_fft = self._hist_data["G_fft"]
        inv_G_fft = 1/G_fft
    
        ax.plot(fft_freq,
                 fft_amplitude,
                 color="purple",
                 label="Absolute Square of FFT\n of KDE",
                 lw=5)
        
        ax.fill_between(fft_freq[fft_freq<inv_G_fft],
                         fft_amplitude[fft_freq<inv_G_fft], 0,
                         edgecolor="red",
                         facecolor = 'none',
                         lw=2,
                         hatch = "//",
                         label="High Pass Filter Range")
        
        
        ax.axvline(x=inv_G_fft, color="green", lw=2, label="$G_{{FFT}}$ = {:3.3f} bins".format(G_fft))
        
        ax.set_yscale("log")
        ax.set_xscale("log")
        ax.grid(which="both")
        ax.set_xlabel("Inverse Bin Number", fontsize=25)
        ax.set_ylabel("$|\\mathcal{F}(f_{Q}(Q))|^{2}$", fontsize=25)
        ax.tick_params(axis="x", labelsize=self._plot_fontsize, rotation=45)
        ax.tick_params(axis="y", labelsize=self._plot_fontsize)
        ax.legend(fontsize=self._plot_legendfontsize)
        
    def PlotBGSub(self, ax):
        density = self._hist_data["density"]
        density_bgsub = self._hist_data["density_bgsub"]
        bin_numbers = self._hist_data["bin_numbers"]
        
        ax.plot(bin_numbers,
                 density,
                 label="KDE",
                 color="purple",
                 lw=5)
        ax.plot(bin_numbers,
                 density_bgsub,
                 label="KDE,\n Filtered",
                 color="red",
                 linestyle="--",
                 lw=5)
        ax.grid(which="both")
        ax.set_xlabel("Bin Number", fontsize=25)
        ax.set_ylabel("$f_{Q}(Q)$ [$Q^{-1}$]", fontsize=self._plot_fontsize)
        ax.tick_params(axis="x", labelsize=self._plot_fontsize, rotation=45)
        ax.tick_params(axis="y", labelsize=self._plot_fontsize)
        ax.legend(fontsize=self._plot_legendfontsize)
    
    
    def PlotPeaks(self, ax):
                
        density = self._hist_data["density"]
        bin_numbers = self._hist_data["bin_numbers"]
        x_0_est = self._hist_data["x_0_est"]
        
        peak_position = self._hist_data["peaks"]["peak_position"]
        peak_position_lower = self._hist_data["peaks"]["peak_position_lower"]
        peak_position_upper =  self._hist_data["peaks"]["peak_position_upper"]
        peak_height =  self._hist_data["peaks"]["peak_height"]
        peak_height_lower =  self._hist_data["peaks"]["peak_height_lower"]
        peak_height_upper =  self._hist_data["peaks"]["peak_height_upper"]
    
        ax.plot(bin_numbers, density, label="KDE", color="purple", lw=5)
        color_vals = [self._cmap(_v) for _v in np.linspace(0,0.75, len(peak_position))]
        
  
        for peak_i, (pp, ppl, ppu, ph, phl, phu) in enumerate(zip(
                           peak_position,
                           peak_position_lower,
                           peak_position_upper,
                           peak_height,
                           peak_height_lower,
                           peak_height_upper
                        )):
      
            if(peak_i == 0):
                annotation_text = "$Q_{{Ped}}$"
                color = "red"
            else:
                annotation_text = "Peak {:d}".format(peak_i)
                color = color_vals[peak_i]

            ax.vlines(x=[pp],
                      ymin=0,
                      ymax=ph,
                      color=color,
                      linestyle="-",
                      lw=2.5,
                      label = '{:s}'.format(annotation_text)
                     )

            ax.vlines(x=[ppl],
                      ymin=0,
                      ymax=phl,
                      color=color,
                      lw=2.5,
                      linestyle="--")

            ax.vlines(x=[ppu],
                      ymin=0,
                      ymax=phu,
                      color=color,
                      lw=2.5,
                      linestyle="--")


        ax.axvline(x=x_0_est,
                    color="green",
                    lw=2,
                    label= "$Q_{{Ped}}$ Estimate")
        
        ax.set_xlabel("Bin Number", fontsize=self._plot_fontsize)
        ax.set_ylabel("$f_{Q}(Q)$ [$Q^{-1}$]", fontsize=self._plot_fontsize)
        ax.tick_params(axis="x", labelsize=self._plot_fontsize, rotation=45)
        ax.tick_params(axis="y", labelsize=self._plot_fontsize)
        ax.grid(which="both")
        ax.legend(fontsize=self._plot_legendfontsize)


    def PlotVarianceFit(self, ax):
        
        density = self._hist_data["density"]
        bin_numbers = self._hist_data["bin_numbers"]
        
        peak_position_norm = self._hist_data["peaks"]["peak_position_norm"]
        peak_variance = self._hist_data["peaks"]["peak_variance"]
        peak_variance_error = self._hist_data["peaks"]["peak_variance_error"]
        
        
        _offset = self._hist_data["bc_min"]
        _scale =  self._hist_data["bw"]
        
        sigma_0 = self._prefit_values["sigma_0"]/_scale
        dsigma_0 = self._prefit_errors["sigma_0"]/_scale
        sigma_1 = self._prefit_values["sigma_1"]/_scale
        dsigma_1 = self._prefit_errors["sigma_1"]/_scale
        
        ax.errorbar(peak_position_norm,
                     peak_variance,
                     yerr=peak_variance_error,
                     fmt="o",
                     lw=3,
                     markersize=10,
                     color="C0",
                     label="Peak Variances"
                    )
        
        ax.plot(peak_position_norm,
                 Linear(peak_position_norm,
                        sigma_1**2,
                        sigma_0**2
                        ),
                 color="C0",
                 linestyle="--",
                 lw=3,
                 label = "Linear Fit: \n $\sigma^{{2}}_{{0}}$ = {:s} $\pm$ {:s} bins \n $\sigma^{{2}}_{{1}}$ = {:s} $\pm$ {:s} bins".format(
                                                              LatexFormat(sigma_0**2),
                                                              LatexFormat(2*sigma_0*dsigma_0),
                                                              LatexFormat(sigma_1**1),
                                                              LatexFormat(2*sigma_0*dsigma_1),
                
                 ),
                )
        
        
        
        ax.set_xlabel("$k$ [n.p.e]", fontsize=self._plot_fontsize)
        ax.set_ylabel("$\sigma^{2}_{k}$ [bins]", fontsize=self._plot_fontsize)
        ax.tick_params(axis="x", labelsize=self._plot_fontsize, rotation=45)
        ax.tick_params(axis="y", labelsize=self._plot_fontsize)
        plt.legend(fontsize=self._plot_legendfontsize)
        plt.show()
        
        
        
    def PlotMeanFit(self, ax):
        
        density = self._hist_data["density"]
        bin_numbers = self._hist_data["bin_numbers"]
        
        peak_position_norm = self._hist_data["peaks"]["peak_position_norm"]
        peak_mean = self._hist_data["peaks"]["peak_mean"]
        peak_mean_error = self._hist_data["peaks"]["peak_mean_error"]
        
        _offset = self._hist_data["bc_min"]
        _scale =  self._hist_data["bw"]
        
        x_0 = (self._prefit_values["x_0"] - _offset)/_scale
        dx_0 = self._prefit_errors["x_0"]/_scale
        G = self._prefit_values["G"]/_scale
        dG = self._prefit_errors["G"]/_scale
        
        ax.errorbar(peak_position_norm,
                     peak_mean,
                     yerr=peak_mean_error,
                     fmt="o",
                     markersize=10,
                     lw=3,
                     color="C0",
                     label="Peak Means"
                    )
        
        ax.plot(peak_position_norm,
                 Linear(peak_position_norm,
                        G,
                        x_0,
                        ),
                 color="C0",
                 linestyle="--",
                 lw=3,
                 markersize=10,
                 label = "Linear Fit: \n $Q_{{Ped}}$ = {:s} $\pm$ {:s} bins \n $G$ = {:s} $\pm$ {:s} bins".format(
                                                              LatexFormat(x_0),
                                                              LatexFormat(dx_0),
                                                              LatexFormat(G),
                                                              LatexFormat(dG),
                
                 ),
                )
        
        
        
        ax.set_xlabel("$k$ [n.p.e]", fontsize=self._plot_fontsize)
        ax.set_ylabel("$\\langle Q_{k} \\rangle$ [bins]", fontsize=self._plot_fontsize)
        ax.tick_params(axis="x", labelsize=self._plot_fontsize, rotation=45)
        ax.tick_params(axis="y", labelsize=self._plot_fontsize)
        ax.legend(fontsize=self._plot_legendfontsize)
        
        
        
        
    def PlotDCR(self, ax):
        
        x_0 = self._prefit_values["x_0"]
        G = self._prefit_values["G"]
        DCR = self._prefit_values["DCR"]
        dDCR = self._prefit_errors["DCR"]

        bin_numbers_norm = np.linspace(0,1,1000)
        density = self._hist_data["k2PDF"](bin_numbers_norm)
        density_errors = self._hist_data["k2PDF_err"](bin_numbers_norm)
        k_DC_min = self._hist_data["k_DC_min"]
        
        cond_02Min = (bin_numbers_norm>=0) & (bin_numbers_norm<k_DC_min)
        cond_Min22Min = (bin_numbers_norm>=k_DC_min) & (bin_numbers_norm<min(2*k_DC_min, 1))
        

        ax.fill_between(
                     bin_numbers_norm[cond_02Min],
                     density[cond_02Min],
                     color='none',
                     hatch="///",
                     edgecolor="b",
                     label = "$P_{0}$")
        


        ax.fill_between(
                     bin_numbers_norm[cond_Min22Min],
                     density[cond_Min22Min],
                     color='none',
                     hatch="///",
                     edgecolor="r",
                     label = "$P_{1}$")
        
        ax.axvline(x=k_DC_min, color="green", lw=2, label="$k_{min}$")
            
            
        ax.plot(bin_numbers_norm,
                density,
                label="KDE",
                color="purple",
                lw=5)

        ax.fill_between(bin_numbers_norm,
                         y1=density - 1.96*density_errors,
                         y2=density + 1.96*density_errors,
                         alpha=0.3,
                         color="purple",
                         label="KDE, 95% Confidence")
        
        
        
        ax.set_xlabel("$k$ [n.p.e]", fontsize=self._plot_fontsize)
        ax.grid(which="both")
        
        ax.set_ylabel("$f_{k}(k)$ [$k^{-1}$]", fontsize=self._plot_fontsize)
        ax.tick_params(axis="x", labelsize=self._plot_fontsize, rotation=45)
        ax.tick_params(axis="y", labelsize=self._plot_fontsize)
        ax.legend(title="$DCR$ = {:s} \n $\pm$ {:s} MHz".format(LatexFormat(DCR*1e3),
                                                                       LatexFormat(dDCR*1e3)),
                  fontsize=self._plot_legendfontsize,
                  title_fontsize=self._plot_legendfontsize
                 )
        ax.set_xlim([0,1])
        ax.set_yscale("log")

            
            
    def PlotSummary(self, display=True, save_directory=None):
    
            fig = plt.figure(figsize=(20,40))
            gs = gridspec.GridSpec(4, 2)
            gs.update(wspace=0.5, hspace=0.5)
              
            ax0 = fig.add_subplot(gs[0,0])
            self.PlotOriginalHistogram(ax0)
            
            ax1 = fig.add_subplot(gs[0,1])
            self.PlotDensity(ax1)
            
            ax3 = fig.add_subplot(gs[1,0])
            self.PlotGenPois(ax3)
            
            ax2 = fig.add_subplot(gs[1,:])
            self.PlotPeaks(ax2)        
            
            ax4 = fig.add_subplot(gs[2,1])
            self.PlotSigmaEst(ax4)
          
            ax6 = fig.add_subplot(gs[3,1])
            self.PlotDCREst(ax6)

            if(save_directory is not None):
                print("Saving figure to {0}...".format(save_directory))
                fig.savefig(save_directory)
            if(display):
                plt.pause(0.01)
                fig.show()
            else:
                plt.close(fig)
                
                
    def GetFitResults(self):
        return self._fit_values, self._fit_errors
    
    
    def GetPrefitResults(self):
        return self._prefit_values, self._prefit_errors
            
          
    def PlotFit(self,
                figsize=(10.0,10.0),
                labelsize=None,
                ticksize=None,
                titlesize=None,
                legendsize=None,
                title=None,
                xlabel = None,
                scaled=False,
                save_directory=None,
                data_label=None,
                y_limits = [None, None],
                residual_limits = [-5, 5],
                residualticksize=None,
                linewidth_main = 5,  
                x_limits = [None, None],
                display=True,
                prefit=False
               ):
    
    
        if(labelsize is None):
            labelsize=self._plot_fontsize
        
        if(ticksize is None):
            ticksize=self._plot_fontsize
            
        if(titlesize is None):
            titlesize = self._plot_fontsize
            
        if(legendsize is None):
            legendsize = self._plot_legendfontsize
            
        if(residualticksize is None):
            residualticksize =  int(0.8*self._plot_fontsize)
            
        x = self._hist_data["bin_centres"]
        y = self._hist_data["density_orig"]
        y_err = self._hist_data["density_orig_error"]
        y_hat = self.GetModel(x, prefit)
        chi2, ndf = self.GetChi2(prefit)
            
        if(xlabel is None):
            xlabel = "Q"
            
        fig = plt.figure(figsize=figsize)
        gs = gridspec.GridSpec(2, 1, height_ratios=[2, 1])
        ax0 = plt.subplot(gs[0])   
        if(data_label is None):
            data_label = "Histogram" 
        
        ax0.plot(x, y, label=data_label, lw=5, color="C0")
        ax0.plot(x, y_hat, linestyle="--", label="Model", lw=5, color="C1")
        ax0.plot([], [], ' ', label="$\\frac{{\\chi^{{2}}}}{{NDF}}$ = {:s}".format(LatexFormat(chi2/ndf)))

        
        ax0.set_ylabel("$f_{Q}(Q)$ [$Q^{-1}$]", fontsize=labelsize)
        ax0.tick_params(axis="y", labelsize=ticksize)
        

        
        if(y_limits[0] is None):
            y_limits[0] = np.min(y[y>0])
        ax0.set_ylim(y_limits)
        ax0.set_xlim(x_limits)
        if(title is not None):
            ax0.set_title(title, fontsize=titlesize)
            
        ax0.grid(which="both")
        
        ax0.set_yscale("log")
        ax0.legend(fontsize=legendsize)
       


        ax1 = plt.subplot(gs[1], sharex = ax0)
        ax1.scatter(x, (y - y_hat)/(y_err), color="C1")    

        ax1.tick_params(axis="x", labelsize=ticksize, rotation=45)
        ax1.tick_params(axis="y", labelsize=ticksize)
        ax1.set_ylabel("$Residuals$", fontsize=labelsize)
        ax1.set_xlabel(xlabel, fontsize=labelsize)
        ax1.set_ylim(residual_limits)
        ax1.set_yticks(np.arange(int(np.floor(np.min(residual_limits))), 
                                 int(np.ceil(np.max(residual_limits)))))
        ax1.axhline(y=-1.0, color="purple", lw=2.5, linestyle="--")
        ax1.axhline(y=1.0, color="purple", lw=2.5, linestyle="--")
        ax1.tick_params(axis="x", labelsize=ticksize)
        ax1.tick_params(axis="y", labelsize=residualticksize)
        ax1.grid(which="both")
        



        fig.tight_layout()
        
        mf = ScalarFormatter(useMathText=True)
        mf.set_powerlimits((-2,2))
        plt.gca().xaxis.set_major_formatter(mf)
        offset = ax1.xaxis.get_offset_text()
        offset.set_size(int(0.8*ticksize))
        


        fig.subplots_adjust(hspace=.0)
        if(save_directory is not None):
            print("Saving figure to {0}...".format(save_directory))
            fig.savefig(save_directory)
        if(display):
            plt.pause(0.01)
            fig.show()
        else:
            plt.close(fig)
        
        
            
    def Fit(self, data, **kwargs):
        
        self.Init()
        
        if not set(kwargs.keys()).issubset(self._fit_kwargs.keys()):
            wrong_keys = list(set(kwargs.keys()).difference(set(self._fit_kwargs.keys())))
            
            raise Exception ("Arguments {:s} not recognised.".format(",".join(wrong_keys)))
        
        if not "tau" in kwargs.keys():
            raise Exception ("Please provide tau (slow component of SiPM pulse) in nanoseconds.")
            
        if not "t_0" in kwargs.keys():
            raise Exception ("Please provide t_0 (integration region before start of SiPM pulse integration gate) in nanoseconds.")
            
        if not "t_gate" in kwargs.keys():
            raise Exception ("Please provide t_gate (length of  SiPM pulse integration gate) in nanoseconds.")
            
        
        self._fit_kwargs.update(kwargs)
        
        
        self.GetPreFit(data, **self._fit_kwargs)
        
        
        
        data_bn = self._hist_data["bc2bn"](self._hist_data["data"])
        f_DRM = UnbinnedLH(DRM, data_bn, 1)
        m_DRM = Minuit(f_DRM,
                          x_0 = self._prefit_values["x_0"],
                          G = self._prefit_values["G"],
                          mu = self._prefit_values["mu"],
                          lbda = self._prefit_values["lbda"],
                          sigma_0 = self._prefit_values["sigma_0"],
                          sigma_1 = self._prefit_values["sigma_1"],
                          tauAp =  self._prefit_values["tauAp"],
                          pAp =  self._prefit_values["pAp"],
                          DCR =  self._prefit_values["DCR"],
                          tau = self._prefit_values["tau"],
                          t_gate = self._prefit_values["t_gate"],
                          t_0 = self._prefit_values["t_0"],
                          k_low = self._prefit_values["k_low"],
                          k_hi = self._prefit_values["k_hi"],
                          k_dcr_low = self._prefit_values["k_dcr_low"],
                          k_dcr_hi = self._prefit_values["k_dcr_hi"],
                      )
        
        

        m_DRM.errors["x_0"]  = self._prefit_errors["x_0"]
        m_DRM.errors["G"] =  self._prefit_errors["G"]
        m_DRM.errors["mu"]  = self._prefit_errors["mu"]
        m_DRM.errors["lbda"]  = self._prefit_errors["lbda"]
        m_DRM.errors["sigma_0"]  = self._prefit_errors["sigma_0"]
        m_DRM.errors["sigma_1"]  = self._prefit_errors["sigma_1"]
        m_DRM.errors["tauAp"]  = self._prefit_errors["tauAp"]
        m_DRM.errors["pAp"]  = self._prefit_errors["pAp"]
        m_DRM.errors["DCR"]  = self._prefit_errors["DCR"]
        
        
        m_DRM.limits["x_0"]  = (self._eps, float(self._hist_data["nbins"]))
        
        m_DRM.limits["G"] =  (self._eps, float(self._hist_data["nbins"]))
        
        m_DRM.limits["mu"]  = (self._eps, self._prefit_values["k_hi"])
        
        m_DRM.limits["lbda"]  = (self._eps, 1-self._eps)
        
        m_DRM.limits["sigma_0"]  = (self._eps, float(self._hist_data["nbins"]))
        
        m_DRM.limits["sigma_1"]  = (self._eps, float(self._hist_data["nbins"]))
        
        m_DRM.limits["tauAp"]  = (self._eps, 
                                  self._prefit_values["tau"]-self._eps)
        
        m_DRM.limits["pAp"]  = (self._eps, 1-self._eps)
        
        m_DRM.limits["DCR"]  = (self._eps, None)
        

        if(self._fit_kwargs["tau_err"] is None):
            m_DRM.fixed["tau"]=True
        else:
            m_DRM.fixed["tau"]=False
            m_DRM.errors["tau"] = abs(self._prefit_errors["tau"])
            m_DRM.limits["tau"] = (max(self._eps, self._prefit_values["tau"]-self._prefit_errors["tau"]),
                                   max(self._eps, self._prefit_values["tau"]+self._prefit_errors["tau"])
                                  )
            
        if(self._fit_kwargs["t_gate_err"] is None):
            m_DRM.fixed["t_gate"]=True
        else:
            
            m_DRM.fixed["t_gate"]=False
            m_DRM.errors["t_gate"] = abs(self._prefit_errors["t_gate"])
            m_DRM.limits["t_gate"] = (max(self._eps, self._prefit_values["t_gate"]-self._prefit_errors["t_gate"]),
                                      max(self._eps, self._prefit_values["t_gate"]+self._prefit_errors["t_gate"])
                                     )
            
                                      
        if(self._fit_kwargs["t_0_err"] is None):
            m_DRM.fixed["t_0"]=True
        else:
            m_DRM.fixed["t_0"]=False
            m_DRM.errors["t_0"] = abs(self._prefit_errors["t_0"])
            m_DRM.limits["t_0"] = (max(self._eps, self._prefit_values["t_0"]-self._prefit_errors["t_0"]),
                                      max(self._eps, self._prefit_values["t_0"]+self._prefit_errors["t_0"])
                                      )
          
        

        m_DRM.fixed["k_low"]=True
        m_DRM.fixed["k_hi"]=True
        m_DRM.fixed["k_dcr_low"]=True
        m_DRM.fixed["k_dcr_hi"]=True
        
        m_DRM.strategy=2
        m_DRM.errordef=m_DRM.LIKELIHOOD
        m_DRM.simplex()
        m_DRM.migrad(ncall = self._n_call_minuit,
                     iterate= self._n_iterations_minuit)
        m_DRM.hesse()
        
        
        self._fit_minuit = m_DRM
        
        _offset = self._hist_data["bc_min"]
        _scale =  self._hist_data["bw"]
        
        for key, value in m_DRM.values.to_dict().items():
            if(key=="x_0"):
                self._fit_values[key] = _offset + value*_scale
            elif(key=="G"):
                self._fit_values[key] = value*_scale
            elif(key=="sigma_0"):
                self._fit_values[key] = value*_scale
            elif(key=="sigma_1"):
                self._fit_values[key] = value*_scale
            else:
                self._fit_values[key] = value
                


        for key, value in m_DRM.errors.to_dict().items():
            if(key=="x_0"):
                self._fit_errors[key] = value*_scale
            elif(key=="G"):
                self._fit_errors[key] = value*_scale
            elif(key=="sigma_0"):
                self._fit_errors[key] = value*_scale
            elif(key=="sigma_1"):
                self._fit_errors[key] = value*_scale
            else:
                self._fit_errors[key] = value
 
        
    
        _prefit_values_temp = self._prefit_values.copy()
        _prefit_errors_temp = self._prefit_errors.copy()
    
        for key, value in _prefit_values_temp.items():
            if(key=="x_0"):
                _prefit_values_temp[key] = _offset + value*_scale
            elif(key=="G"):
                _prefit_values_temp[key] = value*_scale
            elif(key=="sigma_0"):
                _prefit_values_temp[key] = value*_scale
            elif(key=="sigma_1"):
                _prefit_values_temp[key] = value*_scale
            else:
                _prefit_values_temp[key] = value
        
        for key, value in _prefit_errors_temp.items():
            if(key=="x_0"):
                _prefit_errors_temp[key] = value*_scale
            elif(key=="G"):
                _prefit_errors_temp[key] = value*_scale
            elif(key=="sigma_0"):
                _prefit_errors_temp[key] = value*_scale
            elif(key=="sigma_1"):
                _prefit_errors_temp[key] = value*_scale
            else:
                _prefit_errors_temp[key] = value
                
        self._prefit_values.update(_prefit_values_temp)
        self._prefit_errors.update(_prefit_errors_temp)
        
#         plt.figure()
#         plt.plot(bin_centres, density_orig)
#         plt.plot(bin_centres, DRM(bin_centres, **self._fit_values))
#         plt.show()
        
        
        
        
    def GetModel(self, x, prefit=False):
        if(prefit):            
            return DRM(x, **self._prefit_values)
        else:
            return DRM(x, **self._fit_values)
        
        
        
            

    def GetPreFit(self, data, **kwargs):
        
        
        
        ###Process Bins
        data, bw, nbins, bins = self.GetBins(data, kwargs["bw"], kwargs["bin_method"])
        
#         data = np.random.choice(data, size=10000)
        
        counts, _ = np.histogram(data, bins=bins)
        counts_error = np.sqrt(counts)
        density_orig = counts/np.sum(counts)/bw
        density_orig_error = density_orig*np.sqrt(1/counts + 1/np.sum(counts))
        
        min_error = 1/np.sum(counts)/bw
        bin_centres =  (bins[:-1] + bins[1:])/2.
        bin_numbers = np.arange(0, nbins)

        
        f_bc2bn = interp1d(bin_centres,bin_numbers, fill_value="extrapolate")
        
        f_bn2bc = interp1d(bin_numbers,bin_centres,fill_value="extrapolate")
        
        data_bn = f_bc2bn(data)
        
        
        ###Perform KDE smoothing
        bin_numbers_kde, density_kde, density_kde_err, _ = BootstrapKDE(data_bn,
                                                                        n_bootstrap=self._n_bootstrap,
                                                                        bw_limits=(0.01,None)
                                                                       )
        
        cumint = cumtrapz(density_kde/bw, bin_numbers_kde, initial=0)
        cumint = cumint/np.max(cumint)
        


        f_bn2CDF = interp1d( bin_numbers_kde,  cumint, fill_value = (0,1), bounds_error=False )
        f_CDF2bn = interp1d( cumint, bin_numbers_kde, fill_value = (bin_numbers_kde[0], bin_numbers_kde[-1]), bounds_error=False)
        f_bn2PDF = interp1d(bin_numbers_kde, density_kde/bw, fill_value=(0, 0), bounds_error=False)
        f_bn2PDF_err = interp1d(bin_numbers_kde, np.where(density_kde_err/bw>min_error, density_kde_err/bw, min_error),fill_value=(min_error, min_error),bounds_error=False)
        
        
        density = f_bn2PDF(bin_numbers)
        density_error = f_bn2PDF_err(bin_numbers)
        

        ###Perform FFT, Get Gain Estimate and Background Subtraction
        
        fhat = np.fft.fft(density) #computes the fft
        psd = fhat * np.conj(fhat)/nbins
      
        
        
        fft_freq_orig = (1/nbins) * np.arange(nbins) #frequency array
        idxs_half = np.arange(1, np.floor(nbins/2), dtype=np.int32)
                
        

        fft_amplitude = np.abs(psd[idxs_half])
        fft_freq = fft_freq_orig[idxs_half]
        idxs_fft_peaks, _= find_peaks(fft_amplitude)
        fft_freq_peaks = fft_freq[idxs_fft_peaks]
        fft_amplitude_peaks = fft_amplitude[idxs_fft_peaks]
        _idx_min = np.argmax(fft_amplitude_peaks)
        
        inv_G_fft = fft_freq_peaks[_idx_min]
        G_fft = 1/inv_G_fft

    
        density_bgsub = np.fft.ifft(np.where(fft_freq_orig<inv_G_fft, 0, fhat))


        ###Find peaks
            
        idxs_peaks, _ = find_peaks(density_bgsub,
                                   distance=kwargs["peak_dist_factor"]*G_fft)
        idxs_peak_widths = np.vstack(peak_widths(density_bgsub,
                                                 idxs_peaks,
                                                 rel_height=kwargs["peak_width_factor"])[2:]).T.astype(int)
        

        
        x_0_est = self.GetEstPedestal(data_bn, G_fft)
        
        Q_min = bin_numbers[idxs_peaks][np.argmin(abs(bin_numbers[idxs_peaks] - x_0_est))]
        if(kwargs["alpha_peaks"] is not None):
            Q_max = f_CDF2bn(kwargs["alpha_peaks"])
        else:
            Q_max = np.max(bin_numbers)
        
        _idxs_sel = (bin_numbers[idxs_peaks]>=Q_min) &  (bin_numbers[idxs_peaks]<=Q_max)
    
        idxs_peaks = idxs_peaks[_idxs_sel]
        idxs_peak_widths = idxs_peak_widths[_idxs_sel]
        
        _idxs_sort = np.argsort(idxs_peaks)
        
        idxs_peaks = idxs_peaks[_idxs_sort]
        idxs_peak_widths = idxs_peak_widths[_idxs_sort]
        
        
        peak_position = bin_numbers[idxs_peaks]
        peak_position_norm = (peak_position - peak_position[0])/G_fft
        peak_position_lower = bin_numbers[idxs_peak_widths[:,0]]
        peak_position_upper = bin_numbers[idxs_peak_widths[:,1]]
        peak_height = density[idxs_peaks]
        peak_height_lower = density[idxs_peak_widths[:,0]]
        peak_height_upper = density[idxs_peak_widths[:,1]]
        peak_mean = np.empty(len(peak_position))
        peak_mean_error =  np.empty(len(peak_position))
        peak_variance =  np.empty(len(peak_position))
        peak_variance_error =  np.empty(len(peak_position))
        peak_std_deviation =  np.empty(len(peak_position))
        peak_std_deviation_error =  np.empty(len(peak_position))
        peak_skewness =  np.empty(len(peak_position))
        peak_skewness_error =  np.empty(len(peak_position))
        peak_kurtosis = np.empty(len(peak_position))
        peak_kurtosis_error =  np.empty(len(peak_position))
        
  
        for peak_i, (pp, ppl, ppu) in enumerate(zip(
                           peak_position,
                           peak_position_lower,
                           peak_position_upper
                        )):
            
            
            
            data_peak = SelectRangeNumba(data_bn, ppl, ppu)

            try:
                mean, dmean, _ = Bootstrap(data_peak, np.mean, self._n_bootstrap)
            except:
                mean, dmean = np.nan, np.nan
                
            try:
                var, dvar, _ = Bootstrap(data_peak, np.var, self._n_bootstrap)
                std, dstd =   np.sqrt(var), (0.5/np.sqrt(var))*var
            except:
                var, dvar = np.nan, np.nan
                std, dstd = np.nan, np.nan
                
            try:
                skw, dskw, _ = Bootstrap(data_peak, sp_skew, self._n_bootstrap)
            except:
                skw, dskw = np.nan, np.nan
                    
            try:
                krt, dkrt, _ = Bootstrap(data_peak, sp_kurt, self._n_bootstrap)
            except:
                krt, dkrt = np.nan, np.nan
                
            
            peak_mean[peak_i] = mean
            peak_mean_error[peak_i] = dmean
            peak_variance[peak_i] = var
            peak_variance_error[peak_i] = dvar
            peak_std_deviation[peak_i] = std
            peak_std_deviation_error[peak_i] = dstd
            peak_skewness[peak_i]=skw
            peak_skewness_error[peak_i] = dskw
            peak_kurtosis[peak_i] = krt
            peak_kurtosis_error[peak_i] = dkrt
            
        
        cond_sel_peak_var = (~np.isnan(peak_variance) & ~np.isnan(peak_variance_error))
        cond_sel_peak_mean = (~np.isnan(peak_mean) & ~np.isnan(peak_mean_error))      
            
        f_var = HuberRegression(
                              Linear,
                              peak_position_norm[cond_sel_peak_var],
                              peak_variance[cond_sel_peak_var],
                              peak_variance_error[cond_sel_peak_var],
        )
        
        f_mean = HuberRegression(
                              Linear,
                              peak_position_norm[cond_sel_peak_mean],
                              peak_mean[cond_sel_peak_mean],
                              peak_mean_error[cond_sel_peak_mean],
        )
        
        
        m_var = Minuit(f_var,
                       m=np.mean(np.gradient(peak_variance[cond_sel_peak_var])),
                       c=peak_variance[cond_sel_peak_var][0])
        m_var.migrad()
        m_var.hesse()
        
      
        m_mean = Minuit(f_mean,
                        m=np.mean(np.gradient(peak_mean[cond_sel_peak_mean])),
                        c=peak_mean[cond_sel_peak_mean][0])
        m_mean.migrad()
        m_mean.hesse()
        

        
        x_0 = m_mean.values["c"]
        G = m_mean.values["m"]
        dx_0 = m_mean.errors["c"]
        dG = m_mean.errors["m"]
        
        sigma2_0 = m_var.values["c"]
        sigma2_1 = m_var.values["m"]
        dsigma2_0 = m_var.errors["c"]
        dsigma2_1 = m_var.errors["m"]
        
        
        sigma_0 = np.sqrt(sigma2_0)
        sigma_1 = np.sqrt(sigma2_1)
        dsigma_0 = (0.5/sigma_0)*dsigma2_0
        dsigma_1 = (0.5/sigma_1)*dsigma2_1
        
 
      
        mu, dmu, _ = Bootstrap((data_bn - x_0)/G,
                               lambda _data: GP_muGP(*GetStats(_data)),
                               self._n_bootstrap)
        lbda, dlbda, _ = Bootstrap((data_bn - x_0)/G,
                                   lambda _data: GP_lbda(*GetStats(_data)),
                                   self._n_bootstrap)
        
        
        k_low = 0
        k_hi = poisson.ppf(kwargs["alpha_fit"], mu)
        
        f_k2PDF = interp1d((bin_numbers_kde-x_0)/G, density_kde*G, fill_value=(0,0), bounds_error=False)
        f_k2PDF_err = interp1d((bin_numbers_kde-x_0)/G, density_kde_err*G, fill_value=(0,0), bounds_error=False)

        
        f_k2PDF = interp1d((bin_numbers_kde-x_0)/G, density_kde*G, fill_value=(0,0), bounds_error=False)
        f_k2PDF_err = interp1d((bin_numbers_kde-x_0)/G, density_kde_err*G, fill_value=(0,0), bounds_error=False)

        
    
        m_k2PDF = Minuit(lambda k: f_k2PDF(k), k=0.5)
        m_k2PDF.limits["k"] = (0,1)
        m_k2PDF.migrad()
        m_k2PDF.hesse()
        
        k_DC_min = m_k2PDF.values["k"]
        
        fk_DC_min = f_k2PDF(k_DC_min)
        dfk_DC_min = f_k2PDF_err(k_DC_min)
        
        P_int_02Min = quad(lambda k: f_k2PDF(k), 0, k_DC_min)[0]
        dP_int_02Min = quad(lambda k: 2*f_k2PDF_err(k), 0, k_DC_min)[0]
        
        
        P_int_Min22Min = quad(lambda k: f_k2PDF(k), k_DC_min, min(2*k_DC_min, 1))[0]
        dP_int_Min22Min = quad(lambda k: 2*f_k2PDF_err(k), k_DC_min, min(2*k_DC_min, 1))[0]
        
        
        P_sum = P_int_02Min + 0.5*P_int_Min22Min
        dP_sum = np.sqrt(P_int_02Min**2 + 0.5*P_int_Min22Min**2)

              
        DCR = fk_DC_min/(max(P_sum, self._eps)*4*kwargs["tau"])
        dDCR = DCR*np.sqrt((dfk_DC_min/max(fk_DC_min, self._eps))**2  + (dP_sum/max(P_sum, self._eps))**2)
        mu_DCR = DCR*(kwargs["t_0"] + kwargs["t_gate"])
        

        k_dcr_low = 0
        k_dcr_hi = poisson.ppf(kwargs["alpha_fit"], mu_DCR)

        self._hist_data["data"] = data
        self._hist_data["bw"] = bw
        self._hist_data["bc_min"] = f_bn2bc(0)
        self._hist_data["bins"]=bins
        self._hist_data["nbins"]=nbins
        self._hist_data["counts"] = counts
        self._hist_data["counts_error"] = counts_error
        self._hist_data["bin_centres"] = bin_centres
        self._hist_data["bin_numbers"] = bin_numbers
        self._hist_data["density_orig"] = density_orig
        self._hist_data["density_orig_error"] = density_orig_error
        self._hist_data["density"] = density
        self._hist_data["density_error"] = density_error
        self._hist_data["density_bgsub"] = density_bgsub
        self._hist_data["fft_amplitude"] = fft_amplitude
        self._hist_data["fft_freq"] = fft_freq
        self._hist_data["G_fft"] = G_fft
        self._hist_data["k_DC_min"] = k_DC_min
        self._hist_data["x_0_est"] = x_0_est
        self._hist_data["bc2bn"] = f_bc2bn
        self._hist_data["bn2bc"] = f_bn2bc
        self._hist_data["bn2PDF"] = f_bn2PDF
        self._hist_data["bn2PDF_err"] = f_bn2PDF_err
        self._hist_data["k2PDF"] =  f_k2PDF
        self._hist_data["k2PDF_err"] =  f_k2PDF_err
        
        
        self._hist_data["peaks"]["peak_position"] = peak_position
        self._hist_data["peaks"]["peak_position_norm"] = peak_position_norm
        self._hist_data["peaks"]["peak_position_lower"] = peak_position_lower
        self._hist_data["peaks"]["peak_position_upper"] = peak_position_upper
        self._hist_data["peaks"]["peak_height"] = peak_height
        self._hist_data["peaks"]["peak_height_lower"] = peak_height_lower
        self._hist_data["peaks"]["peak_height_upper"] = peak_height_upper
        self._hist_data["peaks"]["peak_mean"] = peak_mean
        self._hist_data["peaks"]["peak_mean_error"] = peak_mean_error
        self._hist_data["peaks"]["peak_variance"] = peak_variance
        self._hist_data["peaks"]["peak_variance_error"] = peak_variance_error
        self._hist_data["peaks"]["peak_std_deviation"] = peak_std_deviation
        self._hist_data["peaks"]["peak_std_deviation_error"] = peak_std_deviation_error
        self._hist_data["peaks"]["peak_skewness"] = peak_skewness
        self._hist_data["peaks"]["peak_skewness_error"] = peak_skewness_error
        self._hist_data["peaks"]["peak_kurtosis"] = peak_kurtosis
        self._hist_data["peaks"]["peak_kurtosis_error"] = peak_kurtosis_error
        

        self._prefit_values["x_0"] = x_0
        self._prefit_values["G"] = G
        self._prefit_values["mu"] = mu
        self._prefit_values["lbda"] = lbda
        self._prefit_values["sigma_0"] = sigma_0
        self._prefit_values["sigma_1"] = sigma_1
        self._prefit_values["DCR"] = DCR
        self._prefit_values["pAp"] = 0.1
        self._prefit_values["tauAp"] = 0.5*kwargs["tau"]
        self._prefit_values["tau"] = kwargs["tau"]
        self._prefit_values["t_0"] = kwargs["t_0"]
        self._prefit_values["t_gate"] = kwargs["t_gate"]
        self._prefit_values["k_low"]=k_low
        self._prefit_values["k_hi"]=max(2,k_hi)
        self._prefit_values["k_dcr_low"]=k_dcr_low
        self._prefit_values["k_dcr_hi"]=max(2, k_dcr_hi)
        
        
        
        self._prefit_errors["x_0"] = dx_0
        self._prefit_errors["G"] = dG
        self._prefit_errors["mu"] = dmu
        self._prefit_errors["lbda"] = dlbda
        self._prefit_errors["sigma_0"] = dsigma_0
        self._prefit_errors["sigma_1"] = dsigma_1
        self._prefit_errors["pAp"] = 0.1
        self._prefit_errors["tauAp"] = 1
        self._prefit_errors["tau"] = kwargs["tau_err"]
        self._prefit_errors["t_0"] = kwargs["t_0_err"]
        self._prefit_errors["t_gate"] = kwargs["t_gate_err"]
        self._prefit_errors["DCR"] = dDCR