#! /usr/bin/env python

from __future__ import division
import logging
import os
import re
import zipfile


from lab.parser import Parser


PATTERN_FLOAT = r"([+-]?\d+(\.\d+)?(e[+-]?\d+)?)"
TIT = "time_incl_training"


def parse_regex(regex, group, type):
    def _parse_regex(content):
        m = regex.search(content)
        return None if m is None else type(m.group(group))
    return _parse_regex


parse_time_incl_training = parse_regex(re.compile(r"Time since training begin: (\d+(\.\d+)?)"), 1, float)


class CommonParser(Parser):
    def add_difference(self, diff, val1, val2):
        def diff_func(content, props):
            if props.get(val1) is None or props.get(val2) is None:
                diff_val = None
            else:
                diff_val = props.get(val1) - props.get(val2)
            props[diff] = diff_val
        self.add_function(diff_func)

    def _get_flags(self, flags_string):
        flags = 0
        for char in flags_string:
            flags |= getattr(re, char)
        return flags

    def add_repeated_pattern(
            self, name, regex, file="run.log", required=False, type=int,
            flags="", group=None, only_count=False, add_final=False):
        flags += "M"
        assert not (only_count and add_final)
        def find_all_occurrences(content, props):
            matches = re.findall(regex, content, flags=self._get_flags(flags))
            if required and not matches:
                logging.error("Pattern {0} not found in file {1}".format(
                    regex, file))
            if only_count:
                props[name] = len(matches)
            else:
                props[name] = [type(m if group is None else m[group])
                               for m in matches]
                if add_final and len(props[name]) > 0:
                    props["final_%s" % name] = props[name][-1]

        self.add_function(find_all_occurrences, file=file)

    def add_bottom_up_pattern(self, name, regex, file="run.log",
                              required=False, type=int, flags="", group=1):

        def search_from_bottom(content, props):
            reversed_content = "\n".join(reversed(content.splitlines()))
            match = re.search(regex, reversed_content,
                              flags=self._get_flags(flags))
            if required and not match:
                logging.error("Pattern {0} not found in file {1}".format(
                    regex, file))
            if match:
                props[name] = type(match.group(group))

        self.add_function(search_from_bottom, file=file)


def expansions_per_second(content, props):
    expansions = props.get("expansions")
    search_time = props.get("search_time")
    if expansions is not None and search_time is not None:
        props["expansions_per_second"] = expansions / search_time


def samples_req_gen_ratio(content, props):
    gen = props.get("samples_generated")
    req = props.get("samples_requested")
    if gen is not None and req is not None and gen != 0:
        props["samples_req_gen_ratio"] = float(req)/gen


def run_training_parser():
    parser = CommonParser()

    parser.add_repeated_pattern(
        "loss",
        r"- loss: ({PATTERN_FLOAT}) -".format(**globals()),
        group=1, type=float)

    parser.add_repeated_pattern(
        "different_inputs",
        r"- inputs:\s*(\d+-\d+|(\d+:\d+(,\s*\d+:\d+)*))\s*-",
        type=lambda x: ((int(x.split("-")[1]) - int(x.split("-")[0]))
                      if x.find("-") > -1 else (x.count(",") + 1)),
        group=0, add_final=True
    )

    parser.add_repeated_pattern(
        "different_predictions",
        r"- predictions:\s*(\d+-\d+|(\d+:\d+(,\s*\d+:\d+)*))\s*-",
        type=lambda x: ((int(x.split("-")[1]) - int(x.split("-")[0]))
                      if x.find("-") > -1 else (x.count(",") + 1)),
        group=0, add_final=True
    )

    parser.add_bottom_up_pattern(
        "final_training_time",
        r"- time\(total, train, wait, avg wait\):"
        r"\s*{PATTERN_FLOAT}s,\s*({PATTERN_FLOAT})s,\s*{PATTERN_FLOAT}s,"
        r"\s*{PATTERN_FLOAT}s".format(**globals()),
        type=float, group=4)
    parser.add_bottom_up_pattern(
        "average_waiting_time",
        (r"- time\(total, train, wait, avg wait\):"
         r"\s*{PATTERN_FLOAT}s,\s*{PATTERN_FLOAT}s,\s*{PATTERN_FLOAT}s,"
         r"\s*({PATTERN_FLOAT})s").format(**globals()),
        type=float, group=10)

    parser.add_bottom_up_pattern(
        "samples_generated",
        r"Sample stats: (\d+) \(generated\), \d+ \(trained on\)",
        type=int, group=1)

    parser.add_bottom_up_pattern(
        "samples_requested",
        r"Sample stats: \d+ \(generated\), (\d+) \(trained on\)",
        type=int, group=1)

    parser.add_function(samples_req_gen_ratio, file="run.log")



    parser.add_repeated_pattern(
        "model_updates", r"Converted \d+ variables to const ops.",
        only_count=True)

    parser.add_repeated_pattern(
        "training_and_evaluation_finished",
        r"RL Training and Evaluation finished.",
        only_count=True)

    parser.add_bottom_up_pattern(
        "max_scrambles",
        r"Sampling> Increased max scrambles to (\d+)",
        type=int, group=1
    )

    parser.add_bottom_up_pattern("final_epochs", r"epoch: (\d+) -", type=int)

    parser.add_function(expansions_per_second, file="run.log")
    parser.add_bottom_up_pattern("time_experiment", r"Total experiment time: (\d+(\.\d+)?)s", type=float)

    def add_time_incl_training(content, props):
        props[TIT] = parse_time_incl_training(content)
    parser.add_function(add_time_incl_training)

    parser.parse()


BASIC_SEARCH_STATS = {
    "coverage": lambda x: 1 if x.find("Solution found.") > -1 else 0,
    "expansions": parse_regex(re.compile(r"Expanded (\d+) state\(s\)\."), 1, int),
    "plan_length": parse_regex(re.compile(r"Plan length: (\d+) step\(s\)\."), 1, int),
    "search_time": parse_regex(re.compile(r"Search time: (\d\.\d+)s"), 1, float),
    TIT: parse_time_incl_training
}


def parse_basic_search_stats(content):
    return {p: f(content) for p, f in BASIC_SEARCH_STATS.items()}


INIT_LOG = "init.log"
INTER_LOG = "inter.zip"


def parse_initial_and_intermediate_evaluations(content, props):
    if os.path.isfile(INIT_LOG):
        new_props = parse_basic_search_stats(open(INIT_LOG, "r").read())
        for k, v in new_props.items():
            props["init_{}".format(k)] = v
    if os.path.isfile(INTER_LOG):
        intermediate_searches = parse_regex(
            re.compile("Intermediate searches started: (\d+)"), 1, int)(content)
        assert intermediate_searches is not None and intermediate_searches >= 0
        accumulated_props = {k: [] for k in BASIC_SEARCH_STATS.keys()}
        def add_unsolved_entries(x):
            for i in range(x):
                for k in accumulated_props.keys():
                    accumulated_props[k].append(0 if k == "coverage" else None)

        with zipfile.ZipFile(INTER_LOG, "r") as zfile:
            # filenames start at run1, .... maximal run${intermediate_searches}
            zinfos = {int(info.filename[3:]): info for info in zfile.infolist()}
            assert max(zinfos.keys()) <= intermediate_searches
            last_search_index = 0
            for search_index in sorted(zinfos.keys()):
                add_unsolved_entries(search_index - last_search_index - 1)
                with zfile.open(zinfos[search_index], "r") as f:
                    new_props = parse_basic_search_stats(f.read())
                    for k, v in new_props.items():
                        accumulated_props[k].append(v)
                last_search_index = search_index
            add_unsolved_entries(intermediate_searches - last_search_index)

        for k, v in accumulated_props.items():
            props["inter_{}".format(k)] = v

        props["sum_inter_coverages"] = sum(props["inter_coverage"])


def parse_first_solution_found(content, props):
    init_tit = "init_{}".format(TIT)
    inter_tit = "inter_{}".format(TIT)
    candidates = (
        [props[TIT] if props["coverage"] else float("inf")] +
        [props[init_tit] if props.get("init_coverage", 0) else float("inf")] +
        [props[inter_tit][n] for n, c in
         enumerate(props.get("inter_coverage", [])) if c])

    first = min(candidates)
    props["time_first_solution"] = None if first == float("inf") else first
    props["total_coverage"] = 1 if first != float("inf") else 0

if __name__ == "__main__":
    run_training_parser()
    p = Parser()
    p.add_function(parse_initial_and_intermediate_evaluations)
    p.add_function(parse_first_solution_found)
    p.parse()

