import numpy as np
from scipy.stats import pearsonr
from matplotlib import pyplot as plt

T = 298.15 # Used by Autodock4
R = 1.9872 / 1000 # kcal / (K * mol)

def Kd_to_dG(value):
    return R * T * np.log(value)


class VariantPair:
    def __init__(self, ref, mut):
        self.ref = ref
        self.mut = mut


class PDB:
    def __init__(self, code, affinity, resolution, casf):
        self.code = code
        self.affinity = Kd_to_dG(affinity)
        self.resolution = resolution
        self.casf = bool(casf)
        self.min_energy = None
        self.min_rmsd = None


class Pose:
    def __init__(self, affinity, rmsd):
        self.affinity = affinity
        self.rmsd = rmsd


class MethodData:
    def __init__(self, title, name, box_span, cursor):
        # Method info
        self.cursor = cursor
        self.title = title
        self.box_span = box_span
        self.name = name
        cursor.execute("SELECT method_id FROM Methods WHERE name = ?;", (name,))
        self.id = cursor.fetchone()[0]
        # Results info
        self.variants = []
        self.pdbs = []
        self.pdb_ids = {}
        # Construct data
        self.fetch_pdb_data()
        self.fetch_docking_data()
        self.construct_variant_energies()

    def fetch_pdb_data(self):
        self.cursor.execute(
            """
            SELECT
                PDB.code, PDB.pdb_id, BPDB.pdb_id, PDB.affinity, BPDB.resolution, PDB.CASF2016
            FROM
                PDB
            INNER JOIN
                PDBBind.PDB AS BPDB
            ON
                BPDB.code = PDB.code
            WHERE
                PDB.pdb_id
                IN (
                    SELECT DISTINCT
                        pdb_id
                    FROM
                        Runs
                    WHERE
                        status = 1
                        AND
                        method_id = ?
                    )
            """,
            (self.id,)
        )
        pdbbind = {}
        data = self.cursor.fetchall()

        for code, results_id, pdbbind_id, delta_g, res, casf in data:
            pdb = PDB(code, delta_g, res, casf)
            self.pdb_ids[results_id] = pdb
            pdbbind[pdbbind_id] = pdb
            self.pdbs.append(pdb)

        self.cursor.execute(
            """
            SELECT
                ref_id, mut_id
            FROM
                PDBBind.PDBPDBMutation
            WHERE
                mutations_within_ligand LIKE '%1%'
                AND
                LENGTH(mutations_within_ligand) < 10
            """)

        not_present = 0
        for ref_id, mut_id in self.cursor.fetchall():
            if ref_id not in pdbbind or mut_id not in pdbbind:
                not_present += 1
                continue

            pair = VariantPair(pdbbind[ref_id], pdbbind[mut_id])
            self.variants.append(pair)

        if not_present:
            print("WARNING: no docking results for {:d} variant pairs on method {}".format(not_present, self.name))

    def fetch_docking_data(self):
        self.cursor.execute(
            """
            SELECT
                min(delta_g), rmsd, Runs.pdb_id
            FROM
                PDB
            INNER JOIN
                Runs
            ON
                Runs.pdb_id = PDB.pdb_id
            INNER JOIN
                Results
            ON
                Results.run_id = Runs.run_id
            WHERE
                method_id = ?
                AND
                box_span = ?
            GROUP BY
                Runs.run_id
            """,
            (self.id, self.box_span)
        )

        query = self.cursor.fetchall()
        print(len(query), self.name)

        for delta_g, rmsd, pdb_id in query:
            pdb = self.pdb_ids[pdb_id]
            pdb.min_energy = Pose(delta_g, rmsd)

        self.cursor.execute(
            """
            SELECT
                delta_g, min(rmsd), Runs.pdb_id
            FROM
                PDB
            INNER JOIN
                Runs
            ON
                Runs.pdb_id = PDB.pdb_id
            INNER JOIN
                Results
            ON
                Results.run_id = Runs.run_id
            WHERE
                method_id = ?
                AND
                box_span = ?
            GROUP BY
                Runs.run_id
            """,
            (self.id, self.box_span)
        )

        for delta_g, rmsd, pdb_id in self.cursor.fetchall():
            pdb = self.pdb_ids[pdb_id]
            pdb.min_rmsd = Pose(delta_g, rmsd)

        self.cursor.execute(
            """
            SELECT
                time
            FROM
                Runs
            WHERE
                method_id = ?
                AND
                box_span = ?
            """,
            (self.id, self.box_span)
        )
        self.timing = np.array(self.cursor.fetchall(), "float64")

    def construct_variant_energies(self):
        """Construct variant energies array for DDG calculations."""
        # Filter variants that have both ref and mut with docking results
        valid_variants = []
        for variant in self.variants:
            if (variant.ref.min_energy is not None and variant.ref.min_rmsd is not None and
                variant.mut.min_energy is not None and variant.mut.min_rmsd is not None):
                valid_variants.append(variant)
        
        if not valid_variants:
            print(f"WARNING: No valid variants found for method {self.name}")
            self.variant_energies = np.empty((0, 8), dtype="float32")
            return
        
        # Create array with columns: [ref_resolution, mut_resolution, ref_exp, mut_exp, ref_dock_min_energy, mut_dock_min_energy, ref_dock_min_rmsd, mut_dock_min_rmsd]
        self.variant_energies = np.zeros((len(valid_variants), 8), dtype="float32")
        
        for i, variant in enumerate(valid_variants):
            self.variant_energies[i, 0] = variant.ref.resolution      # ref resolution
            self.variant_energies[i, 1] = variant.mut.resolution      # mut resolution
            self.variant_energies[i, 2] = variant.ref.affinity        # ref experimental energy
            self.variant_energies[i, 3] = variant.mut.affinity        # mut experimental energy
            self.variant_energies[i, 4] = variant.ref.min_energy.affinity  # ref docking energy (min energy)
            self.variant_energies[i, 5] = variant.mut.min_energy.affinity  # mut docking energy (min energy)
            self.variant_energies[i, 6] = variant.ref.min_rmsd.affinity    # ref docking energy (min RMSD)
            self.variant_energies[i, 7] = variant.mut.min_rmsd.affinity    # mut docking energy (min RMSD)

    def resolution_filter(self, max_resolution):
        # FIXME
        # return self.pdb_data.resolution <= max_resolution
        return tuple(pdb for pdb in self.pdbs if pdb.resolution <= max_resolution)

    def filtered_energy(self, min_rmsd, max_resolution, prev_mask=None):
        filtered = self.resolution_filter(max_resolution)

        data = np.zeros((len(filtered), 2), "float32")
        data[:, 0] = tuple(pdb.affinity for pdb in filtered)

        if min_rmsd:
            data[:, 1] = tuple(pdb.min_rmsd.affinity for pdb in filtered)
        else:
            data[:, 1] = tuple(pdb.min_energy.affinity for pdb in filtered)

        non_outliers = self.non_outliers(data[:, 1], 2)

        return data[non_outliers, :]

    def filtered_rmsd(self, min_rmsd, max_resolution, prev_mask=None):
        if min_rmsd:
            data = self.min_rmsd.rmsd
        else:
            data = self.min_energy.rmsd

        mask = self.resolution_filter(max_resolution)
        if prev_mask is not None:
            mask = np.logical_and(mask, prev_mask)

        return data[mask]

    def filtered_variant_energies(self, min_rmsd, max_resolution):
        if min_rmsd:
            columns = (2, 3, 6, 7)
        else:
            columns = (2, 3, 4, 5)

        mask = np.all(self.variant_energies[:, (0, 1)] <= max_resolution, axis=1)

        return self.variant_energies[mask, :][:, columns]

    def pearson(self, min_rmsd, max_resolution):
        data, ref_data = self.filtered_data(min_rmsd, max_resolution)
        return pearsonr(data, ref_data)
    
    def non_outliers(self, data, constant):
        upper_quartile = np.percentile(data, 75)
        lower_quartile = np.percentile(data, 25)
        iqr = (upper_quartile - lower_quartile) * constant
        quartile_set = (lower_quartile - iqr, upper_quartile + iqr)
        return (data >= quartile_set[0]) & (data <= quartile_set[1])
