#include "curvezmq.h"

#include <curve.h>
#include <zauth.h>
#include <zsocket.h>

#include <cassert>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <string>

#include "log.h"
#include "transport_iface.h"
#include "util.h"
#include "zframe.h"

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wwritable-strings"

/* CLIENT *********************************************************************/

struct CurveZMQClient::CurveZMQClientInternals {
  zcert_t *client_cert{nullptr}, *server_cert{nullptr};
  curve_client_t *client{nullptr};
};

CurveZMQClient::CurveZMQClient(
    std::shared_ptr<const ConnectionContext> conn_ctx)
    : ClientIFace(conn_ctx), internal_{new CurveZMQClientInternals{}} {}

CurveZMQClient::~CurveZMQClient() { cleanup(); }

void CurveZMQClient::cleanup() {
  if (internal().client) {
    curve_client_destroy(&internal().client);
    internal().client = nullptr;
  }

  if (internal().client_cert) {
    zcert_destroy(&internal().client_cert);
    internal().client_cert = nullptr;
  }

  if (internal().server_cert) {
    zcert_destroy(&internal().server_cert);
    internal().server_cert = nullptr;
  }
}

bool CurveZMQClient::try_init() {
#ifdef DEBUG
  bool verbose{true};
#else
  bool verbose{false};
#endif  // DEBUG

  const auto client_cert_path = *(conn_ctx_.get()->curvezmq_client_cert);
  const auto server_cert_path = *(conn_ctx_.get()->curvezmq_server_cert);

  LOG(debug) << "Loading client cert: " << client_cert_path;
  internal().client_cert = zcert_load(client_cert_path.c_str());

  LOG(debug) << "Loading server cert: " << client_cert_path;
  internal().server_cert = zcert_load(server_cert_path.c_str());

  internal().client = curve_client_new(&internal().client_cert);
  curve_client_set_metadata(internal().client, "Client",
                            "CURVEZMQ/curve_client");
  // curve_client_set_metadata(internal().client, "Identity", "E475DA11");
  curve_client_set_verbose(internal().client, verbose);

  return true;
}

void CurveZMQClient::connect() {
  std::string server_endpoint = "tcp://" + conn_ctx_.get()->server_hostname +
                                ":" +
                                std::to_string(conn_ctx_.get()->server_port);
  LOG(debug) << "Connecting to " << server_endpoint;

  curve_client_connect(internal().client,
                       const_cast<char *>(server_endpoint.c_str()),
                       reinterpret_cast<std::uint8_t *>(
                           zcert_public_key(internal().server_cert)));
}

/*
 * Assumes that the client and server are synced up on the same transcript.
 */
void CurveZMQClient::recv(std::size_t nbytes) {
  LOG(debug) << "Begin client receiving " << nbytes << " bytes.";
  auto msg = curve_client_recv(internal().client);
  auto data = zmsg_pop(msg);
  assert(data);

  if (zframe_size(data) != nbytes) {
    LOG_AND_THROW_ERROR("Incorrect number of bytes received",
                        std::runtime_error);
  }

  LOG(info) << "Received " << nbytes << " bytes.";

  for (std::size_t idx = 0; idx < nbytes; ++idx) {
    assert(zframe_data(data)[idx] == static_cast<std::uint8_t>(idx));
  }

  zframe_destroy(&data);
  zmsg_destroy(&msg);
  LOG(debug) << "End client receiving " << nbytes << " bytes.";
}

void CurveZMQClient::send(std::size_t nbytes) {
  LOG(debug) << "Begin client sending " << nbytes << " bytes.";
  zframe_t *data = zframe_new(nullptr, nbytes);
  for (std::size_t idx = 0; idx < nbytes; ++idx) {
    zframe_data(data)[idx] = static_cast<std::uint8_t>(idx);
  }
  auto msg = zmsg_new();
  zmsg_prepend(msg, &data);
  curve_client_send(internal().client, &msg);

  zframe_destroy(&data);
  zmsg_destroy(&msg);
  LOG(debug) << "End client sending " << nbytes << " bytes.";
}

/* SERVER *********************************************************************/

struct CurveZMQServer::CurveZMQServerInternals {
  zcert_t *client_cert{nullptr}, *server_cert{nullptr};
  zctx_t *ctx{nullptr};
  zauth_t *auth{nullptr};
  void *zmq_sock{nullptr};
  curve_codec_t *server{nullptr};
  zframe_t *sender{nullptr};
};

CurveZMQServer::CurveZMQServer(
    std::shared_ptr<const ConnectionContext> conn_ctx)
    : ServerIFace(conn_ctx), internal_{new CurveZMQServerInternals{}} {}

CurveZMQServer::~CurveZMQServer() { cleanup(); }

void CurveZMQServer::cleanup() {
  LOG(debug) << "Running server cleanup routine.";

  if (internal().sender) {
    LOG(debug) << "Destroying sender.";
    zframe_destroy(&internal().sender);
    internal().sender = nullptr;
  }

  if (internal().server) {
    curve_codec_destroy(&internal().server);
    internal().server = nullptr;
  }

  if (internal().zmq_sock) {
    assert(internal().ctx);
    zsocket_destroy(internal().ctx, internal().zmq_sock);
    internal().zmq_sock = nullptr;
  }

  if (internal().auth) {
    zauth_destroy(&internal().auth);
    internal().auth = nullptr;
  }

  if (internal().ctx) {
    zctx_destroy(&internal().ctx);
    internal().ctx = nullptr;
  }

  if (internal().server_cert) {
    zcert_destroy(&internal().server_cert);
    internal().server_cert = nullptr;
  }

  if (internal().client_cert) {
    zcert_destroy(&internal().client_cert);
    internal().client_cert = nullptr;
  }
}

bool CurveZMQServer::try_init() {
#ifdef DEBUG
  bool verbose{true};
#else
  bool verbose{false};
#endif  // DEBUG

  const auto client_cert_path = *(conn_ctx_.get()->curvezmq_client_cert);
  const auto server_cert_path = *(conn_ctx_.get()->curvezmq_server_cert);

  LOG(debug) << "Loading client cert: " << client_cert_path;
  internal().client_cert = zcert_load(client_cert_path.c_str());

  LOG(debug) << "Loading server cert: " << client_cert_path;
  internal().server_cert = zcert_load(server_cert_path.c_str());

  //  Install the authenticator
  internal().ctx = zctx_new();
  internal().auth = zauth_new(internal().ctx);
  assert(internal().auth);
  zauth_set_verbose(internal().auth, verbose);
  zauth_configure_curve(internal().auth, "*", CURVE_ALLOW_ANY);

  internal().zmq_sock = zsocket_new(internal().ctx, ZMQ_ROUTER);

  std::string bind_endpoint = "tcp://" + std::string{"127.0.0.1"} + ":" +
                              std::to_string(conn_ctx_.get()->server_port);
  LOG(debug) << "Binding to " << bind_endpoint;
  auto rc = zsocket_bind(internal().zmq_sock, "%s", bind_endpoint.c_str());

  if (rc == -1) {
    LOG_AND_THROW_ERROR("zsocket_bind()", std::runtime_error);
    return false;
  }

  internal().server =
      curve_codec_new_server(internal().server_cert, internal().ctx);

  if (!internal().server) {
    LOG_AND_THROW_ERROR("curve_codec_new_server()", std::runtime_error);
    return false;
  }

  curve_codec_set_verbose(internal().server, verbose);

  //  Set some metadata properties
  curve_codec_set_metadata(internal().server, "Server", "CURVEZMQ/curve_codec");

  return true;
}

void CurveZMQServer::listen() {}

void CurveZMQServer::accept() {
  while (!curve_codec_connected(internal().server)) {
    zframe_t *sender = zframe_recv(internal().zmq_sock);

    // Save the sender for later sends.
    if (!internal().sender) {
      internal().sender = zframe_dup(sender);
    }

    zframe_t *input = zframe_recv(internal().zmq_sock);
    assert(input);
    zframe_t *output = curve_codec_execute(internal().server, &input);
    assert(output);
    zframe_send(&sender, internal().zmq_sock, ZFRAME_MORE);
    zframe_send(&output, internal().zmq_sock, 0);
  }
  //  Check client metadata
  auto *client_name = reinterpret_cast<const char *>(
      zhash_lookup(curve_codec_metadata(internal().server), "client"));
  assert(client_name);
  assert(streq(client_name, "CURVEZMQ/curve_client"));
}

/*
 * Assumes that the client and server are synced up on the same transcript.
 */
void CurveZMQServer::recv(std::size_t nbytes) {
  zframe_t *sender = zframe_recv(
      internal().zmq_sock);  // This pops the sender off the frame stack.
  zframe_t *encrypted = zframe_recv(internal().zmq_sock);
  assert(encrypted);
  zframe_t *cleartext = curve_codec_decode(internal().server, &encrypted);
  assert(cleartext);

  auto cleartext_nbytes = zframe_size(cleartext);

  LOG(info) << "Received " << cleartext_nbytes << " bytes.";

  if (nbytes != cleartext_nbytes) {
    LOG_AND_THROW_ERROR("Incorrect number of bytes received",
                        std::runtime_error);
  }

  zframe_destroy(&sender);
  zframe_destroy(&cleartext);
}

void CurveZMQServer::send(std::size_t nbytes) {
  auto sender = zframe_dup(internal().sender);

  zframe_t *data = zframe_new(nullptr, nbytes);
  for (std::size_t idx = 0; idx < nbytes; ++idx) {
    zframe_data(data)[idx] = static_cast<std::uint8_t>(idx);
  }

  auto ciphertext = curve_codec_encode(internal().server, &data);

  zframe_send(&sender, internal().zmq_sock, ZFRAME_MORE);
  zframe_send(&ciphertext, internal().zmq_sock, 0);

  zframe_destroy(&ciphertext);
  zframe_destroy(&data);
  zframe_destroy(&sender);
}

#pragma GCC diagnostic pop
