#!/usr/bin/python
import os
import sys
import sqlite3
import csv
import logging
import argparse
from pathlib import Path
from contextlib import closing

import pprint
pp = pprint.PrettyPrinter(indent=4, stream=sys.stderr)

parser = argparse.ArgumentParser(description='Compute results of a paracooba benchmark set')
parser.add_argument('dbfile', type=str, help='a database to process')
parser.add_argument('plotname', type=str, help='a name for the plot. The dbfile per default.', nargs='?')
parser.add_argument('--maxy', help='maximum y value in plots', dest='maxy', type=int, nargs='?', default='700')
parser.add_argument('--summary', help='print summary of results', dest='summary', action='store_true')
parser.add_argument('-i', '--ignore', help='ignore all result categories with names containing one of the given arguments', nargs='*', action='append',  dest='ignore')
parser.add_argument('-o', '--only', help='ignore everything but the ones selected here', nargs='*', action='append', dest='only')

parser.add_argument('-s', '--add-suite', help='add a suite to the qbflib db', nargs='*', action='append', dest='suites')
parser.add_argument('-f', '--add-family', help='add a family to the qbflib db', nargs='*', action='append', dest='families')

plot_group = parser.add_argument_group('plot')
plot_group.add_argument('--plot-families', help='generate gnuplot families output', dest='plot_families', action='store_true')
plot_group.add_argument('--plot-suites', help='generate gnuplot suites output', dest='plot_suites', action='store_true')
plot_group.add_argument('--plot', help='generate gnuplot of complete results', dest='plot', action='store_true')

args = parser.parse_args()

if args.families is None:
    args.families = []
if args.suites is None:
    args.suites = []

dbfile = args.dbfile

if args.plotname is None:
    args.plotname = dbfile

qbflib = {}

class ResultsCategory:
    def __init__(self, name):
        self.name = name
        self.results = {}

    def add_result(self, benchmark_name, wtime):
        if benchmark_name not in self.results:
            self.results[benchmark_name] = []
        self.results[benchmark_name].append(wtime)

    def print_data(self):
        for name, results in self.results.items():
            results.sort()
            print(name)
            print(results)

    def generate_gnuplot(self):
        data = ""
        first = True
        blockname = self.name
        blockname = blockname.replace('-', '_')
        maxx = 0
        maxy = 0
        for name, results in self.results.items():
            results.sort()

            if first:
                first = False
            else:
                data += "\n\n";

            # THE ONE BACKSLASH TO RULE THEM ALL
            # https://xkcd.com/1638/
            data += name.replace('_', '\\\\_') + "\n"
            i = 0
            for r in results:
                data += f"{r} {i}\n"
                i += 1
                if r > maxy:
                    maxy = r
            if i > maxx:
                maxx = i

        blocknameescaped = blockname.replace('_', '\\_')

        s = f'''set terminal cairolatex standalone size 15cm,5.5cm
set title '{args.plotname}: {blocknameescaped}' noenhance
set output '{args.plotname}_{blockname}.tex'
set xtics 1
#set xlabel "#solved instances"
set ylabel "solve time [s]"
#set yrange [ 0 : {maxy} ]
set yrange [ 0 : {args.maxy} ]
set xtics autofreq
set key right outside

$DATA << EOD
{data}EOD

plot for [i=0:*] $DATA using 2:1 index i with lines title columnheader(1) noenhanced
        '''
        return s

suites = {}
families = {}
overall = ResultsCategory('OverAll')

def get_or_create_resultscategory_from_assoc_arr(assocarr, name) -> ResultsCategory:
    r = None
    if name in assocarr:
        r = assocarr[name]
    else:
        r = ResultsCategory(name)
        assocarr[name] = r;
    return r


class Benchmark:
    def __init__(self, tbl):
        self.tbl = tbl
    def __str__(self):
        okcount = self.okcount
        total = self.totalcount
        percentage_solved = int(okcount / total * 100)
        avg_util = self.averages[0]
        avg_real = self.averages[1]
        return f"Solved: {okcount}, Total: {total}, %solved: {percentage_solved}, avg_util: {avg_util}, avg_real: {avg_real}"
    def __repr__(self):
        return self.__str__()

    def parse_tbl(self, conn):
        with closing(conn.cursor()) as cursor:
            cursor.execute(f"SELECT COUNT(status) FROM {self.tbl}")
            count = cursor.fetchall()
            self.totalcount = count[0][0];
            if self.totalcount == 0:
                raise f"Error with table {self.tbl}: totalcount == 0!"

            cursor.execute(f"SELECT COUNT(status) FROM {self.tbl} WHERE status = 'ok'")
            countok = cursor.fetchall()
            self.okcount = countok[0][0];

            cursor.execute(f"SELECT problem,real FROM {self.tbl} WHERE status = 'ok'")

            for e in cursor:
                time = e[1]
                overall.add_result(self.tbl, time)

                # Get fitting family and suite, then insert there.
                stem = Path(e[0]).stem

                qbflibe = None
                
                if(stem not in qbflib):
                    found = False;
                    for f in args.families:
                        f = f[0]
                        if f in stem:
                            family = get_or_create_resultscategory_from_assoc_arr(families, f)
                            family.add_result(self.tbl, time)
                            found = True
                            continue
                    for s in args.suites:
                        s = s[0]
                        if s in stem:
                            suite = get_or_create_resultscategory_from_assoc_arr(suites, s)
                            suite.add_result(self.tbl, time)
                            found = True
                            continue
                else:
                    qbflibe = qbflib[stem]
                    suite = get_or_create_resultscategory_from_assoc_arr(suites, qbflibe['SUITE'])
                    family = get_or_create_resultscategory_from_assoc_arr(families, qbflibe['FAMILY'])
                    suite.add_result(self.tbl, time)
                    family.add_result(self.tbl, time)
            
            cursor.execute(f"SELECT AVG(utilization) as avg_util, AVG(real) as avg_real FROM {self.tbl} WHERE status = 'ok'")
            self.averages = cursor.fetchall()[0]

def one_vs_others(conn, one, others, benchmarktree):
    with closing(conn.cursor()) as cursor:
        sourcetables = []
        if len(benchmarktree[one]) > 0:
            for key in benchmarktree[one]:
                sourcetables.append(benchmarktree[one][key].tbl)
        else:
            sourcetables.append(benchmarktree[one].tbl)

        for tbl in sourcetables:

            if not ((tbl.find("16cores_5treedepth") != -1 and tbl.find("pcnf") != -1) or 
                    (tbl.find("16cores_4treedepth") != -1 and tbl.find("bloqqer") != -1)):
                continue

            for o in others:
                otbl = benchmarktree[o].tbl
                print(f"{tbl} > {otbl}")
                for row in cursor.execute(f"SELECT {tbl}.problem, {tbl}.status, {otbl}.status FROM {tbl} LEFT JOIN {otbl} ON {tbl}.problem = {otbl}.problem WHERE {tbl}.status = 'ok' AND {otbl}.status != 'ok' AND {tbl}.problem = {otbl}.problem"):
                    problem = row[0]
                    thisstatus = row[1]
                    ostatus = row[2]
                    print(f"    Solved {problem} that was not solved by {o}")

                badresults = False
                for row in cursor.execute(f"SELECT {tbl}.problem, {tbl}.result, {otbl}.result FROM {tbl} LEFT JOIN {otbl} ON {tbl}.problem = {otbl}.problem WHERE {tbl}.status = 'ok' AND {otbl}.status = 'ok' AND {tbl}.problem = {otbl}.problem AND {tbl}.result != {otbl}.result"):
                    badresults = True
                    problem = row[0]
                    thisresult = row[1]
                    oresult = row[2]
                    print(f"    Solved {problem} DIFFERENTLY THAN {o}! This result: {thisresult}, theirs: {oresult}")

                if badresults:
                    exit(-1)


def transform_prefixlist_to_tree(lst, conn, layer=0):
    l = lst.copy()
    out = {}
    while len(l) > 0:
        if len(l[0]) <= layer:
            assert(len(l) == 1)
            b = Benchmark('__'.join(l[0]))
            try:
                b.parse_tbl(conn)
                return b
                l.pop(0)
            except:
                return None
        else:
            pivot = l[0][layer]
            tmp = list(filter(lambda e: e[layer] == pivot, l))
            l = list(filter(lambda e: e[layer] != pivot, l))
            out[pivot] = transform_prefixlist_to_tree(tmp, conn, layer + 1)
    return out

dbtables = []

def str_matches_arr(s, arr):
    if arr is None:
        return False
    for e in arr:
        e = e[0]
        if e in s:
            return True
    return False

with sqlite3.connect(dbfile) as conn:
    ttables = {}

    # Generate tree of all tested configurations
    with closing(conn.cursor()) as cursor:
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        dbtables = cursor.fetchall()
        dbtables[:] = (v for v in dbtables if not str_matches_arr(v[0], args.ignore))
        if args.only is not None:
            dbtables[:] = (v for v in dbtables if str_matches_arr(v[0], args.only))
        tables = list(map(lambda e: e[0].split('__'), dbtables))
        ttables = transform_prefixlist_to_tree(tables, conn)

    # Print Overview
    if args.summary:
        pp.pprint(ttables)
        sys.stderr.flush()

    # Print Details (What did paraqs solve that no other solver solved?)
    # for key in ttables:
    #    solvers = [ s for s in ttables[key] ]
    #    notparaqs = list(filter(lambda s: s != 'paraqs', solvers))
    #    one_vs_others(conn, 'paraqs', notparaqs, ttables[key])

if args.plot:
    print(overall.generate_gnuplot())

if args.plot_suites:
    for s in suites.values():
        print(s.generate_gnuplot())

if args.plot_families:
    for f in families.values():
        print(f.generate_gnuplot())
