#include <stdlib.h>
#include <signal.h>
#include <alarm.h>
#include <time.h>
#include <socket.h>

// uncomment for debug prints
// #define DEBUG

// platform dependent structs, etc.
#include "arch_runner.h"

#ifndef PAGE_SIZE
    #define PAGE_SIZE 4096
#endif

/* 
 * The loaders's primary job is to setup program state (memory mappings + register content) and exception handling before code execution.
 * Then, it should load the provided register state with as little "pollition of the execution environment" as possible, whilst still being able to record the state on program "crash".
 * Once the program "crashes" (all tests should be designed in a way that they eventually crash or loop forever leading to a SIGALARM), the full state is reported and output to a file.
 * ("crash" meaning any signal like SIGILL, SIGSEGV, etc., being sent by the kernel.)
 * ("pollution" includes memory mappings and registers with content not matching the provided program state. We should be able to only have a single rx memory mapping (with a small rwx section for signal handler stack) and no wrong registers.)
 */


/*
 * The state file dictates all register contents as well as all memory mappings (including address, content, and size) to set before program execution.
 * It starts with the register content followed by a number of memory mappings (max supported by default 64).
 * Each memory mapping does not include the data itself, but an offset into the file where the actual data is located.
 */

#ifndef MAX_MEMORY_MAPPINGS
    #define MAX_MEMORY_MAPPINGS 64
#endif /* MAX_MEMORY_MAPPINGS */


#define MEMORY_MAPPING_READ 1
#define MEMORY_MAPPING_WRITE 2
#define MEMORY_MAPPING_EXEC 4

struct __attribute__((packed)) memory_mapping {
    // virtual address to map to (need not be page aligned. If not page aligned, the whole page will be mapped, but data will be written starting at address)
    uint64_t address;
    // size of the data to map (must not be a multiple of page size. A bigger page may be mapped, but only [address, address + size) will contain defined data
    uint64_t size;
    // mapping protection. 1 = read, 2 = write, 4 = exec (or for multiple permissions). Everything else is reserved for now.
    uint64_t protection;
    // offset to the data that is mapped (relative to start of mapping data)
    // uint64_t offset;
};

struct crash_info {
    // signal number associated with crash.
    // use fixed size integer here even though it may be too large to make parsing easier.
    int64_t signo;
    // address associated with crash (for SIGILL, this should be the same as pc, but for SIGSEGV, this should be the crashing address)
    uint64_t address;
};

// state file can be invisioned the following way:
// struct state_file {
//     struct crash_info crash; // output only, input files do not contain this!
//     struct register_state regs;
//     uint64_t mappings_count;
//     struct memory_mapping mappings[mappings_count];
//     uint8_t mapping_data[max(m.offset + m.size for m in mappings)];
// }

// first unused mapping has a size of 0, address of 0 and offset of 0.
// If all mappings are used, we know the end because of MAX_MEMORY_MAPPINGS.
static struct memory_mapping mappings[MAX_MEMORY_MAPPINGS];

// register state that was requested.
// We need to copy it from the temporary mapping as we want to unmap the temporary mapping before setting the register state.
static struct register_state registers_before;

// socket file descriptor
static int socket_fd;

// register state after execution
struct register_state registers_after;


// register state that was recorded during crash
// static struct register_state registers_after;

// stack used for signal handling.
static __attribute__((aligned(PAGE_SIZE))) uint8_t stack[PAGE_SIZE * 20];

// defined by linker script.
// start of data segment.
extern void _data_start();
// end of data segment.
extern void _data_end();

static void write_state(struct register_state* state, struct crash_info* crash_info) {
    
    /* open output file */
    uint64_t to_write = 0;
    
    /* write register state */
    queue(socket_fd, (uint8_t*) state, sizeof(*state));
    to_write += sizeof(*state);
    
    /* write crash info */
    queue(socket_fd, (uint8_t*) crash_info, sizeof(*crash_info));
    to_write += sizeof(*crash_info);
    
    /* write mappings count */
    uint64_t mappings_count;
    for(mappings_count = 0; mappings_count < MAX_MEMORY_MAPPINGS && (mappings[mappings_count].address || mappings[mappings_count].size) ; mappings_count ++);
    queue(socket_fd, (uint8_t*) &mappings_count, sizeof(mappings_count));
    to_write += sizeof(mappings_count);
    
    /* write mappings */
    for(uint64_t mapping_idx = 0; mapping_idx < mappings_count; mapping_idx ++) {
        // write mapping 
        queue(socket_fd, (uint8_t*) &mappings[mapping_idx], sizeof(mappings[mapping_idx]));
        to_write += sizeof(mappings[mapping_idx]);
    }
    
    /* write mappings data */
    for(uint64_t mapping_idx = 0; mapping_idx < mappings_count; mapping_idx ++) {
        uint64_t written = 0;
        // make mapping readable if necessary
        if(!(mappings[mapping_idx].protection & MEMORY_MAPPING_READ)) {
            uint64_t mapping_base = mappings[mapping_idx].address - (mappings[mapping_idx].address % PAGE_SIZE);
            uint64_t mapping_size = mappings[mapping_idx].size + (mappings[mapping_idx].address - mapping_base);
            if(mapping_size % PAGE_SIZE) {
                mapping_size += PAGE_SIZE - (mapping_size % PAGE_SIZE);
            }
            // TODO: we may want to restore the original protection after?
            // Right now we don't need to, but once we implement some sort of multi-execution mode, we must.
            mprotect(
                (void*)mapping_base,
                mapping_size,
                PROT_READ | PROT_WRITE
            );
        }
        queue(socket_fd, (uint8_t*) mappings[mapping_idx].address, mappings[mapping_idx].size);
        to_write += mappings[mapping_idx].size;
    }
    
    uint64_t written = flush();
    if(written != to_write) {
        writef(stderr, "failed to write output (%u / %u)!\n", written, to_write); // this is a fatal error so not wrapped in debug
        close(socket_fd);
        exit(-1);
    }
    
    /* delete mappings */
    for(uint64_t mapping_idx = 0; mapping_idx < mappings_count; mapping_idx ++) {
        // make mapping readable if necessary
        uint64_t mapping_base = mappings[mapping_idx].address - (mappings[mapping_idx].address % PAGE_SIZE);
        uint64_t mapping_size = mappings[mapping_idx].size + (mappings[mapping_idx].address - mapping_base);
        if(mapping_size % PAGE_SIZE) {
            mapping_size += PAGE_SIZE - (mapping_size % PAGE_SIZE);
        }
        // TODO: we may want to restore the original protection after?
        // Right now we don't need to, but once we implement some sort of multi-execution mode, we must.
        munmap(
            (void*)mapping_base,
            mapping_size
        );
    }
    /* delete other state */
    memset8(mappings, 0, sizeof(mappings[0]) * mappings_count);
}

static void setup_mappings(struct memory_mapping* mappings, uint64_t mappings_count) {
    for(uint64_t mapping_idx = 0; mapping_idx < mappings_count; mapping_idx ++) {
        struct memory_mapping cur_mapping = mappings[mapping_idx];
        // align base address and size to multiples of PAGE_SIZE (make sure whole data fits, so extend mapping as necessary)
        uint64_t mapping_base = cur_mapping.address - (cur_mapping.address % PAGE_SIZE);
        uint64_t mapping_size = cur_mapping.size + (cur_mapping.address - mapping_base);
        if(mapping_size % PAGE_SIZE) {
            mapping_size += PAGE_SIZE - (mapping_size % PAGE_SIZE);
        }
        
        // try to map the mapping
        for(uint64_t start = mapping_base; start < mapping_base + mapping_size; start += PAGE_SIZE) {
            // TODO: MAP_FIXED without NO_REPLACE is scary!
            if((uint64_t)mmap((void*)start, PAGE_SIZE, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED | MAP_POPULATE, -1, 0) % PAGE_SIZE){
                // TODO: add this print back in so we don't miss mapping related errors!
                #ifdef DEBUG
                    writef(stderr, "mmap failed for mapping %u (base=0x%x, size=0x%x)\n", mapping_idx, start, PAGE_SIZE);
                #endif /* DEBUG */
                // exit(-1);
            }
        }
        
    }
    
    for(uint64_t mapping_idx = 0; mapping_idx < mappings_count; mapping_idx ++) {
        struct memory_mapping cur_mapping = mappings[mapping_idx];
        if(nread(socket_fd, (void*)cur_mapping.address, cur_mapping.size) != cur_mapping.size) {
            #ifdef DEBUG
                writef(stderr, "failed to setup mapping data for mapping %u (base=0x%x, size=0x%x)\n", mapping_idx, cur_mapping.address, PAGE_SIZE);
            #endif /* DEBUG */
            exit(-1); // TODO: check that this does not cause issues!
        }
    }
    
    
    // do protection in a seperate loop to cope with overlapping mappings
    for(uint64_t mapping_idx = 0; mapping_idx < mappings_count; mapping_idx ++) {
        struct memory_mapping cur_mapping = mappings[mapping_idx];
        // align base address and size to multiples of PAGE_SIZE (make sure whole data fits, so extend mapping as necessary)
        uint64_t mapping_base = cur_mapping.address - (cur_mapping.address % PAGE_SIZE);
        uint64_t mapping_size = cur_mapping.size + (cur_mapping.address - mapping_base);
        if(mapping_size % PAGE_SIZE) {
            mapping_size += PAGE_SIZE - (mapping_size % PAGE_SIZE);
        }
        
        // convert from mapping protection to mprotect prot argument
        int mprotect_prot = ((MEMORY_MAPPING_READ  & cur_mapping.protection) ? PROT_READ  : 0) |
                            ((MEMORY_MAPPING_WRITE & cur_mapping.protection) ? PROT_WRITE : 0) |
                            ((MEMORY_MAPPING_EXEC  & cur_mapping.protection) ? PROT_EXEC  : 0);
        mprotect((void*) mapping_base, mapping_size, mprotect_prot);
    }
}


static struct sigset sigset_clear_all;

static void receive_state();

static void destroy_state();

static __attribute__((noinline)) void signal_handler(int signo, struct siginfo* info, void* ucontext) {

    /* make everything read-write */
    uint64_t temp_size = ((uint64_t)stack - (uint64_t)_data_start);
    if(temp_size) {
        mprotect((void*)_data_start, temp_size, PROT_READ | PROT_WRITE);
    }
    temp_size = ((uint64_t)_data_end - (uint64_t)stack - sizeof(stack));
    if(temp_size) {
        if(temp_size % PAGE_SIZE) {
            temp_size += PAGE_SIZE;
            temp_size -= temp_size % PAGE_SIZE;
        }
        mprotect(&stack[sizeof(stack)], temp_size, PROT_READ | PROT_WRITE);
    }
    
    /* unblock signal */
    sigprocmask(SIG_SETMASK, &sigset_clear_all, NULL);
    
    //reset alarm
    ualarm(0);
    
    /* record crash information */
    // registers
    arch_store_registers(signo, info, ucontext, &registers_after);
    // general information
    struct crash_info crash_info;
    crash_info.signo = signo;
    crash_info.address = arch_get_crash_address(signo, info, ucontext);
    
    /* write state after crashing to output file */
    write_state(&registers_after, &crash_info);
    
    /* done! TODO: we may want to allow some kind of hook here, so we can test more than one snippet! */
    // printf("done!\n");
    receive_state();
}

static void setup_signal_handling() {
    /* setup signal handling stack */
    struct sigstack sig_stack = {
        .ss_sp = stack,
        .ss_flags = 0,
        .ss_size = sizeof(stack)
    };
    if(sigaltstack(&sig_stack, NULL)) {
        writef(stderr, "setting signal stack failed!\n");
        exit(-1);
    }
    
    /* setup signal handler for all the signals we may get */
    struct sigaction sig;
    sig.sa_flags = SA_SIGINFO | SA_ONSTACK;
    sig.sa_sigaction = signal_handler;
    sig.sa_restorer = NULL;

    #ifdef __x86_64__
        // for some reason x86 needs a sa_restorer and the SA_RESTORER flag (0x04000000).
        sig.sa_flags |= 0x04000000;
    #endif /* __x86_64__ */

    memset8(&sig.sa_mask, 0, sizeof(sig.sa_mask));
    sigaction(SIGILL, &sig, NULL);
    sigaction(SIGSEGV, &sig, NULL);
    sigaction(SIGALRM, &sig, NULL);
    sigaction(SIGFPE, &sig, NULL);
    sigaction(SIGBUS, &sig, NULL);
    sigaction(SIGTRAP, &sig, NULL);
    sigaction(SIGSYS, &sig, NULL);
    sigaction(SIGABRT, &sig, NULL);
    
    /* setup sigset to clear all signals */
    memset8(&sigset_clear_all, 0, sizeof(sigset_clear_all));
}

static void receive_state() {
    /* parse state from socket */
    
    // registers
    if(nread(socket_fd, &registers_before, sizeof(registers_before)) != sizeof(registers_before)) {
        writef(stderr, "failed to read registers\n");
        exit(-1);
    }
    
    // crash info (ignored)
    // TODO: do not send crash info!
    
    // mappings_count
    uint64_t mappings_count;
    if(nread(socket_fd, &mappings_count, sizeof(mappings_count)) != sizeof(mappings_count)) {
        writef(stderr, "failed to read mappings_count\n");
        exit(-1);
    }
    
    if(mappings_count > MAX_MEMORY_MAPPINGS) {
        writef(stderr, "cannot create %u memory mappings. Maximum supported is %u. Recompile with MAX_MEMORY_MAPPINGS set to higher number if required!\n", mappings_count, MAX_MEMORY_MAPPINGS);
        exit(-1);
    }
    // memory mappings
    if(nread(socket_fd, mappings, sizeof(struct memory_mapping)*mappings_count) != sizeof(struct memory_mapping) * mappings_count) {
        writef(stderr, "failed to read memory mappings\n");
        exit(-1);
    }
    
    setup_mappings(mappings, mappings_count);
    
    /* setup register content (including PC, so we will never end up here!) */
    // perform a final check whether everything seems fine
    arch_check_registers(&registers_before);
    
    // make everything (except stack) read only
    uint64_t temp_size = ((uint64_t)stack - (uint64_t)_data_start);
    if(temp_size) {
        mprotect((void*)_data_start, temp_size, PROT_READ);
    }
    temp_size = ((uint64_t)_data_end - (uint64_t)stack - sizeof(stack));
    if(temp_size) {
        if(temp_size % PAGE_SIZE) {
            temp_size += PAGE_SIZE;
            temp_size -= temp_size % PAGE_SIZE;
        }
        mprotect(&stack[sizeof(stack)], temp_size, PROT_READ);
    }
    
    // setup alarm in case we run into an infinite loop (0.5 seconds should only be reached when an infinite loop is encountered)
    ualarm(500000);
    // make everything read-only ...
    // mprotect(_data_start, (uint64_t)_data_end - (uint64_t)_data_start, PROT_READ);
    // ... except for the signal handling stack
    // mprotect(signal_handler_stack, sizeof(signal_handler_stack), PROT_READ | PROT_WRITE);
    // start execution by restoring full register state (including pc)
    arch_setup_registers(&registers_before);
}

// basically main function, but with a new stack!
void wrapped_main(int argc, char** argv) {
    /*int a = open("/proc/self/maps", O_RDONLY);
    char c;
    while(1) {
        int amount = read(a, &c, 1);
        if(!amount) {
            break;   
        }
        writef(stdout, &c, 1);
    }*/
    
    // printf("size of struct: %u\n", sizeof(struct register_state));
    
    // ./loader <temporary-mapping-address> <state-file-path> <size-to-map> <offset-in-state-file> <output-file-path>
    
    /* parse arguments and map state file */
    if(argc != 3) {
        writef(stderr, "usage: %s <ip> <port>\n", argv[0]);
        exit(-1);
    }
    // TODO: we could sanity check more arguments
    uint64_t ip = strtou64(argv[1], NULL);
    uint64_t port = strtou64(argv[2], NULL);
    
    /* unmap everything we don't need (stack, heap, vdso, ...) */
    // TODO: unmap everything
    
    /* connect to server */
    socket_fd = (int)sys_socket(AF_INET, SOCK_STREAM, 0);
    
    if(socket_fd < 0) {
        writef(stderr, "failed to open socket: %d\n", socket_fd);
        exit(-1);
    }
    
    struct sockaddr_in4 sock_addr = {
        .type = AF_INET,
        .port = htons((uint16_t)port),
        .address = ip, // TODO: also invert that?
        .padding = 0
    };
    
    int connect_res = (int)sys_connect(socket_fd, &sock_addr, sizeof(sock_addr));
    
    if(connect_res < 0) {
        writef(stderr, "failed to connect socket: %d\n", connect_res);
        exit(-1);
    }
    
    
    /* setup signal handling */
    setup_signal_handling();

    // TODO: exit runner when connection drops
    receive_state();
    
    // unreachable
    writef(stderr, "this code should not be reachable!\n");
    exit(-1);
}

void main(int argc, char** argv) {
    // use a different stack, so we can unmap the default stack to have less mappings
    wrap_main(argc, argv, stack + sizeof(stack) - 8 /* for some reason x86_64 needs a -8 here. Perhaps to be 16-byte aligned? */);
}
