import ast
import os
import re
import json

# from importlib_metadata import entry_points
from ConstructKB.Import_Level.core import *
from ConstructKB.Import_Level.core.source_visitor import SourceVisitor
# from wheel_inspect import inspect_wheel
import tarfile
from zipfile import ZipFile



class Tree:
    def __init__(self, name):
        self.name = name
        self.children = []
        self.parent = None
        self.cargo = {}
        self.source = ''
        self.ast = None
    def __str__(self):
        return str(self.name)

def parse_import(tree):
    module_item_dict = {}
    try:
        for node in ast.walk(tree):
            if isinstance(node, ast.ImportFrom): 
                if node.module is None and node.level not in module_item_dict:
                    module_item_dict[node.level] = []
                elif node.module not in module_item_dict:
                   module_item_dict[node.module] = []
                items = [nn.__dict__ for nn in node.names]
              
                for d in items:
                    if d['name'] == 'compute':
                        print('point')
                    if node.module is None:
                        module_item_dict[node.level].append(d['name'])
                    else:
                        module_item_dict[node.module].append(d['name'])
        return module_item_dict
    except(AttributeError):
        return None
 
def gen_AST(filename):
    try:
        source = open(filename).read()
        tree = ast.parse(source, mode='exec')
        return tree
    except (SyntaxError,UnicodeDecodeError,):  
        pass
        return None
def parse_pyx(filename):
    lines = open(filename).readlines()
    all_func_names = []
    for line in lines:
        names = re.findall('def ([\s\S]*?)\(', str(line))
        if len(names)>0:
            all_func_names.append(names[0])

def extract_class(filename):
    try:
        # print(filename)
        source = open(filename).read()
        tree = ast.parse(source, mode='exec')
        visitor = SourceVisitor()
        visitor.visit(tree)
        # print('testing')
        return visitor.result, tree
    except Exception as e:  # to avoid non-python code
        # fail passing python3 
        if filename[-3:] == 'pyx':
            parse_pyx(filename)
        return {}, None  # return empty 

def extract_class_from_source(source):
    try:
        tree = ast.parse(source, mode='exec')
        visitor = SourceVisitor()
        visitor.visit(tree)
        return visitor.result, tree
    except Exception as e:  # to avoid non-python code
        #if filename[-3:] == 'pyx':
        #    #print(filename)
        #    parse_pyx(filename)
        print(e)
        return {}, None# return empty 

def build_dir_tree(node):
    if node.name in ['test', 'tests', 'testing']:
        return 
    if os.path.isdir(node.name) is True:
        os.chdir(node.name)
        items  = os.listdir('.')
        for item in items:
            child_node = Tree(item)
            child_node.parent =  node
            build_dir_tree(child_node)
            if child_node.name.endswith('.py') or os.path.isdir(child_node.name): 
                node.cargo[os.path.splitext(child_node.name)[0]] = ('*','*','*','*')  
                node.children.append(child_node)  
                
        os.chdir('..')
    else:
        if node.name.endswith('.py'):
            source = open(node.name, 'rb').read()
            node.source = source.decode("utf-8", errors="ignore")
            res, tree = extract_class_from_source(node.source)
            res_new = {}
            tmp_API_prefix = leaf2root(node) 
            for k,v in res.items():
                if isinstance(v,tuple):
                    arg_names, len_defaults,node_lineno = v
                    res_new[k] = (arg_names, len_defaults,node_lineno,tmp_API_prefix)
                elif isinstance(v,dict):
                    new_d = {}
                    for t1,t2 in v.items():
                        arg_names, len_defaults,node_lineno = t2
                        new_d[t1] =  (arg_names, len_defaults,node_lineno,tmp_API_prefix)
                    res_new[k]= new_d
                else:
                    print('wrong')
            node.cargo = res_new
            node.ast = tree
            node.origin = True
        

def leaf2root(node):
    tmp_node = node
    path_to_root = []
    while tmp_node is not None:
        path_to_root.append(tmp_node.name)
        tmp_node = tmp_node.parent
    if node.name == '__init__.py':
        path_to_root = path_to_root[1:]
        path_name = ".".join(reversed(path_to_root))
        return path_name
    else:
        path_name = ".".join(reversed(path_to_root[1:]))
        path_name = "{}.{}".format(path_name, node.name.split('.')[0])
        return path_name

 
def find_child_by_name(node, name):
    for ch in node.children:
        if ch.name == name:
            return ch
    return None
def find_node_by_name(nodes, name):
    for node in nodes:
        if node.name == name or os.path.splitext(node.name)[0]== name:
            return node
    return None
def go_to_that_node(root, cur_node, visit_path):
    route_node_names = visit_path.split('.')  
    route_length = len(route_node_names)
    tmp_node = None
    tmp_node =  find_node_by_name(cur_node.parent.children, route_node_names[0])
    if tmp_node is not None:
        for i in range(1,route_length):
            tmp_node =  find_node_by_name(tmp_node.children, route_node_names[i])
            if tmp_node is None:
                break
    elif route_node_names[0] == root.name:
        tmp_node = root
        for i in range(1,route_length):
            tmp_node =  find_node_by_name(tmp_node.children, route_node_names[i])
            if tmp_node is None:
                break
        return tmp_node
    elif route_node_names[0] == cur_node.parent.name:
        tmp_node = cur_node.parent
        for i in range(1,route_length):
            tmp_node =  find_node_by_name(tmp_node.children, route_node_names[i])
            if tmp_node is None:
                break

    if tmp_node is not None and tmp_node.name.endswith('.py') is not True:
       tmp_node =  find_node_by_name(tmp_node.children, '__init__.py')

    return tmp_node

def tree_infer_levels(root_node):
    API_name_lst = []
    leaf_stack = []
    working_queue = []
    working_queue.append(root_node)
    files_map = {}

    while len(working_queue)>0:
        tmp_node = working_queue.pop(0)
        if tmp_node.name.endswith('.py') == True:
            leaf_stack.append(tmp_node)
        working_queue.extend(tmp_node.children)

   
    for node in leaf_stack[::-1]:
        module_item_dict = parse_import(node.ast)
        if module_item_dict is None:
            continue

        for k, v in module_item_dict.items():
            if k is None or isinstance(k, int):
                continue
            dst_node = go_to_that_node(root_node, node, k)
            
            if dst_node is not None:
                if v[0] =='*':
                  for k_ch, v_ch in dst_node.cargo.items():
                      node.cargo[k_ch] = v_ch
                  k_ch_all = list(dst_node.cargo.keys())
                else:
                    
                    for api in v:
                        if api in dst_node.cargo:
                            node.cargo[api]= dst_node.cargo[api]  
            else:
                pass

    for node in leaf_stack:
        API_prefix = leaf2root(node) 
        API_prefix = API_prefix.strip('.')
        node_API_lst = make_API_full_name(node.cargo, API_prefix)
        API_name_lst.extend(node_API_lst)

    return API_name_lst

class function_node(object):
    def __init__(self,API_name,loc_name,args,args_default,filepath='*',lineno='*',namespace='*'):
        self.API_name = API_name  
        self.loc_name = loc_name  
        self.args = args
        self.args_default = args_default
        self.filepath = filepath
        self.lineno = lineno
        self.namespace = namespace

def make_API_full_name(meta_data, API_prefix):
    API_lst = []
    for k, v in meta_data.items():
       
        if isinstance(v, tuple): 
                if len(v) != 4:
                    print(f_name,v)
                    continue

                API_name = function_node(f'{API_prefix}.{k}','.'.join([v[3],k]),";".join(v[0]),v[1],v[3],v[2],'*') 
                API_lst.append(API_name.__dict__)
        elif isinstance(v, dict):
            # there is a constructor
            if '__init__' in v:
                args = v['__init__']
                if len(args) != 4:
                    # print(f_name,args)
                    continue
                API_name = function_node(f'{API_prefix}.{k}','.'.join([args[3],k]),";".join(args[0]),args[1],args[3],args[2],'*')   
                API_lst.append(API_name.__dict__)
         
            for f_name, args in v.items():
               
                    if len(args) != 4:
                        # print(f_name,args)
                        continue
                    API_name = function_node(f'{API_prefix}.{k}.{f_name}','.'.join([args[3],k,f_name]),";".join(args[0]),args[1],args[3],args[2],k) #
                    API_lst.append(API_name.__dict__)


    return API_lst
def search_targets(root_dir, targets):
     entry_points = []
     for root, dirs, files in os.walk(root_dir):
        n_found = 0
        for t in targets:
            if t in dirs :
                entry_points.append(os.path.join(root, t))
                n_found += 1
            elif t+'.py' in files:
                entry_points.append(os.path.join(root, t+'.py'))
                n_found += 1
            
        if n_found == len(targets):
            return entry_points
     return None




def process_single_module(module_path):
    API_name_lst = []
    if os.path.isfile(module_path):
        name_segments =  os.path.splitext(os.path.basename(module_path))[1] == '.py' 
        res, tree = extract_class(module_path)
        node_API_lst = make_API_full_name(res, name_segments)
        API_name_lst.extend(node_API_lst)
    else:
        first_name = os.path.basename(module_path)
        working_dir = os.path.dirname(module_path)
        path = []
        cwd = os.getcwd() 
        os.chdir(working_dir)
        root_node = Tree(first_name)
        build_dir_tree(root_node) 
        API_name_lst = tree_infer_levels(root_node)
        os.chdir(cwd) # go back cwd
    return API_name_lst

def construct_pre_annonation(client,client_path,libs,lib_path=None):

    annotations = {}
    for lib_name in libs +[client]:
        if lib_name == client:
            lib_dir = client_path
        else:
            lib_dir = os.path.join(lib_path,lib_name)
        versions = ['Latest']
        API_data = {"module":[], "API":{}, "version":[]}
        entry_points = [lib_dir]
        if entry_points is not None:
            API_data['module'] = entry_points
            API_data['API'] = []
            for ep in entry_points:
                print(ep)
                API_name_lst = process_single_module(ep)  
                if API_name_lst is None:
                    continue
                API_data['API'].extend(API_name_lst)
        
        annotations[lib_name] = {}
        for line in API_data['API']:
            if line['API_name'].split('.')[-1] == 'self_init':
                line['API_name'] = line['API_name'].replace('.self_init','')
                line['loc_name'] = line['loc_name'].replace('.self_init','')

            annotations[lib_name][line['API_name']] = line
    
    location = os.getcwd()

    if not os.path.exists(os.path.join(location,'pre_knowledge')):
        os.mkdir(os.path.join(location,'pre_knowledge'))
    
    with open(os.path.join(location,'pre_knowledge') +'/'+f'{client}_pre_annotations.json','w') as f:
        json.dump(annotations,f,indent=4)