#include "ssh.h"

#include <libssh/libssh.h>
#include <libssh/server.h>

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <stdexcept>
#include <vector>

#include "log.h"
#include "util.h"

#define SSH_LOG_AND_THROW_ERROR(SRC)              \
  do {                                            \
    LOG(error) << ssh_get_error(SRC);             \
    throw std::runtime_error(ssh_get_error(SRC)); \
  } while (false)

static void channel_write_(ssh_channel channel, std::size_t nbytes) {
  assert(channel);
  static const std::uint8_t BYTE_VAL = 77;

  std::vector<std::uint8_t> buf;
  buf.resize(nbytes);
  std::fill_n(buf.begin(), nbytes, BYTE_VAL);

  std::size_t nbytes_written_cumul{0}, nbytes_left{nbytes};

  while (nbytes_left > 0) {
    auto write_nbytes = ssh_channel_write(
        channel, buf.data() + nbytes_written_cumul, nbytes_left);

    LOG(trace) << "ssh_channel_write() returned " << write_nbytes;

    if (write_nbytes > 0) {
      nbytes_written_cumul += write_nbytes;
      nbytes_left -= write_nbytes;
    } else {
      LOG_AND_THROW_ERROR("ssh_channel_write()", std::runtime_error);
    }
  }
}

static void channel_read_(ssh_channel channel, std::size_t nbytes) {
  assert(channel);
  static const std::uint8_t BYTE_VAL = 77;

  std::vector<std::uint8_t> buf;
  buf.resize(nbytes);

  int nbytes_read_cumul{0};
  std::size_t nbytes_left{nbytes};

  while (nbytes_left > 0) {
    auto nbytes_read = ssh_channel_read(channel, buf.data() + nbytes_read_cumul,
                                        nbytes_left, false);

    LOG(trace) << "ssh_channel_read() returned " << nbytes_read;

    if (nbytes_read > 0) {
      nbytes_read_cumul += nbytes_read;
      nbytes_left -= nbytes_read;
    } else {  // Error on read
      LOG_AND_THROW_ERROR("ssh_channel_read() failed.", std::runtime_error);
    }
  }

  assert(std::all_of(buf.cbegin(), buf.cend(),
                     [](std::uint8_t x) -> bool { return x == BYTE_VAL; }));
}

/* CLIENT *********************************************************************/

struct SSHClient::SSHClientInternals {
  ssh_session session;
  unsigned int server_port;
  ssh_channel channel;
};

SSHClient::SSHClient(std::shared_ptr<const ConnectionContext> conn_ctx)
    : ClientIFace(conn_ctx),
      internal_{new SSHClientInternals{}},
      state_{State::kUninit} {}

SSHClient::~SSHClient() { cleanup(); }

bool SSHClient::try_init() {
  if (state_ != State::kUninit) {
    LOG(error) << "Tried to initialize client from wrong state.";
    return false;
  }

  internal().session = ssh_new();

  if (!internal().session) {
    LOG(error) << "ssh_new() failed:" << ssh_get_error(internal().session);
    state_ = State::kInvalid;
    return false;
  }

  LOG(trace) << "ssh_new() completed.";

  ssh_options_set(internal().session, SSH_OPTIONS_HOST,
                  conn_ctx_.get()->server_hostname.c_str());

  internal().server_port = conn_ctx_.get()->server_port;

  ssh_options_set(internal().session, SSH_OPTIONS_PORT,
                  &internal().server_port);

  ssh_options_set(internal().session, SSH_OPTIONS_USER, "user");

#ifdef DEBUG
  {
    auto log_level = SSH_LOG_FUNCTIONS;
    ssh_options_set(internal().session, SSH_OPTIONS_LOG_VERBOSITY, &log_level);
  }
#endif  // DEBUG

  LOG(trace) << "ssh_options_set() completed.";

  state_ = State::kInit;
  return true;
}

void SSHClient::cleanup() {
  if (internal().session) {
    if (state_ == State::kConnected) {
      ssh_disconnect(internal().session);
      LOG(trace) << "ssh_disconnect() completed.";
    }

    ssh_free(internal().session);
    LOG(trace) << "ssh_free() completed.";
    internal().session = nullptr;
  }

  state_ = State::kDead;
}

void SSHClient::connect() {
  if (state_ != State::kInit) {
    LOG_AND_THROW_ERROR("connect() called from wrong state.",
                        std::runtime_error);
  }

  assert(internal().session);

  auto connect_rc = ssh_connect(internal().session);

  if (connect_rc != SSH_OK) {
    SSH_LOG_AND_THROW_ERROR(internal().session);
  }

  LOG(trace) << "connect() complete.";

  const auto *password = "hunter2";
  auto userauth_pass_rc =
      ssh_userauth_password(internal().session, nullptr, password);

  if (userauth_pass_rc != SSH_AUTH_SUCCESS) {
    SSH_LOG_AND_THROW_ERROR(internal().session);
  }

  /*
  // Authenticate
  auto userauth_none_rc = ssh_userauth_none(internal().session, nullptr);

  if (userauth_none_rc != SSH_AUTH_SUCCESS) {
    SSH_LOG_AND_THROW_ERROR(internal().session);
  }
  */

  LOG(trace) << "ssh_userauth_none() completed.";

  // auto banner = ssh_get_issue_banner(internal().session);

  internal().channel = ssh_channel_new(internal().session);

  if (!internal().channel) {
    LOG_AND_THROW_ERROR("ssh_channel_new() failed.", std::runtime_error);
  }

  LOG(trace) << "ssh_channel_new() succeeded.";

  auto channel_open_session_rc = ssh_channel_open_session(internal().channel);

  if (channel_open_session_rc != SSH_OK) {
    SSH_LOG_AND_THROW_ERROR(internal().session);
  }

  LOG(trace) << "ssh_channel_open_session() completed.";
}

void SSHClient::recv(std::size_t nbytes) {
  channel_read_(internal().channel, nbytes);
}

void SSHClient::send(std::size_t nbytes) {
  channel_write_(internal().channel, nbytes);
}

/* SERVER *********************************************************************/

struct SSHServer::SSHServerInternals {
  ssh_bind bind;
  ssh_session session;
  ssh_channel channel;
  unsigned int port;  // LibSSH wants a reference
};

// TODO: Check against
// <https://github.com/simonsj/libssh/blob/master/examples/samplesshd-cb.c>

SSHServer::SSHServer(std::shared_ptr<const ConnectionContext> conn_ctx)
    : ServerIFace(conn_ctx),
      internal_{new SSHServerInternals{}},
      state_{State::kUninit} {}

SSHServer::~SSHServer() { cleanup(); }

void SSHServer::cleanup() {
  if (internal().session) {
    ssh_free(internal().session);
    internal().session = nullptr;
  }

  if (internal().bind) {
    ssh_bind_free(internal().bind);
    internal().bind = nullptr;
  }

  state_ = State::kDead;
}

bool SSHServer::try_init() {
  if (state_ != State::kUninit) {
    LOG(error) << "Tried to initialize server from wrong state.";
    return false;
  }

  // Set up the bind.
  internal().bind = ssh_bind_new();

  if (!internal().bind) {
    LOG(error) << "ssh_bind_new() failed:" << ssh_get_error(internal().bind);
    state_ = State::kInvalid;
    return false;
  }

  LOG(trace) << "ssh_bind_new() completed.";

  // Bind config
  ssh_bind_options_set(internal().bind, SSH_BIND_OPTIONS_RSAKEY,
                       conn_ctx_.get()->ssh_server_privkey->c_str());

  ssh_bind_options_set(internal().bind, SSH_BIND_OPTIONS_BINDADDR,
                       conn_ctx_.get()->server_hostname.c_str());

  internal().port = conn_ctx_.get()->server_port;

  ssh_bind_options_set(internal().bind, SSH_BIND_OPTIONS_BINDPORT,
                       &internal().port);

  LOG(trace) << "ssh_bind_options() set.";

  internal().session = ssh_new();

  if (!internal().session) {
    LOG(error) << "ssh_new() failed:" << ssh_get_error(internal().session);
    state_ = State::kInvalid;
    return false;
  }

  LOG(trace) << "ssh_new() completed.";

  state_ = State::kInit;
  return true;
}

void SSHServer::listen() {
  if (state_ != State::kInit) {
    LOG_AND_THROW_ERROR("listen() called from wrong state.",
                        std::runtime_error);
  }

  assert(internal().bind && internal().session);

  auto listen_rc = ssh_bind_listen(internal().bind);

  if (listen_rc < 0) {
    SSH_LOG_AND_THROW_ERROR(internal().bind);
  }

  LOG(trace) << "ssh_bind_listen() completed.";

  state_ = State::kListening;
}

void SSHServer::accept() {
  if (state_ != State::kListening) {
    LOG_AND_THROW_ERROR("listen() called from wrong state.",
                        std::runtime_error);
  }

  auto bind_accept_rc = ssh_bind_accept(internal().bind, internal().session);

  if (bind_accept_rc != SSH_OK) {
    SSH_LOG_AND_THROW_ERROR(internal().bind);
  }

  LOG(trace) << "ssh_bind_accept() completed.";

  // Set up session parameters
  ssh_set_auth_methods(internal().session, SSH_AUTH_METHOD_PASSWORD);

  LOG(trace) << "ssh_set_auth_methods() completed.";

  auto key_exchange_rc = ssh_handle_key_exchange(internal().session);

  if (key_exchange_rc != SSH_OK) {
    SSH_LOG_AND_THROW_ERROR(internal().session);
  }

  LOG(trace) << "ssh_key_exchange() completed.";

  {  // Handle initial service request
    auto msg = ssh_message_get(internal().session);
    assert(ssh_message_type(msg) == SSH_REQUEST_SERVICE);
    ssh_message_service_reply_success(msg);
    ssh_message_free(msg);
  }

  auto handle_password_msg = [](ssh_message msg) {
    assert(ssh_message_type(msg) == SSH_REQUEST_AUTH);
    assert(ssh_message_subtype(msg) == SSH_AUTH_METHOD_PASSWORD);
    const auto *password = ssh_message_auth_password(msg);
    LOG(trace) << "Password: " << password << ".";
    ssh_message_auth_reply_success(msg, 0);
  };

  {  // Handle authentication
    auto msg = ssh_message_get(internal().session);

    switch (ssh_message_subtype(msg)) {
      case SSH_AUTH_METHOD_NONE:
        ssh_message_auth_set_methods(msg, SSH_AUTH_METHOD_PASSWORD);
        ssh_message_reply_default(msg);
        ssh_message_free(msg);
        msg = ssh_message_get(internal().session);
        break;
      case SSH_AUTH_METHOD_PASSWORD:
        break;
    }

    handle_password_msg(msg);
    ssh_message_free(msg);
  }

  internal().channel = ssh_channel_new(internal().session);

  if (!internal().channel) {
    LOG_AND_THROW_ERROR("ssh_channel_new() failed.", std::runtime_error);
  }

  {  // Handle channel establishment
    auto msg = ssh_message_get(internal().session);
    ssh_message_channel_request_open_reply_accept_channel(msg,
                                                          internal().channel);
    ssh_message_free(msg);
  }
}

void SSHServer::send(std::size_t nbytes) {
  channel_write_(internal().channel, nbytes);
}

void SSHServer::recv(std::size_t nbytes) {
  channel_read_(internal().channel, nbytes);
}
