import os
import pdb
import pickle as pk
import json
import logging
from copy import deepcopy
from tqdm import tqdm
import pandas as pd
from packaging.requirements import Requirement
from utils import *

logger = logging.getLogger('Brother')
streamhandler = logging.StreamHandler()
logger.addHandler(streamhandler)
logger.setLevel(logging.INFO)


class InvalidDependentException(Exception):
    """Raised when there is not any version that the dependent package depends on the dependency package
    """
    pass

class Brother:
    def __init__(self,
                 deps_fpaths:list,
                 reversed_deps:dict,
                 releases_info:dict,
                 ) -> None:
        # load self.deps
        self._load_deps_to_dict(deps_fpaths=deps_fpaths)
        self.merged_deps = self._merge_deps_together(self.deps)

        # load self.reversed_deps
        self.reversed_deps = reversed_deps
        self.releases_info= releases_info
        

    def _load_deps_to_dict(self, deps_fpaths:list)->None:
        """initialize self.deps
        """
        self.deps = []
        for path in deps_fpaths:
            self.deps.append(gen_deps(pd.read_csv(path, index_col=["name","version"])))

    def _merge_deps_together(self,deps_list)->None:
        """merge multiple deps
        """
        if os.path.isfile('merge_deps.pkl'):
            with open('merge_deps.pkl','rb') as f:
                total_deps = pk.load(f)
            return total_deps
        total_deps = deepcopy(deps_list[0])
        for i in range(1, len(deps_list)):
            deps = deps_list[i]
            for idx, item in deps.items():
                if idx not in total_deps:
                    total_deps[idx] = deepcopy(item)
        pk.dump(total_deps, open('merge_deps.pkl','wb'))
        return total_deps
    

    def get_direct_dependencies(self, name_version:tuple)->set():
        """return a set of direct dependencies of current name_ver pair
        """
        logger.debug(f"Request dependencies of {name_version}")
        direct_dependencies = set()

        if name_version not in self.merged_deps:
            logger.debug(f"No dependencies infomation of {name_version}")
            return direct_dependencies

        raw_dependencies = json.loads(self.merged_deps[name_version]["raw_dependencies"])
        for depend_name, specs in raw_dependencies.items():
            clean_specs = specs
            if ';' in specs:
                # remove markers
                clean_specs = specs.split(';')[0]
            try:
                req = Requirement(depend_name+' '+clean_specs)
            except:
                logger.debug("The specifier does not follow semantic versioning")
                continue
            spec = req.specifier
            spec.prereleases = True
            logger.debug(f"Current specifier {spec}")
            # spec = SpecifierSet(clean_specs, prereleases=True)

            # the dependency is removed from pypi by project owner
            if depend_name not in self.releases_info:
                continue
            for release in self.releases_info[depend_name]['releases']:
                is_contained = False
                try:
                    is_contained = (release in spec)
                except:
                    logger.debug(f"{release} does not follow semantic versioning")
                    pass
                if is_contained:
                    name_ver_index = tuple([depend_name,release])
                    direct_dependencies.add(name_ver_index)
        return direct_dependencies

    def check_dependents_version(self, name_version:tuple, dependent:str, latest_only:bool)->str:
        """return the state of dependent
        "latest" indicates only the latest version of the dependent depends on the name_ver pair
        "non-latest" indicates any version except the latest version of the dependent depends on the name_ver pair
        "mixed" indicates that both the latest and non-latest version of the dependent depend on the name_ver pair
        """
        logger.debug(f"Checking version of the current relationship: {dependent} -> {name_version}")

        def _check_version(self:Brother, dependency_index:tuple, dependent_index:tuple)->bool:
            """return a bool variable indicating the dependency relation
            """
            if dependent_index in self.merged_deps:
                raw_dependencies = json.loads(self.merged_deps[dependent_index]["raw_dependencies"])
                if dependency_index[0] in raw_dependencies:
                    specs = raw_dependencies[dependency_index[0]]
                    clean_specs = specs
                    if ';' in specs:
                        # remove markers
                        clean_specs = specs.split(';')[0]
                    try:
                        req = Requirement(dependency_index[0]+' '+clean_specs)
                        spec = req.specifier
                        spec.prereleases = True
                        logger.debug(f"Current specifier {spec}")
                        # spec = SpecifierSet(clean_specs, prereleases=True)
                        is_contained = False
                        try:
                            is_contained = (dependency_index[1] in spec)
                        except:
                            logger.debug(f"{dependency_index[1]} does not follow semantic versioning")
                            pass
                        if is_contained:
                            return True
                    except:
                        logger.debug("The specifier does not follow semantic versioning")
                        pass
            return False

        latest_ver = self.releases_info[dependent]["latest"]
        latest_dependent = (dependent, latest_ver)
        has_latest = _check_version(self=self, dependency_index=name_version, dependent_index= latest_dependent)
        if latest_only:
            if has_latest:
                return "latest"
            else:
                return "non-latest"
        
        has_non_latest = False
        for release in self.releases_info[dependent]["releases"]:
            current_dependent = (dependent, release)
            if current_dependent == latest_dependent:
                continue
            if _check_version(self=self, dependency_index=name_version, dependent_index= current_dependent):
                has_non_latest = True
                break
        try:
            if has_latest and has_non_latest:
                return "mixed"
            elif has_latest:
                return "latest"
            elif has_non_latest:
                return "non-latest"
            else:
                raise InvalidDependentException
        except InvalidDependentException:
            logger.error(f"{dependent} has no dependencies relationship on {name_version}")

    def get_direct_dependents(self, name_version:tuple, latest_only:bool)->dict:
        """return a dict of direct dependents of current package_version 
        >>> d.keys()
        "package_A", "package_B" ...
        >>> d["package_A"]
        "latest" / "non-latest" / "mixed"
        """
        direct_dependents = dict()

        logger.debug(f"Request direct dependents of {name_version}")
        
        if name_version not in self.reversed_deps:
            logger.debug("No direct dependents")
            return direct_dependents
        
        downstream_packages = self.reversed_deps[name_version]['dependents']
        for package in downstream_packages:
            if package in self.releases_info:
                direct_dependents[package] = self.check_dependents_version(name_version, package, latest_only)
        return direct_dependents

    def get_brothers(self, name_version:tuple, is_inactive:bool, latest_only:bool)->set:
        """return a dict of brothers of the given package
        """
        # obtain direct dependencies
        direct_dependencies = self.get_direct_dependencies(name_version=name_version)            

        # obtain direct dependents of each dependency
        brothers = dict()
        for nam_ver in tqdm(direct_dependencies):
            current_brothers = self.get_direct_dependents(name_version=nam_ver, latest_only=latest_only)
            for key, value in current_brothers.items():
                if latest_only:
                    if value == "latest":
                        brothers[key] = value
                else:
                    if key not in brothers:
                        brothers[key] = value
                    elif brothers[key] != value:
                        brothers[key] = "mixed"
        # check if the latest brother depend on the father.
        return brothers
    

    def init_vuln(self, vulns_path)->None:
        """initialize vuln_name_ver_list
        """
        logger.info("initializing vuln_name_ver_list")
        name_ver_set = set()
        for fname in os.listdir(vulns_path):
            fpath = os.path.join(vulns_path, fname)
            with open(fpath, 'r') as f:
                vuln = json.load(f)
                package_name = vuln["packageName"]
                for ver in vuln["affect_version"]:
                    name_ver_set.add((package_name, ver))
        self.vuln_name_ver_list = list(name_ver_set)

    def detect_vulns(self, package_name:str, version:str)->bool:
        """return True or False a vulnerablity in direct or tansitive dependencies are detected.
        """
        logger.debug(f"traversing {package_name}:{version} dependents")
        index = tuple([package_name,version])
        dependencies = set()
        dependencies.add(index)
        queue = [index]
        q_index = 0 
        # Do bfs
        # accident_append = 0
        while len(queue)>q_index:
            print(f"{100*q_index/len(queue)}% {q_index} / {len(queue)}")
            cur_index = queue[q_index]
            logger.debug(f"iteratering queue({len(queue)-q_index}), current_index: {cur_index}")
            # cur_distance = dependents[cur_index]['distance'] + 1 
            q_index += 1

            # if cur_index not in reversed_deps:
            #     continue
            upstream_dependencies = list(self.get_direct_dependencies(cur_index))
            for upstream_index in upstream_dependencies:
                if upstream_index in self.vuln_name_ver_list:
                    return upstream_index
                if upstream_index not in dependencies:
                    dependencies.add(upstream_index)
                    queue.append(upstream_index)
        return None