from collections import defaultdict
import math
import sys

from lab.reports import *
from downward.reports import PlanningReport
try:
    import numpy
except ImportError:
    print 'numpy not availabe, running FlexibleAggregationReport will not work'

#https://stackoverflow.com/questions/2189800/length-of-an-integer-in-python
def num_digits(number):
    if number > 0:
        return int(math.log10(number))+1
    elif number == 0:
        return 1
    else:
        print "Negative number"
        exit(1)

class TasksWithSpecificValueOfAttributeReport(PlanningReport):
    def __init__(self, specific_value, **kwargs):
        self.specific_value = specific_value
        PlanningReport.__init__(self, **kwargs)

    def get_text(self):
        if len(self.attributes) > 1:
            print("please pass exactly one attribute")
            sys.exit(1)

        attribute = self.attributes[0]
        tasks_with_value_for_attribute = set()
        for (domain, task), runs in self.problem_runs.items():
            for run in runs:
                run_value = run.get(attribute, None)
                if run_value == self.specific_value:
                    tasks_with_value_for_attribute.add('{}:{}'.format(domain, task))
                    break

        formatted_result = []
        for task in sorted(tasks_with_value_for_attribute):
            formatted_result.append("'{}',".format(task))
        return '\n'.join(formatted_result)

class DomainsWithDifferentValuesOfAttributeReport(PlanningReport):
    def __init__(self, **kwargs):
        PlanningReport.__init__(self, **kwargs)

    def get_text(self):
        if len(self.attributes) > 1:
            print("please pass exactly one attribute")
            sys.exit(1)

        attribute = self.attributes[0]
        domains_with_different_values = set()
        for (domain, task), runs in self.problem_runs.items():
            same_value = True
            the_value = None
            for run in runs:
                run_value = run.get(attribute, None)
                if run_value is not None:
                    if the_value is None:
                        the_value = run_value
                    if run_value != the_value:
                        same_value = False
                        break
            if not same_value and the_value is not None:
                domains_with_different_values.add(domain)

        formatted_result = []
        for domain in sorted(domains_with_different_values):
            formatted_result.append("'{}',".format(domain))
        return '\n'.join(formatted_result)

class FlexibleAttribute():
     def __init__(
        self, key, printable_name, default_value, absolute=False, min_wins=True,
        functions=[sum]):
        self.key = key
        self.printable_name = printable_name
        self.default_value = default_value
        self.absolute = absolute
        self.min_wins = min_wins
        self.functions = functions
        #Attribute.__init__(self, key, absolute=absolute, min_wins=min_wins, functions=functions)

DEFAULT_FLEXIBLE_ATTRIBUTES = [
    FlexibleAttribute('coverage', 'Coverage', None, absolute=True, min_wins=False, functions=[sum]),
    FlexibleAttribute('expansions_until_last_jump', 'Exp', float('inf'), absolute=False, min_wins=True, functions=[percentile_50, percentile_75, percentile_90, percentile_95]),
    FlexibleAttribute('search_time', 'Search time', None, absolute=False, min_wins=True, functions=[geometric_mean]),
    FlexibleAttribute('total_time', 'Total time', None, absolute=False, min_wins=True, functions=[geometric_mean]),
    FlexibleAttribute('unsolvable_incomplete', 'Unsolv.', None, absolute=True, min_wins=False, functions=[sum]),
    FlexibleAttribute('perfect_heuristic', 'Perfect h', None, absolute=True, min_wins=False),
]

FUNCTIONS_TO_NAMES = {
    percentile_50: '50th perc',
    percentile_75: '75th perc',
    percentile_90: '90th perc',
    percentile_95: '95th perc',
}

class FlexibleAggregationReport(PlanningReport):
    def __init__(self,
                 algorithm_rows=[],
                 flexible_attributes=[],
                 single_aggregation=False,
                 **kwargs):
        # must have the form [[[first row first block], [first row second block]], [[second row first block], [second row second block]]]
        self.algorithm_rows=algorithm_rows
        self.flexible_attributes = flexible_attributes or DEFAULT_FLEXIBLE_ATTRIBUTES
        self.single_aggregation=single_aggregation
        self.digits=2

        kwargs.setdefault('attributes', ['coverage'])
        kwargs.setdefault('format', 'txt')
        algorithms = []
        for algorithm_row in algorithm_rows:
            for algorithm_block in algorithm_row:
                assert isinstance(algorithm_block, list)
                algorithms.extend(algorithm_block)
        PlanningReport.__init__(self, filter_algorithm=algorithms, **kwargs)

    def get_text(self):
        """
        We do not need any markup processing or loop over attributes here,
        so the get_text() method is implemented right here.
        """
        algorithms = []
        for algorithm_row in self.algorithm_rows:
            for algorithm_block in algorithm_row:
                assert isinstance(algorithm_block, list)
                algorithms.extend(algorithm_block)
        all_algos = set(run['algorithm'] for run in self.props.values())
        all_algo_strings = [x.encode('ascii') for x in all_algos]
        for algo in algorithms:
            if algo not in all_algo_strings:
                print "{} is not in the data set!".format(algo)
                print "known algorithms: {}".format(sorted(all_algo_strings))
                print "given algorithms: {}".format(sorted(algorithms))
                exit(0)

        # For each block, for each attribute, compute the set of tasks where
        # all algorithms of the block have a value for the attribute if it is
        # absolute, otherwise include all tasks.
        row_block_to_attribute_to_relevant_tasks = {}
        for row_index, algorithm_row in enumerate(self.algorithm_rows):
            for block_index, algorithm_block in enumerate(algorithm_row):
                row_block_to_attribute_to_relevant_tasks[row_index, block_index] = defaultdict(set)
                for attribute in self.flexible_attributes:
                    for (domain, task), runs in self.problem_runs.items():
                        if attribute.absolute:
                            all_runs_relevant = True
                        else:
                            all_runs_relevant = True
                            for run in runs:
                                if run['algorithm'] in algorithm_block and run.get(attribute.key, None) is None:
                                    all_runs_relevant = False
                                    break
                        if all_runs_relevant:
                            row_block_to_attribute_to_relevant_tasks[row_index, block_index][attribute.key].add('{}:{}'.format(domain, task))

        # Collect all values for each algorithm block, attribute and run.
        row_block_to_algorithm_attribute_to_values = {}
        row_block_to_algorithm_attribute_to_domain_to_values = {}
        for row_index, algorithm_row in enumerate(self.algorithm_rows):
            for block_index, algorithm_block in enumerate(algorithm_row):
                row_block_to_algorithm_attribute_to_values[row_index, block_index] = defaultdict(list)
                row_block_to_algorithm_attribute_to_domain_to_values[row_index, block_index] = {}
                for attribute in self.flexible_attributes:
                    for algorithm in self.algorithms:
                        row_block_to_algorithm_attribute_to_domain_to_values[row_index, block_index][algorithm, attribute.key] = defaultdict(list)

                    for (domain, task), runs in self.problem_runs.items():
                        if '{}:{}'.format(domain, task) in row_block_to_attribute_to_relevant_tasks[row_index, block_index][attribute.key]:
                            for run in runs:
                                algorithm = run['algorithm']
                                if algorithm in algorithm_block:
                                    if attribute.key in ['ms_linear_order', 'fallback_only']:
                                        value = False
                                        if run.get('ms_abstraction_constructed', False):
                                            value = run.get(attribute.key, False)
                                    else:
                                        value = run.get(attribute.key, attribute.default_value)
                                    if value is not None:
                                        row_block_to_algorithm_attribute_to_values[row_index, block_index][algorithm, attribute.key].append(value)
                                        row_block_to_algorithm_attribute_to_domain_to_values[row_index, block_index][algorithm, attribute.key][domain].append(value)


        # For each attribute and aggregation function, compute the
        # aggregated value of the values collected previously.
        lines = []
        for row_index, algorithm_row in enumerate(self.algorithm_rows):
            # Header line
            algorithm_header_line = ['algorithm']
            for algorithm_block in algorithm_row:
                algorithm_header_line.extend(algorithm_block)
            lines.append(self.format_values(algorithm_header_line))

            # Body lines for each function of each attribute
            for attribute in self.flexible_attributes:
                for function in attribute.functions:
                    #print "Computing values for {} with function {}".format(attribute.key, function)
                    line_string = attribute.printable_name
                    if attribute.printable_name == 'Exp':
                        # HACK!
                        if function in FUNCTIONS_TO_NAMES:
                            line_string += ' {}'.format(FUNCTIONS_TO_NAMES[function])
                    line_string += ' & '
                    # Aggregate values separately for each block
                    for block_index, algorithm_block in enumerate(algorithm_row):
                        if self.single_aggregation:
                            algorithm_attribute_to_values = row_block_to_algorithm_attribute_to_values[row_index, block_index]
                        else:
                            # Convert algorithm_attribute_to_domain_to_values into
                            # algorithm_attribute_to_values by aggregating all
                            # values for each domain individually.
                            algorithm_attribute_to_domain_to_values = row_block_to_algorithm_attribute_to_domain_to_values[row_index, block_index]
                            algorithm_attribute_to_values = defaultdict(list)
                            for (algorithm, attr), domain_to_values in algorithm_attribute_to_domain_to_values.items():
                                for domain, values in domain_to_values.items():
                                    value = function(values)
                                    algorithm_attribute_to_values[algorithm, attr].append(value)
                        block_values = []
                        for algorithm in algorithm_block:
                            assert algorithm, attribute.key in algorithm_attribute_to_values
                            #print "consider algorithm {}".format(algorithm)
                            values = algorithm_attribute_to_values[algorithm, attribute.key]
                            if len(values) == 0:
                                print "WARNING! No values found for {algorithm} and {attribute.printable_name}!".format(**locals())

                            if attribute.key in ['ms_linear_order', 'fallback_only']:
                                # HACK!
                                ms_abstraction_constructed_value = sum(algorithm_attribute_to_values[algorithm, 'ms_abstraction_constructed'])
                                attribute_value = sum(algorithm_attribute_to_values[algorithm, attribute.key])
                                if ms_abstraction_constructed_value == 0:
                                    value = 0
                                else:
                                    value = float(attribute_value) * 100 / ms_abstraction_constructed_value
                            else:
                                value = function(values)
                            if attribute.key == 'expansions_until_last_jump':
                                # HACK!
                                if value == float('inf'):
                                    # since we now aggregate over all tasks for which
                                    # expansions is not None, this should not happen.
                                    assert False
                                    value = 'inf'
                                else:
                                    value = int(value) # round to nearest integer
                            block_values.append(value)
                        block_string = self.format_values(block_values, end_line=(block_index == len(algorithm_row) - 1), min_wins=attribute.min_wins)
                        line_string += block_string
                    lines.append(line_string)
                    #print "resulting line_string: {}".format(line_string)

        return '\n'.join(lines)

    def format_values(self, values, end_line=False, min_wins=None):
        formatted_values = []
        if min_wins is not None:
            min_value = float('inf')
            max_value = 0
            for value in values:
                if value < min_value:
                    min_value = value
                if value > max_value:
                    max_value = value

            for value in values:
                highlight = (min_wins and value == min_value) or (not min_wins and value == max_value)
                if isinstance(value, float):
                    formatted_value = '{0:.{1}f}'.format(value, self.digits)
                else:
                    if num_digits(value) >= 5:
                        # HACK: display ints with at least 5 digits as 10k rather than 10000
                        value = value / 1000
                        formatted_value = str(value)
                        formatted_value += 'k'
                    else:
                        formatted_value = str(value)
                if highlight:
                    formatted_values.append('\\textbf{{{}}}'.format(formatted_value))
                else:
                    formatted_values.append(formatted_value)
        else:
            formatted_values.extend(values)

        formatted_values_string = ''
        for index, value in enumerate(formatted_values):
            formatted_values_string += '{}'.format(value)
            if index == len(formatted_values) - 1 and end_line:
                formatted_values_string += ' \\\\'
            else:
                formatted_values_string += ' & '
        return formatted_values_string

class AlgorithmMatrixReport(PlanningReport):
    def __init__(self,
                 algorithm_matrix=None,
                 row_names=[],
                 column_names=[],
                 row_names_start_algorithm_names=False,
                 column_names_end_algorithm_names=False,
                 attribute_function_pairs=[],
                 **kwargs):
        self.algorithm_matrix = algorithm_matrix
        self.row_names = row_names
        self.column_names = column_names
        self.row_names_start_algorithm_names = row_names_start_algorithm_names
        self.column_names_end_algorithm_names = column_names_end_algorithm_names
        self.attribute_function_pairs = attribute_function_pairs
        kwargs.setdefault('format', 'txt')
        PlanningReport.__init__(self, **kwargs)

    def get_text(self):
        """
        We do not need any markup processing or loop over attributes here,
        so the get_text() method is implemented right here.
        """
        if self.algorithm_matrix is None:
            # Compute matrix of algorithms
            self.algorithm_matrix = []
            for row_names in self.row_names:
                row = []
                for column_names in self.column_names:
                    for algorithm in self.algorithms:
                        #print "considering {} and {} and {}".format(row_names, column_names, algorithm)
                        row_ok = False
                        column_ok = False
                        if self.row_names_start_algorithm_names:
                            assert isinstance(row_names, str)
                            if algorithm.startswith(row_names):
                                row_ok = True
                        else:
                            if isinstance(row_names, str):
                                row_names = [row_names]
                            #print "row ", [row_name in algorithm for row_name in row_names]
                            if all(row_name in algorithm for row_name in row_names):
                                row_ok = True
                        if self.column_names_end_algorithm_names:
                            assert isinstance(column_names, str)
                            if algorithm.endswith(column_names):
                                column_ok = True
                        else:
                            if isinstance(column_names, str):
                                column_names = [column_names]
                            #print "column ", [column_name in algorithm for column_name in column_names]
                            if all(column_name in algorithm for column_name in column_names):
                                column_ok = True
                        if row_ok and column_ok:
                            #print "adding {} to row {} and column {}".format(algorithm, row_names, column_names)
                            row.append(algorithm)
                self.algorithm_matrix.append(row)
            print "Computed algorithm matrix:"
            for row in self.algorithm_matrix:
                print row

        # Check that each algorithm of the matrix is present in the data
        algorithms = []
        for algorithm_row in self.algorithm_matrix:
            for algorithm in algorithm_row:
                algorithms.append(algorithm)
        all_algos = set(run['algorithm'] for run in self.props.values())
        all_algo_strings = [x.encode('ascii') for x in all_algos]
        for algo in algorithms:
            if algo not in all_algo_strings:
                print "{} is not in the data set!".format(algo)
                print "known algorithms: {}".format(sorted(all_algo_strings))
                print "given algorithms: {}".format(sorted(algorithms))
                exit(0)

        # Collect values of each algorithm for the given attribute.
        algo_values = {}
        print self.attribute_function_pairs
        for attribute, function in self.attribute_function_pairs:
            algo_values[attribute] = defaultdict(list)
        for run_id, run in self.props.items():
            algorithm = run['algorithm']
            for attribute, function in self.attribute_function_pairs:
                value = run.get(attribute, None)
                if value is not None:
                    algo_values[attribute][algorithm].append(value)

        lines = []

        # header line
        algorithm_header_line = ['algorithm']
        for column_name in self.column_names:
            algorithm_header_line.append(column_name)
        lines.append(self.format_line(algorithm_header_line))

        # row lines
        for row_index, row in enumerate(self.algorithm_matrix):
            for attribute, function in self.attribute_function_pairs:
                values = ["{}-{}".format(self.row_names[row_index], attribute)]
                for algo in row:
                    aggregated_value = function(algo_values[attribute][algo])
                    if aggregated_value == float('inf'):
                        aggregated_value = 'inf'
                    values.append(aggregated_value)
                lines.append(self.format_line(values, min_wins=False))

        return '\n'.join(lines)

    def format_line(self, values, min_wins=None):
        if min_wins is not None:
            min_value = float('inf')
            max_value = 0
            for index, value in enumerate(values):
                if index == 0:
                    # skip first value
                    continue
                if value < min_value:
                    min_value = value
                if value > max_value:
                    max_value = value

            for index, value in enumerate(values):
                if index == 0:
                    # skip first value
                    continue
                if min_wins and value == min_value:
                    values[index] = '\\textbf{{{}}}'.format(values[index])
                if not min_wins and value == max_value:
                    values[index] = '\\textbf{{{}}}'.format(values[index])

        line = ''
        for index, value in enumerate(values):
            line += '{}'.format(value)
            if index == len(values) - 1:
                line += ' \\\\'
            else:
                line += ' & '
        return line

# TODO: copied from Jendrik; needs to be cleaned up
class DomainwiseReport(PlanningReport):
    def __init__(self, sstddev=None, **kwargs):
        PlanningReport.__init__(self, **kwargs)
        self.sstddev = sstddev or {}

    def get_markup(self):
        # Compare single orders.
        solved_by = defaultdict(set)
        all_tasks = set()
        for (domain, problem), runs in sorted(self.problem_runs.items()):
            all_tasks.add((domain, problem))
            for run in runs:
                config = run['algorithm']

                if "coverage" not in run:
                    print "Missing coverage:", run["domain"], run["problem"], run["run_dir"]
                    run["coverage"] = 0

                coverage = run["coverage"]
                if coverage:
                    solved_by[config].add((domain, problem))

        domain_and_config_to_coverage = defaultdict(int)
        for (domain, problem), runs in self.problem_runs.items():
            for run in runs:
                domain_and_config_to_coverage[(run["domain"], run['algorithm'])] += run["coverage"]
        num_best = defaultdict(int)

        coverage_table = Table("Coverage")
        algorithms = getattr(self, 'configs', None) or self.algorithms
        coverage_table.set_column_order(algorithms)
        domain_groups = sorted(set([group for group, config in domain_and_config_to_coverage.keys()]))
        for domain in domain_groups:
            coverage_values = [
                domain_and_config_to_coverage[(domain, config)]
                for config in algorithms]
            max_coverage = max(coverage_values)
            for config in algorithms:
                coverage = domain_and_config_to_coverage[(domain, config)]
                coverage_table.add_cell(domain, config, coverage)
                if coverage == max_coverage:
                    num_best[config] += 1

        winner_table = Table("Best for domain")
        for config, value in sorted(num_best.items(), key=lambda (k, v): v, reverse=True):
            winner_table.add_cell(config, "#best", value)
            print "{config}: {value}".format(**locals())

        comparison_table = Table()
        comparison_table.set_row_order(algorithms)
        comparison_table.set_column_order(algorithms + ["Coverage"])
        comparison_table.row_min_wins["Coverage"] = False
        for config1 in algorithms:
            for config2 in algorithms:
                num_better = 0
                for domain in domain_groups:
                    coverage1 = domain_and_config_to_coverage[(domain, config1)]
                    coverage2 = domain_and_config_to_coverage[(domain, config2)]
                    if coverage1 > coverage2:
                        num_better += 1
                comparison_table.add_cell(config1, config2, num_better)
        for config in algorithms:
            total_coverage = sum(domain_and_config_to_coverage[(domain, config)] for domain in domain_groups)
            #comparison_table.add_cell(config, "Coverage", total_coverage)

        def print_line(cells):
            print " & ".join(str(c) for c in cells) + r" \\"

        def convert_config(config):
            mapping = {
                "hillclimbing-scp-div200": r"\hhillclimbingscp",
                "systematic2-scp-div200": r"\hsystematicscp",
                "cartesian-scp-div200": r"\hcartesianscp",
                "bjolp-scp-2orders": r"\hbjolpscp",
                "symba2": r"$\symba_2$",
            }
            if config in mapping:
                return mapping[config]

            return config
            #return r"\h" + config.replace("div200", "diverse").replace("-", "").replace("2orders", "two")

        def get_coverage(config):
            return sum(domain_and_config_to_coverage[(domain, config)] for domain in domain_groups)

        include_sstddev = len(self.sstddev) != 0
        best_config = max(algorithms, key=get_coverage)
        print r"\renewcommand{\arraystretch}{\tablearraystretch}"
        print r"\setlength{\tabcolsep}{3pt}"
        print r"\begin{center}"
        print r"\begin{tabular}{l" + "r" * len(algorithms) + "crr}"
        line = [""] + [r"\rot{%s}" % convert_config(c) for c in algorithms] + ["", r"\rot{Coverage}"]
        if include_sstddev:
            line.append(r"\rot{Stddev.}")
        print_line(line)
        offsets = tuple([offset + len(algorithms) for offset in (1, 3, 4 if include_sstddev else 3)])
        print "\cmidrule[\lightrulewidth]{1-%d} \cmidrule[\lightrulewidth]{%d-%d}" % offsets
        for config1 in algorithms:
            total_coverage = get_coverage(config1)
            if config1 == best_config:
                total_coverage = r"\bc{{{:d}}}".format(total_coverage)
            else:
                total_coverage = "{:d}".format(total_coverage)
            sstddev = self.sstddev.get(config1)
            line = []
            for config2 in algorithms:
                num_config1_better = 0
                num_config2_better = 0
                for domain in domain_groups:
                    coverage1 = domain_and_config_to_coverage[(domain, config1)]
                    coverage2 = domain_and_config_to_coverage[(domain, config2)]
                    if coverage1 > coverage2:
                        num_config1_better += 1
                    elif coverage2 > coverage1:
                        num_config2_better += 1

                if config1 == config2:
                    entry = "--"
                elif num_config1_better >= num_config2_better:
                    entry = "\\bc{{{}}}".format(num_config1_better)
                else:
                    entry = str(num_config1_better)
                line.append(entry)
            line = [convert_config(config1)] + line + ["", total_coverage]
            if include_sstddev:
                line.append("{:.2f}".format(sstddev) if sstddev is not None else "--")
            print_line(line)
        print r"\end{tabular}"
        print r"\end{center}"

        print "Algorithms:", self.algorithms
        compared_config = None
        for name in ["diverse200", "scp-div"]:
            if name in self.algorithms:
                compared_config = name
                break
        summary_table = Table("Summary")
        summary_table.set_column_order(algorithms)
        for config in algorithms:
            if not compared_config:
                break
            num_better = 0
            num_worse = 0
            num_equal = 0
            for domain in domain_groups:
                coverage1 = domain_and_config_to_coverage[(domain, compared_config)]
                coverage2 = domain_and_config_to_coverage[(domain, config)]
                if coverage1 > coverage2:
                    num_better += 1
                elif coverage1 < coverage2:
                    num_worse += 1
                else:
                    assert coverage1 == coverage2
                    num_equal += 1
            assert num_better + num_equal + num_worse == len(domain_groups)
            coverage = sum(domain_and_config_to_coverage[(domain, config)] for domain in domain_groups)
            summary_table.add_cell("Coverage", config, coverage)
            summary_table.add_cell("Better", config, num_better)
            #summary_table.add_cell("Equal", config, num_equal)
            summary_table.add_cell("Worse", config, num_worse)
            summary_table.set_row_order(["Coverage", "Better", "Worse"])

        return "\n\n\n".join(str(table) for table in [
            coverage_table, comparison_table, winner_table, summary_table])
