#!/usr/bin/python

# This file is part of CoVeriTeam, a tool for on-demand composition of cooperative verification systems:
# https://gitlab.com/sosy-lab/software/coveriteam
#
# SPDX-FileCopyrightText: 2020 Dirk Beyer <https://www.sosy-lab.org>
#
# SPDX-License-Identifier: Apache-2.0

import xml.etree.ElementTree as ET
from pathlib import Path
import sys
import os

"""
Assumption: the files folder lies in the same directory as the xml
"""


def extract_bz2(path):
    p = Path(path).resolve()
    copied_file = str(p.parent / ("inner-measurement-" + str(p.name)))
    os.system("cp " + path + " " + copied_file)
    os.system("bzip2 -d " + copied_file)
    # remove the .bz2 extension
    new_xml = copied_file[:-4]
    return Path(new_xml)


def compress_bz2(path):
    os.system("bzip2 " + str(path))


def get_exec_trace(task):
    task_name = Path(task).name
    trace = [f for f in traces if task_name in f]
    if len(trace) > 1:
        sys.exit("More than 1 traces found!!! for " + task_name)

    if len(trace) == 1:
        return trace[0]

    return ""


def extract_measurements(trace):
    tree = ET.parse(trace)
    root = tree.getroot()
    cputime = 0.0
    walltime = 0.0
    for m in root.iter("measurements"):
        cputime += float(m.get("cputime"))
        walltime += float(m.get("walltime"))

    return (str(cputime) + "s", str(walltime) + "s")


def process_xml(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()

    for elem in root.iter("result"):
        # update tool name
        elem.set("tool", "INNER MEASUREMENT")

    # find all "item" objects and print their "name" attribute
    for elem in root.iter("run"):
        # print(elem.get('name'))
        task_name = elem.get("name")
        exec_trace = get_exec_trace(task_name)
        if not exec_trace:
            continue
        cputime, walltime = extract_measurements(exec_trace)

        for column in elem.findall("column"):
            if column.get("title") == "cputime":
                column.set("value", cputime)

            if column.get("title") == "walltime":
                column.set("value", walltime)

    tree.write(xml_path)


if len(sys.argv) != 2:
    sys.exit("Missing the xml file produced by benchexec!!")

xml_path = extract_bz2(sys.argv[1])

traces = []
for p in xml_path.parent.glob("**/execution_trace.xml"):
    traces += [str(p)]

process_xml(str(xml_path))
compress_bz2(xml_path)
