import numpy as np
from scipy.interpolate import interp1d
from iminuit.util import describe
from Model import epsilon
from HelperFunctions import FakeFuncCode
import matplotlib.pyplot as plt
   
        
        
                                              
class BinnedLH:
    

    def __init__(self, f, bin_centres, counts, bw):
        self.f = f
        self.x = bin_centres
        self.len_x = len(self.x)
        self.dx = bw
        self.counts = counts
        self.N = np.sum(counts)
        
        self.last_arg = None
        self.func_code = FakeFuncCode(f, dock=True)
        self.n_calls=0
        self.eps = epsilon()
        self.mask = (self.counts>0)
     
        

    def __call__(self, *arg):
        self.last_arg = arg
                
        
        y_hat = self.f(self.x, *arg)

        
        E = y_hat*self.dx
        h = self.counts
        mask = self.mask & (y_hat>self.eps)
        

        nlogL= np.zeros(self.len_x)
        
        nlogL[mask] = h[mask]*(np.log(E[mask]) - np.log(h[mask])) + (h[mask]-E[mask])
        nlogL[~mask] = -E[~mask]

        nlogL = -np.sum(nlogL)    
         
        self.n_calls+=1
        

        return nlogL
    
    
    
    
class Chi2Regression:
    

    def __init__(self, f, x, y, y_err, epsilon=1.35):
        self.f = f
        self.x = x
        self.y = y
        self.y_err = y_err
        self.eps = np.finfo(np.float64).eps * 10
        self.y_err[self.y_err<self.eps] = self.eps
        self.last_arg = None
        self.func_code = FakeFuncCode(f, dock=True)
        self.ndof = len(self.y) - (self.func_code.co_argcount - 1)
      

    def __call__(self, *arg):
        
        self.last_arg = arg
        
        y_hat = self.f(self.x, *arg) 
        
        loss = ((y_hat - self.y)/(self.y_err))**2
        
        return np.sum(loss)
    
    
    
    
class HuberRegression:
    

    def __init__(self, f, x, y, y_err, delta=1.345):
        self.f = f
        self.x = x
        self.y = y
        self.y_err = y_err
        self.delta = delta
        self.eps = np.finfo(np.float64).eps * 10
        self.y_err[self.y_err<self.eps] = self.eps
        self.last_arg = None
        self.func_code = FakeFuncCode(f, dock=True)
        self.ndof = len(self.y) - (self.func_code.co_argcount - 1)
        

    def __call__(self, *arg):
        
        self.last_arg = arg
        
        a = abs((self.y - self.f(self.x, *arg))/self.y_err)
        cond_flag = (a>self.delta)
        
        loss = np.sum((~cond_flag) * (0.5 * a ** 2) - (cond_flag) * self.delta * (0.5 * self.delta - a), -1)
        
        return np.sum(loss)