# -*- coding: utf-8 -*-
"""
Created on Mon Mar 14 23:59:05 2016

@author: itayshom
"""

import warnings

from scipy.optimize import curve_fit

from math import pi
import numpy as np


"""
=====================================================
Curve fitting for coherent response (NA measurement)

Common variables in these functions:

    Δ - detuning from cavity
    κ - cavity bandwidth
    η - cavity coupling parameter, κe/κ
    β - modulation depth (just overall scale)
    c - constant floor
=====================================================
"""


class na_fixed_η:
    """
    Fit to NA using a fixed cavity coupling parameter η.
    """
    def __init__(self, freq, amp, η=1, poly=None):

        self.κ, self.Δ, self.β, _, self.c = na.get_initial(freq, amp)
        self.η = η

        self.poly = poly

        try:
            (self.κ, self.Δ, self.β, self.c), self.pcov = \
                curve_fit(self.func, freq, amp, (2*self.κ, self.Δ, self.β, 0))
            self.Δ = np.abs(self.Δ)
            self.κ = np.abs(self.κ)

        except RuntimeError:
            warnings.warn('NA fitting failed!\n')

        # compatibility
        self.delta = self.Δ
        self.kappa = self.κ

    def func(self, x, κ, Δ, β, c):
        if self.poly is not None:
            β *= np.polyval(self.poly, x/1e9)

        return na.func(x, κ, Δ, β, self.η, c)

    def __call__(self, x):
        return self.func(x, self.κ, self.Δ, self.β, self.c)



class na:
    """
    'pure' NA trace: no corrections for frequency response.
    This is not usually used for fitting because the parameter η weakly affects
    the result, and can be fixed. Instead, the static methods are frequently
    called by other fitting functions.
    """
    def __init__(self, freq, amp):

        self.κ, self.Δ, self.β, self.η, self.c = self.get_initial(freq, amp)

        try:
            self.κ, self.Δ, self.β, self.η = \
                curve_fit(self.func, freq, amp, (2*self.κ, self.Δ, self.β, self.η, self.c))[0]
            self.Δ = np.abs(self.Δ)
            self.κ = np.abs(self.κ)

        except RuntimeError:
            warnings.warn('NA fitting failed!\n')

        # compatibility
        self.delta = self.Δ
        self.kappa = self.κ
        self.eta = self.η

    @staticmethod
    def get_initial(x, y):
        """
        get estimated initial values of κ, Δ, β for fitting, based on FWHM,
        location of peak, and magnitude of peak, respectively.
        """
        m, i = y.max(0), y.argmax(0)

        w = np.flatnonzero(y >= m/2)

        β = 0.5*m
        Δ = x[i]
        κ = x[w[-1]] - x[w[0]]
        η = 0.5
        c = 0

        return κ, Δ, β, η, c


    @staticmethod
    def func(x, κ, Δ, β, η=1, c=0):
        # normalized detuning and frequency scales
        a = (2*Δ/κ)**2
        f = (2*x/κ)**2

        return 2*β*η/(a+1)*np.sqrt( a*f*(4*(η-1)**2+f) / ((a+1)**2 - 2*(a-1)*f + f**2)) + c

    def __call__(self, x):
        return self.func(x, self.κ, self.Δ, self.β, self.η, self.c)



"""
=====================================================
Lorentzian fitting functions
=====================================================
"""

class lorentzian:
    def __init__(self, xdata, ydata, kappa=None, x0=None, inverted=False):

        idx = np.logical_or(np.isnan(xdata), np.isnan(ydata))

        xdata = xdata[~idx]
        ydata = ydata[~idx]

        if x0 is None:
            if inverted:
                x0 = xdata[ydata.argmin(0)]
            else:
                x0 = xdata[ydata.argmax(0)]

        c = ydata[:10].mean()

        if inverted:
            a = ydata.min(0) - c
            w = np.flatnonzero(ydata <= c+a/2)
        else:
            a = ydata.max(0) - c
            w = np.flatnonzero(ydata >= c+a/2)

        if kappa is None:
            kappa = xdata[w[-1]] - xdata[w[0]]

        try:
            (self.kappa, self.x0, self.a, self.c), self.pcov = \
                curve_fit(self.func, xdata, ydata, (kappa, x0, a, c))
            self.kappa = np.abs(self.kappa)
        except RuntimeError:
            warnings.warn('Lorentzian fitting failed!')
            self.a = a
            self.c = c
            self.kappa = kappa
            self.x0 = x0

    @staticmethod
    def func(x, kappa=1, x0=0, a=1, c=0):
        return a/(1+(2*(x-x0)/kappa)**2) + c

    def __call__(self, x):
        return self.func(x, self.kappa, self.x0, self.a, self.c)

    def area(self):
        return 0.5*pi*self.a*self.kappa



class lorentzian_array:
    def __init__(self, xdata, ydata, kappa, x0, a0):

        n = len(kappa)

        idx = np.logical_or(np.isnan(xdata), np.isnan(ydata))

        xdata = xdata[~idx]
        ydata = ydata[~idx]

        c = ydata[:10].mean()

        params, _ = curve_fit(self.func, xdata, ydata, (*kappa,*x0,*a0,c))

        self.kappa, self.x0, self.a = params[:n], params[n:2*n], params[2*n:-1]
        self.c = params[-1]

    @staticmethod
    def func(x, *args):
        n = (len(args)-1)//2
        k, x0, a, c = args[:n], args[n:2*n], args[2*n:-1], args[-1]

        return np.sum(lorentzian.func(x,k[i],x0[i],a[i]) for i in range(n)) + c

    def __call__(self, x):
        return self.func(x,*self.kappa,*self.x0,*self.a,self.c)

    def area(self):
        return 0.5*pi*self.a*self.kappa



class lorentzian_array_constant_width:
    """
    Fit to several Lorentzians of the same width (and baseline)
    """
    def __init__(self, xdata, ydata, *args):

        if len(args)==3:
            kappa, x0, a0 = args

        elif len(args)==1:

            lidx = [np.logical_and(xdata>r[0], xdata<r[1]) for r in args[0]]

            L = [lorentzian(xdata[i], ydata[i]) for i in lidx]

            kappa, x0, a0 = L[0].kappa, [f.x0 for f in L], [f.a for f in L]


        n = len(x0)

        idx = np.logical_or(np.isnan(xdata), np.isnan(ydata))

        xdata = xdata[~idx]
        ydata = ydata[~idx]

        c = ydata[0:10].mean()

        try:
            params, self.pcov = curve_fit(self.func, xdata, ydata, (kappa,*x0,*a0,c))

            self.kappa = np.abs(params[0])
            self.x0, self.a = params[1:(n+1)], params[(n+1):(2*n+1)]
            self.c = params[-1]
        except RuntimeError:
            warnings.warn('Lorentzian fitting failed!')
            self.a = np.array(a0)
            self.c = c
            self.kappa = np.array(kappa)
            self.x0 = np.array(x0)
            self.pcov = np.zeros((2*n+2,2*n+2))

    @staticmethod
    def func(x, k, *args):
        n = (len(args)-1)//2
        x0, a, c = args[:n], args[n:-1], args[-1]

        return np.sum(lorentzian.func(x, k, x0[i], a[i]) for i in range(n)) + c

    def __call__(self, x):
        return self.func(x, self.kappa, *self.x0, *self.a, self.c)

    def area(self):
        return 0.5*pi*self.a*self.kappa

    def separate(self, x):
        return np.array([lorentzian.func(x, self.kappa, x0, a, self.c) for (x0,a) in zip(self.x0,self.a)])

