#!/usr/bin/env python3

import sys, argparse, os, re
from collect_data import collect
from collections import defaultdict

def stats_old(job_csv,job_output):
    data = collect(job_csv,job_output)
    solved = defaultdict(list)
    total = {}
    for job in data:
        solved[job] = defaultdict(list)
        for config in data[job]:
            total[job] = 0
            for b in data[job][config]:
                d = data[job][config][b]
                if d["expected"] in ["Theorem","Unknown"]:
                    total[job]+=1
                    if d["result"] in ['Theorem','Unsatisfiable']:
                        solved[job][config].append(d)

    for job in data:
        print(f"Total {job}: {total[job]}")
        for config in data[job]:
            print(f"{job:4}{config:40}: solved {len(solved[job][config]):5}")

        for c1, c2 in [("original_fo.sh","appext_fo.sh"),("original_fo.sh","app_fo.sh"),("appext_fo.sh","app_fo.sh")]:
            counter = 0
            for d1 in solved[job][c1]:
                for d2 in solved[job][c2]:
                    if d1["benchmark"] == d2["benchmark"]:
                        counter += 1
            print(f"common {c1} and {c2}: {counter}")

def stats(job_csv,job_output):
    data = collect(job_csv,job_output)


    stats = calc_stats(data)

    #for b in stats["int"]["THF"]["rpo"]["supatvars"]["unsat"]:
    #    print(b["benchmark"])

    for tensional in ["int","ext"]:
        print(tensional)

        for job in ["TFF", "THF"]:
            print(job)

            for mode in ["fo","app","supatvars","purification"]:
                if config_name(tensional,job,"rpo",mode) != None:
                    if mode == "purification":
                        endline = ""
                    else:
                        endline = "\\\\"
                    name = pretty_name(mode)
                    columns = [
                        ("rpo","num_sat"),("kbo","num_sat"),
                        ("rpo","num_unsat"),("kbo","num_unsat"),
                        ("rpo","common_average_time"),("kbo","common_average_time"),
                        ("rpo","common_average_steps"),("kbo","common_average_steps"),
                    ]

                    print(f"& {name:25}", end='')
                    for order,field in columns:
                        number = stats[tensional][job][order][mode][field]
                        bold = stats[tensional][job][order][mode][field + "_bold"]
                        if number == 0:
                            formatting = "\\relax"
                        elif bold:
                            formatting = "\\bf   "
                        else:
                            formatting = "      "
                        print(f"&{formatting}{number:4} ", end='')
                    print(endline)

def pretty_name(mode):
    if mode == "fo":
        return "first-order mode"
    elif mode == "app":
        return "applicative encoding"
    elif mode == "purification":
        return "purifying calculus"
    elif mode == "supatvars":
        return "nonpurifying calculus"


def config_name(tensional,job,order,mode):
    if order == "kbo":
        order_suffix = "_kbo"
    else:
        order_suffix = ""
    if mode == "fo":
        if job == "THF":
            return None
        else:
            return "original_fo" + order_suffix + ".sh"
    if mode == "app":
        if tensional == "int":
            return "app_fo" + order_suffix + ".sh"
        else:
            return "appext_fo" + order_suffix + ".sh"
    return "original_" + mode + "_" + tensional + order_suffix + ".sh"

def calc_stats(data):
    stats = {}
    for tensional in ["int","ext"]:
        stats[tensional] = {}
        for job in ["TFF", "THF"]:
            stats[tensional][job] = {}
            for order in ["rpo","kbo"]:
                stats[tensional][job][order] = {}
                for mode in ["fo","app","purification","supatvars"]:
                    stats[tensional][job][order][mode] = {
                        "unsat":[],
                        "unsat_time_sum":0,
                        "unsat_steps_sum":0,
                        "sat":[],
                    }
                    config = config_name(tensional,job,order,mode)
                    if config != None:
                        for b in data[job][config]:
                            d = data[job][config][b]
                            if d["expected"] in ["Theorem","Unknown"] \
                            and d["result"] in ['Theorem','Unsatisfiable']:
                                stats[tensional][job][order][mode]["unsat"].append(d)
                                stats[tensional][job][order][mode]["unsat_time_sum"] += d["cpu time"]
                                stats[tensional][job][order][mode]["unsat_steps_sum"] += d["zp_steps"]
                            if d["result"] in ["Satisfiable", "CounterSatisfiable","GaveUp"]:
                                stats[tensional][job][order][mode]["sat"].append(d)
                        stats[tensional][job][order][mode]["num_sat"] = \
                            len(stats[tensional][job][order][mode]["sat"])
                        stats[tensional][job][order][mode]["num_unsat"] = \
                            len(stats[tensional][job][order][mode]["unsat"])

    # determine commonly solved problems
    for tensional in ["int","ext"]:
        for job in ["TFF", "THF"]:
            for order in ["rpo","kbo"]:
                num_modes = 0 # total number of modes
                num_modes_unsat = defaultdict(int) # number of modes that say a certain benchmark is unsat
                for mode in ["fo","app","purification","supatvars"]:
                    if config_name(tensional,job,order,mode) != None:
                        num_modes += 1
                        for b in stats[tensional][job][order][mode]["unsat"]:
                            num_modes_unsat[b["benchmark"]] += 1

                stats[tensional][job][order]["commonly_unsat"] = [
                    b for b in num_modes_unsat if num_modes_unsat[b] == num_modes
                ]

                # sums for time and steps over commonly solved problems
                for mode in ["fo","app","purification","supatvars"]:
                    config = config_name(tensional,job,order,mode)
                    if config != None:
                        stats[tensional][job][order][mode]["common_time_sum"] = 0
                        stats[tensional][job][order][mode]["common_steps_sum"] = 0
                        for b in stats[tensional][job][order]["commonly_unsat"]:
                            d = data[job][config][b]
                            stats[tensional][job][order][mode]["common_time_sum"] += d["cpu time"]
                            stats[tensional][job][order][mode]["common_steps_sum"] += d["zp_steps"]


    for tensional in ["int","ext"]:
        for job in ["TFF", "THF"]:
            for order in ["rpo","kbo"]:
                for mode in ["fo","app","purification","supatvars"]:
                    if config_name(tensional,job,order,mode) != None:
                        stat = stats[tensional][job][order][mode]
                        # Calc averages (time per unsat problem in that mode)
                        stat["average_time"] = round(stat["unsat_time_sum"] / stat["num_unsat"],1)
                        stat["average_steps"] = round(stat["unsat_steps_sum"] / stat["num_unsat"])
                        # Calc common averages (time per problem that is unsat in all modes)
                        stat["common_average_time"] = round(stat["common_time_sum"] / len(stats[tensional][job][order]["commonly_unsat"]),1)
                        stat["common_average_steps"] = round(stat["common_steps_sum"] / len(stats[tensional][job][order]["commonly_unsat"]))

    # Calc best
    for tensional in ["int","ext"]:
        for job in ["TFF", "THF"]:
            for order in ["rpo","kbo"]:
                for column, higher_is_better in [
                    ("num_unsat",True),
                    ("average_time",False),
                    ("average_steps",False),
                    ("common_average_time",False),
                    ("common_average_steps",False),
                    ("num_sat",True)
                ]:
                    best_number = None
                    best_mode = []
                    for mode in ["fo","app","purification","supatvars"]:
                        if config_name(tensional,job,order,mode) != None:
                                stats[tensional][job][order][mode][column + "_bold"] = False
                                n = stats[tensional][job][order][mode][column]
                                if best_number == None \
                                or (higher_is_better and n > best_number) or \
                                (not higher_is_better and n < best_number):
                                    best_number = n
                                    best_mode = [mode]
                                elif n == best_number:
                                    best_mode.append(mode)
                    for mode in best_mode:
                        stats[tensional][job][order][mode][column + "_bold"] = True

    return stats


if __name__ == "__main__":
    p = argparse.ArgumentParser('analyze result file')
    p.add_argument('--old', help='old way to output stats', default=False)
    p.add_argument('--job_csv', help='csv file from StarExec', default=None)
    p.add_argument('--job_output', help='csv file from StarExec', default=None)
    args = p.parse_args()
    if args.old:
        stats_old(args.job_csv,args.job_output)
    else:
        stats(args.job_csv,args.job_output)
