import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import numpy as np
import pandas as pd
import argparse
from PeakOTron import PeakOTron
from joblib import dump
import time

parser = argparse.ArgumentParser(description='Fit SiPM data')
parser.add_argument('-V_bd_hmt', type=float, default=51.570574 + 0.307, help='V_bd_hmt value')
parser.add_argument('-V_0_hmt', type=float, default=2.906, help='V_0_hmt value')
parser.add_argument('-tau', type=float, default=21.953, help='SLOW COMPONENT OF SIPM PULSE')
parser.add_argument('-t_0', type=float, default=100.0, help='PRE-INTEGRATION TIME')
parser.add_argument('-t_gate', type=float, default=104.0, help='GATE LENGTH')
parser.add_argument('-bin_0', type=float, default=-100.0, help='SELECT FIRST BIN OF SPECTRUM')
parser.add_argument('-truncate_nsigma0_up', type=float, default=2.0, help='SCAN SPECTRUM FROM Q < Q_0 - 4 sigma_0')
parser.add_argument('-truncate_nsigma0_do', type=float, default=2.0, help='EVALUATE SPECTRUM CHI2 IN Q_0 - x*sigma_0 < Q < Q_0 + 2*sigma_0')
parser.add_argument('-prefit_only', action='store_true', help='FIT THE WHOLE SPECTRUM')
parser.add_argument('-folder', type=str, default='data/hamamatsu_pcb6/Light', help='Directory containing the data files')

args = parser.parse_args()

C_tau = lambda V, V_bd, V_0: (V - V_bd)/V_0
f_tau = lambda V, V_bd, V_0: -1/np.log((1-np.exp(C_tau(V, V_bd, V_0)*np.exp(-1)))/(1 - np.exp(C_tau(V, V_bd, V_0))))  

V_bd_hmt = args.V_bd_hmt
V_0_hmt = args.V_0_hmt
tau = args.tau ##SLOW COMPONENT OF SIPM PULSE
t_0 = args.t_0 ## PRE-INTEGRATION TIME
t_gate = args.t_gate ## GATE LENGTH
bin_0 = args.bin_0 ## SELECT FIRST BIN OF SPECTRUM (CAN BE AUTOMATIC)
truncate_nsigma0_up = args.truncate_nsigma0_up ## SCAN SPECTRUM FROM Q < Q_0 - 4 sigma_0
truncate_nsigma0_do = args.truncate_nsigma0_do ## EVALUATE SPECTRUM CHI2 IN Q_0 - x*sigma_0  < Q < Q_0 + 2*sigma_0
prefit_only = args.prefit_only ## FIT THE WHOLE SPECTRUM




print("--------------------")
print('EXAMPLE SIPM CALIBRATION RUN')
print("--------------------")


out_dict = {}
files_to_fit = []

## Find all histograms in directory 
folder = args.folder
for root, dirs, files in os.walk(folder):
     
        for file in files:
            
            if file.endswith(".txt"):
                files_to_fit.append([file, os.path.join(root, file)])
                

## Print files.
print("Files to fit:")
for i, (file, _) in enumerate(files_to_fit):
    print('File {0}: {1}'.format(i, file))
    
    
    
SiPM = "PM1125NS_SBO"


## Loop thorough files
for i, (file, path) in enumerate(files_to_fit):
    items = file.split('_')

    V = float(items[2].replace('V', '').replace('p', '.'))
    print(V)
    f_tau_hmt = f_tau(V, V_bd_hmt, V_0_hmt)  


    print("\n\n")
    print("===============================================================")
    print("FIT {:d} - {:s}".format(i, file))
    print("===============================================================")
    print("\n\n")
        

    ## Load files. 
    data = np.loadtxt(path, skiprows=8)

    ## Create a PeakOTron Fit Object. 
    f_data = PeakOTron(verbose=False)

    ## Perform fit. 
    f_data.Fit(data, 
          tau=tau,  #SLOW PULSE COMPONENT TIME CONSTANT (ns)
          t_gate=t_gate, #GATE LENGTH (ns)
          t_0 = t_0, #INTEGRATION TIME BEFORE GATE (ns)               
          tau_R=f_tau_hmt*tau,
          bin_0 = bin_0,
          truncate_nsigma0_up = truncate_nsigma0_up,
          truncate_nsigma0_do = truncate_nsigma0_do


    ) #BINNING RULE "knuth", "freedman", "scott" - use bw= #### to override. it is still required for DCR calculation. 

    f_data.PlotFit(plot_in_bins=True, display=False, save_directory="./Results/{0}_fit.png".format(os.path.splitext(file)[0]))
    


    
    dump(f_data, "./Results/{0}.joblib".format(os.path.splitext(file)[0]))
    
    
    
    fit_out = {}
    fit_val, fit_err = f_data.GetFitResults()
    for key, val in fit_val.items():
        print("{:s} : {:3.3E}".format(key, val))
    
    fit_out["SiPM"] = SiPM
    fit_out["V"] = V

    for key, value in fit_err.items():
        fit_out["d_{:s}".format(key)] = value

    fit_out.update(fit_val)
    out_dict.update()
    if out_dict == {}:
        for key in fit_out.keys():
            out_dict[key] = []

    for key in fit_out.keys():
        out_dict[key].append(fit_out[key])

    print("===============================================================")
    print("\n\n")
        



df = pd.DataFrame.from_dict(out_dict)
df.to_csv("./fit_results_{:s}.csv".format(SiPM))