import click
import os, sys
import copy
import tempfile

sys.path.append(os.path.join(os.path.expanduser("~"), "AD-Fuzzer"))

import dataparser as _Parser
from run_vse import VSERunner


class Trimmer:
    def __init__(self, _seed, _strategy):
        self.scenario = _Parser.scenario_parser(_seed)
        self.collision_object = set()
        self.collision_time = 0
        self.count = dict(npc=0, pedestrian=0, obstacle=0)
        self.strategy = _strategy

    def run_once(self, scenario, run_time=30):
        with tempfile.TemporaryDirectory() as tmpdirname:
            os.mkdir(tmpdirname + "/collision")
            vse_runner = VSERunner(scenario, tmpdirname, "none")
            vse_runner.run(run_time)

            print("[+] collision get:", vse_runner.collision_object)
            print("[+] collision time:", vse_runner.collision_time)
            return vse_runner.collision_object, vse_runner.collision_time

    def run_obj_once(self, obj_list, type, run_time=30):
        tmp_scenario = copy.deepcopy(self.scenario)
        tmp_scenario.elements[type] = obj_list
        return self.run_once(tmp_scenario, run_time)

    def same_collision(self, tmp_collision_object):
        return (
            len(tmp_collision_object) == 2
            and tmp_collision_object.pop() in self.collision_object
            and tmp_collision_object.pop() in self.collision_object
        )

    def trimming(self, _type):
        obj_list = self.scenario.elements[_type][::-1]
        obj_new_list = list.copy(obj_list)
        recommend_run_time = min(self.collision_time + 15, 30)

        if self.strategy == "normal":
            for obj in obj_list:
                obj_tmp_list = list.copy(obj_new_list)
                obj_tmp_list.remove(obj)

                tmp_collision_object, _ = self.run_obj_once(obj_tmp_list, _type, recommend_run_time)

                if self.same_collision(tmp_collision_object):
                    print("[+] Trimming", _type, obj)
                    self.count[_type] += 1
                    obj_new_list.remove(obj)

        elif self.strategy == "bisection":
            failed = False
            while len(obj_new_list) > 1:
                print("[+] bisection object length: ", len(obj_new_list))
                obj_tmp_list = list.copy(obj_new_list)
                obj_tmp_l = obj_tmp_list[: len(obj_tmp_list) // 2]
                obj_tmp_r = obj_tmp_list[len(obj_tmp_list) // 2 :]

                tmp_collision_l, _ = self.run_obj_once(obj_tmp_l, _type, recommend_run_time)
                tmp_collision_r, _ = self.run_obj_once(obj_tmp_r, _type, recommend_run_time)
                collision_l = self.same_collision(tmp_collision_l)
                collision_r = self.same_collision(tmp_collision_r)

                if not collision_l and not collision_r:
                    failed = True
                    break
                obj_new_list = obj_tmp_l if collision_l else obj_tmp_r

            print("[+] turning to normal mode")
            if failed:
                obj_list = list.copy(obj_new_list)
                for obj in obj_list:
                    obj_tmp_list = list.copy(obj_new_list)
                    obj_tmp_list.remove(obj)

                    tmp_collision_object, _ = self.run_obj_once(obj_tmp_list, _type, recommend_run_time)

                    if self.same_collision(tmp_collision_object):
                        print("[+] Trimming", _type, obj)
                        obj_new_list.remove(obj)

            self.count[_type] = len(self.scenario.elements[_type]) - len(obj_new_list)
        
        elif self.strategy == "bisection_limit":
            failed = False
            while len(obj_new_list) > 1:
                print("[+] bisection object length: ", len(obj_new_list))
                obj_tmp_list = list.copy(obj_new_list)
                obj_tmp_l = obj_tmp_list[: len(obj_tmp_list) // 2]
                obj_tmp_r = obj_tmp_list[len(obj_tmp_list) // 2 :]

                tmp_collision_l, _ = self.run_obj_once(obj_tmp_l, _type, recommend_run_time)
                tmp_collision_r, _ = self.run_obj_once(obj_tmp_r, _type, recommend_run_time)
                collision_l = self.same_collision(tmp_collision_l)
                collision_r = self.same_collision(tmp_collision_r)

                if not collision_l and not collision_r:
                    failed = True
                    break
                obj_new_list = obj_tmp_l if collision_l else obj_tmp_r

            print("[+] turning to normal mode")
            if failed:
                obj_list = list.copy(obj_new_list)
                for obj in obj_list:
                    obj_tmp_list = list.copy(obj_new_list)
                    obj_tmp_list.remove(obj)

                    tmp_collision_object, _ = self.run_obj_once(obj_tmp_list, _type, recommend_run_time)

                    if self.same_collision(tmp_collision_object):
                        print("[+] Trimming", _type, obj)
                        obj_new_list.remove(obj)

            self.count[_type] = len(self.scenario.elements[_type]) - len(obj_new_list)

        return obj_new_list[::-1]

    def report(self):
        print(
            "[+] Remove {} npc, {} pedestrian, {} obstacle".format(
                self.count["npc"], self.count["pedestrian"], self.count["obstacle"]
            )
        )


@click.command()
@click.option(
    "-i",
    "--input",
    type=click.Path(dir_okay=False, exists=True),
    required=True,
    help="input test case to be shrunk by the tool",
)
@click.option(
    "-o",
    "--output",
    type=click.Path(dir_okay=False, exists=False),
    required=True,
    help="final output location for the minimized data",
)
@click.option(
    "--type",
    type=click.Choice(["npc", "pedestrian", "obstacle", "all"]),
    multiple=True,
    required=True,
    help="the object type for trimming",
)
@click.option(
    "--strategy",
    type=click.Choice(["normal", "bisection", "bisection_limit"]),
    default="normal",
    help="trimming strategy",
)
def main(input, output, type, strategy):
    """A tool that trims corpus for smallest seed."""
    trimmer = Trimmer(input, strategy)
    trimmer.collision_object, trimmer.collision_time = trimmer.run_once(trimmer.scenario)
    if len(trimmer.collision_object) == 0:
        print("[+] The seed doesn't trigger collision.")
        return

    new_scenario = copy.copy(trimmer.scenario)

    if "all" in type:
        new_scenario.elements["npc"] = trimmer.trimming("npc")
        new_scenario.elements["pedestrian"] = trimmer.trimming("pedestrian")
        new_scenario.elements["obstacle"] = trimmer.trimming("obstacle")
    else:
        if "npc" in type:
            new_scenario.elements["npc"] = trimmer.trimming("npc")
        if "pedestrian" in type:
            new_scenario.elements["pedestrian"] = trimmer.trimming("pedestrian")
        if "obstacle" in type:
            new_scenario.elements["obstacle"] = trimmer.trimming("obstacle")

    trimmer.report()
    new_scenario.to_json()
    new_scenario.store(output)


if __name__ == "__main__":
    main()
