import numpy as np
from scipy.interpolate import interp1d
from iminuit.util import describe
from Model import epsilon
from HelperFunctions import FakeFuncCode
import matplotlib.pyplot as plt
    

           
                                              
class BinnedLH:
    

    def __init__(self, f, bcs, counts, bw):
        self.f = f
        self.x = bcs
        self.dx = bw
        self.counts = counts
        self.N = np.sum(counts)
        self.last_arg = None
        self.func_code = FakeFuncCode(f, dock=True)
        self.n_calls=0
        self.eps = epsilon()
        

    def __call__(self, *arg):
        self.last_arg = arg
        y_hat = self.f(self.x, *arg)
        y_hat = np.nan_to_num(y_hat, nan=self.eps, posinf=self.eps, neginf=self.eps)
        y_hat = np.where(y_hat<self.eps, self.eps, y_hat)
        
        E = y_hat*self.N*self.dx
        h = self.counts
        mask = (h>0)
        E = E[mask]
        h = h[mask]
        
        nlogL = -np.sum(h*(np.log(E) - np.log(h)) + (h-E))
        
        self.n_calls+=1
        

        
        return nlogL
    
    
    
    
class Chi2Regression:
    

    def __init__(self, f, x, y, y_err, epsilon=1.35):
        self.f = f
        self.x = x
        self.y = y
        self.y_err = y_err
        self.eps = np.finfo(np.float64).eps * 10
        self.y_err[self.y_err<self.eps] = self.eps
        self.last_arg = None
        self.func_code = FakeFuncCode(f, dock=True)
        self.ndof = len(self.y) - (self.func_code.co_argcount - 1)
      

    def __call__(self, *arg):
        
        self.last_arg = arg
        
        loss = ((self.f(self.x, *arg) - self.y)/(self.y_err))**2
        
        return np.sum(loss)
    
    
    
    
class HuberRegression:
    

    def __init__(self, f, x, y, y_err, delta=1.345):
        self.f = f
        self.x = x
        self.y = y
        self.y_err = y_err
        self.delta = delta
        self.eps = np.finfo(np.float64).eps * 10
        self.y_err[self.y_err<self.eps] = self.eps
        self.last_arg = None
        self.func_code = FakeFuncCode(f, dock=True)
        self.ndof = len(self.y) - (self.func_code.co_argcount - 1)
        

    def __call__(self, *arg):
        
        self.last_arg = arg
        
        a = abs((self.y - self.f(self.x, *arg))/self.y_err)
        cond_flag = (a>self.delta)
        
        loss = np.sum((~cond_flag) * (0.5 * a ** 2) - (cond_flag) * self.delta * (0.5 * self.delta - a), -1)
        
        return np.sum(loss)