#include <gtest/gtest.h>

#include <cstdint>
#include <fstream>
#include <string>
#include <thread>
#include <utility>
#include <zmq.hpp>

#include "common.h"
#include "curvezmq.h"
#include "driver_common.h"
#include "driver_requester.h"
#include "driver_responder.h"
#include "ssh.h"
#include "test_common.h"
#include "tls.h"
#include "transcript.pb.h"
#include "util.h"

static const std::uint16_t kClientPort_{7777};
static const std::uint16_t kServerPort_{7778};

static void run_driver_tls_client_() {
  std::unique_ptr<ClientIFace> client{
      std::make_unique<TLS12Client>(tls_test_context(), TLSVersion::k1_3)};

  DriverResponder driver_client{std::move(client), kClientPort_};

  auto init_succeeded = driver_client.init();
  ASSERT_TRUE(init_succeeded);

  driver_client.work_loop();
}

static void run_driver_tls_server_() {
  std::unique_ptr<ServerIFace> server{
      std::make_unique<TLS12Server>(tls_test_context(), TLSVersion::k1_3)};

  DriverResponder driver_server{std::move(server), kServerPort_};

  auto init_succeeded = driver_server.init();
  ASSERT_TRUE(init_succeeded);

  driver_server.work_loop();
}

static void run_driver_ssh_client_() {
  std::unique_ptr<ClientIFace> client{
      std::make_unique<SSHClient>(ssh_test_context())};

  DriverResponder driver_client{std::move(client), kClientPort_};

  auto init_succeeded = driver_client.init();
  ASSERT_TRUE(init_succeeded);

  driver_client.work_loop();
}

static void run_driver_ssh_server_() {
  std::unique_ptr<ServerIFace> server{
      std::make_unique<SSHServer>(ssh_test_context())};

  DriverResponder driver_server{std::move(server), kServerPort_};

  auto init_succeeded = driver_server.init();
  ASSERT_TRUE(init_succeeded);

  driver_server.work_loop();
}

static void run_driver_curvezmq_client_() {
  std::unique_ptr<ClientIFace> client{
      std::make_unique<CurveZMQClient>(curvezmq_test_context())};

  DriverResponder driver_client{std::move(client), kClientPort_};

  auto init_succeeded = driver_client.init();
  ASSERT_TRUE(init_succeeded);

  driver_client.work_loop();
}

static void run_driver_curvezmq_server_() {
  std::unique_ptr<ServerIFace> server{
      std::make_unique<CurveZMQServer>(curvezmq_test_context())};

  DriverResponder driver_server{std::move(server), kServerPort_};

  auto init_succeeded = driver_server.init();
  ASSERT_TRUE(init_succeeded);

  driver_server.work_loop();
}

/*
TEST(Driver, stub) {
  DriverRequester controller{"127.0.0.1", kClientPort_, "127.0.0.1",
                             5000};

  controller.init();

  controller.send_message(make_no_data_command(Command::SERVER_LISTEN),
                          Endpoint::SERVER);
}
*/

TEST(Driver, basic) {
  DriverRequester controller{"127.0.0.1", kClientPort_, "127.0.0.1",
                             kServerPort_};

  std::thread client_thread{&run_driver_curvezmq_client_};
  std::thread server_thread{&run_driver_curvezmq_server_};

  controller.init();

  controller.send_message(make_no_data_command(Command::SERVER_LISTEN),
                          Endpoint::SERVER);

  controller.send_message(make_no_data_command(Command::CLIENT_CONNECT),
                          Endpoint::CLIENT);

  controller.send_message(make_data_command(Command::RECV, 10000),
                          Endpoint::SERVER);

  controller.send_message(make_data_command(Command::SEND, 10000),
                          Endpoint::CLIENT);

  controller.send_message(make_no_data_command(Command::STOP),
                          Endpoint::SERVER);

  controller.send_message(make_no_data_command(Command::STOP),
                          Endpoint::CLIENT);

  client_thread.join();
  server_thread.join();
}

TEST(Driver, transcript_tls) {
  DriverRequester controller{"127.0.0.1", kClientPort_, "127.0.0.1",
                             kServerPort_};

  std::thread client_thread{&run_driver_tls_client_};
  std::thread server_thread{&run_driver_tls_server_};

  controller.init();

  controller.send_message(make_no_data_command(Command::SERVER_LISTEN),
                          Endpoint::SERVER);

  controller.send_message(make_no_data_command(Command::CLIENT_CONNECT),
                          Endpoint::CLIENT);

  const auto transcript_text = read_file("./interactive_transcript.json");
  const auto transcript = *transcript_of_json(transcript_text->c_str());
  const auto transcript_cmd = make_transcript_command(*transcript);

  controller.send_message(transcript_cmd, Endpoint::CLIENT);
  controller.send_message(transcript_cmd, Endpoint::SERVER);

  controller.send_message(make_no_data_command(Command::STOP),
                          Endpoint::CLIENT);

  controller.send_message(make_no_data_command(Command::STOP),
                          Endpoint::SERVER);
  client_thread.join();
  server_thread.join();
}

TEST(Driver, command_serialization) {
  auto stop_cmd = make_no_data_command(Command::Kind::Command_Kind_STOP);
  zmq::message_t zmsg;

  zmq_message_of_command(stop_cmd, &zmsg);
  const auto cmd2 = command_of_zmq_message(zmsg);

  ASSERT_EQ(stop_cmd.kind(), cmd2->kind());
}

TEST(Driver, transcript_deserialization) {
  std::string json{
      "{\"rounds\":[{\"clientSendNbytes\":\"100\",\"serverSendNbytes\":\"100\","
      "\"sleepUsec\":\"500\"},{\"clientSendNbytes\":\"100\","
      "\"serverSendNbytes\":\"100\",\"sleepUsec\":\"500\"}]}"};

  auto transcript = transcript_of_json(json.c_str());

  ASSERT_TRUE(transcript.has_value());
}

TEST(Driver, transcript_deserialization_and_serialization) {
  std::string json{
      "{\"rounds\":[{\"clientSendNbytes\":\"100\",\"serverSendNbytes\":\"100\","
      "\"sleepUsec\":\"500\"},{\"clientSendNbytes\":\"100\","
      "\"serverSendNbytes\":\"100\",\"sleepUsec\":\"500\"}]}"};

  auto transcript = transcript_of_json(json.c_str());

  ASSERT_TRUE(transcript.has_value());

  auto str = json_of_transcript(*transcript.value().get());

  ASSERT_TRUE(str.has_value());
}
