'''
This function contains the code to create tables to compare the heuristics through the metrics.
'''

import os
from numpy import Inf
import csv
import os

from heuristics_guess import choose_order_given_projections
from heuristic_tools import finding_time_limit, get_dataset, compute_markups, not_heuristics, all_heuristics, computing_time, compute_real_timings
from extra_metrics import compute_extra_metrics

def compute_metrics(virtual_best_timings, choice_timings, timings, guesses, targets, number_no_timedout, ncells, only_without_cost=False):

    # TotalTime
    total_time = round(sum(choice_timings)) # All decimals removed

    # markup
    markups = compute_markups(virtual_best_timings, choice_timings)

    ncell_markup_default = 10
    ncells_markup_without_costs = [ncell[guess]/min([nc for nc in ncell if type(nc)!=str])-1 if type(ncell[guess])!=str else ncell_markup_default for ncell, guess in zip(ncells, guesses)]
    ncells_markup_without_cost = sum(ncells_markup_without_costs)/len(ncells_markup_without_costs)

    # We compute the percentage of chosen orderings whose timing was a str (timeout)
    did_finish = [0 if type(timing[guess])==str else 1 for timing, guess in zip(timings,guesses)]
    terminating = sum(did_finish)


    if only_without_cost: 
        # then we are only interested in the metrics that have an option without cost
        # that are the ones that have already been computed
        return total_time, markups, ncells_markup_without_cost, terminating


    # Samples are counted
    # NoSamples
    no_samples = len(virtual_best_timings)


    # Accuracy of the heuristic is computed
    results = [int(target==guess) for target, guess in zip(targets, guesses)]
    accuracy_heuristic = sum(results)/len(results)


    # We compute the percentage of chosen orderings whose timing was a str (timeout)
    # Over 30
    did_not_finish_30 = [1 if timing[guess]=='Over 30' or (timing[guess] is not str and finding_time_limit(timings=timing)==30 and choice_timing>30) else 0 for timing, guess, choice_timing in zip(timings,guesses, choice_timings)] # Either it timed out or the timelimit was 30 and the timing including the heuristic cost is more than 30
    timeouts_30 = sum(did_not_finish_30)
    # Over 60
    did_not_finish_60 = [1 if timing[guess]=='Over 60' or (timing[guess] is not str and finding_time_limit(timings=timing)==60 and choice_timing>60) else 0 for timing, guess, choice_timing in zip(timings,guesses, choice_timings)]# Either it timed out or the timelimit was 60 and the timing including the heuristic cost is more than 60
    timeouts_60 = sum(did_not_finish_60)

    terminating = no_samples-timeouts_30-timeouts_60


    # The percentage of times the only available ordering was found is computed
    # FoundOutOf1
    did_found_1 = [0 if type(timing[guess])==str else 1 for timing, guess, number_options in zip(timings, guesses, number_no_timedout) if number_options==1]
    perc_found_1 = sum(did_found_1)/len(did_found_1)
    # Same for when two or three ordering finish only
    # FoundOutOf2
    did_found_2 = [0 if type(timing[guess])==str else 1 for timing, guess, number_options in zip(timings, guesses, number_no_timedout) if number_options==2]
    perc_found_2 = sum(did_found_2)/len(did_found_2)
    # FoundOutOf3
    did_found_3 = [0 if type(timing[guess])==str else 1 for timing, guess, number_options in zip(timings, guesses, number_no_timedout) if number_options==3]
    perc_found_3 = sum(did_found_3)/len(did_found_3)

    return total_time, markups, no_samples, accuracy_heuristic, terminating, timeouts_30, timeouts_60, perc_found_1, perc_found_2, perc_found_3


def study_heuristic_guesses(dataset, heuristic='sotd', max_penalization_if_not_finished=Inf):

    # Values are unpacked
    projections, targets, timings, heuristics_costs, ncells = dataset
    # For each problem the number of orderings that didn't timeout are counted
    number_no_timedout = [len([t for t in timing if type(t) is not str]) for timing in timings]


    # The timings taken to complete the CADs with the virtual_best ordering
    virtual_best_timings = [timing[target] for target, timing in zip(targets, timings)]
    # The timings that will be use, eliminating timeouts and substituting them for the guess, observe that when the chosen order times out its supposed that it would have finished in two times the timelimit
    useful_timings = [ [t if type(t) is not str else min(max_penalization_if_not_finished*virtual_best_time,2*finding_time_limit(timing)) for t in timing] for timing, virtual_best_time in zip(timings, virtual_best_timings)]

    name = str(heuristic).replace('_','-')

    
    if heuristic in not_heuristics:
        metrics = compute_extra_metrics(heuristic, virtual_best_timings, timings, number_no_timedout, useful_timings)

    else:
        
        # Heuristic guesses are made
        guesses = [choose_order_given_projections(projection, heuristic=heuristic) for projection in projections]
        # The timings taken to complete the CADs with the chosen ordering (without heuristic cost)
        choice_timings_without_cost = [useful_timing[guess] for useful_timing, guess in zip(useful_timings, guesses)]

        # For the expensive heuristics the real time is computed adding the heuristic cost
        choice_timings_including_str = computing_time(heuristic, timings, guesses, heuristics_costs)
        choice_timings = compute_real_timings(timings, choice_timings_including_str, virtual_best_timings, max_penalization_if_not_finished=max_penalization_if_not_finished) 
        # These are essential for the computation of many metrics


        ###################################
        # The computation of metrics for the heuristics starts
        ###################################

        # Metrics penalysing heuristic costs
        total_time, markups, no_samples, accuracy_heuristic, terminatings, timeouts_30, timeouts_60, perc_found_1, perc_found_2, perc_found_3 = compute_metrics(virtual_best_timings, choice_timings, timings, guesses, targets, number_no_timedout, ncells, only_without_cost=False)

        # Metrics not penalysing heuristic costs
        total_time_without_cost, markups_without_cost, ncells_markup_without_cost, terminatings_without_cost = compute_metrics(virtual_best_timings, choice_timings_without_cost, timings, guesses,targets, number_no_timedout, ncells, only_without_cost=True)
        
        metrics = [name, accuracy_heuristic, terminatings_without_cost, terminatings, timeouts_30, timeouts_60, markups_without_cost, markups, ncells_markup_without_cost, total_time_without_cost, total_time, perc_found_1, perc_found_2, perc_found_3, no_samples]
    return [round(metric, 2) if type(metric) is not str else metric for metric in metrics]


def create_csv_with_heuristics_metrics(
    heuristics = all_heuristics,
    without_repetition=True,
    max_penalization_if_not_finished=Inf, 
    minimum_time_to_consider=0
    ):

    dataset = get_dataset(without_repetition=without_repetition, minimum_time_to_consider=minimum_time_to_consider)
    if without_repetition:
        aux_name = 'without_repetition'
    else:
        aux_name = 'with_repetition'
    
    data_heuristics = []
    for heuristic in heuristics:
        data_heuristics.append(study_heuristic_guesses(dataset, heuristic=heuristic, max_penalization_if_not_finished=max_penalization_if_not_finished))
    
    csv_location = os.path.join(os.path.dirname(__file__), '..', 'Datasets', 'study_heuristics_guess__'+aux_name+'__max_penalisation_'+str(max_penalization_if_not_finished)+'__min_time_'+str(minimum_time_to_consider)+'.csv')
    # open the file
    with open(csv_location, 'w', encoding='UTF8', newline='') as f:
        header = ['Name', 'Accuracy', 'TerminatingWithoutCost', 'Terminating', '30sTimeouts', '60sTimeouts', 'MarkupWithoutCost', 'Markup', 'NcellsMarkup', 'TotalTimeWithoutCost', 'TotalTime','FoundOutOf1','FoundOutOf2','FoundOutOf3', 'NoSamples']
        writer = csv.writer(f)

        # write the header
        writer.writerow(header)

        # write the data
        for data_heuristic in data_heuristics:
            writer.writerow(data_heuristic)