#!/usr/bin/env /usr/bin/python
import logging
import sys
import numpy
from optparse import OptionParser, OptionGroup

import pycbc.vetoes
import pycbc.psd
import pycbc.waveform
import pycbc.events
import pycbc.noise

from pycbc.waveform import TemplateBank
from pycbc.filter import resample_to_delta_t, highpass, make_frequency_series
from pycbc.filter import  matched_filter_core, sigmasq
from pycbc.scheme import CUDAScheme, CPUScheme, OpenCLScheme
from pycbc.types import TimeSeries, FrequencySeries, float32, complex64, zeros, Array
from pycbc.frame import read_frame
from pycbc.inject import InjectionSet

# Option Parsing ##############################################################
parser = OptionParser(
    usage   = "%prog [OPTIONS]",
    description = "Find single detector gravitational-wave triggers." )

parser.add_option("-V", "--verbose", action="store_true", help="print extra debugging information", default=False )
parser.add_option("--bank-file", type=str)
parser.add_option("--bank-veto-bank-file", type=str)
parser.add_option("--injection-file", type=str)
parser.add_option("--trig-start-time", type=int, default=0)
parser.add_option("--trig-end-time", type=int, default=0)
parser.add_option("--ifo-tag", type=str)
parser.add_option("--user-tag", type=str)
parser.add_option("--approximant", help="Approximant to use for filtering.", type=str)
parser.add_option("--order", type=str)
parser.add_option("--gps-start-time", help="The gps start time of the data", type=int)
parser.add_option("--gps-end-time", help="The gps end time of the data", type=int)
parser.add_option("--snr-threshold", help="The the minimum snr threshold", type=float)
parser.add_option("--strain-high-pass", type=float)
parser.add_option("--chisq-bins", type=int, default=0)
parser.add_option("--chisq-threshold", type=float, default=0)
parser.add_option("--chisq-delta", type=float, default=0)
parser.add_option("--low-frequency-cutoff", help="The low frequency cutoff to use for filtering (Hz)", type=float)
parser.add_option("--pad-data", help="Extra padding to remove highpass corruption (s)", type=int)
parser.add_option("--sample-rate", help="The sample rate to use for filtering (Hz)", type=int)
parser.add_option("--cluster-method", choices=["template", "window"])
parser.add_option("--cluster-window", type=int, help="Length of clustering window in seconds", default = -1)
parser.add_option("--processing-scheme", help="The processing scheme to use", choices=["cpu", "cuda", "opencl"], default="cpu")
parser.add_option("--processing-device-id", help="ID of gpu to use for accelerated processineg", default=0, type=int)
parser.add_option("--segment-length", type=int, help="The length of a segment (s)")
parser.add_option("--segment-start-pad", type=int, help="Length of corruption at the beginning of a segment(s)")
parser.add_option("--segment-end-pad", type=int, help="Length of corruption at the end of a segment (s)")
parser.add_option("--channel-name", help="The channel containing the gravitational strain data", type=str)
parser.add_option("--frame-cache",help="Cache file containing the frame locations.", type=str)
parser.add_option("--fake-strain", help="Name of psd for generating fake gaussain noise", type=str)
parser.add_option("--fake-strain-seed", help="Seed value for the generation of fake colored gaussian noise", type=int, default=0)
parser.add_option("--maximization-interval", help="Maximize triggers over the template bank (ms)", type=int, default=0)

print "THIS IS THE TESTING VERSION"
# Add options for PSD generation
pycbc.psd.insert_psd_option_group(parser)

opt, argv = parser.parse_args()

# Check that the PSD options make sense
pycbc.psd.verify_psd_options(opt, parser)

if opt.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN

logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

if opt.processing_scheme == "cuda":
    logging.info("Running with CUDA support")
    ctx = CUDAScheme(opt.processing_device_id)
elif opt.processing_scheme == "opencl":
    logging.info("Running with OpenCL support")
    ctx = OpenCLScheme()
else:
    logging.info("Running with CPU support only")
    ctx = CPUScheme()

# Data Preconditioning ########################################################
if opt.frame_cache:
    logging.info("Reading Frames")
    strain = read_frame(opt.frame_cache, opt.channel_name, 
                        start_time=opt.gps_start_time-opt.pad_data, 
                        end_time=opt.gps_end_time+opt.pad_data)                     

    if opt.injection_file:
        logging.info("Applying injections")
        injections = InjectionSet(opt.injection_file)
        injections.apply(strain, opt.channel_name[0:2])

    logging.info("Highpass Filtering")
    strain = highpass(strain, frequency=opt.strain_high_pass)

    logging.info("Converting to float32")
    strain = (strain * pycbc.DYN_RANGE_FAC).astype(float32)   
    strain.save( "raw_pycbc.npy")

    logging.info("Resampling data")
    strain = resample_to_delta_t(strain, 1.0/opt.sample_rate, method='broken_butterworth')   
    strain.save("resampled_pycbc.npy")
    
    logging.info("Highpass Filtering")
    strain = highpass(strain, frequency=opt.strain_high_pass)

    logging.info("Remove Padding")
    start = opt.pad_data*opt.sample_rate
    end = len(strain)-opt.sample_rate*opt.pad_data
    strain = strain[start:end]    
    strain.save("conditioned_pycbc.npy")
    
if opt.fake_strain:
    logging.info("Generating Fake Strain")
    duration = opt.gps_end_time - opt.gps_start_time
    tlen = duration * opt.sample_rate
    pdf = 1.0/128
    plen = int(opt.sample_rate / pdf) / 2 + 1

    logging.info("Making PSD for strain")
    strain_psd = pycbc.psd.from_string(opt.fake_strain, plen, 
                                       pdf, opt.low_frequency_cutoff)
    
    logging.info("Making colored noise")
    strain = pycbc.noise.noise_from_psd(tlen, 1.0/opt.sample_rate, strain_psd, 
                                        seed=opt.fake_strain_seed)

    if opt.injection_file:
        logging.info("Applying injections")
        injections = InjectionSet(opt.injection_file)
        injections.apply(strain, opt.channel_name[0:2])
    
    logging.info("Converting to float32")
    strain = (pycbc.DYN_RANGE_FAC * strain).astype(float32)

with ctx:

    logging.info("Making frequency-domain data segments")
    segments = []
    seg_width_idx = (opt.segment_length - opt.segment_start_pad - 
                     opt.segment_end_pad) * opt.sample_rate
    seg_len_idx = opt.segment_length * opt.sample_rate
    seg_start_idx = 0
    while seg_start_idx + seg_width_idx < len(strain):
        seg_end_idx = seg_start_idx + seg_len_idx
        seg = make_frequency_series(strain[seg_start_idx:seg_end_idx])
        segments.append(seg)
        seg_start_idx += seg_width_idx    

    logging.info("Computing noise PSD")
    psd = pycbc.psd.from_cli(opt, len(segments[0]), segments[0].delta_f,
                             opt.low_frequency_cutoff, strain,
                             pycbc.DYN_RANGE_FAC).astype(float32)
                             
    numpy.savetxt( "psd_pycbc.txt", psd.numpy())
        
    bank_chisq = pycbc.vetoes.SingleDetBankVeto(opt.bank_veto_bank_file, 
                                          opt.approximant, psd, segments,
                                          opt.low_frequency_cutoff, 
                                          phase_order=opt.order)                                       
    power_chisq = pycbc.vetoes.SingleDetPowerChisq(opt.chisq_bins)
  
    logging.info("Overwhitening frequency-domain data segments")
    for seg in segments:
        seg /= psd
            
    template_work_mem = zeros((len(segments[0])-1)*2, dtype=complex64)
    snr_work_mem = zeros((len(segments[0])-1)*2, dtype=complex64)
    corr_work_mem = zeros((len(segments[0])-1)*2, dtype=complex64)

    event_mgr = pycbc.events.EventManager(opt, 
                         ['time_index', 'snr', 'chisq', 'bank_chisq'], 
                         [int, complex64, float32, float32, float32], psd=psd)

    logging.info("Read in template bank")
    try:
        bank = TemplateBank(opt.bank_file, opt.approximant, len(segments[0]), 
                        segments[0].delta_f, opt.low_frequency_cutoff, 
                        dtype=complex64, phase_order=opt.order,
                        psd = psd, out=template_work_mem)  
    except:
        logging.info("No templates found in template bank, so writing process "
                     "table and exiting.")
        event_mgr.write_events()
        sys.exit(0)

    for t_num, template in enumerate(bank):  
        event_mgr.new_template(tmplt=template.params, sigmasq=template.sigmasq)                                                                     
        cumulative_index = opt.segment_start_pad * opt.sample_rate
        
        for s_num, stilde in enumerate(segments):    
            cumulative_index += seg_width_idx
              
            logging.info("Filtering template %d/%d segment %d/%d" % \
                         (t_num + 1, len(bank), s_num + 1, len(segments)))

            snr_start = (opt.segment_start_pad) * opt.sample_rate
            snr_end  =  (opt.segment_length - opt.segment_end_pad) * opt.sample_rate

            if opt.trig_start_time:
                seg_start =(cumulative_index - seg_width_idx)/opt.sample_rate + opt.gps_start_time
                if seg_start < opt.trig_start_time:
                    snr_start = (opt.trig_start_time - seg_start) * opt.sample_rate + snr_start         
            if opt.trig_end_time:
                seg_end = (cumulative_index)/opt.sample_rate +  opt.gps_start_time
                if seg_end > opt.trig_end_time:
                    snr_end = snr_end - (seg_end - opt.trig_end_time) * opt.sample_rate                
            if snr_start >= snr_end:
                continue
                
            snr, corr, norm = matched_filter_core(template, stilde, h_norm=template.sigmasq, 
                                                  low_frequency_cutoff=opt.low_frequency_cutoff, 
                                                  out=snr_work_mem, corr_out=corr_work_mem)                                         
            idx, snrv = pycbc.events.threshold(snr[snr_start:snr_end], opt.snr_threshold / norm) 
            
            if len(idx) == 0:
                continue          
                
            numpy.savetxt( "snr_pycbc.txt", snr[0:10000].squared_norm().numpy() * norm * norm)       
             
            logging.info("%s points above threshold" % str(len(idx)))
            bank_chisqv = bank_chisq.values(template, s_num, snr, norm, idx+snr_start)
            power_chisqv, chisqdof = power_chisq.values(corr, snr, norm, psd, idx+snr_start, template, bank, opt.low_frequency_cutoff)
            
            snrv *= norm
            idx += (cumulative_index - seg_width_idx + (snr_start - opt.segment_start_pad * opt.sample_rate))
            event_mgr.add_template_events( ["snr", "time_index", "chisq", "bank_chisq"], [snrv, idx, power_chisqv, bank_chisqv])

        if opt.cluster_method == "window":
            cluster_window = opt.cluster_window * opt.sample_rate
        elif opt.cluster_method == "template":
            cluster_window = template.length_in_time * opt.sample_rate
            
        event_mgr.cluster_template_events("time_index", "snr", cluster_window) 
        event_mgr.finalize_template_events()   

num_events = len(event_mgr.events)
logging.info("Found %s triggers" % str(num_events))

if opt.chisq_threshold and opt.chisq_bins:
    logging.info("Removing triggers with poor chisq")
    event_mgr.chisq_threshold(opt.chisq_threshold, opt.chisq_bins, opt.chisq_delta)
    logging.info("%s remaining triggers" % str(len(event_mgr.events)))

if opt.maximization_interval:       
    logging.info("Maximizing triggers over %s ms window" % opt.maximization_interval)
    event_mgr.maximize_over_bank("time_index", "snr", opt.maximization_interval * opt.sample_rate / 1000)
    logging.info("%s remaining triggers" % str(len(event_mgr.events)))

logging.info("Writing out triggers")
event_mgr.write_events()
logging.info("Finished")
