#include "driver_responder.h"

#include <cassert>
#include <chrono>
#include <string>
#include <thread>
#include <type_traits>
#include <utility>
#include <zmq.hpp>

#include "driver_common.h"
#include "log.h"
#include "transcript.pb.h"

struct DriverResponder::DriverResponderInternals {
  zmq::context_t ctx;
  zmq::socket_t sock;
};

DriverResponder::~DriverResponder() {}

DriverResponder::DriverResponder(
    std::variant<std::unique_ptr<ClientIFace>, std::unique_ptr<ServerIFace>>
        endpoint_iface,
    std::uint16_t port)
    : internal_{new DriverResponderInternals{}},
      client_{},
      server_{},
      port_{port} {
  std::visit(
      [this](auto &&arg) {
        using T = std::decay_t<decltype(arg)>;

        if constexpr (std::is_same_v<T, std::unique_ptr<ClientIFace>>) {
          LOG(info) << "Driver responder is in the client role.";
          client_ = std::move(arg);
        } else if constexpr (std::is_same_v<T, std::unique_ptr<ServerIFace>>) {
          LOG(info) << "Driver responder is in the server role.";
          server_ = std::move(arg);
        } else {
          assert(false);
        }
      },
      endpoint_iface);
}

bool DriverResponder::init() {
  internal().sock = zmq::socket_t{internal().ctx, zmq::socket_type::rep};

  std::string bind_addr = "tcp://*:" + std::to_string(port_);

  LOG(info) << "Binding to " << bind_addr << ".";

  internal().sock.bind(bind_addr);

  if (client_role()) {
    return (*client_)->try_init();
  } else {
    assert(server_role());
    return (*server_)->try_init();
  }
}

void DriverResponder::work_loop() {
  LOG(debug) << "Entering work loop.";

  bool done = false;

  while (!done) {
    LOG(debug) << "Begin receiving message.";
    auto cmd = recv_command_synch(&internal().sock);
    LOG(debug) << "End receiving message.";

    LOG(info) << "Received " << (*cmd) << ".";

    send_command_synch(make_ack_command(cmd->kind()), &internal().sock);

    std::size_t nbytes{0};
    std::unique_ptr<Transcript> transcript{};

    switch (cmd->kind()) {
      case Command::ACK:
        break;
      case Command::STOP:
        done = true;
        break;
      case Command::CLIENT_CONNECT:
        assert(client_role());
        (*client_)->connect();
        break;
      case Command::SERVER_LISTEN:
        assert(server_role());
        (*server_)->listen();
        (*server_)->accept();
        break;
      case Command::SEND:
        nbytes = cmd->nbytes();
        if (client_role()) {
          (*client_)->send(nbytes);
        } else {
          assert(server_role());
          (*server_)->send(nbytes);
        }
        break;
      case Command::RECV:
        nbytes = cmd->nbytes();
        if (client_role()) {
          (*client_)->recv(nbytes);
        } else {
          assert(server_role());
          (*server_)->recv(nbytes);
        }
        break;
      case Command::TRANSCRIPT:
        LOG(debug) << "Running transcript.";
        execute_transcript(cmd->transcript());
        break;
      default:
        assert(false);
        break;
    }
  }

  if (client_role()) {
    (*client_)->cleanup();
  } else {
    assert(server_role());
    (*server_)->cleanup();
  }

  LOG(info) << "Exiting.";
}

void DriverResponder::execute_transcript(const Transcript &transcript) {
  for (const auto &round : transcript.rounds()) {
    if (client_role()) {  // Client
      if (round.client_send_nbytes() > 0) {
        LOG(info) << "Begin client sending " << round.client_send_nbytes()
                  << " bytes.";
        (*client_)->send(round.client_send_nbytes());
        LOG(info) << "End client sending " << round.client_send_nbytes()
                  << " bytes.";
      }

      if (round.server_send_nbytes() > 0) {
        LOG(info) << "Begin client receiving " << round.server_send_nbytes()
                  << " bytes.";
        (*client_)->recv(round.server_send_nbytes());
        LOG(info) << "End client receiving " << round.server_send_nbytes()
                  << " bytes.";
      }
    } else {  // Server
      assert(server_role());

      if (round.server_send_nbytes() > 0) {
        LOG(info) << "Begin server sending " << round.server_send_nbytes()
                  << " bytes.";
        (*server_)->send(round.server_send_nbytes());
        LOG(info) << "End server sending " << round.server_send_nbytes()
                  << " bytes.";
      }

      if (round.client_send_nbytes() > 0) {
        LOG(info) << "Begin server receiving " << round.client_send_nbytes()
                  << " bytes.";
        (*server_)->recv(round.client_send_nbytes());
        LOG(info) << "End server receiving " << round.client_send_nbytes()
                  << " bytes.";
      }
    }

    LOG(info) << "Sleeping for " << round.sleep_usec() << " microseconds.";
    std::this_thread::sleep_for(std::chrono::microseconds(round.sleep_usec()));
  }
}
