#include <boost/program_options.hpp>
#include <boost/property_tree/ini_parser.hpp>
#include <boost/property_tree/ptree.hpp>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <exception>
#include <iostream>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
#include <utility>

#include "curvezmq.h"
#include "driver_responder.h"
#include "etgen_rs.h"
#include "log.h"
#include "rs.h"
#include "ssh.h"
#include "tls.h"
#include "transport_iface.h"
#include "util.h"

enum class Mode { kClient, kServer };

static std::optional<Mode> mode_of_string_(const std::string &s) {
  if (s.compare("client") == 0) {
    return Mode::kClient;
  } else if (s.compare("server") == 0) {
    return Mode::kServer;
  } else {
    return {};
  }
}

static std::string string_of_mode_(const Mode &m) {
  switch (m) {
    case Mode::kClient:
      return "Client";
      break;
    case Mode::kServer:
      return "Server";
      break;
    default:
      UNREACHABLE;
      break;
  }

  UNREACHABLE;
}

enum class Protocol { kTLS, kSSH, kCurveZMQ, kSnow, kSecio, kTLS1_3 };

static std::optional<Protocol> protocol_of_string_(const std::string &s) {
  if (s.compare("tls") == 0) {
    return Protocol::kTLS;
  } else if (s.compare("ssh") == 0) {
    return Protocol::kSSH;
  } else if (s.compare("curvezmq") == 0) {
    return Protocol::kCurveZMQ;
  } else if (s.compare("snow") == 0) {
    return Protocol::kSnow;
  } else if (s.compare("secio") == 0) {
    return Protocol::kSecio;
  } else if (s.compare("tls13") == 0) {
    return Protocol::kTLS1_3;
  } else {
    return {};
  }
}

static std::string string_of_protocol_(const Protocol &p) {
  switch (p) {
    case Protocol::kTLS:
      return "TLS1.2";
      break;
    case Protocol::kSSH:
      return "SSH";
      break;
    case Protocol::kCurveZMQ:
      return "CurveZMQ";
      break;
    case Protocol::kSnow:
      return "Snow";
      break;
    case Protocol::kSecio:
      return "Secio";
      break;
    case Protocol::kTLS1_3:
      return "TLS1.3";
      break;
    default:
      UNREACHABLE;
      break;
  }

  UNREACHABLE;
}

class ProgramOptions {
 public:
  ProgramOptions() {}

  static ProgramOptions make_or_exit(int argc, char **argv) {
    namespace po = boost::program_options;

    ProgramOptions retval{};

    std::string mode_string{}, protocol_string{};

    // clang-format off
    po::options_description desc{"Program options"};
    desc.add_options()
      ("help,h", "produce help message")
      ("mode,m", po::value<std::string>(&mode_string)->required(), "client|server")
      ("control-port,c", po::value<std::uint16_t>(&retval.control_port_)->required(), "listen for control commands on this port")
      ("protocol,p", po::value<std::string>(&protocol_string)->required(), "tls|ssh|curvezmq|snow|secio|tls13")
      ("hostname,H", po::value<std::string>(&retval.hostname_), "required client option, connect to given target host")
      ("target-port,t", po::value<std::uint16_t>(&retval.target_port_), "required client option, connect to target port")
      ("listen-port,L", po::value<std::uint16_t>(&retval.listen_port_), "required server option, listen for client connections on this port")
      ("config-filepath,C", po::value<std::string>(&retval.config_filepath_)->required(), "path to config file")
      ("log-level,l", po::value<std::string>(&retval.log_level_)->default_value("info"), "trace|debug|info|warning|error|fatal");
    // clang-format on

    po::positional_options_description p;
    p.add("mode", 1);
    p.add("control-port", 1);
    p.add("config-filepath", 1);
    p.add("protocol", 1);

    po::variables_map vm{};

    auto exit = [&](int exit_code) {
      std::ostream *os{nullptr};

      if (exit_code == EXIT_SUCCESS) {
        os = &std::cout;
      } else {
        os = &std::cerr;
      }

      *os << "Usage:\t" << argv[0]
          << " [OPTIONS] MODE CONTROL_PORT CONFIG_FILEPATH PROTOCOL"
          << std::endl;
      *os << desc << std::endl;
      std::exit(exit_code);
    };

    try {
      po::store(
          po::command_line_parser(argc, argv).options(desc).positional(p).run(),
          vm);

      po::notify(vm);

      if (vm.count("help")) {
        exit(EXIT_SUCCESS);
      }

      auto protocol = protocol_of_string_(protocol_string);

      if (!protocol) {
        throw po::invalid_option_value(protocol_string);
      } else {
        retval.protocol_ = *protocol;
      }

      auto mode = mode_of_string_(mode_string);
      if (!mode) {
        throw po::invalid_option_value(mode_string);
      } else {
        retval.mode_ = *mode;
      }

      switch (retval.mode_) {
        case Mode::kClient:

          if (vm.count("hostname") != 1) {
            throw po::required_option("hostname");
          }

          if (vm.count("target-port") != 1) {
            throw po::required_option("target-port");
          }

          break;
        case Mode::kServer:
          if (vm.count("listen-port") != 1) {
            throw po::required_option("listen-port");
          }

          break;
        default:
          UNREACHABLE;
      }

      const auto &ll = retval.log_level_;
      if (ll.compare("trace") != 0 && ll.compare("debug") != 0 &&
          ll.compare("info") != 0 && ll.compare("warning") != 0 &&
          ll.compare("error") != 0 && ll.compare("fatal") != 0) {
        throw po::invalid_option_value(ll);
      }

    } catch (const std::exception &ex) {
      std::cerr << "ERROR:\t" << ex.what() << "\n" << std::endl;
      exit(EXIT_FAILURE);
    }

    return retval;
  }

  friend std::ostream &operator<<(std::ostream &stream,
                                  const ProgramOptions &po) {
    stream << "Mode:\t" << string_of_mode_(po.mode_) << "\n";
    stream << "Hostname:\t" << po.hostname() << "\n";
    stream << "Listen port:\t" << po.listen_port() << "\n";
    stream << "Target port:\t" << po.target_port() << "\n";
    stream << "Control port:\t" << po.control_port() << "\n";
    stream << "Config filepath:\t" << po.config_filepath() << "\n";
    stream << "Log level:\t" << po.log_level() << "\n";
    stream << "Protocol:\t" << string_of_protocol_(po.protocol());
    return stream;
  }

  const std::string &hostname() const { return hostname_; }
  const std::uint16_t &listen_port() const { return listen_port_; }
  const std::uint16_t &target_port() const { return target_port_; }
  const std::uint16_t &control_port() const { return control_port_; }
  const Mode &mode() const { return mode_; }
  const std::string &log_level() const { return log_level_; }
  const std::string &config_filepath() const { return config_filepath_; }
  const Protocol &protocol() const { return protocol_; }

 private:
  std::string hostname_{"0.0.0.0"};
  std::uint16_t listen_port_{0}, target_port_{0}, control_port_{0};
  Mode mode_;
  std::string log_level_{"null"};
  std::string config_filepath_{"/dev/null"};
  Protocol protocol_;
};

static std::optional<ConnectionContext> connection_context_of_ini_file_(
    const std::string &ini_filepath, Protocol protocol) {
  ConnectionContext retval{};

  try {
    boost::property_tree::ptree pt;
    boost::property_tree::ini_parser::read_ini(ini_filepath, pt);

    switch (protocol) {
      case Protocol::kTLS:
      case Protocol::kTLS1_3:
        retval.tls_server_key = pt.get<std::string>("tls.server_key");
        retval.tls_server_cert = pt.get<std::string>("tls.server_cert");
        break;
      case Protocol::kSSH:
        retval.ssh_server_pubkey = pt.get<std::string>("ssh.server_pubkey");
        retval.ssh_server_privkey = pt.get<std::string>("ssh.server_privkey");
        break;
      case Protocol::kCurveZMQ:
        retval.curvezmq_client_cert =
            pt.get<std::string>("curvezmq.client_cert");
        retval.curvezmq_server_cert =
            pt.get<std::string>("curvezmq.server_cert");
        break;
      case Protocol::kSnow:
        retval.snow_noise_params = pt.get<std::string>("snow.noise_params");
        break;
      case Protocol::kSecio:
        break;
      default:
        UNREACHABLE;
        break;
    }
  } catch (const std::exception &ex) {
    LOG(error) << ex.what();
    return {};
  }

  return retval;
}

static std::unique_ptr<ClientIFace> make_client_(
    Protocol protocol, std::shared_ptr<ConnectionContext> conn_ctx) {
  switch (protocol) {
    case Protocol::kSSH:
      return std::make_unique<SSHClient>(conn_ctx);
      break;
    case Protocol::kTLS:
      return std::make_unique<TLS12Client>(conn_ctx, TLSVersion::k1_2);
      break;
    case Protocol::kSnow:
      return std::make_unique<RSClient>(conn_ctx, RSKind::kSnow);
      break;
    case Protocol::kCurveZMQ:
      return std::make_unique<CurveZMQClient>(conn_ctx);
      break;
    case Protocol::kSecio:
      return std::make_unique<RSClient>(conn_ctx, RSKind::kSecio);
      break;
    case Protocol::kTLS1_3:
      return std::make_unique<TLS12Client>(conn_ctx, TLSVersion::k1_3);
      break;
    default:
      UNREACHABLE;
      break;
  }

  UNREACHABLE;
}

static std::unique_ptr<ServerIFace> make_server_(
    Protocol protocol, std::shared_ptr<ConnectionContext> conn_ctx) {
  switch (protocol) {
    case Protocol::kSSH:
      return std::make_unique<SSHServer>(conn_ctx);
      break;
    case Protocol::kTLS:
      return std::make_unique<TLS12Server>(conn_ctx, TLSVersion::k1_2);
      break;
    case Protocol::kSnow:
      return std::make_unique<RSServer>(conn_ctx, RSKind::kSecio);
      break;
    case Protocol::kCurveZMQ:
      return std::make_unique<CurveZMQServer>(conn_ctx);
      break;
    case Protocol::kSecio:
      return std::make_unique<RSServer>(conn_ctx, RSKind::kSecio);
      break;
    case Protocol::kTLS1_3:
      return std::make_unique<TLS12Server>(conn_ctx, TLSVersion::k1_3);
      break;
    default:
      UNREACHABLE;
      break;
  }

  UNREACHABLE;
}

using Endpoint_t =
    std::variant<std::unique_ptr<ClientIFace>, std::unique_ptr<ServerIFace>>;

static Endpoint_t make_endpoint_(Mode mode, Protocol protocol,
                                 std::shared_ptr<ConnectionContext> conn_ctx) {
  switch (mode) {
    case Mode::kClient:
      return make_client_(protocol, conn_ctx);
      break;
    case Mode::kServer:
      return make_server_(protocol, conn_ctx);
      break;
    default:
      UNREACHABLE;
      break;
  }

  UNREACHABLE;
}

int main(int argc, char **argv) {
  auto program_options = ProgramOptions::make_or_exit(argc, argv);

  auto int_of_log_level = [](const std::string &s) {
    if (s.compare("fatal") == 0) {
      return 0;
    }
    if (s.compare("error") == 0) {
      return 0;
    }
    if (s.compare("warning") == 0) {
      return 1;
    }
    if (s.compare("info") == 0) {
      return 2;
    }
    if (s.compare("debug") == 0) {
      return 3;
    }
    if (s.compare("trace") == 0) {
      return 4;
    }

    UNREACHABLE;
  };

  log_init(program_options.log_level());
  ffi_log_init(int_of_log_level(program_options.log_level()));

  LOG(info) << "Program options:\n" << program_options;

  auto parsed_conn_ctx = connection_context_of_ini_file_(
      program_options.config_filepath(), program_options.protocol());

  if (!parsed_conn_ctx) {
    LOG(fatal) << "Error parsing config file.";
    std::exit(EXIT_FAILURE);
  }

  if (program_options.mode() == Mode::kClient) {
    parsed_conn_ctx->server_hostname = program_options.hostname();
    parsed_conn_ctx->server_port = program_options.target_port();
  } else {
    //parsed_conn_ctx->server_hostname = "10.0.3.24";
    parsed_conn_ctx->server_hostname = "localhost";
    parsed_conn_ctx->server_port = program_options.listen_port();
  }

  LOG(info) << "Configuration:\n" << *parsed_conn_ctx;

  auto conn_ctx = std::make_shared<ConnectionContext>(*parsed_conn_ctx);

  auto endpoint = make_endpoint_(program_options.mode(),
                                 program_options.protocol(), conn_ctx);

  auto responder =
      DriverResponder{std::move(endpoint), program_options.control_port()};

  responder.init();

  responder.work_loop();

  return 0;
}
