#! /usr/bin/env python

import os
import sys
import re
import logging
from importlib import reload
from z3 import Tactic, And, Exists, Int
from lexert2lasso import lexer_main
from lexert2lasso import *
# import subprocess
import ply.yacc as yacc
import parserFormula
import parserAnant

z3Vars = []
namesToZ3VarsIndex = {}
names = set([])
namesUpdateIndexForStem = {}
namesUpdateIndexForCycle = {}
isStem = True
stemRelation = []
cycleRelation = []
quantifiedVarsStem = []
quantifiedVarsCycle = []


class MyError(Exception):
    def __init__(self, value):
        self.value = value

    def __str__(self):
        return repr(self.value)


def p_program_lasso(p):
    'program : STEMIS stem CYCLEIS cycle'
    pass


def p_stem(p):
    'stem : statement_list'
    global isStem
    # print "parsed stem"
    isStem = False
    for var in namesUpdateIndexForStem.keys():
        vfipost = namesToZ3VarsIndex[var + "_post"]
        if namesUpdateIndexForStem[var] != 0:
            updateIndex = namesUpdateIndexForStem[var]
            varFinal = var + str(updateIndex)
            vfi = namesToZ3VarsIndex[varFinal]
        else:
            varFinal = var
            vfi = namesToZ3VarsIndex[varFinal]
        stemRelation.append(z3Vars[vfipost] == z3Vars[vfi])


def p_cycle(p):
    "cycle : statement_list"
    # print "parsed cycle"
    for var in namesUpdateIndexForCycle.keys():
        vfipost = namesToZ3VarsIndex[var + "_post"]
        if namesUpdateIndexForCycle[var] != 0:
            updateIndex = namesUpdateIndexForCycle[var]
            varFinal = var + str(updateIndex)
            vfi = namesToZ3VarsIndex[varFinal]
        else:
            varFinal = var
            vfi = namesToZ3VarsIndex[varFinal]
        cycleRelation.append(z3Vars[vfipost] == z3Vars[vfi])


def p_statement_list_base(p):
    'statement_list : statement'
    pass


def p_statement_list_recurrence(p):
    'statement_list : statement_list statement'
    pass


def p_statement_skip(p):
    "statement : SKIP ';'"
    pass


def p_statement_assign(p):
    "statement : NAME ':' '=' expression ';' "
    global isStem
    if p[1] not in names:
        print("Parsing Error : Variable {} is Undeclared.".format(p[1]))
        sys.exit()

    if isStem:
        namesUpdateIndexForStem[p[1]] += 1
        updateIndex = str(namesUpdateIndexForStem[p[1]])
    else:
        namesUpdateIndexForCycle[p[1]] += 1
        updateIndex = str(namesUpdateIndexForCycle[p[1]])
    name = p[1] + updateIndex

    try:
        i = namesToZ3VarsIndex[name]
    except LookupError:
        z3Vars.append(Int(name))
        i = len(z3Vars) - 1
        namesToZ3VarsIndex[name] = i

    if isStem:
        quantifiedVarsStem.append(z3Vars[i])
    else:
        quantifiedVarsCycle.append(z3Vars[i])
    p[0] = z3Vars[i] == p[4]
    if isStem:
        stemRelation.append(p[0])
    else:
        cycleRelation.append(p[0])


def p_statement_random(p):
    "statement : NAME ':' '=' NONDET '(' ')' ';' "
    global isStem
    if p[1] not in names:
        print("Parsing Error : Variable {} is Undeclared.".format(p[1]))
        sys.exit()

    if isStem:
        namesUpdateIndexForStem[p[1]] += 1
        updateIndex = str(namesUpdateIndexForStem[p[1]])
    else:
        namesUpdateIndexForCycle[p[1]] += 1
        updateIndex = str(namesUpdateIndexForCycle[p[1]])
    name = p[1] + updateIndex

    try:
        i = namesToZ3VarsIndex[name]
    except LookupError:
        z3Vars.append(Int(name))
        i = len(z3Vars) - 1
        namesToZ3VarsIndex[name] = i

    if isStem:
        quantifiedVarsStem.append(z3Vars[i])
    else:
        quantifiedVarsCycle.append(z3Vars[i])


def p_statement_assume(p):
    "statement : ASSUME boolexpression ';' "
    if isStem:
        stemRelation.append(p[2])
    else:
        cycleRelation.append(p[2])


def p_expression_binop(p):
    '''expression : expression '+' expression
                  | expression '-' expression
                  | expression '*' expression
                  | expression '/' expression'''
    if p[2] == '+':
        p[0] = p[1] + p[3]
    elif p[2] == '-':
        p[0] = p[1] - p[3]
    elif p[2] == '*':
        p[0] = p[1] * p[3]
    elif p[2] == '/':
        p[0] = p[1] / p[3]


def p_expression_uminus(p):
    "expression : '-' expression %prec UMINUS"
    p[0] = -p[2]


def p_expression_group(p):
    "expression : '(' expression ')'"
    p[0] = p[2]


def p_expression_coef_var(p):
    "expression : NUMBER '*' NAME"
    global isStem
    if p[3] not in names:
        print("Parsing Error : Variable {} is Undeclared.".format(p[3]))
        sys.exit()
    if isStem:
        updateIndex = namesUpdateIndexForStem[p[3]]
    else:
        updateIndex = namesUpdateIndexForCycle[p[3]]
    name = p[3]
    if updateIndex != 0:
        name = name + str(updateIndex)
    z3VarIndex = namesToZ3VarsIndex[name]
    p[0] = p[1] * z3Vars[z3VarIndex]


def p_expression_number(p):
    "expression : NUMBER"
    p[0] = p[1]


def p_expression_name(p):
    "expression : NAME"
    global isStem
    if p[1] not in names:
        print("Parsing Error : Variable {} is Undeclared.".format(p[1]))
        sys.exit()
    if isStem:
        updateIndex = namesUpdateIndexForStem[p[1]]
    else:
        updateIndex = namesUpdateIndexForCycle[p[1]]
    name = p[1]
    if updateIndex != 0:
        name = name + str(updateIndex)
    z3VarIndex = namesToZ3VarsIndex[name]
    p[0] = z3Vars[z3VarIndex]


def p_bool_expression_paranthesis(p):
    '''boolexpression : '(' boolexpression ')' '''
    p[0] = p[2]


def p_bool_expression_and(p):
    '''boolexpression : boolexpression AND boolexpression'''
    p[0] = And(p[1], p[3])


def p_bool_expression(p):
    '''boolexpression : constraint'''
    p[0] = p[1]


def p_constraint_1(p):
    '''constraint : expression EQ expression '''
    p[0] = p[1] == p[3]


def p_constraint_2(p):
    '''constraint : expression LE expression '''
    p[0] = p[1] <= p[3]


def p_constraint_3(p):
    '''constraint : expression GE expression '''
    p[0] = p[1] >= p[3]


def p_constraint_4(p):
    '''constraint : expression LT expression '''
    p[0] = p[1] < p[3]


def p_constraint_5(p):
    '''constraint : expression GT expression '''
    p[0] = p[1] > p[3]


def p_error(p):
    print("Error while Parsing T2 Lasso")
    if p:
        print("Syntax error at '%s': line= %d" % (p.value, p.lineno))
    else:
        print("Syntax error at EOF")
    sys.exit()


# Although all the variables are about a cycle, the function can be invoked
# on a stem too.
def getGuardsAssignments(cycleRelation, quantifiedVarsCycle, isCycle):
    t_tactic = Tactic('qe')
    count = len(cycleRelation)
    formula = cycleRelation[0]
    for i in range(1, count):
        try:
            formula = And(formula, cycleRelation[i])
        except Exception:
            logging.exception("Error in getGuardsAssignments")
            sys.exit()

    quantifierFreeFormulaGuards = []
    if quantifiedVarsCycle:
        quantifiedFormula = Exists(quantifiedVarsCycle, formula)
        quantifierFreeFormulaCycle = t_tactic(quantifiedFormula)[0]
        formula = quantifierFreeFormulaCycle[0]
        for i in range(1, len(quantifierFreeFormulaCycle)):
            formula = And(formula, quantifierFreeFormulaCycle[i])
    else:
        quantifierFreeFormulaCycle = cycleRelation

    quantifiedPostVars = []
    for name in names:
        vfipost = namesToZ3VarsIndex[name + "_post"]
        quantifiedPostVars.append(z3Vars[vfipost])
    # quantifiedVarsCycle.extend(quantifiedPostVars)
    quantifiedFormulaGuards = Exists(quantifiedPostVars, formula)
    quantifierFreeFormulaGuards = t_tactic(quantifiedFormulaGuards)[0]

    correctFormulaCycle = list(quantifierFreeFormulaGuards)
    for formula in quantifierFreeFormulaCycle:
        preVarFound = False
        postVarFound = False
        fStr = str(formula)  # formula.__str__().encode("ascii")
        fParser = parserFormula.parserFormula_main()
        cstrList = fParser.parse(fStr)

        for i in range(0, len(cstrList)):
            # It is confusing whether the following for loop for preVars is
            # actually needed.
            for var in names:
                for key in cstrList[i].keys():
                    # We may have a key like i * j
                    if var in key:
                        preVarFound = True
            for var in names:
                varpost = var + '_post'
                for key in cstrList[i].keys():
                    if varpost in key:
                        postVarFound = True

        # It is confusing what the if condition should be.
        if preVarFound is True and postVarFound is True:
            # if postVarFound == True:
            correctFormulaCycle.append(formula)

    correctFormulaCycleAssignments = []  # list(quantifierFreeFormulaGuards)
    assignmentsCycleDict = {}
    for name in names:
        quantifiedPostVars = []
        assignmentsCycleDict[name] = {}
        for otherName in names:
            if otherName != name:
                vfipost = namesToZ3VarsIndex[otherName + "_post"]
                quantifiedPostVars.append(z3Vars[vfipost])

        formula = correctFormulaCycle[0]
        for i in range(1, len(correctFormulaCycle)):
            formula = And(formula, correctFormulaCycle[i])
        if quantifiedPostVars:
            qfCyc = Exists(quantifiedPostVars, formula)
            qfFreeCyc = t_tactic(qfCyc)[0]
        else:
            qfFreeCyc = correctFormulaCycle

        for formula in qfFreeCyc:
            matchFoundForCurrentFormula = False  # Default False.
            fStr = str(formula)  # formula.__str__().encode("ascii")
            fParser = parserFormula.parserFormula_main()
            cstrList = fParser.parse(fStr)
            namePost = name + "_post"
            if namePost not in cstrList[0].keys():
                # This Formula is not an assignment of namePost.
                continue
            if len(cstrList) == 2:
                # This Formula is not an assignment of namePost.
                correctFormulaCycleAssignments.append(formula)
                d = {}
                for key in cstrList[0].keys():
                    if key == namePost:
                        assert cstrList[0][
                            key] == 1, "post var coefficient not 1"
                        continue
                    if key != 'rhsConstant':
                        coeff = cstrList[0][key] * -1
                        d[key] = coeff
                    else:
                        coeff = cstrList[0][key]
                        d[key] = coeff
                assignmentsCycleDict[name] = d
                break

            # Compare current formula against every other formula in qfFreeCyc.
            for formulaMatch in qfFreeCyc:
                # fStrMatch = formulaMatch.__str__().encode("ascii")
                fStrMatch = str(formulaMatch)
                if fStrMatch == 'True':
                    continue
                cstrListForMatch = fParser.parse(fStrMatch)
                if len(cstrListForMatch) == 2:
                    continue
                isMatch = True  # Default True.
                for key in cstrList[0].keys():
                    if key not in cstrListForMatch[0]:
                        isMatch = False
                        break
                    if cstrListForMatch[0][key] != -1 * cstrList[0][key]:
                        isMatch = False
                        break

                if isMatch:
                    matchFoundForCurrentFormula = True
                    correctFormulaCycleAssignments.append(formula)
                    d = {}
                    for key in cstrList[0].keys():
                        if key == namePost:
                            assert cstrList[0][
                                key] == 1, "post var coefficient not 1"
                            continue
                        if key != 'rhsConstant':
                            coeff = cstrList[0][key] * -1
                            d[key] = coeff
                        else:
                            coeff = cstrList[0][key]
                            d[key] = coeff
                    assignmentsCycleDict[name] = d
                    break
            if matchFoundForCurrentFormula:
                break

    underApproxFormulas = []
    if not isCycle:
        return quantifierFreeFormulaCycle, quantifierFreeFormulaGuards, \
               correctFormulaCycle, correctFormulaCycleAssignments, \
               assignmentsCycleDict, underApproxFormulas

    for var in names:
        if assignmentsCycleDict[var] != {}:
            continue
        for guard in quantifierFreeFormulaGuards:
            gStr = str(guard)  # guard.__str__().encode("ascii")
            fParser = parserFormula.parserFormula_main()
            cstrList = fParser.parse(gStr)
            if var in cstrList[0].keys():
                underApproxFormulas.append(guard)
    # print 'underApproxFormulas'
    # print underApproxFormulas

    return (quantifierFreeFormulaCycle, quantifierFreeFormulaGuards,
            correctFormulaCycle, correctFormulaCycleAssignments,
            assignmentsCycleDict, underApproxFormulas)

    # Neglect the sign == or <=. correctFormulaCycleAssignments contains
    # only assignments i.e. ==


def createLassoInInterprocFormat(quantifierFreeFormulaStemGuards,
                                 assignmentsStemDict,
                                 quantifierFreeFormulaGuards,
                                 assignmentsCycleDict, underApproxFormulas):
    # replace 'True' with 'true' and 'Not' with 'n'
    substs = ["true", "false", "not"]  # index must correspond to regexp order.
    regexp = re.compile(r"(True)|(False)|(Not)")

    lasso = open(os.getcwd() + "/build/lasso_interproc", "w")
    lasso.write("var ")
    lasso.write(", ".join(["{} : int".format(var) for var in names]))
    lasso.write(";\n")

    lasso.write("begin\n")
    for formula in quantifierFreeFormulaStemGuards:
        fStr = str(formula)
        fStr = regexp.sub(lambda m: substs[m.lastindex - 1], fStr)
        lasso.write("assume({});\n".format(fStr))

    writeAssignments(assignmentsStemDict, lasso)

    lasso.write("while(")

    lasso.write(" and ".join(regexp.sub(lambda m: substs[m.lastindex - 1],
                                        str(formula))
                             for formula in quantifierFreeFormulaGuards))
    lasso.write(') do\n')

    writeAssignments(assignmentsCycleDict, lasso)

    if underApproxFormulas:
        lasso.write('assume(')
        lasso.write(" and ".join(regexp.sub(lambda m: substs[m.lastindex - 1],
                                            str(formula))
                                 for formula in underApproxFormulas))
        lasso.write(');\n')

    lasso.write("done;\nend")
    lasso.close()


def writeAssignments(correctFormulaAssignments, lasso):
    for var in names:
        lasso.write("{} = ".format(var))
        d = correctFormulaAssignments[var]
        if d:
            def to_expr(k, c):
                if k == "rhsConstant":
                    return c
                else:
                    return "{}*{}".format(c, k)

            lasso.write(" + ".join(to_expr(k, str(c)) for k, c in d.items()))
        else:
            lasso.write("random")
        lasso.write(";\n")


def main_func(parseStr, extractInv):
    global z3Vars
    global namesToZ3VarsIndex
    global namesUpdateIndexForStem
    global namesUpdateIndexForCycle
    global isStem
    global stemRelation
    global cycleRelation
    global quantifiedVarsStem
    global quantifiedVarsCycle

    z3Vars = []
    namesToZ3VarsIndex = {}
    namesUpdateIndexForStem = {}
    namesUpdateIndexForCycle = {}
    isStem = True
    stemRelation = []
    cycleRelation = []
    quantifiedVarsStem = []
    quantifiedVarsCycle = []

    # Build the lexer
    for name in names:
        namesUpdateIndexForStem[name] = 0
        namesUpdateIndexForCycle[name] = 0
        z3Vars.append(Int(name))
        namesToZ3VarsIndex[name] = len(z3Vars) - 1
        post_name = "{}_post".format(name)
        z3Vars.append(Int(post_name))
        namesToZ3VarsIndex[post_name] = len(z3Vars) - 1

    t2lassoL = lexer_main()
    t2lassoP = yacc.yacc(debug=0, write_tables=0)
    t2lassoP.parse(parseStr, lexer=t2lassoL)
    (quantifierFreeFormulaStem, quantifierFreeFormulaStemGuards,
     correctFormulaStem, correctFormulaStemAssignments,
     assignmentsStemDict, underApproxFormulas) = \
        getGuardsAssignments(stemRelation, quantifiedVarsStem, False)
    # print("quantifierFreeFormula for Stem =")
    # print(quantifierFreeFormulaStem)
    # print("Guards for Stem =")
    # print(quantifierFreeFormulaStemGuards)
    # print("correct Formula Stem =")
    # print(correctFormulaStem)
    # print("correct Formula Stem Assignments =")
    # print(correctFormulaStemAssignments)

    # print("cycleRelation =")
    # print(cycleRelation)

    (quantifierFreeFormulaCycle, quantifierFreeFormulaGuards,
     correctFormulaCycle, correctFormulaCycleAssignments,
     assignmentsCycleDict, underApproxFormulas) = \
        getGuardsAssignments(cycleRelation, quantifiedVarsCycle, True)

    # print("quantifierFreeFormula for Cycle =")
    # print(quantifierFreeFormulaCycle)
    # print("Guards for Cycle =")
    # print(quantifierFreeFormulaGuards)
    # print("correct Formula Cycle =")
    # print(correctFormulaCycle)
    # print("correct Formula Cycle Assignments =")
    # print(correctFormulaCycleAssignments)

    # writes on `build/lasso_interproc` file.
    createLassoInInterprocFormat(quantifierFreeFormulaStemGuards,
                                 assignmentsStemDict,
                                 quantifierFreeFormulaGuards,
                                 assignmentsCycleDict, underApproxFormulas)

    reload(parserAnant)
    parserAnant.main_func(os.getcwd() + "/build/lasso_interproc", extractInv)


if __name__ == "__main__":
    main_func()
