import numpy as np
from scipy.interpolate import interp1d
from iminuit.util import describe
from HelperFunctions import FakeFuncCode
import matplotlib.pyplot as plt
    

class UnbinnedLH:
    

    def __init__(self, f, data, bw):
        self.f = f
        self.x_min = np.min(data)
        self.x_max = np.max(data)
        self.n_bins = np.ceil((np.max(data) - np.min(data))/bw).astype(int)
        self.x = np.linspace(self.x_min, self.x_max, 10*self.n_bins) 
        self.data = data
        self.last_arg = None
        self.func_code = FakeFuncCode(f, dock=True)
        self.ndof = len(self.data) - (self.func_code.co_argcount - 1)
        self.eps = np.finfo(np.float64).eps * 10
        self.eps_inv = 1/self.eps
        self.log_eps = np.log(self.eps)
        self.n_calls = 0
        
        
    def Logify(self, y):
        
        return np.where(y>self.eps, np.log(y), self.log_eps)

          

    def __call__(self, *arg):
        self.last_arg = arg
        y_hat = self.f(self.x, *arg)
        y_hat = np.nan_to_num(y_hat, nan=self.eps)
        
        log_y_hat = interp1d(self.x, 
                             self.Logify(y_hat), 
                             fill_value=(self.log_eps,self.log_eps), 
                             bounds_error=False)(self.data)

        
        
        nlogL = -np.sum(log_y_hat)
        
        self.n_calls+=1
        
#         if(self.n_calls%200==0):
#             print(arg)
#             plt.figure()
#             plt.hist(self.data, bins=200, density=True)
#             plt.plot(self.x, y_hat)
#             plt.yscale("log")
#             plt.show()
       
        return nlogL
        
           
                                              
class BinnedLH:
    

    def __init__(self, f, x, y):
        self.f = f
        self.x = x
        self.y = y
        self.last_arg = None
        self.func_code = FakeFuncCode(f, dock=True)
        self.ndof = len(self.y) - (self.func_code.co_argcount - 1)
        self.eps = np.finfo(np.float64).eps * 10
        self.eps_inv = 1/self.eps
        self.log_eps = np.log(self.eps)
        
        
    def Logify(self, y):
     
        return np.where(y>self.eps, np.log(y), self.log_eps)

        

    def __call__(self, *arg):
        self.last_arg = arg
        y_hat = self.f(self.x, *arg)
        # y_hat = np.nan_to_num(y_hat, nan=self.eps_inv)
        logL = self.y*(self.Logify(y_hat) - self.Logify(self.y)) + (self.y - y_hat)
        nlogL = -np.sum(logL)

        return nlogL
    
    
    
    
class Chi2Regression:
    

    def __init__(self, f, x, y, y_err, epsilon=1.35):
        self.f = f
        self.x = x
        self.y = y
        self.y_err = y_err
        self.eps = np.finfo(np.float64).eps * 10
        self.y_err[self.y_err<self.eps] = self.eps
        self.last_arg = None
        self.func_code = FakeFuncCode(f, dock=True)
        self.ndof = len(self.y) - (self.func_code.co_argcount - 1)
      

    def __call__(self, *arg):
        
        self.last_arg = arg
        
        loss = ((self.f(self.x, *arg) - self.y)/(self.y_err))**2
        
        return np.sum(loss)
    
    
    
    
class HuberRegression:
    

    def __init__(self, f, x, y, y_err, delta=1.345):
        self.f = f
        self.x = x
        self.y = y
        self.y_err = y_err
        self.delta = delta
        self.eps = np.finfo(np.float64).eps * 10
        self.y_err[self.y_err<self.eps] = self.eps
        self.last_arg = None
        self.func_code = FakeFuncCode(f, dock=True)
        self.ndof = len(self.y) - (self.func_code.co_argcount - 1)
        

    def __call__(self, *arg):
        
        self.last_arg = arg
        
        a = abs((self.y - self.f(self.x, *arg))/self.y_err)
        cond_flag = (a>self.delta)
        
        loss = np.sum((~cond_flag) * (0.5 * a ** 2) - (cond_flag) * self.delta * (0.5 * self.delta - a), -1)
        
        return np.sum(loss)