#!/usr/bin/env python

# Copyright (C) 2017 Vaibhav Tiwari

# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

"""
This script estimates black hole merger rates for two astro-physical models:
1) Uniform in log of the compoment mass
2) Power-law distribution of the masses

It performs weighted MonteCarlo integration to calculate the sensitive volume.
"""

__author__ = "Vaibhav Tiwari"
__email__ = "vaibhav.tiwari@ligo.org"
__version__ = "0.0"
__date__ = "31.10.2017"

import logging
import argparse
import pycbc.version

import os
import h5py
import pylab
import numpy as np
import scipy.stats as ss
from matplotlib.pyplot import cm
from pycbc.population import scale_injections as si
from pycbc.population import rates_functions as rf

from numpy import logaddexp, log, newaxis, expm1

import matplotlib as mpl

# Parse command line
parser = argparse.ArgumentParser(description = "Estimate rates for flat in "
                                     "log and power-law mass distribution.")
parser.add_argument('--sim-files', dest='sim_files', nargs='+', required=True, 
                  help="List of simulation files to estimate the sensitive "
                                                                   "volume")
parser.add_argument('--m_dist', dest = 'm_dist', nargs = '+', required=True,
                  help="Specify the mass distribution for the simulations")
parser.add_argument('--s_dist', dest = 's_dist', nargs = '+',required=True,
                  help="Specify the spin distribution for the simulations")
parser.add_argument('--d_dist', dest = 'd_dist', nargs = '+', required=True,
                  help="Specify the distance distribution for the "
                                                              "simulations")
parser.add_argument('--bank-file', dest='bank_file', required=True, 
                  help="File containing template bank used in the search.")
parser.add_argument('--statmap-file', dest='statmap_file', required=True,
                  help="File containing trigger information")
parser.add_argument('--prior-samples', dest='prior_samples', required=True,
                  help="File storing samples of prior for the analysis - "
                                      "posterior from the previous analysis")
parser.add_argument('--output-folder', dest='output_folder', required=True,
                  help="Output folder where files are to be saved.")
parser.add_argument('--thr-var', dest='thr_var', required=False, 
                 default = 'stat', help="Variable used to define threshold.")
parser.add_argument('--thr-val', dest='thr_val', type = float, required=False,
                 default = 8.0, help="Value of threshold.")
parser.add_argument('--min-mass', dest='min_mass', type = float, 
                   required=True, help="Minimum mass of the compact object.")
parser.add_argument('--max-mass', dest='max_mass', type = float, 
                   required=True, help="Maximum mass of the compact object.")
parser.add_argument('--max-mtotal', dest='max_mtotal', type = float, 
                     required=True, help="Maximum total mass of the binary.")
parser.add_argument('--min-tmplt-mass', dest='min_tmplt_mass', type = float, 
                 required=True, help="Minimum mass of the template "
                                    "considered for trigger identification.")
parser.add_argument('--max-tmplt-mass', dest='max_tmplt_mass', type = float, 
                 required=True, help="Maximum mass in the template considered"
                                              " for trigger identification.")
parser.add_argument('--calibration-error', dest='cal_err', type=float, 
                 required=False, default = 3.0, help="Percentage calibration"
                                       " errors in measurement of distance.")
    
opts = parser.parse_args()
path = opts.output_folder + '/astro_output'
if path is not None and not os.path.exists(path):
    os.makedirs(path)

# Read the simulation files
injections = si.read_injections(opts.sim_files,
          opts.m_dist, opts.s_dist, opts.d_dist)

# Read the chirp-mass samples -- Imported from rates_function
mchirp_sampler, prob, label = {}, {}, {}
distrs = ['lnm', 'imf']
mchirp_sampler['lnm'] = rf.mchirp_sampler_lnm
mchirp_sampler['imf'] = rf.mchirp_sampler_imf

prob['lnm'] = rf.prob_lnm
prob['imf'] = rf.prob_imf
label['lnm'] = 'Flat in log'
label['imf'] = 'Power-law'

# Estimate the rates and make supporting plots
vt, vol_time, sig_vt, inj_falloff = {}, {}, {}, {}
for dist in distrs:
    vt[dist] = si.estimate_vt(injections, mchirp_sampler[dist], prob[dist], 
                              thr_var = opts.thr_var, 
                              thr_val = opts.thr_val, 
                              min_mass = opts.min_mass, 
                              max_mass = opts.max_mass, 
                              max_mtotal = opts.max_mtotal)
    
    vol_time[dist], sig_vt[dist] = vt[dist]['VT'], vt[dist]['VT_err']
    inj_falloff[dist] = vt[dist]['thr_falloff']

#Include the calibration uncertainity
sigma_w_cal_uncrt = {}
for dist in distrs:
    vol, vol_err, cal_err = vol_time[dist], sig_vt[dist], opts.cal_err
    sigma_w_cal_uncrt[dist] = np.sqrt((3*cal_err/100.)**2 + (vol_err/vol)**2)

#Sabe background data and coincidences
rf.save_bkg_falloff(opts.statmap_file, opts.bank_file,
                    path, opts.thr_val, opts.min_tmplt_mass, 
                    opts.max_tmplt_mass)

#Load background data and coincidences/ make some plots
coincs = np.loadtxt(path + "/coincs.txt")
bg_l, bg_h, bg_counts = np.loadtxt(path + "/background_bins.txt", unpack=True)
bg_bins = np.append(bg_l, bg_h[-1])

fg_stats = np.concatenate([inj_falloff[dist] for dist in distrs])
fg_bins = np.logspace(np.log10(opts.thr_val), np.log10(np.max(fg_stats)), 101)
fg_stats = fg_stats[fg_stats > opts.thr_val]

log_fg_ratios = rf.log_rho_fg_mc(coincs, fg_stats, fg_bins) 
log_fg_ratios -= rf.log_rho_bg(coincs, bg_bins, bg_counts)

#Load prior samples and fit a skew-log-normal to it
with h5py.File(opts.prior_samples, "r") as f:
    Rfl = np.array(f['flat/Rf'])
    Rpl = np.array(f['power-law/Rf'])
    
prior, alpha, mu, sigma, minrp, maxrp = {}, {}, {}, {}, {}, {}
prior['lnm'], prior['imf'] = Rfl, Rpl
alpha['lnm'], mu['lnm'], sigma['lnm'] = rf.fit(Rfl)
alpha['imf'], mu['imf'], sigma['imf'] = rf.fit(Rpl)

#Estimate rates
rate_samples = {}

for dist in distrs:
    
    rate_samples[dist], log_R = {}, log(prior[dist])
    mu_log_vt = np.log(vol_time[dist]/1e9)
    sigma_log_vt = sigma_w_cal_uncrt[dist]
    Rf_samp = rf.skew_lognormal_samples(alpha[dist], mu[dist], 
                        sigma[dist], min(log_R), max(log_R))
    rate_samples[dist]['Rf'], rate_samples[dist]['Lf'], rate_samples[dist]['Lb'] = \
    rf.fg_mc(log_fg_ratios, mu_log_vt, sigma_log_vt, Rf_samp, max(fg_stats))

r50, r95, r05 = {}, {}, {}    
for dist in distrs:
    rate_post = rate_samples[dist]['Rf']
    r50[dist], r95[dist], r05[dist] = np.percentile(rate_post, [50, 95, 5])
    
#Save rate posteriors
with h5py.File(path+'/rate-posterior.hdf5', 'w') as out:

    pl = out.create_group('power-law')
    pl.create_dataset('Lf', data=rate_samples['imf']['Lf'], compression='gzip')
    pl.create_dataset('Rf', data=rate_samples['imf']['Rf'], compression='gzip')
    
    fl = out.create_group('flat')
    fl.create_dataset('Lf', data=rate_samples['lnm']['Lf'], compression='gzip')
    fl.create_dataset('Rf', data=rate_samples['lnm']['Rf'], compression='gzip')
    
    d = out.create_group('data')
    d.create_dataset('log_fg_bg_ratio', data=log_fg_ratios, compression='gzip')
    d.create_dataset('newsnr', data=coincs, compression='gzip')

# Make prior/posterior plot
pylab.figure()
color=iter(cm.rainbow(np.linspace(0, 1, 2)))
for ii, dist in enumerate(distrs):
    c = next(color)
    post_alpha, post_mu, post_sigma = rf.fit(rate_samples[dist]['Rf'])

    log_R = np.log(prior[dist])
    xs = np.linspace(min(log_R), max(log_R), 200)
    pylab.plot(xs, ss.skewnorm.pdf(xs, alpha[dist], mu[dist], 
            sigma[dist]), '--', label=label[dist] + ' Prior', color=c)
    pylab.plot(xs, ss.skewnorm.pdf(xs, post_alpha, post_mu, 
               post_sigma), label=label[dist] + ' Posterior', color=c)


pylab.xlabel(r'$R$ ($\mathrm{Gpc}^{-3} \, \mathrm{yr}^{-1}$)')
pylab.ylabel(r'$RP(R)$')

pylab.legend(loc='best')
pylab.savefig(path+'/rate_prior_posterior.png')

#Estimate p_astro for top 5 events
p_astro = []
pylab.figure()
color=iter(cm.rainbow(np.linspace(0, 1, 2)))
for ii, dist in enumerate(distrs):
    c = next(color)
    log_pastros = logaddexp.reduce(log(rate_samples[dist]['Lf'][:, newaxis]) +\
                                                    log_fg_ratios[newaxis,:] -\
                         logaddexp(log(rate_samples[dist]['Lf'][:, newaxis]) +\
                                                      log_fg_ratios[newaxis,:], 
                         log(rate_samples[dist]['Lb'][:, newaxis])), axis=0) -\
                                         log(rate_samples[dist]['Lf'].shape[0])
    p_astro.append(1 + expm1(np.sort(log_pastros)[::-1]))

    pylab.plot(log_fg_ratios, -expm1(log_pastros), '.', 
                                                  label=label[dist], color=c)

pylab.xlabel(r'$\log p(x\mid f)/p(x\mid b)$')
pylab.ylabel(r'$1-p_\mathrm{astro}$')
pylab.legend(loc='best')
pylab.yscale('log')
pylab.savefig(path+'/p_astro.png')
   
print p_astro
