#include <stdlib.h>

#include "arch_runner.h"

#include <ucontext.h>

void arch_check_registers(struct register_state* registers_before){
    if(registers_before->pc & 0xffff800000000001ull) {
        writef(stderr, "pc register must be a user address aligned to 2 bytes (is 0x%x)\n", registers_before->pc);
        exit(-1);
    }
    
    // TODO: make sure mapping exists
    uint32_t* ip = (uint32_t*)registers_before->pc;   

    if(ip[0] != 0xf9400020 || ip[1] != 0xf9400421) {
        writef(stderr, "aarch64 requires two additional instructions at pc to restore state: f9400020 (ldr x0, [x1,  #0]) and 0xf9400421 (ldr x1, [x1,  #8]).\n");
        writef(stderr, "was: 0x%x 0x0x\n", ip[0], ip[1]);
        exit(-1);
    }
}

// defined in loader_asm.S
extern void arch_setup_registers2(struct register_state* registers_before);

void arch_setup_registers(struct register_state* registers_before) {
    #ifdef AMX
    inject_state(&registers_before->regs.amx_state);
    #endif
    arch_setup_registers2(registers_before);
}

// defined in loader_asm.S
extern void arch_store_vector_registers(struct register_state* state);
extern void arch_store_amx_registers(struct register_state* state);

void arch_store_registers(int signo, struct siginfo* info, void* ucontext, struct register_state* state) {
    mcontext_t* mcontext = &((ucontext_t *)ucontext)->uc_mcontext;

    #ifdef VECTOR
    /* arch_store_vector_registers(state); */
    for (unsigned i = 0; i < 32; i++) {
        ((vv*)&state->regs.vec)[i] = ((vv*)mcontext->__reserved)[1+i];
    }
    /* ((uint64_t*)&((vv*)&state->regs.vec)[31])[0] = ((uint64_t*)&((vv*)&state->regs.vec)[31])[1]; */
    /* ((uint64_t*)&((vv*)&state->regs.vec)[31])[1] = 0; */
    #endif

    state->pc = mcontext->pc;
    memcpy(&state->regs.gp.x0, &mcontext->regs, sizeof(state->regs.gp)-sizeof(state->regs.gp.sp));
    memcpy(&state->regs.gp.sp, &mcontext->sp, sizeof(state->regs.gp.sp));

    state->regs.pstate = mcontext->pstate;

    #if defined(FLOATS) || defined(VECTOR)
    state->regs.fpsr = ((freg*)mcontext->__reserved)[1];
    #endif

    /* #ifdef FLOATS */
    /* // TODO: doest his vary between kernels? */
    /* // is there a better way? */
    /* for (unsigned i = 0; i < sizeof(regs_result.fp)/sizeof(regs_result.fp.d0); i++) { */
    /*     ((freg*)&regs_result.fp)[i] = ((freg*)mcontext->__reserved)[2+2*i]; */
    /* } */
    /* #endif */

    /* printf("x31 %x\n", state->regs.gp.sp); */
    /* printf("test %x\n", *((uint64_t*)&state->regs.vec.v0)); */
    /* printf("test %x\n", *(((uint64_t*)&state->regs.vec.v0)+8)); */

    #ifdef AMX
    capture_state(&state->regs.amx_state);
    #endif
}

uint64_t arch_get_crash_address(int signo, struct siginfo* info, void* ucontext) {
    return (uint64_t) info->_sigfault.si_addr;
}

// defined in loader_asm.S
extern  void wrap_main(int argc, char** argv, void* stack);
