import ast
import testCases
from model import Scope, ScopeType, Symbol
from collections import deque


class SymbolTable(ast.NodeVisitor):
    """
        The symbol table contains for each symbol the set of all scope
        where each symbol will be potentially used
    """

    def __init__(self, tree):
        self.tree = tree
        self.currentScope = Scope(id_scope=0, scope_type=ScopeType.MAIN)
        self.symbolTable = dict()
        self.scope_order = deque([])
        self.scope_lst = [self.currentScope]
        self.numScope = 1

    def __del__(self):
        self.tree = None
        self.currentScope = None
        self.scope_order.clear()
        self.symbolTable.clear()

    def __str__(self):
        s = ""
        for key, value in self.symbolTable.items():
            s += "symbol:{key} {value}".format(key=key, value=value) + "\n"

        return s.rstrip()

    def put(self, key: Symbol, value: Scope):
        if key not in self.symbolTable.keys():
            self.symbolTable[key] = [value]
        else:
            self.symbolTable[key].append(value)

    def get(self, key: Symbol) -> list[Scope]:
        if key in self.symbolTable.keys():
            return self.symbolTable[key]
        return []

    def beginScope(self, scope_type):
        # preserve the order of the scope by pushing on the stack of scopes
        self.scope_order.append(self.currentScope)

        # creating a new scope
        new_id = self.getNumScop()
        new_scope = Scope(id_scope=new_id, scope_type=scope_type)

        # adding to the list of scope
        self.scope_lst.append(new_scope)

        # incrementing the number of scope
        self.numScope += 1

        return new_scope

    def endScope(self):
        self.currentScope = self.scope_order.pop()

    def keys(self):
        for key in self.symbolTable.keys():
            yield key

    def getNumScop(self):
        return self.numScope

    def getScopes(self):
        for scope in self.scope_lst:
            yield scope

    @staticmethod
    def addLink(parent: Scope, child: Scope):
        parent.innerScopes.append(child)
        child.parent = parent

    def build(self):
        self.visit(self.tree)

    @staticmethod
    def getNames(node):
        names = []
        if isinstance(node, ast.Name):
            names.append(node)
        elif isinstance(node, ast.Tuple):
            for element in node.elts:
                if isinstance(element, ast.Name):
                    names.append(element)
        else:
            raise Exception("Another case to cover at line 93")
        return names

    def visit_Assign(self, node: ast.Assign):
        targets = node.targets

        if len(targets) > 1 or not isinstance(targets[0], ast.Name):
            return

        assert len(targets) == 1, "Assign node should only have 1 target for now"
        target = targets[0]
        assert isinstance(target, ast.Name), "Target should be ast name for now"

        name = target.id
        value = node.value

        symbol = Symbol.symbol(n=name)
        # adding the symbol into the symbol table
        self.put(key=symbol, value=self.currentScope)

        # adding the instruction to the scope
        self.currentScope.insert_instruction(symbol=symbol, lineno=node.lineno, node=value)

        # adding the variable as the local var of the scope
        self.currentScope.insert_var(symbol=symbol, lineno=node.lineno)

    def visit_For(self, node: ast.For):
        # need to start a new scope
        for_scope = self.beginScope(scope_type=ScopeType.FOR)

        # adding link
        self.addLink(parent=self.currentScope, child=for_scope)

        # adding the for as an instruction for the current scope
        symbol = Symbol.symbol(n="FOR")
        self.currentScope.insert_instruction(symbol=symbol, lineno=node.lineno, node=for_scope)

        self.currentScope = for_scope

        target, iter = node.iter, node.target
        names = self.getNames(node=iter)

        for name in names:
            symbol = Symbol.symbol(n=name.id)

            # because it is scope for we need to specify the iter variable
            self.currentScope.iter_var = symbol

            # adding the instruction into the current scope
            self.currentScope.insert_instruction(symbol=symbol, lineno=target.lineno, node=iter)

            # adding inserting the local variable
            self.currentScope.insert_var(symbol=symbol, lineno=node.lineno)

            # inserting the symbol into the symbol table
            self.put(key=symbol, value=self.currentScope)

        # visiting the child of the node to add them to the symbol table too
        for child in node.body:
            self.visit(child)

        self.endScope()

    def visit_Call(self, node: ast.Call):

        func = node.func
        args = node.args

        if not isinstance(func, ast.Attribute) or not isinstance(func.value, ast.Name):
            return

        assert isinstance(func, ast.Attribute), "nodes other than ast attribute are not handle yet"
        value = func.value
        assert isinstance(value, ast.Name), "nodes other than ast name are not handle yet"

        name = value.id
        symbol = Symbol.symbol(n=name)

        # adding the symbol in the symbol table
        self.put(key=symbol, value=self.currentScope)

        # adding the args
        for arg in args:
            # TODO need to handle more case depending on the type of the argument
            # TODO it can a call to a function or Binop too
            if isinstance(arg, ast.Name):
                s = Symbol.symbol(n=arg.id)
                self.put(key=s, value=self.currentScope)
            elif isinstance(arg, ast.BinOp):
                left = arg.left
                right = arg.right

                if isinstance(left, ast.Name):
                    s = Symbol.symbol(n=left.id)
                    self.put(key=s, value=self.currentScope)

                if isinstance(right, ast.Name):
                    s = Symbol.symbol(n=right.id)
                    self.put(key=s, value=self.currentScope)

        # adding instruction in the current scope
        self.currentScope.insert_instruction(symbol=symbol, lineno=node.lineno, node=node)

    def visit_If(self, node: ast.If):
        # creating necessary scope if scopes
        if_scope = self.beginScope(scope_type=ScopeType.IF)

        # adding instruction
        symbol = Symbol.symbol(n="IF")
        self.currentScope.insert_instruction(symbol=symbol, lineno=node.lineno, node=if_scope)

        # adding link # TODO to change
        self.addLink(parent=self.currentScope, child=if_scope)

        self.currentScope = if_scope
        test = node.test
        self.visit(test)

        for child in node.body:
            self.visit(child)

        # checking the block else
        orelse = node.orelse
        if len(orelse) > 0:
            lineno = node.lineno + len(self.currentScope.instructions) + 1
            self.endScope()
            else_scope = self.beginScope(scope_type=ScopeType.ELSE)
            symbol = Symbol.symbol(n="ELSE")
            self.currentScope.insert_instruction(symbol=symbol, lineno=lineno, node=else_scope)
            self.addLink(parent=self.currentScope, child=else_scope)
            self.currentScope = else_scope

            for child in orelse:
                self.visit(child)
            self.endScope()  # main block
        else:
            self.endScope()

    def visit_FunctionDef(self, node: ast.FunctionDef):
        # creating a new scope for the function definition
        function_scope = self.beginScope(scope_type=ScopeType.FUNCTION_DEF)

        # adding the instruction
        symbol = Symbol.symbol(n="FUNCTION_DEF")
        self.currentScope.insert_instruction(symbol=symbol, lineno=node.lineno, node=function_scope)

        self.currentScope = function_scope

        for child in node.body:
            self.visit(child)

        # ending the scope of the function def
        self.endScope()


def main():
    src = testCases.wait_what
    tree = ast.parse(src)

    symbTable = SymbolTable(tree=tree)
    symbTable.build()

    scopes = symbTable.getScopes()
    for scope in scopes:
        print(scope)


if __name__ == '__main__':
    main()
