#include "rs.h"

#include <etgen_rs.h>

#include <cassert>
#include <memory>

#include "transport_iface.h"
#include "util.h"

/* CLIENT *********************************************************************/

struct RSClient::Internal {
  void *client{nullptr};
};

RSClient::RSClient(std::shared_ptr<const ConnectionContext> conn_ctx,
                   RSKind rs_kind)
    : ClientIFace{conn_ctx}, internal_{new RSClient::Internal} {
  const char *ptr{nullptr};
  if (conn_ctx->snow_noise_params.has_value()) {
    ptr = conn_ctx->snow_noise_params->c_str();
  }

  auto config = ConfigRS{
      .client_hostname = conn_ctx->client_hostname.c_str(),
      .server_hostname = conn_ctx->server_hostname.c_str(),
      .client_port = conn_ctx->client_port,
      .server_port = conn_ctx->server_port,
      .snow_noise_params = ptr,
  };

  switch (rs_kind) {
    case RSKind::kSnow:
      internal().client = snow_client_create(config);
      break;
    case RSKind::kSecio:
      internal().client = secio_client_create(config);
      break;
    default:
      UNREACHABLE;
  }
}

RSClient::~RSClient() { cleanup(); }

void RSClient::cleanup() {
  if (internal().client) {
    rs_client_cleanup(internal().client);
    snow_client_destroy(internal().client);
    internal().client = nullptr;
  }
}
bool RSClient::try_init() {
  assert(internal().client);
  return rs_client_init(internal().client);
}
void RSClient::connect() {
  assert(internal().client);
  rs_client_connect(internal().client);
}

void RSClient::recv(std::size_t nbytes) {
  assert(internal().client);
  rs_client_recv(internal().client, nbytes);
}

void RSClient::send(std::size_t nbytes) {
  assert(internal().client);
  rs_client_send(internal().client, nbytes);
}

/* SERVER *********************************************************************/

struct RSServer::Internal {
  void *server{nullptr};
};

RSServer::RSServer(std::shared_ptr<const ConnectionContext> conn_ctx,
                   RSKind rs_kind)
    : ServerIFace{conn_ctx}, internal_{new RSServer::Internal} {
  const char *ptr{nullptr};
  if (conn_ctx->snow_noise_params.has_value()) {
    ptr = conn_ctx->snow_noise_params->c_str();
  }

  auto config = ConfigRS{
      .client_hostname = conn_ctx->client_hostname.c_str(),
      .server_hostname = conn_ctx->server_hostname.c_str(),
      .client_port = conn_ctx->client_port,
      .server_port = conn_ctx->server_port,
      .snow_noise_params = ptr,
  };

  switch (rs_kind) {
    case RSKind::kSnow:
      internal().server = snow_server_create(config);
      break;
    case RSKind::kSecio:
      internal().server = secio_server_create(config);
      break;
    default:
      UNREACHABLE;
      break;
  }
}

RSServer::~RSServer() { cleanup(); }

void RSServer::cleanup() {
  if (internal().server) {
    rs_server_cleanup(internal().server);
    snow_server_destroy(internal().server);
    internal().server = nullptr;
  }
}

bool RSServer::try_init() {
  assert(internal().server);
  return rs_server_init(internal().server);
}
void RSServer::listen() {
  assert(internal().server);
  rs_server_listen(internal().server);
}

void RSServer::accept() {
  assert(internal().server);
  rs_server_accept(internal().server);
}

void RSServer::recv(std::size_t nbytes) {
  assert(internal().server);
  return rs_server_recv(internal().server, nbytes);
}

void RSServer::send(std::size_t nbytes) {
  assert(internal().server);
  return rs_server_send(internal().server, nbytes);
}
