# -*- coding: utf-8 -*-
"""
Created on Thu May 12 17:17:21 2016

@author: shomroni

Generate figures 3 and 4 in:
I. Shomroni et al. "Optical Backaction-Evading Measurement of a Mechanical Oscillator"
Nature Communications 10(1), 2086 (2019)

The data are stored in individual txt files for each trace, each file
containing two column for x (frequency) and y axes.

Coherent response traces (from the network analyzer) are used for verify
detuning, cavity linewidth etc.

Heterodyne power noise spectra a actual measurements of the mechanical motion.

Each txt file is accompanied by a spe file containing stored experimental
parameters for this particular data.
"""


import numpy as np
import matplotlib.pyplot as plt

from scipy.optimize import curve_fit

import datafiles
import fit
import fig_util

from constants import kHz,MHz,GHz,hc,π

#%%


def prepare_bae_data(folder):
    """
    Read noise spectra and NA traces (optics) for a single experiment.
    Generate some diagnostic figures.
    Note that some parameters are not needed for the final figures.
    """

    #%% Externally measured parameters

    κ = 1.7*GHz         # optical linewidth (also obtained from NA trace)
    ηc = 0.3            # cavity coupling parameter κe/κ
    λ = 1540e-9
    coupling_eff = 0.24 # measured two-way coupling efficiency fiber to cavity, includes intermediate circulator
    η23 = 0.74          # efficiency of circulator 2->3


    #%% Read heterodyne PSD data

    freq, H_dB, params, n = datafiles.concat_sweep_files(folder, 'Hetero_')
    _, H_bg_dB, _, _ = datafiles.concat_sweep_files(folder, 'HeteroBG')

    H_raw = 10**(H_dB/10)
    H_bg = 10**(H_bg_dB/10)

    H = 10**((H_dB-H_bg_dB)/10) # H_raw/H_bg


    T = params['tempHigh']
    att = params.get('AttenuationDB')
    ΔPLL1 = 2*π*params['PLLOffset1']
    ΔPLL2 = 2*π*params['PLLOffset2']
    ΔAOM = 2*π*params['FreqAOM'][0] if 'FreqAOM' in params else 0 # for experiments w/o cooling tone, there is no AOM.
    Ωm = 2*π*params['OmegaM']
    ΔLO = 2*π*params['DeltaLO'][0]
    wl = params.get('Wavelength')
    Pmon_factor = params['PmonFactor'][0]
    Pref_factor = params['PrefFactor'][0]

    PmonCool = params['PowerMonCool']

    PmonProbes = params['PowerMon1'][0] + params['PowerMon2'][0]


    # LO is locked to cooling beam (AOM on red probe)
    Ωmod = (ΔPLL2+ΔAOM)/2

    δ = Ωmod - Ωm

    δ0 = δ[abs(δ)==min(abs(δ))]
    idx0 = np.abs(δ).argmin()       # index for middle PSD in sweep

    ℏω = hc/λ

    PinCool = PmonCool * Pmon_factor * np.sqrt(coupling_eff/η23)
    PinProbes = PmonProbes * Pmon_factor * np.sqrt(coupling_eff/η23)


    #%% Coherent response (Network analyzer trace), verify detuning etc.

    freqNA, NA, _, nNA = datafiles.concat_sweep_files(folder, 'NA_broad')

    na_poly = [-0.000628390199840,   0.007987637850518,  -0.052804779996302, 1]

    NAfit = [fit.na_fixed_η(freqNA, y, 0.3, na_poly) for y in NA.T]

    ΔNA = 2*π * np.array([x.Δ for x in NAfit])
    κNA = 2*π * np.array([x.κ for x in NAfit])

    Δcool_nominal = Ωmod - ΔAOM

    Δred = -ΔNA - ΔAOM
    Δblue = -ΔNA + ΔPLL2

    # average detuning of probes, for nc calc
    Δprobes = np.array([Δblue, -Δred]).mean(axis=0)

    ncool = ηc*κNA / (ΔNA**2 + κNA**2/4) * PinCool/ℏω
    nprobes = ηc*κNA / (Δprobes**2 + κNA**2/4) * PinProbes/ℏω

    nred = ηc*κNA / (Δred**2 + κNA**2/4) * PinProbes/ℏω / 2
    nblue = ηc*κNA / (Δblue**2 + κNA**2/4) * PinProbes/ℏω / 2

    if True:
        plt.figure(10)
        plt.clf()

        ax1 = plt.subplot(221)
        ax1.plot(freqNA/1e9, NA)
        ax1.plot(freqNA/1e9, np.asarray([x(freqNA) for x in NAfit]).T, 'k')
        ax1.axis('tight')

        ax3 = plt.subplot(223)

        ax3a, ax3b = fig_util.plotyy(δ/MHz, (Δcool_nominal-ΔNA)/MHz, κNA/GHz,
                        xlabel='δ/2π (MHz)', y1label='(Δ-Δnominal)/2π (MHz)',
                        y2label='κ/2π (GHz)')

        ax2 = plt.subplot(222, xlabel='δ/2π (MHz)',
                               ylabel='Pmon Final / Pmon Initial',
                               title='red/blue inital = %g' % (params['PowerMon1'][0]/params['PowerMon2'][0]))
        ax2.plot(δ/MHz, params['PowerMonFinal']/(PmonCool+PmonProbes), 'o-')


        ax4 = plt.subplot(224, xlabel='δ/2π (MHz)', ylabel=r'$P_{final} / \bar P_{final}$',)
        ax4.plot(δ/MHz, params['PowerMonFinal']/params['PowerMonFinal'].mean(), 'o-')

        plt.tight_layout()


    #%% Fit Lorentzians to PSDs

    # fit first two individual Lorentzians to get initial guesses, then fit array

    # frequency ranges (tuple) to use to fit spectra (lorentizans)
    frange = ( (0,np.abs(ΔLO)/(2*π)), (np.abs(ΔLO)/(2*π),200e6) )

    lidx = [np.logical_and(freq>r[0], freq<r[1]) for r in frange]

    L = [ [fit.lorentzian(freq[i], y) for y in H[i].T] for i in lidx ]

    L3 = [fit.lorentzian_array_constant_width(freq, H[:,i].T, L[0][i].kappa, [f[i].x0 for f in L], [f[i].a for f in L]) for i in range(n)]

    # Hack - fit middle lorentzian separately
    L3[idx0] = fit.lorentzian_array_constant_width(freq, H[:,idx0].T, L[0][idx0].kappa, [f[idx0].x0 for f in L], [f[idx0].a/2 for f in L])


    Γmeas = 2*π * np.array([f.kappa for f in L3]).T
    Ωmeas = 2*π * np.array([f.x0 for f in L3]).T
    Ameas = np.array([f.a for f in L3]).T
    modearea = 2*π * np.array([f.area() for f in L3]).T

    c = np.array([f.c for f in L3])

    asymm = (Ameas.max(axis=0) / nblue / (κ**2/4 + (Ωm-np.abs(Δblue))**2)) / (Ameas.min(axis=0) / nred / (κ**2/4 + (Ωm-np.abs(Δred))**2))

    nm = 1/(asymm-1)


    #%% Plot individual spectra in tabbed window

    if True:

        figs = []

        lim1 = (H.min(), H.max())
        lim2 = (min(H_dB.min(), H_bg_dB.min()), max(H_dB.max(), H_bg_dB.max()))

        for i in range(n):
            fig = plt.Figure()

            fig.suptitle(folder)

            ax1 = fig.add_subplot(121, xlabel='Freq. (MHz)')
            ax1.plot(freq/1e6, H_dB[:,i], freq/1e6, H_bg_dB[:,i])
            ax1.set_xlim((freq.min()/1e6, freq.max()/1e6))
            #ax1.axis('tight')
            ax1.set_ylim(lim2)
            ax1.grid()

            ax2 = fig.add_subplot(122, title='δ = %g MHz'%(δ[i]/MHz), xlabel='Freq. (MHz)')
            ax2.plot(freq/1e6, H[:,i], 'b', linewidth=2)
            ax2.plot(freq/1e6, L3[i](freq), 'k', linewidth=1)

            ax2.text(0.1, 0.6, 'Γ = %g kHz\nAmp Sum = %.4g\nArea Sum = %.4g\nδ from fit = %.3g MHz\nphonons = %.3g'
                     %(Γmeas[i]/kHz, L3[i].a.sum(), L3[i].area().sum(), (L3[i].x0[1]-L3[i].x0[0])/2e6, nm[i]),
                       transform=ax2.transAxes, linespacing=2)

            ax2.set_xlim((freq.min()/1e6, freq.max()/1e6))
            ax2.set_ylim(lim1)
            ax2.grid()

            figs.append(fig)

        try:
            fig_util.tabbed(figs)

        except RuntimeError:
            # just plot all spectra in one figure
            plt.figure()
            plt.clf()

            ax = plt.axes(xlabel='Freq. (MHz)', title=folder)
            ax.plot(freq/1e6, H, linewidth=1)
            for i in range(n):
                ax.plot(freq/1e6, L3[i](freq), 'k', linewidth=0.5)
            ax.grid(True)

            ax.set_xlim((freq.min()/1e6,freq.max()/1e6))


    plt.show()

    #%% Return copy of local variables
    return locals().copy()




#%% GENERATE FIGURE 3 FROM ARTICLE


data = prepare_bae_data('fig3_data')

freq = data['freq']
H = data['H']
ΔLO = data['ΔLO']
δrange = data['δ']
nprobes = data['nprobes']
κ = data['κ']
Ameas = data['Ameas']
Γ = data['Γmeas'].mean()    # 607 kHz


# theory δ curve

def sweepδ(δrange, C):

    n = 6.33 - C

    ω = 2*π*freq - ΔLO

    χm = lambda ω: 1/(-1j*ω+Γ/2)

    Sii = lambda C, δ: n*np.abs(χm(ω+δ))**2 + (n+1)*np.abs(χm(ω-δ))**2 + C*np.abs(χm(ω+δ)-χm(ω-δ))**2

    lrz = [fit.lorentzian_array_constant_width(ω, Sii(C, δ), Γ, [δ, -δ], [n+1, n]) for δ in δrange]

    return np.array([l.a.sum() for l in lrz]) * (Γ**2/4) /2 - 0.5


g0 = 780*kHz

C = 4 * g0**2 * nprobes.mean()/2 / (Γ*κ)

δrange_t = np.linspace(δrange.min(), δrange.max(), 300)

bae_theory = sweepδ(δrange_t, C)


# make figure -- FLIPPED x-axis to account for LO

color_bae = '#803c6c'
color_fit_bae = '#c98db7'
color_ba = '#43a1d0'
color_fit_ba = '#9dcee7'
color_delta = '#BBBBBB'
color_delta_fit = '#b11f24'
color_fill = '#f4f4d2'

idx0, idx1 = 8, 0

y = Ameas.sum(axis=0)

n = 6.33
y = y / y[[0,1,-1-2]].mean() * (2*n+1) /2 - 0.5



lrz_data = [fit.lorentzian_array_constant_width(freq, y, Γ/(2*π), [(δ+ΔLO)/(2*π), (-δ+ΔLO)/(2*π)], [(y.max()-1)/2, (y.max()-1)/2]) for y,δ in zip(H.T, δrange)]

lrz0 = fit.lorentzian(freq, H[:,idx0], Γ/(2*π), ΔLO/(2*π))
lrz1 = lrz_data[idx1]

# average the four outermost points to extract occupancy
asymm = np.array([f.a[1]/f.a[0] for f in lrz_data])

nbar = 1 / (asymm[[0,1,-2,-1]].mean()-1)


plt.style.use('aps.mplstyle')


fig = plt.figure('fig3')
fig.clf()

ax1 = plt.axes(xlabel=r'$\omega/2\pi$ (MHz)', ylabel=r'$\bar S_{II}(\omega)$')

freq2 = (freq - ΔLO/(2*π))/1e6

ax1.plot(-freq2, H[:,idx1], color=color_ba)
ax1.plot(-freq2, H[:,idx0], color=color_bae)
ax1.plot(-freq2, fit.lorentzian.func(freq, lrz_data[idx1].kappa, lrz_data[idx1].x0.mean(), lrz_data[idx1].a.sum(), 1), color=color_fit_ba, linestyle='--')
ax1.plot(-freq2, lrz1(freq), color=color_fit_ba)
ax1.plot(-freq2, lrz0(freq), color=color_fit_bae)

ax1.set_xlim((freq2.min(), freq2.max()))
ax1.set_ylim((0.95,2.4))


plt.tight_layout(pad=0)


# left inset: δ sweep

axi = fig_util.inset(ax1, [0.014, 0.686, 0.3, 0.3], xlabel=r'$\delta/2\pi$ (MHz)', ylabel=r'$\bar n$')

y = Ameas.sum(axis=0)

axi.plot(δrange/MHz, y / y[idx1]*nbar, 'o', markersize=4, color=color_delta)
axi.plot(δrange[idx1]/MHz, y[idx1] / y[idx1]*nbar, 'o', markersize=4, color=color_ba)
axi.plot(δrange[idx0]/MHz, y[idx0] / y[idx1]*nbar, 'o', markersize=4, color=color_bae)
axi.plot(δrange_t/MHz, bae_theory, color=color_delta_fit)

axi.set_ylabel(r'$\bar n$', labelpad=2, fontsize=6)
axi.set_xlabel(r'$\delta/2\pi$ (MHz)', labelpad=0, fontsize=6)

axi.set_xticks([-3,0,3])
axi.set_yticks([5.6,6,6.4])
axi.set_ylim((5.6,6.45))
axi.tick_params(labelsize=6, pad=3)

axi.yaxis.tick_right()
axi.yaxis.set_label_position('right')

axi.annotate('', xy=(2.5,nbar), xytext=(2.5, y[idx0]/y[idx1]*nbar),
             arrowprops = dict(arrowstyle='<|-|>', shrinkA=0, shrinkB=0, linewidth=0.5, color='k',))

axi.text(2.5, 6, r'0.7', horizontalalignment='center', verticalalignment='center', fontsize=6, bbox=dict(facecolor='w', edgecolor='none', pad=0))

# right inset -- zoom of peak

f0 = np.argmin(np.abs(freq - lrz1.x0[0]))
f1 = np.argmin(np.abs(freq - lrz1.x0[1]))

axj = fig_util.inset(ax1, [0.686-0.05, 0.686, 0.35, 0.3])

axj.plot(freq2, H[:,idx0], color=color_bae)
axj.plot(freq2, fit.lorentzian.func(freq, lrz_data[idx1].kappa, lrz_data[idx1].x0.mean(), lrz_data[idx1].a.sum(), 1), color=color_fit_ba, linestyle='--')
axj.plot(freq2, lrz0(freq), color=color_fit_bae)

axj.set_xlim((-0.175,0.175))
axj.set_ylim((2.05,2.32))

axj.tick_params(labelsize=6, pad=3)

axj.fill_between(freq2, fit.lorentzian.func(freq, lrz_data[idx1].kappa, lrz_data[idx1].x0.mean(), lrz_data[idx1].a.sum(), 1), lrz0(freq), color=color_fill)


from matplotlib.patches import Rectangle

ax1.add_patch(Rectangle(xy=(axj.get_xlim()[0], axj.get_ylim()[0]),
                        width=axj.get_xlim()[1]-axj.get_xlim()[0],
                        height=axj.get_ylim()[1]-axj.get_ylim()[0],
                        facecolor='none',
                        edgecolor=[0.7,0.7,0.7]))

plt.savefig('fig3.pdf')

plt.show()


#%% GENERATE FIGURE 4 FROM ARTICLE

import os

data = [prepare_bae_data(os.path.join('fig4_data', '%duW'%i)) for i in [1,2,4,6,8]]


nidx = [0,-1,1,-2]  # indices of four extremal data points

idx0 = 3            # index of BAE data point

nm = np.array([i['nm'][nidx].mean() for i in data])
nm_err = np.array([i['nm'][nidx].std() for i in data])
nprobes = np.array([i['nprobes'][0]/2 for i in data])
snr = np.array([i['Ameas'].sum(axis=0) for i in data])
Γ = np.array([i['Γmeas'] for i in data])

# NOTE: for these measurements, nc = 420, Γopt = 600kHz, Γm=120kHz, Ccool = 5

Γm, g0, κ = 120*kHz, 780*kHz, 1.7*GHz

C0 = 4*g0**2 / (κ*Γ.mean())

snrBAE = snr[:,idx0]        # signal-to-shot noise in the BAE meas, δ=0
snrBA = snr[:,nidx].mean(axis=1)    # signal-to-shot noise in the non-BAE measurements, δ=0

nprobes_t = np.linspace(1e-3/C0, 1.6/C0, 200)   # x-axis for theory curves


# n_bae = number of quanta measured in the BAE measurement, δ=0
n_bae = snrBAE/snrBA * (nm+0.5) - 0.5
n_bae_err = snrBAE/snrBA * nm_err

n_imp = (n_bae+0.5) / snrBAE
n_imp_err = n_bae_err / snrBAE

n_imp_f, _ = curve_fit(lambda x,a: a/x, nprobes, n_imp, 1)


n_bae_fit = np.polyfit(nprobes, n_bae, 1)

plt.style.use('aps.mplstyle')

plt.figure('fig4')
plt.clf()

ax1 = plt.axes()
ax2 = ax1.twinx()

ax1.set_zorder(10)
ax1.patch.set_visible(False)

# red, thermal quanta (excluding backaction)
ax1.plot(C0*nprobes_t, np.polyval(n_bae_fit, nprobes_t), color='#dc545c', linestyle='--')
e1 = ax1.errorbar(C0*nprobes, n_bae, yerr=n_bae_err, color='#960810', marker='o', markersize=3, linestyle='', elinewidth=1, label=r'$\bar n$')

# blue, number of evaded backaction quanta
ax1.plot(C0*nprobes_t, C0*nprobes_t, color='#7ab4d9')
e2 = ax1.errorbar(C0*nprobes, nm-n_bae, yerr=n_bae_err, color='#1b69af', marker='s', markersize=3, linestyle='', elinewidth=1, label=r'$\bar n_\mathrm{BA}$')

# green, imprecision quanta
ax2.plot(C0*nprobes_t, n_imp_f / nprobes_t, color='#35a12e', linestyle='-.')
e3 = ax2.errorbar(C0*nprobes, n_imp, yerr=n_imp_err, color='darkgreen', marker='o', markersize=3, linestyle='', elinewidth=1, label=r'$\bar n_\mathrm{imp}$')


ax2.set_ylim((1.2,18))
ax2.set_yticks([2,6,10,14,18])

ax1.set_xticks([0,0.4,0.8,1.2,1.6])
ax1.set_ylim((-0.6,9.2))
ax1.set_xlim((0,1.6))
ax1.set_xlabel(r'$\mathcal{C}$', fontsize=10)
ax1.set_ylabel(r'$\bar n$, $\bar n_\mathrm{BA}$')
ax2.set_ylabel(r'$\bar n_\mathrm{imp}$')

ax1.yaxis.labelpad = 0
ax2.yaxis.labelpad = 0

ax1.text(1.3, 0.6, r'$\mathcal{C}$', fontsize=10, color='#7ab4d9')
ax1.text(0.275, 6.7, r'$\frac{1}{8\eta\mathcal{C}}$', color='#35a12e', fontsize=14)

dx = 0.45
dy = -4.4

arw = lambda c: dict(facecolor=c, edgecolor='none', shrink=0.05, width=2, headwidth=6, headlength=6)

lns = (e1, e2, e3)
labs = [l.get_label() for l in lns]
lg = ax1.legend(lns, labs, loc=(0.775,0.4), frameon=False, handletextpad=0)

plt.tight_layout(pad=0)

plt.savefig('fig4.pdf')

plt.show()