"""
 Description: code to transform the code of a python file into list comprehension or for loop
 Date: 16-05-2023
"""

from rewriter import Rewriter
import argparse
from rules import *
from enum import Enum
from collections import deque

from helper import get_all_files, copy_folder, create_directory
import asttokens, ast


class Tag(Enum):
    COMMENT = 1
    BLOCK_COMMENT = 2
    CODE = 3
    BLANK = 4
    FOR = 5
    LIST_COMP = 6


class Block:

    def __init__(self, block_type, value=None, depth=0):
        self.block_type = block_type
        self.value = value
        self.depth = depth
        self.statements = deque()
        self.head = None  # for-block

    def add_statement(self, statement):
        self.statements.append(statement)

    def set_value(self, value):
        assert value is not None, "value must not be none"
        self.value = value

    def set_depth(self, depth):
        assert depth is not None, "depth must not be none"
        self.depth = depth

    def number_of_line_to_skip(self):
        line_to_skip = self.depth
        for stmt in self.statements:
            if stmt.block_type == Tag.COMMENT or stmt.block_type == Tag.BLOCK_COMMENT:
                line_to_skip -= 1

        return line_to_skip

    def add_head(self, head):
        self.head = head

    def set_block_tag(self, tag):
        assert tag is not None, "tag must not be none"
        self.block_type = tag

    def __str__(self):
        s = "block: " + str(self.block_type) + "\n"
        s += "value: " + str(self.value)
        s += "depth: " + str(self.depth)

        return s


def contains_bracket(line, open_list):
    assert line is not None, "line must not be none"
    assert open_list is not None, "open_list must not be none"

    open_string = False  # ignore if the brackets are in a string

    for l in line:
        if l in open_list and not open_string:
            return True
        elif l == '"' or l == "'":
            open_string = not open_string

    return False


def calcul_indent(line):
    assert line is not None, "line must not be none"
    count = 0
    for l in line:
        if l != " ":
            break
        else:
            count += 1
    return count


def is_list_comp(tree):
    assert tree is not None, "tree must not be none"

    # we know that the tree has only ONE node because
    # we parse a single instruction at the time
    node = tree.body[0]
    return isinstance(node, ast.Assign) and isinstance(node.value, ast.ListComp)


def extract_info(file):
    assert file is not None, "file must not be none"

    fp = open(file, 'r')
    lines = fp.readlines()
    fp.close()

    open_list = ["(", "{", "["]  # for hanlding case of instructions on multiple lines
    closed_list = [")", "}", "]"]  # same

    lines_removal = list()  # remove comment from the code

    blocks = deque()
    i = 0
    while i < len(lines):
        line = lines[i]
        copy_line = line[:].strip()

        # detection of comments blocks starting """
        if copy_line.startswith('"""'):

            lines_removal.append(line)
            if len(copy_line) > 3 and copy_line.endswith('"""'):
                b = Block(block_type=Tag.BLOCK_COMMENT, value=line, depth=0)
                blocks.append(b)
                i += 1
            else:
                j = i + 1
                depth = 0
                while j < len(lines) and not lines[j].endswith('"""\n') and j != i:
                    line += lines[j]
                    lines_removal.append(lines[j])
                    j += 1
                    depth += 1

                lines_removal.append(lines[j])
                # handling end of the file
                if j < len(lines):
                    line += lines[j]
                b = Block(block_type=Tag.BLOCK_COMMENT, value=line, depth=depth)
                blocks.append(b)
                i = j + 1

        # detection of comments blocks starting '''
        elif copy_line.startswith("'''"):

            lines_removal.append(line)
            if len(copy_line) > 3 and copy_line.endswith("'''"):
                b = Block(block_type=Tag.BLOCK_COMMENT, value=line, depth=0)
                blocks.append(b)
                i += 1
            else:
                j = i + 1
                depth = 0
                while j < len(lines) and not lines[j].endswith("'''\n") and j != i:
                    line += lines[j]
                    lines_removal.append(lines[j])
                    j += 1
                    depth += 1

                lines_removal.append(lines[j])
                # handling end of the file
                if j < len(lines):
                    line += lines[j]
                b = Block(block_type=Tag.BLOCK_COMMENT, value=line, depth=depth)
                blocks.append(b)
                i = j + 1

        # detection of for block
        elif copy_line.startswith("for "):  # added space here because there are also fonctions that can start with the
            # the keyword for e.g: forward_tunnel(...)
            indentation = calcul_indent(lines[i])
            depth = 0
            j = i + 1

            # getting the head of the block if it is on
            # multiples lines
            if not line.endswith(":\n"):
                while True:
                    if lines[j].endswith(":\n"):
                        line += lines[j]
                        j += 1
                        break
                    line += lines[j]
                    j += 1

            b = Block(block_type=Tag.FOR)
            b.head = line
            line = ""

            # getting the body of the block
            while True:
                if j < len(lines) and lines[j] == "\n":
                    inside_b = Block(block_type=Tag.BLANK, value=lines[j], depth=0)
                    line += lines[j]
                    b.add_statement(inside_b)
                elif j >= len(lines) or indentation >= calcul_indent(lines[j]):
                    break
                else:
                    line += lines[j]
                    c_l = lines[j][:].strip()
                    if c_l.startswith("#"):
                        inside_b = Block(block_type=Tag.COMMENT, value=lines[j], depth=0)
                        b.add_statement(inside_b)
                    elif contains_bracket(line=c_l, open_list=open_list):
                        stack = list()  # stack
                        unresolved_symbol = [0] * len(open_list)
                        index = -1

                        open_string = False  # to test
                        for l in lines[j]:
                            if l in open_list and not open_string:
                                index = open_list.index(l)
                                unresolved_symbol[index] += 1
                            elif l in closed_list and not open_string:
                                unresolved_symbol[index] -= 1
                            elif l == '"' or l == "'":
                                open_string = not open_string
                            stack.append(l)

                        if sum(unresolved_symbol) > 0:
                            j += 1
                            while j < len(lines) and sum(unresolved_symbol) > 0:
                                for l in lines[j]:
                                    if l in open_list and not open_string:
                                        index = open_list.index(l)
                                        unresolved_symbol[index] += 1
                                    elif l in closed_list and not open_string:
                                        unresolved_symbol[index] -= 1
                                    elif l == '"':
                                        open_string = not open_string
                                    stack.append(l)
                                j += 1
                            j -= 1  # trick to stay at the god line

                        line_t = "".join(stack)
                        # being able to decide wether or it is list comprehension
                        try:
                            temp_tree = ast.parse(line_t.lstrip())
                            inside_b = Block(block_type=Tag.CODE, value=line_t, depth=0)
                            if is_list_comp(tree=temp_tree):
                                inside_b.set_block_tag(tag=Tag.LIST_COMP)
                            b.add_statement(inside_b)
                        except Exception as error:
                            message = str(error) + " " + file
                            inside_b = Block(block_type=Tag.CODE, value=line_t, depth=0)
                            b.add_statement(inside_b)
                    else:
                        inside_b = Block(block_type=Tag.CODE, value=lines[j], depth=0)
                        b.add_statement(inside_b)
                j += 1
                depth += 1

            b.set_value(line)
            b.set_depth(depth)
            blocks.append(b)
            i = j
        # simple comments
        elif copy_line.startswith("#"):
            b = Block(block_type=Tag.COMMENT, value=lines[i], depth=1)
            blocks.append(b)
            lines_removal.append(line)
            i += 1

        # handling the case when same instruction is on multiple lines
        elif contains_bracket(line=lines[i], open_list=open_list):
            stack = list()  # stack
            unresolved_symbol = [0] * len(open_list)
            index = -1

            open_string = False  # to test

            for l in lines[i]:
                if l in open_list and not open_string:
                    index = open_list.index(l)
                    unresolved_symbol[index] += 1
                elif l in closed_list and not open_string:
                    unresolved_symbol[index] -= 1
                elif l == '"':
                    open_string = not open_string

                stack.append(l)

            i += 1

            while i < len(lines) and sum(unresolved_symbol) > 0:
                for l in lines[i]:
                    if l in open_list and not open_string:
                        index = open_list.index(l)
                        unresolved_symbol[index] += 1
                    elif l in closed_list and not open_string:
                        unresolved_symbol[index] -= 1
                    elif l == '"':
                        open_string = not open_string

                    stack.append(l)

                i += 1
            line = "".join(stack)

            # being able to decide wether or it is list comprehension
            try:
                temp_tree = ast.parse(line.lstrip())
                b = Block(block_type=Tag.CODE, value=line, depth=0)
                if is_list_comp(tree=temp_tree):
                    b.set_block_tag(tag=Tag.LIST_COMP)
                blocks.append(b)
            except Exception as error:
                message = str(error) + " " + file
                b = Block(block_type=Tag.CODE, value=line, depth=0)
                blocks.append(b)

        # detecting strings on multiple lines
        elif '"""' in lines[i]:
            j = i + 1
            depth = 0

            while j < len(lines) and not lines[j].endswith('"""\n') and j != i:
                line += lines[j]
                j += 1
                depth += 1

            # handling end of the file
            if j < len(lines):
                line += lines[j]
            b = Block(block_type=Tag.CODE, value=line, depth=depth)
            blocks.append(b)
            i = j + 1

        # Blank line
        elif len(copy_line) == 0:
            b = Block(block_type=Tag.BLANK, value=lines[i], depth=1)
            blocks.append(b)
            i += 1

        # code
        else:
            b = Block(block_type=Tag.CODE, value=lines[i], depth=1)
            blocks.append(b)
            i += 1

    # helper function
    def search_comments(block):
        comments = []
        if block.block_type == Tag.COMMENT or block.block_type == Tag.BLOCK_COMMENT:
            comments.append(block.value)
        elif block.block_type == Tag.FOR:
            for stmt in block.statements:
                comments += search_comments(stmt)

        return comments

    comments = []
    for block in blocks:
        comments += search_comments(block)

    # remove all comments from the original code
    for line in lines_removal:
        lines.remove(line)

    # for debugging only
    for block in blocks:
        print(block)
        print("----------------------------------")

    return lines, blocks, comments


def conversion(file, out_file):
    assert file is not None, "file must not be none"
    assert out_file is not None, "out file must not be none"

    code, blocks, comments = extract_info(file=file)

    # To handle eventual failure of the code
    error_file = open("./logs_error.txt", "a")
    try:
        src = "".join(code)

        # TODO pass the rules as paramater for more flexibility
        r = Rewriter(src=src, rule_func=list_comp_rewrite_rules)
        new_code, stmts = r.rewrite()

        for comment in comments:
            new_code = new_code.replace(comment, "")

        lines_of_codes = new_code.split("\n")
        # removing possible comments in codes because how the unparse methods work
        codes = list(filter(lambda l: len(l) > 0
                                      and (not l.strip().startswith("'") or not l.strip().startswith('"'))
                            , lines_of_codes))

        lines = []
        while len(codes) > 0:

            block = blocks.popleft()

            if block.block_type == Tag.CODE:
                line = codes.pop(0)
                line = line + "\n"
                lines.append(line)
            elif block.block_type == Tag.BLANK:
                lines.append("\n")
            elif block.block_type == Tag.COMMENT or block.block_type == Tag.BLOCK_COMMENT:
                comment = comments.pop(0)
                lines.append(comment)

            # handling of list comp successfully or unsucessfully transformed
            elif block.block_type == Tag.LIST_COMP:
                declaration = codes.pop(0) + "\n"
                if not is_list_comp(ast.parse(declaration.lstrip())):
                    line = codes.pop(0)
                    copy_line = line[:].strip()

                    indentation = calcul_indent(line)

                    if copy_line.startswith("for"):
                        line = declaration + line + "\n"
                        while True:
                            if len(codes) == 0 or calcul_indent(codes[0]) <= indentation:
                                break
                            else:
                                line += codes.pop(0) + "\n"
                    else:
                        line = block.value

                lines.append(line)

            # Handling for loop sucessfully or unsucessfully transformed
            elif block.block_type == Tag.FOR:
                line = codes.pop(0) + "\n"
                copy_line = line[:].strip()

                if copy_line.startswith("for"):
                    # case when something else happened
                    lines.append(line)
                    for stmt in block.statements:
                        if stmt.block_type == Tag.LIST_COMP:
                            declaration = codes.pop(0) + "\n"
                            if not is_list_comp(ast.parse(declaration.lstrip())):
                                line_stmt = codes.pop(0)
                                copy_stmt = line_stmt[:].strip()

                                if copy_stmt.startswith("for"):
                                    indentation = calcul_indent(line_stmt)
                                    line = declaration + line_stmt + "\n"
                                    while True:
                                        if len(codes) == 0 or calcul_indent(codes[0]) <= indentation:
                                            break
                                        else:
                                            line += codes.pop(0) + "\n"
                                    lines.append(line)
                                else:
                                    line = stmt.value
                        elif stmt.block_type == Tag.CODE:
                            line = codes.pop(0) + "\n"
                            lines.append(line)
                        elif stmt.block_type == Tag.COMMENT or stmt.block_type == Tag.BLOCK_COMMENT:
                            comment = comments.pop(0)
                            lines.append(comment)
                        elif stmt.block_type == Tag.BLANK:
                            lines.append("\n")


                else:
                    # case when the transformation is sucessfull
                    for stmt in block.statements:
                        if stmt.block_type == Tag.COMMENT or stmt.block_type == Tag.BLOCK_COMMENT:
                            comment = comments.pop(0)
                            space = calcul_indent(line)
                            comment = comment.lstrip()
                            comment = space * " " + comment
                            lines.append(comment)

                    line = line + "\n"
                    lines.append(line)

        # taking the remaing commments
        while len(blocks) > 0:
            block = blocks.popleft()
            if block.block_type == Tag.COMMENT or block.block_type == Tag.BLOCK_COMMENT:
                comment = comments.pop(0)
                lines.append(comment)
            elif block.block_type == Tag.BLANK:
                lines.append("\n")

        fp = open(out_file, 'w')
        print("".join(new_code), file=fp)
        fp.close()
    except Exception as error:
        message = str(error) + " " + file
        print(message, file=error_file)
    error_file.close()


def main():
    parser = argparse.ArgumentParser(
        description='Python AST parser')

    parser.add_argument("--file",
                        type=str,
                        nargs='?',
                        dest='file',
                        help="path of the file to parse.")

    parser.add_argument("--out",
                        type=str,
                        nargs='?',
                        dest='out_file',
                        help="path of the result file.")

    parser.add_argument("--p",
                        type=str,
                        nargs='?',
                        dest='project',
                        help="project to analyze")

    parser.add_argument("--t",
                        type=bool,
                        nargs='?',
                        dest='test',
                        help="running test")

    args = parser.parse_args()

    if args.file is None and args.out_file is None and args.project is None:
        raise Exception("arguments missing ...")

    elif args.file is not None and args.out_file is not None:
        conversion(file=args.file, out_file=args.out_file)

    elif args.project is not None:
        if args.test is True:
            arr = args.project.split("/")
            cp = "./" + arr[len(arr) - 1]
            create_directory(path_to_directory=cp)
            copy_folder(folder=args.project, toDirectory=cp)
            files = get_all_files(directory=cp, extension=".py")

            for file in files:
                conversion(file=file, out_file=file)

        else:
            files = get_all_files(directory=args.project, extension=".py")
            for f in files:
                arr = f.split("/")
                out_file = "./csv/" + arr[len(arr) - 1]
                conversion(file=f, out_file=out_file)


if __name__ == '__main__':
    main()
