#!/usr/bin/env python

# Copyright (C) 2018 Ian Harry
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
Plot single-detector qscan of strain data
"""

import sys
import argparse
import logging
import numpy

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm

import pycbc.strain
import pycbc.version
import pycbc.results

# https://stackoverflow.com/questions/9978880/python-argument-parser-list-of-list-or-tuple-of-tuples
def t_window(s):
    try:
        start, end = map(float, s.split(','))
        return [start, end]
    except:
        raise argparse.ArgumentTypeError("Input must be start,end start,end")

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--output-file', required=True, help='Output plot')
parser.add_argument('--center-time', type=float, required=True,
                    help='Center plot on the given GPS time')
parser.add_argument('--time-windows', required=True, type=t_window,
                    nargs='+',
                    help='Use these set of times for time windows. Should '
                         'be provided as start1,end1 start2,end2 ...')

parser.add_argument('--qtransform-delta-t', default=0.001, type=float,
                    help='The time resolution to interpolate to (optional)')
parser.add_argument('--qtransform-delta-f', default=None, type=float,
                    help='Frequency resolution to interpolate to (optional)')
parser.add_argument('--qtransform-logfsteps', type=int, default=200,
                    help='Do a log interpolation (incompatible with '
                         '--qtransform-delta-f option) and set the number '
                         'of steps to take')
parser.add_argument('--qtransform-frange-lower', default=None, type=float,
                    help='Lower frequency at which to compute qtransform. '
                         'Optional, default=10')
parser.add_argument('--qtransform-frange-upper', default=None, type=float,
                    help='Upper frequency at which to compute qtransform. '
                         'Optional, default=Half of Nyquist')
parser.add_argument('--qtransform-qrange-lower', default=4, type=float,
                    help='Lower limit of the range of q to consider, '
                         'default=4')
parser.add_argument('--qtransform-qrange-upper', default=64, type=float,
                    help='Upper limit of the range of q to consider, '
                         'default=64')
parser.add_argument('--qtransform-mismatch', default=0.2, type=float,
                    help='Mismatch between frequency tiles, default=0.2')

parser.add_argument('--linear-y-axis', dest='log_y', default=True, 
                    action='store_false',
                    help='Use a linear y-axis. By default a log axis is used.')
parser.add_argument('--linear-colorbar', dest='log_colorbar', default=True,
                    action='store_false',
                    help='Use a linear colorbar scale.')
parser.add_argument('--plot-title',
                    help="If given, use this as the plot title")
parser.add_argument('--plot-caption',
                    help="If given, use this as the plot caption")

pycbc.strain.insert_strain_option_group(parser)
opts = parser.parse_args()

logging.basicConfig(format='%(asctime)s %(message)s', level=logging.INFO)

#opts.low_frequency_cutoff = opts.f_low
strain = pycbc.strain.from_cli(opts, pycbc.DYN_RANGE_FAC)

if opts.center_time is None:
    center_time = (opts.gps_start_time + opts.gps_end_time) / 2.
else:
    center_time = opts.center_time

if opts.qtransform_frange_upper is None and \
        opts.qtransform_frange_lower is None:
    curr_frange = (30, opts.sample_rate / 4.)
elif opts.qtransform_frange_upper is None or \
        opts.qtransform_frange_lower is None:
    err_msg = 'Must provide either both --qtransfor-frange-upper and '
    err_msg += '--qtransfor-frange-lower or neither option.'
    raise ValueError(err_msg)
else:
    curr_frange = (opts.qtransform_frange_lower, opts.qtransform_frange_upper)

rem_corrupted = True
if (center_time - strain.start_time) < 2 or (strain.end_time - center_time) < 2:
    rem_corrupted = False

strain = strain.whiten(4, 4, remove_corrupted=rem_corrupted)

wins = opts.time_windows
fig, axes = plt.subplots(len(wins),1)
times_all = []
freqs_all = []
qvals_all = []

for curr_idx in range(len(wins)):
    curr_win = wins[curr_idx]
    # Catch the case that not enough data is available.
    if opts.center_time - curr_win[0] < strain.start_time:
        curr_win[0] = float(opts.center_time - strain.start_time - 0.01)
    if opts.center_time + curr_win[1] > strain.end_time:
        curr_win[1] = float(strain.end_time - opts.center_time - 0.01)
    strain_zoom = strain.time_slice(opts.center_time - curr_win[0],
                                    opts.center_time + curr_win[1])

    times, freqs, qvals = strain_zoom.qtransform\
        (delta_t=opts.qtransform_delta_t, delta_f = opts.qtransform_delta_f,
         logfsteps=opts.qtransform_logfsteps, frange=curr_frange,
         qrange=(opts.qtransform_qrange_lower, opts.qtransform_qrange_upper),
         mismatch=opts.qtransform_mismatch)
    times_all.append(times)
    freqs_all.append(freqs)
    qvals_all.append(qvals)

max_qval = max([qvals.max() for qvals in qvals_all])

for curr_idx in range(len(wins)):
    ax = axes[curr_idx]
    times = times_all[curr_idx]
    freqs = freqs_all[curr_idx]
    qvals = qvals_all[curr_idx]
    curr_win = wins[curr_idx]

    norm=None
    if opts.log_colorbar:
        norm=LogNorm(vmin=1, vmax=max_qval)

    im = ax.pcolormesh(times - opts.center_time, freqs, qvals, norm=norm)
    ax.set_xlim(-curr_win[0], curr_win[1])
    ax.set_ylim(curr_frange[0], curr_frange[1])
    if opts.log_y:
        ax.set_yscale('log')

# https://stackoverflow.com/questions/6963035/pyplot-axes-labels-for-subplots
fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off',
                right='off')
plt.grid(False)
plt.xlabel('Time from {:.3f} (s)'.format(opts.center_time))
plt.ylabel('Frequency (Hz)')

# https://stackoverflow.com/questions/13784201/matplotlib-2-subplots-1-colorbar
cb = fig.colorbar(im, ax=axes.ravel().tolist())
cb.set_label('Normalized power')

if opts.plot_title is None:
    opts.plot_title = 'Q-transform plot around {:.3f}'.format(opts.center_time)
if opts.plot_caption is None:
    # FIXME: Someone please improve!
    opts.plot_caption = ("This shows the Q-transform as a function of time and "
                        "frequency")

pycbc.results.save_fig_with_metadata\
    (fig, opts.output_file, cmd=' '.join(sys.argv), fig_kwds={'dpi': 150},
     title=opts.plot_title, caption=opts.plot_caption)
