#!/usr/bin/env python
"""Generate a bank of templates using a brute force stochastic method.
"""
import numpy, h5py, logging, argparse, numpy.random, sys
import pycbc.waveform, pycbc.filter, pycbc.types, pycbc.psd, pycbc.fft, pycbc.pnutils

parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--verbose', action='store_true')
parser.add_argument('--output-file', required=True,
    help='Output file name for template bank.')
parser.add_argument('--input-file',
    help='Bank to use as a starting point.')
parser.add_argument('--params', required=True,
    help='list of paramaters to use', nargs='+')
parser.add_argument('--min', required=True,
    help='list of the minimum parameter values', nargs='+', type=float)
parser.add_argument('--max',  required=True,
    help='list of the maximum parameter values', nargs='+', type=float)
parser.add_argument('--approximant',  required=True,
    help='The waveform approximant to place')
parser.add_argument('--minimal-match', default=0.97, type=float)
parser.add_argument('--buffer-length', default=2, type=float,
    help='size of waveform buffer in seconds')
parser.add_argument('--sample-rate', default=2048, type=int,
    help='sample rate in seconds')
parser.add_argument('--low-frequency-cutoff', default=20.0, type=float)
parser.add_argument('--enable-sigma-bound', action='store_true')
parser.add_argument('--tau0-threshold', type=float)
parser.add_argument('--permissive', action='store_true',
    help='Allow waveform generator to fail.')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--tolerance', type=float)
parser.add_argument('--max-mtotal', type=float)
pycbc.psd.insert_psd_option_group(parser)
args = parser.parse_args()
pycbc.init_logging(args.verbose)
numpy.random.seed(args.seed)

class Shrinker(object):
    def __init__(self, data):
        self.data = data 

    def pop(self):
        if len(self.data) == 0:
            return None
        l = self.data[-1]
        self.data = self.data[:-1]
        return l

class TriangleBank(object):
    """ A bank of templates that uses the triangle inequality to estimate
    matches based on prior ones.
    """
    def __init__(self, p=None):
        self.waveforms = p if p is not None else []

    def __len__(self):
        return len(self.waveforms)

    def insert(self, hp):
        self.waveforms.append(hp)

    def __getitem__(self, index):
        return self.waveforms[index]
        
    def key(self, k):
        return numpy.array([p.params[k] for p in self.waveforms])
        
    def sigma_match_bound(self, sig):
        if not hasattr(self, 'sigma'):
            self.sigma = None
        if self.sigma is None or len(self.sigma) != len(self):
            self.sigma = numpy.array([h.s for h in bank.waveforms])
        return numpy.minimum(sig / self.sigma, self.sigma / sig)

    def range(self):
        if not hasattr(self, 'r'):
            self.r = None
        if self.r is None or len(self.r) != len(self):
            self.r = numpy.arange(0, len(self))
        return self.r

    def tau0(self):
        if not hasattr(self, 't0'):
            self.t0 = None
        if self.t0 is None or len(self.t0) != len(self):
            self.t0 = numpy.array([h.tau0 for h in self])
        return self.t0

    def __contains__(self, hp):  
        mmax = 0  
        mnum = 0
        #Apply sigmas maximal match.
        if args.enable_sigma_bound:
            matches = self.sigma_match_bound(hp.s)            
            r = self.range()[matches > hp.threshold]
        else:
            matches = numpy.ones(len(self))
            r = self.range()

        msig = len(r)

        #Apply tua0 threshold
        if args.tau0_threshold:
            hp.tau0, _ = pycbc.pnutils.mass1_mass2_to_tau0_tau3(
                                            hp.params['mass1'], 
                                            hp.params['mass2'], 15.0)
            r = r[abs(self.tau0()[r] - hp.tau0) < args.tau0_threshold]
        mtau = len(r)

        # Try to do some actual matches
        inc = Shrinker(r*1)         
        while 1:
            j = inc.pop()
            if j is None:
                hp.matches = matches[r]
                hp.indices = r
                logging.info("Template is Added MaxMatch:%0.3f BankSize:%i "
                             "AfterSigma:%i AfterTau0:%i Matches:%i"
                              % (mmax, len(self), msig, mtau, mnum))
                return False

            hc = self[j]
            m = hp.gen.match(hp, hc)
            matches[j] = m
            mnum += 1
            
            # Update bounding match values, apply triangle inequality
            maxmatches = hc.matches - m + 1.10
            update = numpy.where(maxmatches < matches[hc.indices])[0]
            matches[hc.indices[update]] = maxmatches[update]

            # Update where to calculate matches
            skip_threshold = 1 - (1 - hp.threshold) * 2.0
            inc.data = inc.data[matches[inc.data] > skip_threshold]
            
            if m > hp.threshold:
                return True
            if m > mmax:
                mmax = m     

    def check_params(self, gen, params, threshold):
        num_tried = 0
        num_added = 0
        for i in range(len(params.values()[0])):
            num_tried += 1.0

            try:
                hp = gen.generate(**{key:params[key][i] for key in params})
            except Exception as err:
                print(err)
                continue
    
            hp.gen = gen
            hp.threshold = threshold
            if hp not in self:
                num_added += 1
                self.insert(hp)

        return bank, num_added / float(num_tried)

class GenUniformWaveform(object):
    def __init__(self, buffer_length, sample_rate, f_lower):
        self.f_lower = f_lower
        self.delta_f = 1.0 / buffer_length
        tlen = int(buffer_length * sample_rate)
        self.flen = tlen / 2 + 1
        psd = pycbc.psd.from_cli(args, self.flen, self.delta_f, self.f_lower)
        self.kmin = int(f_lower * buffer_length)
        self.w = ((1.0 / psd[self.kmin:-1]) ** 0.5).astype(numpy.float32)
        qtilde = pycbc.types.zeros(tlen, numpy.complex64)
        q = pycbc.types.zeros(tlen, numpy.complex64)
        self.qtilde_view = qtilde[self.kmin:self.flen - 1]
        self.ifft = pycbc.fft.IFFT(qtilde, q)
        self.md = q._data[-100:]
        self.md2 = q._data[0:100]

    def generate(self, **kwds):
        if kwds['approximant'] in pycbc.waveform.fd_approximants():
            hp, hc = pycbc.waveform.get_fd_waveform(delta_f=self.delta_f, 
                                                f_lower=self.f_lower, **kwds)
            if 'fratio' in kwds:
                hp = hc * kwds['fratio'] + hp * (1 - kwds['fratio'])
        else:
            dt = 1.0 / args.sample_rate
            hp = pycbc.waveform.get_waveform_filter(
                        pycbc.types.zeros(self.flen, dtype=numpy.complex64), 
                        delta_f=self.delta_f, delta_t=dt,
                        f_lower=self.f_lower, **kwds) 

        hp.resize(self.flen)
        hp = hp.astype(numpy.complex64)
        hp[self.kmin:-1] *= self.w
        s = float(1.0 / pycbc.filter.sigmasq(hp, low_frequency_cutoff=self.f_lower) ** 0.5)
        hp *= s
        hp.params = kwds
        hp.view = hp[self.kmin:-1]
        hp.s = (1.0 / s) ** 2.0
        return hp

    def match(self, hp, hc):
        pycbc.filter.correlate(hp.view, hc.view, self.qtilde_view)
        self.ifft.execute()
        m = max(abs(self.md).max(), abs(self.md2).max())
        return m * 4.0 * self.delta_f

r = 0
if not args.tolerance:
    tolerance = (1 - args.minimal_match) / 10
else:
    tolerance = args.tolerance
conv = 1
size = int(1.0 / tolerance)

gen = GenUniformWaveform(args.buffer_length, 
    args.sample_rate, args.low_frequency_cutoff)
bank = TriangleBank()

if args.input_file:
    f = h5py.File(args.input_file, 'r')
    params = {k: f[k][:] for k in f}
    bank, _ = bank.check_params(gen, params, args.minimal_match)

while conv > tolerance:
    r += 1
    params = {}
    for name, pmin, pmax in zip(args.params, args.min, args.max):
        params[name] = numpy.random.uniform(pmin, pmax, size=size)
    params['approximant'] = numpy.array([args.approximant]*len(params[name]))

    # Filter out stuff
    if args.max_mtotal:
        l = params['mass1'] + params['mass2'] < args.max_mtotal
        for k in params:
            params[k] = params[k][l]

    blen = len(bank)
    bank, conv = bank.check_params(gen, params, args.minimal_match)
    logging.info("Round: %s Size: %s conv: %s added: %s",
                 r, len(bank), conv, len(bank) - blen)

o = h5py.File(args.output_file, 'w')
for k in params:
    o[k] = bank.key(k)
