#define VEC_REG_SIZE 16
typedef struct vv {
    uint8_t v[VEC_REG_SIZE];
} vv;

typedef uint64_t reg;

struct gp {
    reg   x0;
    reg   x1;
    reg   x2;
    reg   x3;
    reg   x4;
    reg   x5;
    reg   x6;
    reg   x7;
    reg   x8;
    reg   x9;
    reg  x10;
    reg  x11;
    reg  x12;
    reg  x13;
    reg  x14;
    reg  x15;
    reg  x16;
    reg  x17;
    reg  x18;
    reg  x19;
    reg  x20;
    reg  x21;
    reg  x22;
    reg  x23;
    reg  x24;
    reg  x25;
    reg  x26;
    reg  x27;
    reg  x28;
    reg  x29;
    union {
        reg  x30;
        reg   lr;
    };
    reg sp;
};

typedef uint64_t freg;
union fpv {
    float  flt;
    double dbl;
    uint64_t u;
#ifdef VECTOR
    // Vector and floating point registers are merged on aarch64
    vv v;
#endif
};
struct fp {
    union fpv  d0;
    union fpv  d1;
    union fpv  d2;
    union fpv  d3;
    union fpv  d4;
    union fpv  d5;
    union fpv  d6;
    union fpv  d7;
    union fpv  d8;
    union fpv  d9;
    union fpv d10;
    union fpv d11;
    union fpv d12;
    union fpv d13;
    union fpv d14;
    union fpv d15;
    union fpv d16;
    union fpv d17;
    union fpv d18;
    union fpv d19;
    union fpv d20;
    union fpv d21;
    union fpv d22;
    union fpv d23;
    union fpv d24;
    union fpv d25;
    union fpv d26;
    union fpv d27;
    union fpv d28;
    union fpv d29;
    union fpv d30;
    union fpv d31;
};

struct vec {
    vv  v0;
    vv  v1;
    vv  v2;
    vv  v3;
    vv  v4;
    vv  v5;
    vv  v6;
    vv  v7;
    vv  v8;
    vv  v9;
    vv v10;
    vv v11;
    vv v12;
    vv v13;
    vv v14;
    vv v15;
    vv v16;
    vv v17;
    vv v18;
    vv v19;
    vv v20;
    vv v21;
    vv v22;
    vv v23;
    vv v24;
    vv v25;
    vv v26;
    vv v27;
    vv v28;
    vv v29;
    vv v30;
    vv v31;
};

/* https://github.com/corsix/amx/blob/main/aarch64.h */
#include <stdint.h>

#define AMX_NOP_OP_IMM5(op, imm5) \
    __asm("nop\nnop\nnop\n.word (0x201000 + (%0 << 5) + %1)" : : "i"(op), "i"(imm5) : "memory")

#define AMX_OP_GPR(op, gpr) \
    __asm(".word (0x201000 + (%0 << 5) + 0%1 - ((0%1 >> 4) * 6))" : : "i"(op), "r"((uint64_t)(gpr)) : "memory")

#define AMX_LDX(gpr)    AMX_OP_GPR( 0, gpr)
#define AMX_LDY(gpr)    AMX_OP_GPR( 1, gpr)
#define AMX_STX(gpr)    AMX_OP_GPR( 2, gpr)
#define AMX_STY(gpr)    AMX_OP_GPR( 3, gpr)
#define AMX_LDZ(gpr)    AMX_OP_GPR( 4, gpr)
#define AMX_STZ(gpr)    AMX_OP_GPR( 5, gpr)
#define AMX_LDZI(gpr)   AMX_OP_GPR( 6, gpr)
#define AMX_STZI(gpr)   AMX_OP_GPR( 7, gpr)
#define AMX_EXTRX(gpr)  AMX_OP_GPR( 8, gpr)
#define AMX_EXTRY(gpr)  AMX_OP_GPR( 9, gpr)
#define AMX_FMA64(gpr)  AMX_OP_GPR(10, gpr)
#define AMX_FMS64(gpr)  AMX_OP_GPR(11, gpr)
#define AMX_FMA32(gpr)  AMX_OP_GPR(12, gpr)
#define AMX_FMS32(gpr)  AMX_OP_GPR(13, gpr)
#define AMX_MAC16(gpr)  AMX_OP_GPR(14, gpr)
#define AMX_FMA16(gpr)  AMX_OP_GPR(15, gpr)
#define AMX_FMS16(gpr)  AMX_OP_GPR(16, gpr)
#define AMX_SET()       AMX_NOP_OP_IMM5(17, 0)
#define AMX_CLR()       AMX_NOP_OP_IMM5(17, 1)
#define AMX_VECINT(gpr) AMX_OP_GPR(18, gpr)
#define AMX_VECFP(gpr)  AMX_OP_GPR(19, gpr)
#define AMX_MATINT(gpr) AMX_OP_GPR(20, gpr)
#define AMX_MATFP(gpr)  AMX_OP_GPR(21, gpr)
#define AMX_GENLUT(gpr) AMX_OP_GPR(22, gpr)

/* https://github.com/corsix/amx/blob/main/emulate.h */
typedef union amx_reg {
    uint8_t  u8 [64];
    uint16_t u16[32];
    uint32_t u32[16];
    int8_t   i8 [64];
    int16_t  i16[32];
    int32_t  i32[16];
    _Float16 f16[32];
    float    f32[16];
    double   f64[ 8];
} amx_reg;

typedef struct amx_state {
    amx_reg x[ 8];
    amx_reg y[ 8];
    amx_reg z[64];
} amx_state;


#define PTR_ROW_FLAGS(ptr, row, flags) (((uint64_t)&*(ptr)) + (((uint64_t)((row) + (flags) * 64)) << 56))

static void capture_state(amx_state* dst) {
    uint32_t row = 0;
    for (; row < 8; row += 2) {
        AMX_STX(PTR_ROW_FLAGS(dst->x[row].u8, row, 1));
        AMX_STY(PTR_ROW_FLAGS(dst->y[row].u8, row, 1));
        AMX_STZ(PTR_ROW_FLAGS(dst->z[row].u8, row, 1));
    }
    for (; row < 64; row += 2) {
        AMX_STZ(PTR_ROW_FLAGS(dst->z[row].u8, row, 1));
    }
    AMX_CLR();
}

static void inject_state(const amx_state* src) {
    AMX_SET();
    uint32_t row = 0;
    for (; row < 8; row += 2) {
        AMX_LDX(PTR_ROW_FLAGS(src->x[row].u8, row, 1));
        AMX_LDY(PTR_ROW_FLAGS(src->y[row].u8, row, 1));
        AMX_LDZ(PTR_ROW_FLAGS(src->z[row].u8, row, 1));
    }
    for (; row < 64; row += 2) {
        AMX_LDZ(PTR_ROW_FLAGS(src->z[row].u8, row, 1));
    }
}

struct regs {
    struct gp gp;
    reg pstate;
#if defined(FLOATS) || defined(VECTOR)
    reg fpsr;
    union {
        // NOTE: vector and fp registers are shared
        struct fp fp;
  #ifdef VECTOR
        struct vec vec;
  #endif
    };
#endif
    #ifdef AMX
    struct amx_state amx_state;
    #endif
};
