#!/usr/bin/env python3
#
# Copyright (c) 2020-2021 LG Electronics, Inc.
#
# This software contains code licensed as described in LICENSE.
#

import json
import logging
import os
import re
import sys
import math
import time
import lgsvl
from datetime import datetime

import timeout_decorator

FORMAT = '%(asctime)-15s [%(levelname)s][%(module)s] %(message)s'

logging.basicConfig(level=logging.INFO, format=FORMAT)
log = logging.getLogger(__name__)

class MyException(Exception):
    pass

class VSERunner:
    def __init__(self, _seed, _workdir, _feedback, _sensor_conf=None):
        # with open(json_file) as f:
        #     self.VSE_dict = json.load(f)
        self.workdir = _workdir
        self.seed = _seed
        self.VSE_dict = _seed.to_json()
        self.sim = None
        self.ego_agents = []
        self.npc_agents = []
        self.npc_count = 0
        self.pedestrian_agents = []
        self.collision_object = set()
        self.feedback = _feedback
        self.maxint = 130
        self.is_collision = False
        self.sensor_conf = _sensor_conf

        self.start_time = 0
        self.collision_time = 0

    def reset(self):
        log.debug("Reset VSE runner")
        self.ego_agents.clear()
        self.npc_agents.clear()
        self.pedestrian_agents.clear()
    
    def close(self):
        self.reset()
        self.sim.reset()
        self.sim.close()

    def setup_sim(self, default_host="127.0.0.1", default_port=8181):
        if not self.sim:
            simulator_host = os.getenv('LGSVL__SIMULATOR_HOST', default_host)
            simulator_port = int(os.getenv('LGSVL__SIMULATOR_PORT', default_port))
            log.debug("simulator_host is {}, simulator_port is {}".format(simulator_host, simulator_port))
            self.sim = lgsvl.Simulator(simulator_host, simulator_port)

    def connect_bridge(self, ego_agent, ego_index=0, default_host="127.0.0.1", default_port=9090):
        autopilot_host_env = "LGSVL__AUTOPILOT_{}_HOST".format(ego_index)
        autopilot_port_env = "LGSVL__AUTOPILOT_{}_PORT".format(ego_index)
        bridge_host = os.environ.get(autopilot_host_env, default_host)
        bridge_port = int(os.environ.get(autopilot_port_env, default_port))
        ego_agent.connect_bridge(bridge_host, bridge_port)

        return bridge_host, bridge_port

    def load_scene(self):
        if "map" not in self.VSE_dict.keys():
            log.error("No map specified in the scenario.")
            # sys.exit(1)
            raise MyException

        scene = self.VSE_dict["map"]["name"]
        log.info("Loading {} map.".format(scene))
        if self.sim.current_scene == scene:
            self.sim.reset()
        else:
            self.sim.load(scene, seed=650387)
        log.info("Loaded.")

    def load_agents(self):
        if "agents" not in self.VSE_dict.keys():
            log.warning("No agents specified in the scenario")
            return

        agents_data = self.VSE_dict["agents"]
        for agent_data in agents_data:
            log.debug("Adding agent {}, type: {}".format(agent_data["variant"], agent_data["type"]))
            agent_type_id = agent_data["type"]
            if agent_type_id == lgsvl.AgentType.EGO.value:
                self.ego_agents.append(agent_data)

            elif agent_type_id == lgsvl.AgentType.NPC.value:
                self.npc_agents.append(agent_data)

            elif agent_type_id == lgsvl.AgentType.PEDESTRIAN.value:
                self.pedestrian_agents.append(agent_data)

            else:
                log.warning("Unsupported agent type {}. Skipping agent.".format(agent_data["type"]))

        self.npc_count = len(self.npc_agents)
        log.info("Loaded {} ego agents".format(len(self.ego_agents)))
        log.info("Loaded {} NPC agents".format(len(self.npc_agents)))
        log.info("Loaded {} pedestrian agents".format(len(self.pedestrian_agents)))
    def set_weather(self):
        if "weather" not in self.VSE_dict.keys() or "rain" not in self.VSE_dict["weather"]:
            log.debug("No weather specified in the scenarios")
            return
        weather_data = self.VSE_dict["weather"]
        weather_state = lgsvl.WeatherState(rain=weather_data["rain"],fog=weather_data["fog"],wetness=weather_data["wetness"],cloudiness=weather_data["cloudiness"],damage=weather_data["damage"])
        self.sim.weather = weather_state

    def set_time(self):
        if "time" not in self.VSE_dict.keys() or "year" not in self.VSE_dict["time"]:
            log.debug("No time specified in the scenarios")
            return
        time_data = self.VSE_dict["time"]
        dt = datetime(
            year = time_data["year"],
            month = time_data["month"],
            day = time_data["day"],
            hour = time_data["hour"],
            minute = time_data["minute"],
            second = time_data["second"]
        )
        self.sim.set_date_time(dt,fixed=False)

    def add_controllables(self):
        if "controllables" not in self.VSE_dict.keys():
            log.debug("No controllables specified in the scenarios")
            return

        controllables_data = self.VSE_dict["controllables"]
        for controllable_data in controllables_data:	
            #Name checking for backwards compability
            spawned = "name" in controllable_data or ("spawned" in controllables_data and controllable_data["spawned"])
            if spawned:
                log.debug("Adding controllable {}".format(controllable_data["name"]))
                controllable_state = lgsvl.ObjectState()
                controllable_state.transform = self.read_transform(controllable_data["transform"])
                try:
                    controllable = self.sim.controllable_add(controllable_data["name"], controllable_state)
                    controllable.attr = controllable_state.transform.position.x
                    policy = controllable_data["policy"]
                    if len(policy) > 0:
                        controllable.control(policy)
                except Exception as e:
                    msg = "Failed to add controllable {}, please make sure you have the correct simulator".format(controllable_data["name"])
                    log.error(msg)
                    log.error("Original exception: " + str(e))
            else:
                uid = controllable_data["uid"]
                log.debug("Setting policy for controllable {}".format(uid))
                controllable = self.sim.get_controllable_by_uid(uid)
                policy = controllable_data["policy"]
                if len(policy) > 0:
                    controllable.control(policy)
                
    def add_ego(self):
        for i, agent in enumerate(self.ego_agents):
            if "id" in agent:
                agent_name = agent["id"]
            else:
                agent_name = agent["variant"]
            agent_state = lgsvl.AgentState()
            if 'initial_speed' in agent:
                agent_state.velocity = lgsvl.Vector(agent['initial_speed']['x'],agent['initial_speed']['y'],agent['initial_speed']['z'])
            agent_state.transform = self.read_transform(agent["transform"])
            if "destinationPoint" in agent:
                agent_destination = lgsvl.Vector(
                    agent["destinationPoint"]["position"]["x"],
                    agent["destinationPoint"]["position"]["y"],
                    agent["destinationPoint"]["position"]["z"]
                )
                #
                # Set distination rotation once it is supported by DreamView
                #
                agent_destination_rotation = lgsvl.Vector(
                    agent["destinationPoint"]["rotation"]["x"],
                    agent["destinationPoint"]["rotation"]["y"],
                    agent["destinationPoint"]["rotation"]["z"],
                )

            def _on_collision(agent1, agent2, contact):
                self.is_collision = True
                self.collision_time = time.time() - self.start_time
                name1 = "STATIC OBSTACLE" if agent1 is None else agent1.name
                name2 = "STATIC OBSTACLE" if agent2 is None else agent2.name
                print("{} collided with {} at {}".format(name1, name2, contact))
                self.seed.store(self.workdir + "/collision/" + str(self.seed.get_hash()))
                if agent1 is None or agent2 is None:
                    pass
                else:
                    self.collision_object.add(agent1.attr)
                    self.collision_object.add(agent2.attr)

                log.info("Stopping simulation")
                self.sim.stop()

            try:
                log.info(self.sensor_conf)
                if self.sensor_conf:
                    ego = self.sim.add_agent(self.sensor_conf, lgsvl.AgentType.EGO, agent_state)
                elif "sensorsConfigurationId" in agent:
                    ego = self.sim.add_agent(agent["sensorsConfigurationId"], lgsvl.AgentType.EGO, agent_state)
                else:
                    ego = self.sim.add_agent(agent_name, lgsvl.AgentType.EGO, agent_state)
                ego.attr = agent_state.transform.position.x
                ego.on_collision(_on_collision)
            except Exception as e:
                msg = "Failed to add agent {}, please make sure you have the correct simulator".format(agent_name)
                log.error(msg)
                log.error("Original exception: " + str(e))
                # sys.exit(1)
                raise MyException

            try:
                bridge_host = self.connect_bridge(ego, i)[0]

                default_modules = [
                    'Localization',
                    'Transform',
                    'Routing',
                    'Prediction',
                    'Planning',
                    'Control',
                ]

                try:
                    modules = os.environ.get("LGSVL__AUTOPILOT_{}_VEHICLE_MODULES".format(i)).split(",")
                    if len(modules) == 0:
                        modules = default_modules
                except Exception:
                    modules = default_modules
                dv = lgsvl.dreamview.Connection(self.sim, ego, bridge_host)

                hd_map = os.environ.get("LGSVL__AUTOPILOT_HD_MAP")
                if not hd_map:
                    hd_map = self.sim.current_scene
                    words = self.split_pascal_case(hd_map)
                    hd_map = ' '.join(words)

                dv.set_hd_map(hd_map)
                dv.set_vehicle(os.environ.get("LGSVL__AUTOPILOT_{}_VEHICLE_CONFIG".format(i), agent["variant"]))
                if "destinationPoint" in agent:
                    dv.setup_apollo(agent_destination.x, agent_destination.z, modules)
                else:
                    log.info("No destination set for EGO {}".format(agent_name))
                    for mod in modules:
                        dv.enable_module(mod)
            except RuntimeWarning as e:
                msg = "Skipping bridge connection for vehicle: {}".format(agent["id"])
                log.warning("Original exception: " + str(e))
                log.warning(msg)
            except Exception as e:
                msg = "Something went wrong with bridge / dreamview connection."
                log.error("Original exception: " + str(e))
                log.error(msg)
                raise MyException

    def add_npc(self):
        for agent in self.npc_agents:
            if "id" in agent:
                agent_name = agent["id"]
            else:
                agent_name = agent["variant"]
            agent_state = lgsvl.AgentState()
            agent_state.transform = self.read_transform(agent["transform"])
            agent_color = lgsvl.Vector(agent["color"]["r"], agent["color"]["g"], agent["color"]["b"]) if "color" in agent else None

            try:
                npc = self.sim.add_agent(agent_name, lgsvl.AgentType.NPC, agent_state, agent_color)
                npc.attr = agent_state.transform.position.x
            except Exception as e:
                msg = "Failed to add agent {}, please make sure you have the correct simulator".format(agent_name)
                log.error(msg)
                log.error("Original exception: " + str(e))
                # sys.exit(1)
                raise MyException

            if agent["behaviour"]["name"] == "NPCWaypointBehaviour":
                waypoints = self.read_waypoints(agent["waypoints"])
                if waypoints:
                    npc.follow(waypoints)
            elif agent["behaviour"]["name"] == "NPCLaneFollowBehaviour":
                npc.follow_closest_lane(
                    True,
                    agent["behaviour"]["parameters"]["maxSpeed"],
                    agent["behaviour"]["parameters"]["isLaneChange"]
                )

    def add_pedestrian(self):
        for agent in self.pedestrian_agents:
            if "id" in agent:
                agent_name = agent["id"]
            else:
                agent_name = agent["variant"]
            agent_state = lgsvl.AgentState()
            agent_state.transform = self.read_transform(agent["transform"])

            try:
                pedestrian = self.sim.add_agent(agent_name, lgsvl.AgentType.PEDESTRIAN, agent_state)
                pedestrian.attr = agent_state.transform.position.x
            except Exception as e:
                msg = "Failed to add agent {}, please make sure you have the correct simulator".format(agent_name)
                log.error(msg)
                log.error("Original exception: " + str(e))
                # sys.exit(1)
                raise MyException

            waypoints = self.read_waypoints(agent["waypoints"])
            if waypoints:
                pedestrian.follow(waypoints)

    def read_transform(self, transform_data):
        transform = lgsvl.Transform()
        transform.position = lgsvl.Vector.from_json(transform_data["position"])
        transform.rotation = lgsvl.Vector.from_json(transform_data["rotation"])

        return transform

    def read_waypoints(self, waypoints_data):
        waypoints = []
        for waypoint_data in waypoints_data:
            position = lgsvl.Vector.from_json(waypoint_data["position"])
            speed = waypoint_data["speed"]
            angle = lgsvl.Vector.from_json(waypoint_data["angle"])
            if "wait_time" in waypoint_data:
                wait_time = waypoint_data["wait_time"]
            else:
                wait_time = waypoint_data["waitTime"]
            trigger = self.read_trigger(waypoint_data)

            if 'trigger_distance' in waypoint_data:
                td = waypoint_data['trigger_distance']
                waypoint = lgsvl.DriveWaypoint(position, speed, angle=angle, idle=wait_time, trigger_distance=td, trigger=trigger)
            else:
                waypoint = lgsvl.DriveWaypoint(position, speed, angle=angle, idle=wait_time, trigger=trigger)

            waypoints.append(waypoint)

        return waypoints

    def read_trigger(self, waypoint_data):
        if "trigger" not in waypoint_data:
            return None
        effectors_data = waypoint_data["trigger"]["effectors"]
        if len(effectors_data) == 0:
            return None

        effectors = []
        for effector_data in effectors_data:
            effector = lgsvl.TriggerEffector(effector_data["typeName"], effector_data["parameters"])
            effectors.append(effector)
        trigger = lgsvl.WaypointTrigger(effectors)

        return trigger

    def split_pascal_case(self, s):
        matches = re.finditer('.+?(?:(?<=[a-z])(?=[A-Z\d])|(?<=[A-Z\d])(?=[A-Z][a-z])|$)', s)
        return [m.group(0) for m in matches]

    def calc_distance(self, ego, npc):
        ego_x = ego["x"]
        ego_y = ego["y"]
        ego_z = ego["z"]
        npc_x = npc["x"]
        npc_y = npc["y"]
        npc_z = npc["z"]
        dis = math.pow(npc_x - ego_x, 2) + math.pow(npc_y - ego_y, 2) + math.pow(npc_z - ego_z, 2)
        dis = math.sqrt(dis)
        return dis

    def calc_score(self):
        os.system("sleep 1 && cp ~/.config/unity3d/LGElectronics/SVLSimulator-2021.2.2/in.txt.gz /tmp")
        os.system("gunzip -f /tmp/in.txt.gz")

        f = open("/tmp/in.txt", 'r')

        ego_list = []
        npc_list = []

        tmp_buf = f.readline()
        while tmp_buf:
            tmp_obj = json.loads(tmp_buf)

            if type(tmp_obj) is dict:
                ego_list.append(tmp_obj)
            if type(tmp_obj) is list:
                if len(tmp_obj) != 0:
                    npc_list.append(tmp_obj)
            
            tmp_buf = f.readline()

        f.close()
        minD = self.maxint

        assert len(ego_list) == len(npc_list)

        for i in range(len(ego_list)):
            ego = ego_list[i]["Position"]
            npcs = npc_list[i]

            # print("------")
            for _i in npcs:
                npc = _i["Position"]
                curD = self.calc_distance(ego, npc)
                # print(_i["Id"], curD)
                if minD > curD:
                    minD = curD

        log.info(" *** minD is " + str(minD) + " *** ")

        # fitness = -1 * minD
        # score = (fitness + self.maxint) / float(len(self.npc_agents) - 1)
        score = minD / float(self.npc_count)
        return score

    @timeout_decorator.timeout(180)
    def run(self, duration=0.0, force_duration=False, loop=False):
        log.debug("Duration is set to {}.".format(duration))
        self.setup_sim()

        try:
            while True:
                self.load_scene()
                self.load_agents()
                self.set_weather()
                self.set_time()
                self.add_ego()  # Must go first since dreamview api may call sim.run()
                self.add_npc()
                self.add_pedestrian()
                self.add_controllables()

                # def _on_agents_traversed_waypoints():
                #     log.info("All agents traversed their waypoints.")

                #     if not force_duration:
                #         log.info("Stopping simulation")
                #         self.sim.stop()

                # self.sim.agents_traversed_waypoints(_on_agents_traversed_waypoints)

                log.info("Starting scenario...")
                self.start_time = time.time()
                self.sim.run(duration)
                log.info("Scenario simulation ended.")

                if loop:
                    self.reset()
                else:
                    break
        except MyException:
            log.error("Program exit!")
            exit(-1)

        self.close()

        # calculate feedback score
        if self.feedback == "avfuzzer":
            score = self.calc_score()
        else:
            score = 0

        if self.is_collision:
            score = 100 

        return score

