#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"

#include "tls.h"

#include <netdb.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
#include <sys/socket.h>
#include <sys/types.h>

#include <algorithm>
#include <array>
#include <cassert>
#include <cstdint>
#include <stdexcept>
#include <string>
#include <vector>

#include "log.h"
#include "net_util.h"
#include "util.h"

void tls_global_init() {
  SSL_library_init();
  OpenSSL_add_all_algorithms();
}

static void write_(SSL *p_ssl, std::size_t nbytes) {
  static const std::uint8_t BYTE_VAL = 77;

  std::vector<std::uint8_t> buf;
  buf.resize(nbytes);
  std::fill_n(buf.begin(), nbytes, BYTE_VAL);

  std::size_t nbytes_written_cumul{0}, nbytes_left{nbytes};

  while (nbytes_left > 0) {
    std::size_t write_nbytes{0};

    auto write_rc = SSL_write_ex(p_ssl, buf.data() + nbytes_written_cumul,
                                 nbytes_left, &write_nbytes);

    if (write_rc == 0) {
      LOG_AND_THROW_ERROR("SSL_write_ex()", std::runtime_error);
    }

    LOG(trace) << "send() returned " << write_nbytes;

    if (write_nbytes > 0) {
      nbytes_written_cumul += write_nbytes;
      nbytes_left -= write_nbytes;
    } else {
      LOG_AND_THROW_ERROR("send()", std::runtime_error);
    }
  }
}

static void recv_(SSL *p_ssl, std::size_t nbytes) {
  static const std::uint8_t BYTE_VAL = 77;

  std::vector<std::uint8_t> buf;
  buf.resize(nbytes);

  int nbytes_read_cumul{0};
  std::size_t nbytes_left{nbytes};

  while (nbytes_left > 0) {
    std::size_t nbytes_read{0};

    auto read_rc = SSL_read_ex(p_ssl, buf.data() + nbytes_read_cumul,
                               nbytes_left, &nbytes_read);

    if (read_rc == 0) {
      LOG_AND_THROW_ERROR("SSL_read_ex()", std::runtime_error);
    }

    LOG(trace) << "recv() returned " << nbytes_read;

    if (nbytes_read > 0) {
      nbytes_read_cumul += nbytes_read;
      nbytes_left -= nbytes_read;
    } else { // Error on read
      LOG_AND_THROW_ERROR("recv() failed.", std::runtime_error);
    }
  }

  assert(std::all_of(buf.cbegin(), buf.cend(),
                     [](std::uint8_t x) -> bool { return x == BYTE_VAL; }));
}

static void cleanup_socket_(int sock) {
  if (sock > 0) {
    shutdown(sock, SHUT_WR);

    // Read until 0.
    constexpr std::size_t NBYTES{512};
    auto buf = std::array<std::uint8_t, NBYTES>{};
    ssize_t recv_rc{-1};

    // Drain any remaining data in the socket.
    while (recv_rc != 0) {
      recv_rc = ::recv(sock, buf.data(), NBYTES, 0);

      if (recv_rc == -1) {
        LOG(warning) << "During close, recv() returned " << errno;
      } else if (recv_rc == 0) {
        LOG(info) << "Connection close detected.";
      } else {
        LOG(debug) << "During cleanup, read " << recv_rc << " bytes.";
      }
    }

    close(sock);
  }
}

static int int_of_tls_version_(TLSVersion value) {
  int retval{-1};

  switch (value) {
  case TLSVersion::k1_2:
    retval = TLS1_2_VERSION;
    break;
  case TLSVersion::k1_3:
    retval = TLS1_3_VERSION;
    break;
  default:
    UNREACHABLE;
    break;
  }

  return retval;
}

/* CLIENT *********************************************************************/

struct TLS12Client::TLS12ClientInternals {
  TLSVersion version;
  SSL_CTX *ctx{nullptr};
  int sock{-1};
  SSL *ssl{nullptr};
};

TLS12Client::~TLS12Client() { cleanup(); }

TLS12Client::TLS12Client(std::shared_ptr<const ConnectionContext> conn_ctx,
                         TLSVersion version)
    : ClientIFace(conn_ctx), internal_{new TLS12ClientInternals{version}},
      state_{State::kUninit} {}

void TLS12Client::cleanup() {
  if (state_ == kConnected && internal().ssl) {
    SSL_shutdown(internal().ssl);
  }

  if (internal().sock > 0) {
    cleanup_socket_(internal().sock);
    internal().sock = -1;
  }

  if (internal().ssl) {
    SSL_free(internal().ssl);
    internal().ssl = nullptr;
  }

  if (internal().ctx) {
    SSL_CTX_free(internal().ctx);
    internal().ctx = nullptr;
  }
}

bool TLS12Client::try_init() {
  if (state_ != State::kUninit) {
    LOG(error) << "Tried to initialize client from wrong state.";
    return false;
  }

  { // Setup context
    auto ctx = SSL_CTX_new(TLS_client_method());
    auto rc = SSL_CTX_set_min_proto_version(
        ctx, int_of_tls_version_(internal().version));
    assert(rc == 1);
    rc = SSL_CTX_set_max_proto_version(ctx,
                                       int_of_tls_version_(internal().version));
    assert(rc == 1);

    if (!ctx) {
      LOG(error) << "Could not create SSL context.";
      state_ = State::kInvalid;
      return false;
    } else {
      internal().ctx = ctx;
    }
  }

  state_ = State::kInit;
  return true;
}

void TLS12Client::connect() {
  std::string str_port = std::to_string(conn_ctx_->server_port);
  auto addrinfo = addrinfo_ipv4_tcp(conn_ctx_.get()->server_hostname.c_str(),
                                    str_port.c_str());

  if (!addrinfo) {
    LOG_AND_THROW_ERROR("addrinfo failed.", std::runtime_error);
  }

  internal().sock = socket_from_addrinfo(*addrinfo);

  if (internal().sock == -1) {
    LOG_AND_THROW_ERROR("connect() failed.", std::runtime_error);
  }

  connect_socket_from_addrinfo(internal().sock, *addrinfo);
  freeaddrinfo(addrinfo);
  addrinfo = nullptr;

  internal().ssl = SSL_new(internal().ctx);

  if (!internal().ssl) {
    LOG_AND_THROW_ERROR("SSL_new() failed.", std::runtime_error);
  }

  auto ssl_set_fd_rc = SSL_set_fd(internal().ssl, internal().sock);

  if (ssl_set_fd_rc == 0) {
    LOG_AND_THROW_ERROR("SSL_set_fd", std::runtime_error);
  }

  SSL_connect(internal().ssl);

  LOG(info) << "Client connected.";

  state_ = State::kConnected;
}

void TLS12Client::recv(std::size_t nbytes) { recv_(internal().ssl, nbytes); }

void TLS12Client::send(std::size_t nbytes) { write_(internal().ssl, nbytes); }

/* SERVER *********************************************************************/

static unsigned int psk_callback_(SSL *ssl __attribute__((unused)),
                                  const char *id __attribute__((unused)),
                                  unsigned char *psk __attribute__((unused)),
                                  unsigned int max_psk_len
                                  __attribute__((unused))) {
  LOG(trace) << "In psk callback.";
  return std::strlen(reinterpret_cast<const char *>(psk));
}

struct TLS12Server::TLS12ServerInternals {
  TLSVersion version;
  SSL_CTX *ctx{nullptr};
  struct addrinfo *servinfo{nullptr};
  int server_sock{-1}, client_sock{-1};
  struct sockaddr_storage client_addr {};
  socklen_t client_addr_size{};

  SSL *ssl{nullptr};
};

TLS12Server::TLS12Server(std::shared_ptr<const ConnectionContext> conn_ctx,
                         TLSVersion version)
    : ServerIFace(conn_ctx), internal_{new TLS12ServerInternals{version}},
      state_{State::kUninit} {}

TLS12Server::~TLS12Server() { cleanup(); }

void TLS12Server::cleanup() {
  LOG(trace) << "Cleaning up server.";

  if (state_ == kConnected && internal().ssl) {
    SSL_shutdown(internal().ssl);
  }

  if (internal().ssl) {
    SSL_free(internal().ssl);
    internal().ssl = nullptr;
  }

  if (internal().client_sock > 0) {
    cleanup_socket_(internal().client_sock);
    internal().client_sock = -1;
  }

  if (internal().server_sock > 0) {
    close(internal().server_sock);
    internal().server_sock = -1;
  }

  if (internal().servinfo) {
    freeaddrinfo(internal().servinfo);
    internal().servinfo = nullptr;
  }

  if (internal().ctx) {
    SSL_CTX_free(internal().ctx);
    internal().ctx = nullptr;
  }

  state_ = State::kDead;
}

bool TLS12Server::try_init() {
  if (state_ != State::kUninit) {
    LOG(error) << "Tried to initialize server from wrong state.";
    return false;
  }

  { // Setup context
    auto ctx = SSL_CTX_new(TLS_server_method());
    auto rc = SSL_CTX_set_min_proto_version(
        ctx, int_of_tls_version_(internal().version));
    assert(rc == 1);
    rc = SSL_CTX_set_max_proto_version(ctx,
                                       int_of_tls_version_(internal().version));
    assert(rc == 1);

    if (!ctx) {
      LOG(error) << "Could not create SSL context.";
      return false;
    } else {
      internal().ctx = ctx;
    }
  }

  { // Setup server cert
    auto rc = SSL_CTX_use_certificate_file(
        internal().ctx, conn_ctx_.get()->tls_server_cert->c_str(),
        SSL_FILETYPE_PEM);

    if (rc <= 0) {
      LOG(error) << "Could not set SSL cert.";
      return false;
    }
  }

  { // Setup server key
    auto rc = SSL_CTX_use_PrivateKey_file(
        internal().ctx, conn_ctx_.get()->tls_server_key->c_str(),
        SSL_FILETYPE_PEM);

    if (rc <= 0) {
      LOG(error) << "Could not set SSL key.";
      return false;
    }
  }

  LOG(trace) << "Set up SSL cert and key.";

  SSL_CTX_set_psk_server_callback(internal().ctx, psk_callback_);

  std::string str_port = std::to_string(conn_ctx_->server_port);
  internal().servinfo = addrinfo_ipv4_tcp("localhost", str_port.c_str());
  internal().server_sock = socket_from_addrinfo(*internal().servinfo);
  bind_socket_from_addrinfo(internal().server_sock, *internal().servinfo);

  LOG(info) << "TLS 1.2 server initialized.";

  state_ = State::kInit;
  return true;
}

void TLS12Server::listen() {
  if (state_ != State::kInit) {
    LOG_AND_THROW_ERROR("listen() called from wrong state.",
                        std::runtime_error);
  }

  LOG(trace) << "Calling listen().";

  auto listen_rc = ::listen(internal().server_sock, kBacklog);

  if (listen_rc == -1) {
    LOG_AND_THROW_ERROR("listen().", std::runtime_error);
  } else {
    LOG(trace) << "listen() succeeded.";
  }

  state_ = State::kListening;
}

void TLS12Server::accept() {
  if (state_ != State::kListening) {
    LOG_AND_THROW_ERROR("accept() called from wrong state.",
                        std::runtime_error);
  }

  internal().client_addr_size = sizeof(struct sockaddr_storage);

  internal().client_sock =
      ::accept(internal().server_sock,
               reinterpret_cast<struct sockaddr *>(&(internal().client_addr)),
               &(internal().client_addr_size));

  set_socket_reuseaddr(internal().client_sock);

  if (internal().client_sock == -1) {
    LOG_AND_THROW_ERROR("accept()", std::runtime_error);
  } else {
    LOG(trace) << "accept() succeeded.";
  }

  assert(internal().ctx);
  internal().ssl = SSL_new(internal().ctx);

  if (!internal().ssl) {
    LOG_AND_THROW_ERROR("SSL_new()", std::runtime_error);
  } else {
    LOG(trace) << "SSL_new() succeeded.";
  }

  auto ssl_set_fd_rc = SSL_set_fd(internal().ssl, internal().client_sock);

  if (ssl_set_fd_rc == 0) { // Zero is error case
    LOG_AND_THROW_ERROR("SSL_set_fd", std::runtime_error);
  } else {
    LOG(trace) << "SSL_set_fd() succeeded.";
  }

  auto ssl_accept_rc = SSL_accept(internal().ssl);

  if (ssl_accept_rc <= 0) {
    ERR_print_errors_fp(stderr);
    LOG_AND_THROW_ERROR("SSL_accept", std::runtime_error);
  } else {
    LOG(trace) << "SSL_accept() succeeded.";
  }

  state_ = kConnected;
}

void TLS12Server::recv(std::size_t nbytes) { recv_(internal().ssl, nbytes); }

void TLS12Server::send(std::size_t nbytes) { write_(internal().ssl, nbytes); }

#pragma GCC diagnostic pop
