import ast
import subprocess

from rules import *
from source import Source
from unit import *
from testCases import *

from helper import *
import copy

"""inspired by https://github.com/SMAT-Lab/Scalpel/blob/d0df274dea179d022fe5305ce2c130a9d2c22ee1/scalpel/core/util
.pyL145 
Modified by Ossim Belias """


class Rewriter:
    def __init__(self, src, rule_func=None):
        self.src = src
        self.tree = ast.parse(src)
        self.rule_func = rule_func

    def get_source(self):
        """
        :return: the old source
        """
        return ast.unparse(self.tree)

    def rewrite(self, line=None):
        if self.rule_func is None:
            raise Exception("rule_func can not be None type!")

        source = Source(tree=self.tree)
        tree = extend_ast_tree(self.tree)

        stmts = list()

        for unit in tree:
            if line is not None:
                old_begin, old_end, new_begin, new_end = line
                if new_begin != unit.node.lineno or new_end != unit.node.end_lineno:
                    continue
            new_stmts, rules = self.rule_func(node=unit.node, source=source)

            if not isinstance(new_stmts, list):
                raise Exception("The return type of rule func should be a list")

            if len(new_stmts) > 0:
                t = [(rules[i], new_stmts[i]) for i in range(len(new_stmts))]
                stmts.append((unit.node, t))
                unit.insert_stmts_before(new_stmts)

        new_root = ast.fix_missing_locations(self.tree)
        new_src = ast.unparse(new_root)
        return new_src, stmts

    # patch of the rewriter function
    # it helps save the comment of a file
    def rewrite_v2(self, line=None):
        if self.rule_func is None:
            raise Exception("rule_func can not be None type!")

        source = Source(tree=self.tree)
        tree = extend_ast_tree(self.tree)

        stmts = list()
        nodes = list()

        for unit in tree:
            if line is not None:
                begin, end = line
                if begin != unit.node.lineno or end != unit.node.end_lineno:
                    continue
            new_stmts = self.rule_func(node=unit.node, source=source)

            if not isinstance(new_stmts, list):
                raise Exception("The return type of rule func should be a list")

            if len(new_stmts) > 0:
                stmts.append((unit.node, new_stmts))
                unit.insert_stmts_before(new_stmts)

                for new_stmt in new_stmts:
                    nodes.append(((unit.node.lineno, unit.node.end_lineno), new_stmt))

        new_root = ast.fix_missing_locations(self.tree)
        new_src = ast.unparse(new_root)
        return new_src, stmts, nodes


def get_number_passed_test(text):
    assert text is not None, "text must not be none"

    lines_text = text.split('\n')
    lines_text = list(filter(lambda x: len(x) > 0, lines_text))
    size = len(lines_text)
    last_line = lines_text[size - 1]
    last_line = last_line.replace("=", "")
    arr = last_line.split()

    for i in range(len(arr)):
        if "passed" in arr[i + 1]:
            number = arr[i]
            return number

    return 0


def remove_string(src, str_to_remove):
    assert src is not None, "src must not be none"
    assert str_to_remove is not None, "string to remove must not be none"

    src = src.replace(str_to_remove, "")
    return src


def rewrite(info, file, rule_func, n):
    assert info is not None, "info must not be none"
    assert file is not None, "code to transform must not be none"
    assert rule_func is not None, "missing transformation rules"

    tree = None
    source = None

    has_for = False

    fp = open(file, 'r')
    code = fp.readlines()

    function = "def log(file, message):\n\tfp = open(file, \'a\')\n\tprint(message, file=fp)\n\tfp.close()\n"
    code.insert(0, function)
    src = "".join(code)

    try:
        tree = ast.parse(src)
        source = Source(tree)
    except Exception as error:
        row = [file, str(error)]
        write_row_in_csv(filename="exclude_files.csv", row=row)
        print(error)

    t = extend_ast_tree(tree)
    stmts = list()

    copy_tree = copy.deepcopy(tree)

    if tree is None:
        print(f"could not parse {file}")
        return

    # parallelize this
    for unit in t:
        if isinstance(unit.node, ast.For):
            has_for = True
            info["tot_for"] += 1

        copy_node = copy.deepcopy(unit.node)

        try:
            new_stmts, rule = rule_func(node=unit.node, source=source)

            if not isinstance(new_stmts, list):
                raise Exception("The return type of rule func should be a list")

            if len(new_stmts) > 0:
                info["tot_for_transf"] += 1
                path = os.getcwd()
                filename = f"\'{path}/test_cases.csv\'"
                s = ",".join([file, "passed", str((unit.node.lineno, unit.node.end_lineno))])
                message = f"\'{s}\'"
                t = [(rule, new_stmts[i]) for i in range(len(new_stmts))]
                cmd_1 = f"log(file={filename}, message={message})"
                node = ast.parse(cmd_1)
                log_node = ast.fix_missing_locations(node)
                new_stmts.append(log_node)
                stmts.append((unit.node, t))
                unit.insert_stmts_before(new_stmts)

                # testing
                new_root = ast.fix_missing_locations(tree)
                new_src = ast.unparse(new_root)
                fp = open(file, "w")
                print(new_src, file=fp)
                fp.close()

                pytest = "pytest"
                cmd = " ".join([f'cd {info["system"]};', f'{pytest}'])
                reference = subprocess.run(cmd, shell=True, capture_output=True)
                number_test_passed = get_number_passed_test(text=reference.stdout.decode())

                if n != number_test_passed:
                    # test failed go back to the last updated code
                    tree = copy.deepcopy(copy_tree)
                    row = [file, str(ast.unparse(unit.node)),
                           str((unit.node.lineno, unit.node.end_lineno)), str(0), rule]
                    write_row_in_csv(filename="parsed.csv", row=row)

                else:
                    # test passed good just remove the call log function
                    tree = Remove("log").visit(tree)
                    tree = ast.fix_missing_locations(tree)
                    copy_tree = copy.deepcopy(tree)

                    row = [file, str(ast.unparse(unit.node)),
                           str((unit.node.lineno, unit.node.end_lineno)), str(1), rule]
                    write_row_in_csv(filename="parsed.csv", row=row)

                    info["tot_for_transf_test_ok"] += 1

        except Exception as error:
            # case when the transformation generate error
            tree = copy.deepcopy(copy_tree)
            row = [str(ast.unparse(copy_node)), str(error)]
            write_row_in_csv(filename="not_parse.csv", row=row)
            print(error)

    if has_for:
        info["number_of_files_with_for"] += 1

    new_root = ast.fix_missing_locations(tree)
    new_src = ast.unparse(new_root)

    # removing all added things
    s = new_src.split("\n")
    s = s[4:]
    new_src = "\n".join(s)

    fp = open(file, "w")
    print(new_src, file=fp)
    fp.close()


class Remove(ast.NodeTransformer):

    def __init__(self, func):
        self.func = func

    def visit_Expr(self, node):
        try:
            if isinstance(node.value, ast.Call):
                n = node.value
                if isinstance(n.func, ast.Name) and n.func.id == self.func:
                    return None

        except Exception as error:
            print(f"visit expr {error}")

        return node


def main():
    src = open_file(file_path="./tests/rule_1.py")
    tree = ast.parse(src)

    tree = Remove("log").visit(tree)
    new_root = ast.fix_missing_locations(tree)
    print(ast.unparse(new_root))
    print("============================================")


    # rewrite(file="./tests/rule_1.py", rule_func=for_rewrite_rules)


if __name__ == '__main__':
    main()
