use snow::{params::NoiseParams, Builder};

use std::{
    io::{self, Read, Write},
    net::{TcpListener, TcpStream},
};

static SHARED_SECRET: &[u8; 32] = b"i don't care for fidget spinners";

static CLIENT_PK: [u8; 32] = [
    100, 102, 7, 78, 65, 124, 149, 255, 162, 243, 107, 2, 231, 246, 97, 249, 69, 21, 87, 80, 159,
    63, 0, 180, 11, 222, 164, 96, 238, 122, 142, 55,
];

static CLIENT_SK: [u8; 32] = [
    246, 76, 89, 43, 73, 7, 205, 24, 91, 102, 84, 39, 109, 194, 35, 83, 239, 83, 62, 53, 3, 21, 95,
    45, 110, 68, 94, 250, 139, 96, 203, 154,
];

static SERVER_PK: [u8; 32] = [
    80, 205, 126, 220, 7, 147, 219, 41, 223, 247, 163, 69, 4, 168, 65, 243, 133, 204, 217, 53, 102,
    101, 120, 40, 199, 201, 104, 139, 142, 58, 134, 49,
];

static SERVER_SK: [u8; 32] = [
    90, 192, 96, 147, 194, 124, 174, 168, 59, 232, 100, 254, 189, 21, 243, 65, 69, 224, 11, 49, 50,
    147, 168, 44, 73, 70, 47, 134, 83, 7, 6, 153,
];

macro_rules! make_handshake_funcs {
    () => {
        fn send_handshake_msg(&mut self) {
            let mut buf = vec![0u8; 65535];

            let len = self
                .noise
                .as_mut()
                .unwrap()
                .write_message(&[], &mut buf)
                .unwrap();
            send(&mut self.stream.as_mut().unwrap(), &buf[..len]);
        }

        fn recv_handshake_msg(&mut self) {
            let mut buf = vec![0u8; 65535];

            self.noise
                .as_mut()
                .unwrap()
                .read_message(&recv(&mut self.stream.as_mut().unwrap()).unwrap(), &mut buf)
                .unwrap();
        }
    };
}

/// Hyper-basic stream transport receiver. 16-bit BE size followed by payload.
fn recv(stream: &mut TcpStream) -> io::Result<Vec<u8>> {
    let mut msg_len_buf = [0u8; 4];
    stream.read_exact(&mut msg_len_buf)?;
    let msg_len = u32::from_be_bytes(msg_len_buf) as usize;
    let mut msg = vec![0u8; msg_len];
    stream.read_exact(&mut msg[..])?;
    Ok(msg)
}

/// Hyper-basic stream transport sender. 16-bit BE size followed by payload.
fn send(stream: &mut TcpStream, buf: &[u8]) {
    let len = u32::try_from(buf.len()).expect("message too large");

    let mut write_buf: Vec<u8> = len.to_be_bytes().iter().cloned().collect();
    write_buf.extend_from_slice(buf);

    stream.write_all(&write_buf).unwrap();
}

#[derive(Debug)]
enum HandshakePattern {
    Rtt0_5,
    Rtt1,
    Rtt1_5,
}

impl HandshakePattern {
    fn of_str(value: &str) -> Option<HandshakePattern> {
        let mut splits = value.split("_");

        if let Some(_) = splits.next() {
            if let Some(s) = splits.next() {
                match s {
                    "Npsk0" => Some(Self::Rtt0_5),
                    "Xpsk1" => Some(Self::Rtt0_5),
                    "Kpsk0" => Some(Self::Rtt0_5),
                    "XX" => Some(Self::Rtt1_5),
                    "NX" => Some(Self::Rtt1),
                    "K" => Some(Self::Rtt0_5),
                    _ => None,
                }
            } else {
                None
            }
        } else {
            None
        }
    }
}

// Client //////////////////////////////////////////////////////////////////////

#[derive(Debug)]
pub struct SnowClient {
    server_hostname: String,
    server_port: u16,
    params: NoiseParams,
    noise: Option<snow::HandshakeState>,
    stream: Option<TcpStream>,
    transport: Option<snow::TransportState>,
    handshake_pattern: HandshakePattern,
}

impl SnowClient {
    pub fn new(config: crate::interface::Config) -> SnowClient {
        SnowClient {
            server_hostname: config.server_hostname,
            server_port: config.server_port,
            params: config.snow_noise_params.clone().unwrap().parse().unwrap(),
            noise: None,
            stream: None,
            transport: None,
            handshake_pattern: HandshakePattern::of_str(&config.snow_noise_params.unwrap())
                .unwrap(),
        }
    }

    make_handshake_funcs!();

    fn perform_handshake(&mut self) {
        match self.handshake_pattern {
            HandshakePattern::Rtt0_5 => {
                self.send_handshake_msg();
            }
            HandshakePattern::Rtt1 => {
                self.send_handshake_msg();
                self.recv_handshake_msg();
            }
            HandshakePattern::Rtt1_5 => {
                self.send_handshake_msg();
                self.recv_handshake_msg();
                self.send_handshake_msg();
            }
        }
    }
}

impl crate::interface::ClientIFace for SnowClient {
    fn try_init(&mut self) -> bool {
        log::info!("Client init...");

        let builder = Builder::new(self.params.clone());

        self.noise = Some(
            builder
                .local_private_key(&CLIENT_SK)
                .remote_public_key(&SERVER_PK)
                .psk(0, SHARED_SECRET)
                .build_initiator()
                .unwrap(),
        );

        log::info!("Client: {:?}", self);

        true
    }

    fn connect(&mut self) {
        //let mut buf = vec![0u8; 65535];

        log::info!("Connecting...");

        let endpoint = format!("{}:{}", self.server_hostname, self.server_port);
        self.stream = Some(TcpStream::connect(endpoint).unwrap());

        self.perform_handshake();

        self.transport = Some(self.noise.take().unwrap().into_transport_mode().unwrap());

        log::info!("Handshake complete! Moving into transport mode.");
    }

    fn recv(&mut self, nbytes: usize) {
        let mut buf = vec![0u8; 65535];

        if let Ok(msg) = recv(&mut self.stream.as_mut().unwrap()) {
            let len = self
                .transport
                .as_mut()
                .unwrap()
                .read_message(&msg, &mut buf)
                .unwrap();
            log::info!("Client received {0} bytes.", len);
            assert!(len == nbytes);
        } else {
            assert!(false);
        }
    }

    fn send(&mut self, nbytes: usize) {
        let mut buf = vec![0u8; 65535];
        let data = vec![0; nbytes];
        let len = self
            .transport
            .as_mut()
            .unwrap()
            .write_message(data.as_slice(), &mut buf)
            .unwrap();
        send(&mut self.stream.as_mut().unwrap(), &buf[..len]);
        log::info!("Client sent {0} bytes of goodput.", nbytes);
    }

    fn cleanup(&mut self) {
        log::info!("Cleaning up...");
    }
}

use crate::interface::ClientIFace;

impl SnowClient {
    pub fn into_interface_ffi(self) -> crate::ffi::ClientInterfaceFFI {
        fn try_init(p_this: *mut core::ffi::c_void) -> bool {
            let ref_this = unsafe { &mut *(p_this as *mut SnowClient) };
            ref_this.try_init()
        }

        fn connect(p_this: *mut core::ffi::c_void) {
            let ref_this = unsafe { &mut *(p_this as *mut SnowClient) };
            ref_this.connect()
        }

        fn send(p_this: *mut core::ffi::c_void, nbytes: usize) {
            let ref_this = unsafe { &mut *(p_this as *mut SnowClient) };
            ref_this.send(nbytes)
        }

        fn recv(p_this: *mut core::ffi::c_void, nbytes: usize) {
            let ref_this = unsafe { &mut *(p_this as *mut SnowClient) };
            ref_this.recv(nbytes)
        }

        fn cleanup(p_this: *mut core::ffi::c_void) {
            let ref_this = unsafe { &mut *(p_this as *mut SnowClient) };
            ref_this.cleanup()
        }

        fn free(p_this: *mut core::ffi::c_void) {
            let p_this = p_this as *mut SnowClient;
            let boxed = unsafe { Box::from_raw(p_this) };
            drop(boxed);
        }

        let p_this = Box::into_raw(Box::new(self)) as *mut core::ffi::c_void;

        crate::ffi::ClientInterfaceFFI {
            p_this,
            try_init,
            free,
            connect,
            send,
            recv,
            cleanup,
        }
    }
}

// Server //////////////////////////////////////////////////////////////////////

#[derive(Debug)]
pub struct SnowServer {
    server_port: u16,
    params: NoiseParams,
    noise: Option<snow::HandshakeState>,
    listener: Option<TcpListener>,
    stream: Option<TcpStream>,
    transport: Option<snow::TransportState>,
    handshake_pattern: HandshakePattern,
}

impl SnowServer {
    pub fn new(config: crate::interface::Config) -> SnowServer {
        SnowServer {
            server_port: config.server_port,
            params: config.snow_noise_params.clone().unwrap().parse().unwrap(),
            noise: None,
            listener: None,
            stream: None,
            transport: None,
            handshake_pattern: HandshakePattern::of_str(&config.snow_noise_params.unwrap())
                .unwrap(),
        }
    }

    make_handshake_funcs!();

    fn perform_handshake(&mut self) {
        match self.handshake_pattern {
            HandshakePattern::Rtt0_5 => {
                self.recv_handshake_msg();
            }
            HandshakePattern::Rtt1 => {
                self.recv_handshake_msg();
                self.send_handshake_msg();
            }
            HandshakePattern::Rtt1_5 => {
                self.recv_handshake_msg();
                self.send_handshake_msg();
                self.recv_handshake_msg();
            }
        }
    }
}

impl crate::interface::ServerIFace for SnowServer {
    fn try_init(&mut self) -> bool {
        log::info!("Server init...");
        let builder = Builder::new(self.params.clone());

        self.noise = Some(
            builder
                .local_private_key(&SERVER_SK)
                .remote_public_key(&CLIENT_PK)
                .psk(0, SHARED_SECRET)
                .build_responder()
                .unwrap(),
        );

        log::info!("Server: {:?}", self);

        true
    }

    fn listen(&mut self) {
        use std::net::SocketAddr;
        log::info!("Server listening...");
        let addr = SocketAddr::from(([127, 0, 0, 1], self.server_port));
        self.listener = Some(TcpListener::bind(addr).unwrap());
    }

    fn accept(&mut self) {
        log::info!("Server accepting...");
        let (s, _) = self.listener.as_ref().unwrap().accept().unwrap();
        self.stream = Some(s);
        log::info!("Accepted connection!");

        self.perform_handshake();

        // Transition the state machine into transport mode now that the handshake is complete.
        self.transport = Some(self.noise.take().unwrap().into_transport_mode().unwrap());

        log::info!("Handshake complete! Moving into transport mode.");
    }

    fn recv(&mut self, nbytes: usize) {
        let mut buf = vec![0u8; 65535];

        if let Ok(msg) = recv(&mut self.stream.as_mut().unwrap()) {
            let len = self
                .transport
                .as_mut()
                .unwrap()
                .read_message(&msg, &mut buf)
                .unwrap();
            log::info!("Server received {0} bytes.", len);
            assert!(len == nbytes);
        } else {
            assert!(false);
        }
    }

    fn send(&mut self, nbytes: usize) {
        let mut buf = vec![0u8; 65535];
        let data = vec![0; nbytes];
        let len = self
            .transport
            .as_mut()
            .unwrap()
            .write_message(data.as_slice(), &mut buf)
            .unwrap();
        send(&mut self.stream.as_mut().unwrap(), &buf[..len]);
        log::info!("Server sent {0} bytes of goodput.", nbytes);
    }

    fn cleanup(&mut self) {
        log::info!("Server cleaning up...");
    }
}

use crate::interface::ServerIFace;

impl SnowServer {
    pub fn into_interface_ffi(self) -> crate::ffi::ServerInterfaceFFI {
        fn try_init(p_this: *mut core::ffi::c_void) -> bool {
            let ref_this = unsafe { &mut *(p_this as *mut SnowServer) };
            ref_this.try_init()
        }

        fn listen(p_this: *mut core::ffi::c_void) {
            let ref_this = unsafe { &mut *(p_this as *mut SnowServer) };
            ref_this.listen()
        }

        fn accept(p_this: *mut core::ffi::c_void) {
            let ref_this = unsafe { &mut *(p_this as *mut SnowServer) };
            ref_this.accept()
        }

        fn send(p_this: *mut core::ffi::c_void, nbytes: usize) {
            let ref_this = unsafe { &mut *(p_this as *mut SnowServer) };
            ref_this.send(nbytes)
        }

        fn recv(p_this: *mut core::ffi::c_void, nbytes: usize) {
            let ref_this = unsafe { &mut *(p_this as *mut SnowServer) };
            ref_this.recv(nbytes)
        }

        fn cleanup(p_this: *mut core::ffi::c_void) {
            let ref_this = unsafe { &mut *(p_this as *mut SnowServer) };
            ref_this.cleanup()
        }

        fn free(p_this: *mut core::ffi::c_void) {
            let p_this = p_this as *mut SnowServer;
            let boxed = unsafe { Box::from_raw(p_this) };
            drop(boxed);
        }

        let p_this = Box::into_raw(Box::new(self)) as *mut core::ffi::c_void;

        crate::ffi::ServerInterfaceFFI {
            p_this,
            try_init,
            free,
            listen,
            accept,
            send,
            recv,
            cleanup,
        }
    }
}

#[cfg(test)]
#[ctor::ctor]
fn init() {
    stderrlog::new()
        .verbosity(5)
        .timestamp(stderrlog::Timestamp::Second)
        .init()
        .unwrap();
}

#[cfg(test)]
mod tests {
    use super::*;
    use itertools::Itertools;

    #[test]
    fn test_snow() {
        let patterns = ["Kpsk0"]; //, "NX"];
                                  //let patterns = ["Kpsk0", "NX", "XX"]; //, "NX"];
                                  //let patterns = ["NX"];
        let patterns = ["XX", "NX"];
        let handshakes = ["25519"];
        let ciphers = ["ChaChaPoly", "AESGCM"];
        let hash_functions = ["SHA256", "SHA512", "BLAKE2s"];

        for (pattern, handshake, cipher, hash_function) in patterns
            .iter()
            .cartesian_product(handshakes.iter())
            .cartesian_product(ciphers.iter())
            .map(|((a, b), c)| (*a, *b, *c))
            .cartesian_product(hash_functions.iter())
            .map(|((a, b, c), d)| (a, b, c, *d))
        {
            let snow_noise_params = format!(
                "Noise_{}_{}_{}_{}",
                pattern, handshake, cipher, hash_function
            );
            test_snow_impl(&snow_noise_params);
        }
    }

    fn test_snow_impl(snow_noise_params: &String) {
        let cfg = crate::interface::Config {
            client_hostname: "localhost".to_string(),
            server_hostname: "localhost".to_string(),
            client_port: 50001,
            server_port: 8081,
            snow_noise_params: Some(snow_noise_params.to_string()),
        };

        let mut client = SnowClient::new(cfg.clone());
        let mut server = SnowServer::new(cfg);

        client.try_init();
        server.try_init();

        let server_thread = std::thread::spawn(move || {
            server.listen();
            server.accept();
            server.recv(100);
            server.send(100);
            server.cleanup();
        });

        std::thread::sleep(std::time::Duration::from_millis(100));

        let client_thread = std::thread::spawn(move || {
            client.connect();
            std::thread::sleep(std::time::Duration::from_millis(100));
            client.send(100);
            client.recv(100);
            client.cleanup();
        });

        let _ = client_thread.join();
        let _ = server_thread.join();
    }
}
