use capstone::prelude::*;
use anyhow::Result;
use goblin::Object;

use clap::{Parser, Subcommand};

use serde::{Serialize, Deserialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::fs::read;

#[derive(Parser, Debug)]
struct Args {
    /// Path of the file to analyze
    #[arg(short, long)]
    path: PathBuf,

    /// Architecture of the file
    arch: String,
}

enum Arch {
    AMD64,
    AARCH64
}

fn gadgets_capstone(addr: usize, data: &[u8], arch: Arch) -> Result<()> {
    let (cs, needles, step_size) = if matches!(arch, Arch::AMD64) {
        let cs = Capstone::new()
            .x86()
            .mode(arch::x86::ArchMode::Mode64)
            .syntax(arch::x86::ArchSyntax::Intel)
            .detail(true)
            .build()
            .expect("Failed to create Capstone object");

        let needles = [
            "mov cr0",
            "mov cr3",
            "mov cr4",
            "wrmsr",
            "lidt",
            "lgdt",
            "lldt",
            "ltr",
            "wrgsbase",
            "wrfsbase",
            "swapgs",
            "popfq",
            "stac",
            "clac",
            "iretq",
        ].to_vec();
        (cs, needles, 1)
    } else if matches!(arch, Arch::AARCH64) {
        let cs = Capstone::new()
            .arm64()
            .mode(arch::arm64::ArchMode::Arm)
            .detail(false)
            .build()
            .expect("Failed to create Capstone object");

        // on aarch64, just doing a .contains for `msr <register-name> results in
        // false positives, because things like `msr elr_el1` will also match on
        // `msr elr_el12`. To solve this we suffix with commas and pop the commas
        // off when storing results.
        let needles = [
            "msr elr_el1,",
            "msr pan,",
            "msr sctlr_el1,",
            "msr spsr_el1,",
            "msr tcr_el1,",
            "msr ttbr0_el1,",
            "msr ttbr1_el1,",
            "msr vbar_el1,",
        ].to_vec();
        (cs, needles, 4)
    } else {
        unreachable!();
    };

    let mut found_map: HashMap<String, Vec<(u64, String)>> = HashMap::new();
    for n in &needles {
        let needle_out = if matches!(arch, Arch::AARCH64) {
            let mut tmp = n.to_string();
            tmp.pop();
            tmp
        } else {
            n.to_string()
        };
        found_map.insert(needle_out, vec![]);
    }

    for idx in (0..data.len()).step_by(step_size) {
        // fast path for some amd64 opcodes we don't care about
        if matches!(arch, Arch::AMD64) && (data[idx] == 0x00 || data[idx] == 0xcc || data[idx] == 0xc3 || data[idx] == 0x90) {
            continue;
        }
        let insns = match cs.disasm_count(&data[idx..], (addr+idx) as u64, 1) {
            Ok(insns) => { insns },
            Err(_) => continue
        };

        for insn in insns.as_ref() {
            for needle in &needles {
                if matches!(arch, Arch::AMD64)  {
                    let detail: InsnDetail = cs.insn_detail(insn)?;
                    if let ArchDetail::X86Detail(d) = detail.arch_detail() {
                        if d.prefix() != &[0u8,0,0,0] {
                            continue;
                        }
                    }
                }

                let needle_out = if matches!(arch, Arch::AARCH64) {
                    let mut tmp = needle.to_string();
                    tmp.pop();
                    tmp
                } else {
                    needle.to_string()
                };

                if insn.to_string().contains(needle) {
                    let curr = found_map.get_mut(&needle_out).unwrap();
                    let mut mnemonic = insn.mnemonic().unwrap_or("").to_string();
                    let op_str = insn.op_str().unwrap_or("");
                    if !op_str.is_empty() {
                        mnemonic.push(' ');
                        mnemonic.push_str(op_str);
                    }

                    curr.push((insn.address(), mnemonic));
                }
            }
        }
    }

    //println!("~~~~~~~~~~~~~~ STATS ~~~~~~~~~~~~~~");
    //for (needle, addrs) in found_map.iter() {
    //    println!("{needle}: {}", addrs.len());
    //}

    //println!();

    //println!("~~~~~~~~~~~~~~ INSNS ~~~~~~~~~~~~~~");
    //for (needle, addrs) in found_map.iter() {
    //    println!("{needle}:");
    //    for (_addr, insn) in addrs {
    //        println!("    {insn}");
    //    }
    //}

    let serialized = serde_json::to_string(&found_map)?;
    println!("{}", serialized);

    Ok(())
}


fn main() -> Result<()> {
    let args = Args::parse();

    let input_path = args.path;
    let arch_str = args.arch;

    let (arch, base) = if arch_str == "x86" {
        (Arch::AMD64, 0xffffffff81000000)
    } else if arch_str == "arm" {
        (Arch::AARCH64, 0xffff800080010000)
    } else {
        panic!("Invalid Architecture! Must be one of: x86, arm")
    };

    let input_data = read(input_path)?;

    gadgets_capstone(base, &input_data, arch)?;

    Ok(())
}
