import os
import sys
import json
from math import *
import copy
import numpy as np

sys.path.append(os.path.join(os.path.expanduser("~"), "AD-Fuzzer"))
import dataparser as _Parser
from typing import Tuple, List, Set
from queue import PriorityQueue
import pandas as pd


class NPCPosition:
    def __init__(self, _id, _pos, _min_gradient, _max_gradient):
        self.id = _id
        self.pos = _pos
        self.min_gradient = _min_gradient
        self.max_gradient = _max_gradient


def distance(p1, p2):
    return sqrt((p1["x"] - p2["x"]) ** 2 + (p1["z"] - p2["z"]) ** 2)


def get_wheel_position(position: Tuple, rotation, width, length) -> List[Tuple]:
    wheel: List[Tuple] = []
    # length = 2.835 / 2
    # width = 1.66 / 2
    width = width / 2
    length = length / 2
    l2 = length * length + width * width
    l = np.sqrt(l2)
    sina = width / l
    degree = np.radians(rotation)
    x = position[0]
    y = position[1]
    a = asin(sina)
    b = degree
    y1 = y + l * cos(a + b)
    x1 = x + l * sin(a + b)
    wheel.append((x1, y1))
    y2 = y + l * cos(a - b)
    x2 = x - l * sin(a - b)
    wheel.append((x2, y2))
    y3 = y - l * cos(a + b)
    x3 = x - l * sin(a + b)
    wheel.append((x3, y3))
    y4 = y - l * cos(a - b)
    x4 = x + l * sin(a - b)
    wheel.append((x4, y4))
    return wheel


def get_data(data_file_path):
    f = open(data_file_path, "r")
    npc_dict = dict()
    tmp_buf = f.readline()
    ego_list = []
    while tmp_buf:
        tmp_obj = json.loads(tmp_buf)
        if type(tmp_obj) is list:
            if len(tmp_obj) != 0:
                for npc in tmp_obj:
                    id = npc["Id"]
                    wp = dict(Position=npc["Position"], Rotation=npc["Rotation"], Scale=npc["Scale"], Speed=npc["LinearVelocity"]["x"])
                    if id in npc_dict:
                        npc_dict[id].append(wp)
                    else:
                        npc_dict[id] = [wp]
        else:
            # print(tmp_obj)
            wp = dict(Position=tmp_obj["Position"], Rotation=tmp_obj["Rotation"], Scale=tmp_obj["Scale"], Speed=tmp_obj["LinearVelocity"]["x"])
            ego_list.append(wp)

        tmp_buf = f.readline()
    f.close()
    return ego_list, npc_dict


def get_line_gradient(base, pos):
    relative_pos = (pos[0] - base[0], pos[1] - base[1])
    radian = atan2(relative_pos[1], relative_pos[0])
    if radian < 0:
        radian += 2 * pi
    return radian


# print(get_line_gradient((0,0),(-1,0)))


def tmin_main(data_path, seed_file_path, output_path):
    ego_list, npc_dict = get_data(data_path)
    # print(len(ego_list))
    related_npc = []
    for i in range(len(ego_list)):
        ego_pos = ego_list[i]["Position"]
        q = PriorityQueue()
        for id in npc_dict:
            if i < len(npc_dict[id]):
                npc_pos = npc_dict[id][i]["Position"]
                npc_d = distance(ego_pos, npc_pos)
                if npc_d > 70:
                    continue
                rotation = npc_dict[id][i]["Rotation"]
                scale = npc_dict[id][i]["Scale"]
                wheels = get_wheel_position((npc_pos["x"], npc_pos["z"]), rotation["y"], scale["x"], scale["z"])
                min_gradient = 2 * pi
                max_gradient = 0
                for wheel in wheels:
                    gradient = get_line_gradient((ego_pos["x"], ego_pos["z"]), wheel)
                    min_gradient = min(min_gradient, gradient)
                    max_gradient = max(max_gradient, gradient)
                q.put((npc_d, NPCPosition(id, (npc_pos["x"], npc_pos["z"]), min_gradient, max_gradient)))

        # print("############################")
        NPC_S = []
        # print(q.qsize())
        while not q.empty():
            curr_npc = q.get()[1]
            is_overlap = False
            for s in NPC_S:
                if curr_npc.min_gradient >= s.min_gradient and curr_npc.max_gradient <= s.max_gradient:
                    # print(curr_npc.min_gradient,curr_npc.max_gradient,s.min_gradient,s.max_gradient)
                    is_overlap = True
                    break
            if not is_overlap:
                NPC_S.append(curr_npc)
        # print(len(NPC_S))
        for npc_s in NPC_S:
            if npc_s.id not in related_npc:
                related_npc.append(npc_s.id)

    scenario = _Parser.scenario_parser(seed_file_path)
    NPC = copy.deepcopy(scenario.elements["npc"])
    new_npcs = set()
    for id in related_npc:
        logging_npc = npc_dict[id][0]
        gt_idx = -1
        for idx in range(len(NPC)):
            if distance(logging_npc["Position"], NPC[idx].transform.position) < 2:
                gt_idx = idx
                break
        print(gt_idx)
        if gt_idx >= 0:
            new_npcs.add(gt_idx)
    # fixme
    for idx in range(len(NPC)):
        for id in related_npc:
            logging_npc = npc_dict[id][0]
            if distance(logging_npc["Position"], NPC[idx].transform.position) < 2:
                # if (abs(logging_npc["Position"]["x"] - NPC[idx].transform.position["x"])) < 0.5:
                new_npcs.add(idx)
    New_NPCs = []
    for idx in new_npcs:
        New_NPCs.append(NPC[idx])
    print("tmin NPC from {} to {}".format(len(NPC), len(New_NPCs)))
    scenario.elements["npc"] = New_NPCs
    scenario.to_json()
    scenario.store(output_path)

if __name__ == "__main__":
    data_path = sys.argv[1]
    seed_file_path = sys.argv[2]
    output_path = sys.argv[3]
    tmin_main(data_path,seed_file_path,output_path)

