"""
Utilities for processing dynamic coupon data.

Includes loading *.unv files from Polytec LDV system and calculating damping
from time series signals
"""

import os
import numpy as np

from scipy.signal import butter, filtfilt, find_peaks

import matplotlib.pyplot as plt


def load_unv(fname):
    """
    Load a unv file

    Parameters
    ----------
    fname : str
        File path of the unv file that is to be loaded

    Returns
    -------
    data_dict : dict
        Dictionary containing data from the unv file.

    """

    with open(fname, 'r') as file:
        lines = file.readlines()

    block_ends = [li.strip() == '-1' for li in lines]

    block_inds = np.where(block_ends)[0]

    block_lines = [None] * (block_inds.shape[0]//2)

    for bind in range(block_inds.shape[0]//2):

        block_lines[bind] = lines[block_inds[bind*2]+1:block_inds[bind*2+1]]

    data_dict = len(block_lines) * [None]

    data_dict[0] = dict_block0(block_lines[0])

    data_dict[1] = dict_block1(block_lines[1])

    for ind,line_set in enumerate(block_lines[3:]):

        data_dict[3+ind] = parse_data_block(line_set)

    return data_dict

def dict_block0(data_lines):
    """
    Parses first set block of data lines from unv file

    Parameters
    ----------
    data_lines : list of str
        File lines from the header for the unv file.

    Returns
    -------
    data_dict : dict
        Dictionary with header information from unv file lines.

    """

    data_lines = [line.strip() for line in data_lines
                  if line.strip()]

    data_dict = {
        "unknown": int(data_lines[0]),
        "filename": data_lines[1],
        "psv_version": data_lines[2].replace("PSV Version ", ""),
        "timestamp1": data_lines[3],
        # "none_fields": data_lines[4],
        "export_tool": data_lines[5].split(" - ")[0],
        "export_tool_info": data_lines[5].split(" - ")[1],
        "timestamp2": data_lines[6]
    }

    return data_dict

def dict_block1(data_lines):
    """
    Parses second set block of data lines from unv file

    Parameters
    ----------
    data_lines : list of str
        File lines from the second header for the unv file.

    Returns
    -------
    data_dict : dict
        Dictionary with header information from unv file lines.

    """

    lines = [line.strip() for line in data_lines
                  if line.strip()]

    # Combine all float values into a single list
    values = list(map(float, lines[2].split())) + list(map(float, lines[3].split()))

    data_dict = {
        "unknown0" : int(lines[0]),
        "unit_note" : lines[1].strip().split()[0],
        "unknown1" : lines[1].strip().split()[0],
        "values" : values
    }

    return data_dict

def parse_data_block(data_lines):
    """
    Parses time series data from unv file into nice format

    Parameters
    ----------
    data_lines : list of str
        File lines from a time series data block in the unv file

    Returns
    -------
    data_dict : dict
        Dictionary with time series information from unv file lines.

    """

    channel = data_lines[2]

    data = np.loadtxt(data_lines[12:-1])

    data = data.reshape(-1, order='C')

    data1 = np.array([float(val) for val in data_lines[-1].strip().split()])

    full_data = np.hstack((data, data1))

    dt = np.loadtxt([data_lines[7]])[4]

    units = data_lines[9].strip().split()[-1]
    data_type = data_lines[9].strip().split()[-2]

    data_dict = {
        'type' : data_type,
        'channel' : channel,
        'data' : full_data,
        'dt' : dt,
        'units' : units,
        'time' : dt * np.arange(full_data.shape[0])
    }

    return data_dict

def filter_log_dec(time, signal, low_freq=2.0, high_freq=200.0, filter_order=3,
                   min_amp=0.001, max_amp=0.02, max_amp_scale=0.9,
                   show_plots=False, title='', start_offset=0.0,
                   save_name=None,
                   units='m/s',
                   fig_ext='.eps'):
    """
    Filter and process damping (log dec and least squares) from time series.

    Parameters
    ----------
    time : (N,) numpy.ndarray
        Time points for the `signal`.
    signal : (N,) numpy.ndarray
        Time series measurements.
    low_freq : float, optional
        Low frequency cut off for filtering data, in Hz. The default is 2.0.
    high_freq : float, optional
        High frequency cut off for filtering data, in Hz. The default is 200.0.
    filter_order : int, optional
        Filter order for butterworth bandpass filter. The default is 3.
    min_amp : float, optional
        Stop processing signal when amplitude is below this threshold.
        The default is 0.001.
    max_amp : float, optional
        Only start processing samples when amplitude drops below this threshold
        The start threshold is also limitted on the steady-state amplitude.
        The default is 0.02.
    max_amp_scale : float, optional
        Only start processing samples when the amplitude drops below
        `max_amp_scale` times the steady-state amplitude.
        The default is 0.9.
    show_plots : bool, optional
        Flag to show plots of the response and processing.
        The default is False.
    title : str, optional
        Title to put on plots. The default is ''.
    start_offset : float, optional
        Start processing this amount of time after the max amplitude
        thresholds are met.
        The default is 0.0.
    save_name : str or None, optional
        Name to be used for this signal when saving some figures.
        Should not include the file extension.
        Figures are only saved if `save_name is not None`.
        Additional figures may be produced when this option is given.
        The default is None.
    units : str, optional
        Units of the signal, only used in plot labels. The default is 'm/s'.
    fig_ext : str, optional
        File extension when saving figures (only if `save_name is not None`).
        The default is '.eps'.

    Returns
    -------
    freq_id : float
        Identified frequency of the response (Hz). This is damped frequency.
        This is calculated based on counting peaks over the identified time
        region. If the data is poorly filtered, this could be a bad estimate.
    zeta_log_dec : float
        Damping factor (as fraction of critical damping) calculated via
        log decrement between the first and last cycles.
        Prefer useage of `zeta_lsq`.
    zeta_lsq : float
        Damping factor (as fraction of critical damping) calculated via
        least squares on positive peaks over the signal region used in
        processing.
    lsq_rsq : float
        R^2 value for the least squares fitting used in calculating
        `zeta_lsq`.
    norm_area_error : float
        Normalized area error metric that can be used to check measurement
        quality.

    Notes
    -----

    The last 0.5 seconds is always eliminated after filtering to avoid filter
    end effects.

    """

    ########################
    # Filter data

    fs = 1 / (time[1] - time[0])

    low = low_freq / (0.5 * fs)
    high = high_freq / (0.5 * fs)

    b, a = butter(filter_order, [low, high], btype='band')

    filtered = filtfilt(b, a, signal)

    ############
    # Second bandstop filter

    if show_plots:

        freq = np.fft.rfftfreq(len(time), d=1/fs)
        fft_orig = np.abs(np.fft.rfft(signal))
        fft_filtered = np.abs(np.fft.rfft(filtered))

        # normalize fft by signal length
        fft_orig[0] = fft_orig[0] / signal.shape[0]
        fft_orig[1:] = 2 * fft_orig[1:] / signal.shape[0]

        fft_filtered[0] = fft_filtered[0] / filtered.shape[0]
        fft_filtered[1:] = 2 * fft_filtered[1:] / filtered.shape[0]

        plt.figure(figsize=(12, 6))

        plt.subplot(2, 1, 1)
        plt.plot(time, signal, label='Original')
        plt.plot(time, filtered, label='Filtered', linewidth=2)
        plt.title(title)
        plt.xlabel('Time [s]')
        plt.ylabel('Amplitude [' + units + ']')
        plt.legend()


        plt.subplot(2, 1, 2)
        plt.plot(freq, fft_orig, label='Original')
        plt.plot(freq, fft_filtered, '--', label='Filtered')
        plt.title('FFT')
        plt.xlabel('Frequency [Hz]')
        plt.ylabel('Magnitude [' + units + ']')
        plt.legend()

        xmax = high_freq+20

        plt.xlim((0, xmax))

        mask = freq <= xmax
        ymin = np.minimum(fft_orig[mask].min(), fft_filtered[mask].min())
        ymax = np.maximum(fft_orig[mask].max(), fft_filtered[mask].max())
        plt.ylim((ymin, ymax))

        ax = plt.gca()
        ax.set_yscale('log')

        plt.tight_layout()
        plt.show()


        if save_name is not None:

            os.makedirs('figures', exist_ok=True)

            plt.plot(time, signal)
            plt.xlabel('Time [s]')
            plt.ylabel('Amplitude [' + units + ']')

            plt.xlim((0, time[-1]))
            ymax = 1.05*np.abs(signal).max()
            plt.ylim((-ymax, ymax))

            ax = plt.gca()
            ax.tick_params(which='major', left=True, right=True,
                       top=True, bottom=True,
                       direction='in')

            # Show ticks ontop of block data
            ax.set_axisbelow(False)
            ax.tick_params(zorder=10)

            plt.savefig('figures/' + save_name + 'raw_timeseries'+fig_ext,
                        dpi=300, bbox_inches='tight')
            plt.show()

            plt.plot(time, filtered, label='Filtered')
            plt.xlabel('Time [s]')
            plt.ylabel('Amplitude [' + units + ']')

            plt.xlim((0, time[-1]))
            plt.ylim((-ymax, ymax))

            ax = plt.gca()
            ax.tick_params(which='major', left=True, right=True,
                       top=True, bottom=True,
                       direction='in')

            # Show ticks ontop of block data
            ax.set_axisbelow(False)
            ax.tick_params(zorder=10)

            # using the filtered data with the fit as the reference plot.
            # plt.savefig('figures/'+save_name + 'filtered_timeseries'+fig_ext,
            #             dpi=300, bbox_inches='tight')
            plt.show()


    ########################

    # Peaks
    peak_inds = find_peaks(filtered)[0]

    # Do a histogram of the peaks to approximate the steady-state amplitude
    hist, bin_edges = np.histogram(filtered[peak_inds], bins=30)

    # want the highest amplitude peak in the histogram (excluding few points)
    hist[hist < 8] = 0
    # adding a zero size bin at end so can always find the last bin as a peak
    peak_hist = find_peaks(np.hstack((hist, 0)))[0][-1]

    peak_amp = filtered[peak_inds]
    mask = np.logical_and(peak_amp > bin_edges[peak_hist],
                          peak_amp < bin_edges[peak_hist+1])

    steady_amp = np.median(peak_amp[mask])

    max_amp = np.minimum(max_amp, steady_amp*max_amp_scale)

    # Set Final Time
    peak_inds = peak_inds[filtered[peak_inds] > min_amp]

    # Set Start Time
    peak_inds = peak_inds[filtered[peak_inds] < max_amp]

    # eliminate outlier cases where the period between points is way to high
    periods = np.hstack((np.diff(time[peak_inds]),
                        np.diff(time[peak_inds[-2:]]) ))

    mask = periods < 3 * np.median(periods)

    false_inds = np.where(np.logical_not(mask))[0]

    segment_starts = np.hstack((0, false_inds+1))
    segment_ends = np.hstack((false_inds, mask.shape[0]))

    segment_ind = np.argmax(segment_ends - segment_starts)

    mask[:segment_starts[segment_ind]] = False
    mask[segment_ends[segment_ind]:] = False

    peak_inds = peak_inds[mask]

    # extra fine tuning to eliminate some points from start of decay.
    if start_offset > 0.0:
        start_time = time[peak_inds[0]]

        peak_inds = peak_inds[time[peak_inds] > start_time + start_offset]

    # Never use the last 0.5 seconds of data - potential filter end effects
    peak_inds = peak_inds[time[peak_inds] < time[-1] - 0.5]


    ########################
    # Log dec for properties

    freq_id = (peak_inds.shape[0]-1) / (time[peak_inds[-1]]
                                        - time[peak_inds[0]])

    delta = 1/(peak_inds.shape[0]-1) \
                * np.log(filtered[peak_inds[0]] / filtered[peak_inds[-1]])

    zeta_log_dec = delta / np.sqrt(4*np.pi**2 + delta**2)

    ########################
    # Exponential Fit for Alternative Zeta
    # Derivation follows the same as for log dec between two points versus
    # slope of a line fit through log(X(nT))

    lsq_coef = np.vstack((time[peak_inds], np.ones_like(time[peak_inds]))).T

    lsq_amp = np.log(filtered[peak_inds])

    # approx solution is lsq_coef @ lsq_fit = lsq_amp
    lsq_fit = np.linalg.lstsq(lsq_coef, lsq_amp, rcond=None)

    slope = lsq_fit[0][0]

    # Damped natural frequency
    omega_d = freq_id * 2 * np.pi

    zeta_lsq = np.abs(slope) / np.sqrt(omega_d**2 + slope**2)

    # R^2 value for the least squares fit
    lsq_rsq = 1 - lsq_fit[1][0] / np.sum((lsq_amp - np.mean(lsq_amp))**2)


    ########################
    # Area Error Estimation Based on area between linear fit and peaks

    # Envelope fit of exponential decay
    exp_fit = np.vstack((time, np.ones_like(time))).T @ lsq_fit[0]

    # time[peak_inds], filtered[peak_inds]
    # time[peak_inds], exp_fit[peak_inds]
    # start_time, end_time

    ref_area = np.trapz(exp_fit[peak_inds] - exp_fit[peak_inds[-1]],
                        x=time[peak_inds])

    error_area = np.trapz(
        np.abs(exp_fit[peak_inds] - np.log(filtered[peak_inds])),
                        x=time[peak_inds])

    norm_area_error = np.abs(error_area / ref_area)

    ########################
    # Visualize data processing

    if show_plots:

        end_time = time[peak_inds[-1]]
        start_time = time[peak_inds[0]]

        plt.plot(time, filtered, label='Filtered')

        plt.axvline(start_time, color='k', label='Fit Region Bounds')
        plt.axvline(end_time, color='k')

        # # Envelope fit of exponential decay
        # exp_fit = np.vstack((time, np.ones_like(time))).T @ lsq_fit[0]
        plt.plot(time, np.exp(exp_fit), '--', color='#D55E00',
                 label='Loss Factor='+'{:.2e}'.format(2*zeta_lsq))


        plt.ylim((1.05*filtered.min(), 1.05*filtered.max()))

        ax = plt.gca()
        ax.tick_params(which='major', left=True, right=True,
                       top=True, bottom=True,
                       direction='in')

        plt.ylabel('Amplitude [' + units + ']')
        plt.xlabel('Time [s]')

        # Show ticks ontop of block data
        ax.set_axisbelow(False)
        ax.tick_params(zorder=10)
        plt.xlim((0, time[-1]))


        # Save plot
        if save_name is not None:
            plt.ylim((-ymax, ymax))

            legend = plt.legend(loc='lower center', bbox_to_anchor=(0.5, 1.05))
            legend.get_frame().set_edgecolor('black')

            os.makedirs('figures', exist_ok=True)
            plt.savefig('figures/' + save_name + 'linscale_decay'+fig_ext,
                        dpi=300, bbox_inches='tight')


        plt.title(title)
        plt.plot(time[peak_inds], filtered[peak_inds], 'o')

        # replot on top
        plt.plot(time, np.exp(exp_fit), '--', color='#D55E00')
        plt.show()

        ###################
        # Exponential Fit

        ylim_vals = [np.log(min_amp)-1, np.log(max_amp)+0.5]

        plt.plot(time[filtered > 0],
                 np.log(np.abs(filtered[filtered > 0])),
                 label='Positive Velocity')

        plt.axvline(start_time, color='k', label='Fit Region Bounds')
        plt.axvline(end_time, color='k')

        plt.plot(time, exp_fit, '--', color='#D55E00',
                 label='Fit (Area Error {:.1f}%)'.format(norm_area_error*100))

        plt.ylim(ylim_vals)
        plt.xlim((start_time-5, end_time+5))
        plt.ylabel('Log Amplitude [ln(' + units + ')]')
        plt.xlabel('Time [s]')


        ax = plt.gca()
        ax.tick_params(which='major', left=True, right=True,
                       top=True, bottom=True, direction='in')

        # Show ticks ontop of block data
        ax.set_axisbelow(False)
        ax.tick_params(zorder=10)

        legend = plt.legend(loc='lower center', bbox_to_anchor=(0.5, 1.05))
        legend.get_frame().set_edgecolor('black')

        if save_name is not None:
            os.makedirs('figures', exist_ok=True)
            plt.savefig('figures/' + save_name + 'logscale_decay'+fig_ext,
                        dpi=300, bbox_inches='tight')

        plt.plot(time[filtered < 0],
                 np.log(np.abs(filtered[filtered < 0])), ':',
                 label='Negative Velocity',
                 color='#009E73')

        # replot on top
        plt.plot(time, exp_fit, '--', color='#D55E00', label='Fit')

        plt.legend()
        plt.title(title)
        plt.show()

    return freq_id, zeta_log_dec, zeta_lsq, lsq_rsq, norm_area_error
