#include <stdlib.h>
#include <signal.h>
#include <alarm.h>
#include <time.h>

#include "arch_runner.h"


static uint64_t was_defined = 0;
static uint64_t first_defined = 0;

static uint64_t first;
static uint64_t end;

static uint32_t* code_base;
static uint64_t* cur_instr;

// stack used for signal handling.
static __attribute__((aligned(PAGE_SIZE))) uint8_t signal_handler_stack[PAGE_SIZE * 4];

// "normal" stack. Mapped as RO during execution of gadget. No need to be aligned.
static uint8_t stack[PAGE_SIZE * 20];

// defined by linker script.
// start of data segment.
extern void _data_start();
// end of data segment.
extern void _data_end();

static struct sigset sigset_clear_all;

typedef void (*void_func)(void);


static void run_instr() {
    ualarm(5000);
    mprotect(code_base, 2 * PAGE_SIZE, PROT_READ | PROT_EXEC);
    asm volatile("move $sp, %0" :: "r" (stack + sizeof(stack) - PAGE_SIZE));
    ((void_func)(void*)(code_base))();
    *(volatile uint8_t*)NULL;
}

static __attribute__((noinline)) void signal_handler(int signo, struct siginfo* info, void* ucontext) {
    
    /* unblock signal */
    sigprocmask(SIG_SETMASK, &sigset_clear_all, NULL);
    
    mprotect(code_base, 2 * PAGE_SIZE, PROT_READ | PROT_WRITE);
    
    if(signo != 4 || SIG_PC(ucontext) != (uint64_t)(void*)cur_instr) {
        if(!was_defined) {
            was_defined = 1;
            first_defined = *cur_instr;
        }
    } else if(was_defined){
        writef(stderr, "0x%x - 0x%x\n", first_defined, *cur_instr - 1);
        was_defined = 0;
    }
    
    *cur_instr += 1;
    
    if((*cur_instr % (1024 * 1024)) == 0) {
        writef(stderr, "progress: 0x%x / 0x%x\n", *cur_instr, end);
        // printf("progress: %u%%\n", (*cur_instr - first) * 100ull / (end - first));
    }
    
    if(*cur_instr > end) {
        if(was_defined) {
            writef(stderr, "0x%x - 0x%x\n", first_defined, *cur_instr - 1);
        }
        exit(0);
    }
    
    run_instr();
}

static void setup_signal_handling() {
    /* setup signal handling stack */
    struct sigstack sig_stack = {
        .ss_sp = signal_handler_stack,
        .ss_flags = 0,
        .ss_size = sizeof(signal_handler_stack)
    };
    if(sigaltstack(&sig_stack, NULL)) {
        writef(stderr, "setting signal stack failed!\n");
        exit(-1);
    }
    
    /* setup signal handler for all the signals we may get */
    struct sigaction sig;
    sig.sa_flags = SA_SIGINFO | SA_ONSTACK;
    sig.sa_sigaction = signal_handler;
    sig.sa_restorer = NULL;
    memset8(&sig.sa_mask, 0, sizeof(sig.sa_mask));
    sigaction(SIGILL, &sig, NULL);
    sigaction(SIGSEGV, &sig, NULL);
    sigaction(SIGALRM, &sig, NULL);
    sigaction(SIGFPE, &sig, NULL);
    sigaction(SIGBUS, &sig, NULL);
    sigaction(SIGTRAP, &sig, NULL);
    sigaction(SIGSYS, &sig, NULL);
    sigaction(SIGABRT, &sig, NULL);
    
    /* setup sigset to clear all signals */
    memset8(&sigset_clear_all, 0, sizeof(sigset_clear_all));
}

void wrapped_main(int argc, char** argv) {
    setup_signal_handling();
    code_base = mmap(NULL, PAGE_SIZE * 2, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_POPULATE, -1, 0);
    for(int i = 0; i < 31; i++) {
        code_base[i] = 0b00000000000101011000000000000000 | (i + 1);
    }
    cur_instr = (uint64_t*)&code_base[31];
    *cur_instr = strtou64(argv[1], NULL);
    first = *cur_instr;
    end = strtou64(argv[2], NULL);
    run_instr();
}


void main(int argc, char** argv) {
    // use a different stack, so we can unmap the default stack to have less mappings
    wrap_main(argc, argv, stack + sizeof(stack) - 8 /* for some reason x86_64 needs a -8 here. Perhaps to be 16-byte aligned? */);
}
