#!/usr/bin/env python3

from abc import ABCMeta, abstractmethod
import argparse
import asyncio
from enum import Enum
import logging
import socket
import time
import threading
import sys

import zmq

from command_pb2 import Command
from ptadapter import ClientAdapter, ServerAdapter

import common as cmn


class Endpoint(Enum):
    CLIENT = 1
    SERVER = 2

    def of_str(value: str):
        if value == "client":
            return Endpoint.CLIENT
        elif value == "server":
            return Endpoint.SERVER
        else:
            return None


class SynchServer:
    class Instruction(Enum):
        INIT = 1
        STOP = 2

    def __init__(self, pt_exec, transport, options, listen_port, forward_port):
        self._pt_exec = pt_exec
        self._transport = transport
        self._options = options
        self._listen_port = listen_port
        self._forward_port = forward_port

    async def _init(self):
        logging.info(f"{self._transport}")

        self._srv_adapter = ServerAdapter(
            self._pt_exec,
            None,
            "127.0.0.1",
            self._forward_port,
        )

        self._srv_adapter.add_transport(
            self._transport,
            "192.168.10.127",
            self._listen_port,
            options=self._options,
        )

        logging.info("Server has initialized.")

        await self._srv_adapter.start()
        logging.info("Server has started.")

    async def _stop(self):
        await self._srv_adapter.stop()
        logging.info("Server has stopped.")
        await self._srv_adapter.wait()
        logging.info("Server has exited.")

    def execute_instruction(self, ins: Instruction, loop):
        if ins == SynchServer.Instruction.INIT:
            loop.run_until_complete(self._init())
        elif ins == SynchServer.Instruction.STOP:
            loop.run_until_complete(self._stop())


class SynchClient:
    class Instruction(Enum):
        INIT = 1
        CONNECT = 2
        STOP = 3

    def __init__(self, pt_exec, transports, options, connect_host, connect_port):
        self._client_adapter = None
        self._pt_exec = pt_exec
        self._transports = transports
        self._options = options
        self._connect_host = connect_host
        self._connect_port = connect_port
        self._reader = None
        self._writer = None

    async def _init(self):
        logging.info(f"{self._transports}")
        self._client_adapter = ClientAdapter(self._pt_exec, "./x", [self._transports])
        logging.info("Client has initialized.")
        await self._client_adapter.start()
        logging.info("Client has started.")

    async def _stop(self):
        self._writer.close()
        await self._writer.wait_closed()

        await self._client_adapter.stop()
        logging.info("Client has stopped.")
        await self._client_adapter.wait()
        logging.info("Client has exited.")

    async def _connect(self):
        logging.info("Client connecting...")

        reader, writer = await self._client_adapter.open_transport_connection(
            self._transports, self._connect_host, self._connect_port, args=self._options
        )

        self._reader = reader
        self._writer = writer

        logging.info("Client has connected.")

    async def _send(self, nbytes: int):
        data = bytes([0xEF for _ in range(nbytes)])
        self._writer.write(data)
        await self._writer.drain()

    async def _recv(self, nbytes: int):
        await self._reader.readexactly(nbytes)

    def send(self, nbytes, loop):
        loop.run_until_complete(self._send(nbytes))

    def recv(self, nbytes, loop):
        loop.run_until_complete(self._recv(nbytes))

    def execute_instruction(self, ins: Instruction, loop):
        if ins == SynchClient.Instruction.INIT:
            loop.run_until_complete(self._init())
        if ins == SynchClient.Instruction.CONNECT:
            loop.run_until_complete(self._connect())
        elif ins == SynchClient.Instruction.STOP:
            loop.run_until_complete(self._stop())


class Role(Enum):
    CLIENT = 1
    SERVER = 2


class Protocol(metaclass=ABCMeta):
    def __init__(self, role: Role):
        self._role = role

    @abstractmethod
    def handle_command(self, cmd) -> bool:
        pass


class ClientProtocol(Protocol):
    def __init__(self, server_addr, server_port, pt_exec, transport, options, loop):
        super().__init__(Role.CLIENT)
        self._server_addr = server_addr
        self._server_port = server_port
        self._loop = loop

        self._synch_client = SynchClient(
            pt_exec, transport, options, server_addr, server_port
        )

        self._synch_client.execute_instruction(SynchClient.Instruction.INIT, self._loop)

    def handle_command(self, cmd) -> bool:
        if cmd.kind == Command.Kind.STOP:
            self._synch_client.execute_instruction(
                SynchClient.Instruction.STOP, self._loop
            )
            return True
        elif cmd.kind == Command.Kind.CLIENT_CONNECT:
            logging.info("Connecting to server")
            self._synch_client.execute_instruction(
                SynchClient.Instruction.CONNECT, self._loop
            )
        elif cmd.kind == Command.Kind.SEND:
            pass
        elif cmd.kind == Command.Kind.RECV:
            pass
        elif cmd.kind == Command.Kind.TRANSCRIPT:
            transcript = cmd.transcript

            for round in transcript.rounds:
                if round.client_send_nbytes > 0:
                    self._synch_client.send(round.client_send_nbytes, self._loop)
                    logging.info(f"Sent {round.client_send_nbytes} bytes.")
                if round.server_send_nbytes > 0:
                    self._synch_client.recv(round.server_send_nbytes, self._loop)
                    logging.info(f"Received {round.server_send_nbytes} bytes.")

                if round.sleep_usec > 0:
                    logging.info(f"Sleeping for {round.sleep_usec} usec.")
                    time.sleep(round.sleep_usec / 1000000.0)

        else:
            logging.error(f"Received unsupported command {cmd}")
            assert False

        return False


class ServerProtocol(Protocol):
    def __init__(self, port, pt_exec, transport, options, loop):
        super().__init__(Role.SERVER)
        #self._sock = socket.create_server(("localhost", 0), reuse_port=True)
        self._sock = socket.create_server(("", 0), reuse_port=True)
        self._conn_sock = None

        self._loop = loop

        listen_port = self._sock.getsockname()[1]

        self._pt_server = SynchServer(
            pt_exec, transport, options, listen_port=port, forward_port=listen_port
        )

        self._pt_server.execute_instruction(SynchServer.Instruction.INIT, self._loop)

    def handle_command(self, cmd) -> bool:
        if cmd.kind == Command.Kind.STOP:
            self._conn_sock.shutdown(socket.SHUT_WR)
            recv_rest(self._conn_sock)
            self._conn_sock.close()
            self._sock.close()

            self._pt_server.execute_instruction(
                SynchServer.Instruction.STOP, self._loop
            )
            return True
        elif cmd.kind == Command.Kind.SERVER_LISTEN:
            self._sock.listen()
            logging.info("Server listening.")
            logging.debug(f"Server sock: {self._sock}")
            self._conn_sock, _ = self._sock.accept()
            logging.info("Received connection.")
        elif cmd.kind == Command.Kind.SEND:
            send_determined(self._conn_sock, cmd.nbytes)
        elif cmd.kind == Command.Kind.RECV:
            recv_determined(self._conn_sock, cmd.nbytes)
        elif cmd.kind == Command.Kind.TRANSCRIPT:
            transcript = cmd.transcript

            for round in transcript.rounds:
                if round.client_send_nbytes > 0:
                    recv_determined(self._conn_sock, round.client_send_nbytes)
                if round.server_send_nbytes > 0:
                    send_determined(self._conn_sock, round.server_send_nbytes)

                if round.sleep_usec > 0:
                    logging.info(f"Sleeping for {round.sleep_usec} usec.")
                    time.sleep(round.sleep_usec / 1000000.0)
        else:
            logging.error(f"Received unsupported command {cmd}")
            assert False

        return False


def recv_determined(sock, nbytes):
    nbytes_rem = nbytes

    while nbytes_rem > 0:
        try:
            b = sock.recv(4096)
            assert len(b) > 0
            logging.info(f"Received {len(b)} bytes.")
            nbytes_rem -= len(b)
        except Exception as e:
            logging.error(f"Error on recv: {e}.")
            sys.exit(-1)


def send_determined(sock, nbytes):
    try:
        sock.sendall(bytes([0xEF for _ in range(nbytes)]))
        logging.info(f"Sent {nbytes} bytes.")
    except Exception as e:
        logging.error(f"Error on send: {e}.")
        sys.exit(-1)


def recv_rest(sock):
    try:
        b = sock.recv(4096)

        while len(b) > 0:
            b = sock.recv(4096)
    except Exception as e:
        logging.warning(f"Received {e} during recv_rest().")


def work_loop(socket: zmq.Socket, protocol: Protocol):
    done = False

    while not done:
        cmd = cmn.recv_command_synch(socket)
        cmn.send_command_synch(socket, cmn.make_ack_command(cmd.kind))

        done = protocol.handle_command(cmd)


def execute_protocol(context: zmq.Context, port: int, protocol: Protocol):
    socket = context.socket(zmq.REP)
    socket.bind(f"tcp://*:{port}")
    work_loop(socket, protocol)


def main(args: argparse.Namespace):
    def level_of_str(value: str):
        if value == "debug":
            return logging.DEBUG
        elif value == "info":
            return logging.INFO
        elif value == "warning":
            return logging.WARNING
        elif value == "error":
            return logging.ERROR
        elif value == "critical":
            return logging.CRITICAL
        else:
            return None

    logging.basicConfig(level=level_of_str(args.log_level), stream=sys.stderr)
    logging.info(f"Command line arguments: {args}")

    cfg = cmn.parse_config(args.config_filepath)
    logging.info(f"Protocol configuration: {cfg}")

    endpoint = Endpoint.of_str(args.mode)

    protocol = args.protocol
    pt_exec = cfg[protocol]["exe_path"]

    assert protocol in ["obfs4", "proteus"]

    context = zmq.Context()
    ev_loop = asyncio.new_event_loop()

    if endpoint == Endpoint.CLIENT:
        logging.info("Running in client mode.")
        client_opts = cmn.parse_transport_options(cfg, protocol, "client")

        client_protocol = ClientProtocol(
            args.hostname,
            args.target_port,
            pt_exec,
            protocol,
            client_opts,
            ev_loop,
        )

        execute_protocol(context, args.control_port, client_protocol)
    else:
        assert endpoint == Endpoint.SERVER
        logging.info("Running in server mode.")
        server_opts = cmn.parse_transport_options(cfg, protocol, "server")

        server_protocol = ServerProtocol(
            args.listen_port, pt_exec, protocol, server_opts, ev_loop
        )

        execute_protocol(context, args.control_port, server_protocol)

    logging.info("Done!")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("mode", type=str, help="client|server")
    parser.add_argument(
        "control_port", type=int, help="listen from control commands on this port"
    )
    parser.add_argument("config_filepath", type=str, help="path to config file")
    parser.add_argument("protocol", type=str, help="obfs4|proteus")

    parser.add_argument(
        "-H",
        "--hostname",
        help="required client option, connect to given target host",
        type=str,
    )

    parser.add_argument(
        "-t",
        "--target-port",
        help="required client option, connect to target port",
        type=int,
    )

    parser.add_argument(
        "-L",
        "--listen-port",
        help="required server option, listen for client connections on this port",
        type=int,
    )

    parser.add_argument(
        "-l",
        "--log-level",
        help="debug|info|warning|error|critical",
        type=str,
        default="info",
    )
    return parser.parse_args()


if __name__ == "__main__":
    main(parse_args())
