#! /usr/bin/env python

# import pdb
import os
import sys
import logging
from z3 import Int, Not, BoolVal, simplify
from lexerAnant import lexer_main
from lexerAnant import *
import subprocess
import ply.yacc as yacc
import parserFormula
import parserConcrete

z3Vars = []
namesToZ3VarsIndex = {}
names = set([])
namesPost = set([])
# namesUpdateIndexForStem = { }
# namesUpdateIndexForCycle = { }
isStem = True
stemRelation = []
cycleRelation = []
# quantifiedVarsStem = []
# quantifiedVarsCycle = []
# stemAssignmentsFound = False
# cycleAssignmentsFound = False


def p_program(p):
    'program : varlist BEGIN lasso END'
    pass


def p_varlist(p):
    "varlist : VAR vars_decl ';' "
    # print names
    pass


def p_vars_decl(p):
    "vars_decl : NAME ':' INT"
    names.add(p[1])
    namesPost.add(p[1] + '_post')
    # namesUpdateIndexForStem[p[1]] = 0
    # namesUpdateIndexForCycle[p[1]] = 0
    name = p[1]
    z3Vars.append(Int(name))
    namesToZ3VarsIndex[name] = len(z3Vars) - 1
    z3Vars.append(Int(name + "_post"))
    namesToZ3VarsIndex[name + "_post"] = len(z3Vars) - 1


def p_vars_decl_recurrent(p):
    "vars_decl : vars_decl ',' NAME ':' INT"
    names.add(p[3])
    namesPost.add(p[3] + '_post')
    # namesUpdateIndexForStem[p[3]] = 0
    # namesUpdateIndexForCycle[p[3]] = 0
    name = p[3]
    z3Vars.append(Int(name))
    namesToZ3VarsIndex[name] = len(z3Vars) - 1
    z3Vars.append(Int(name + "_post"))
    namesToZ3VarsIndex[name + "_post"] = len(z3Vars) - 1


def p_lasso(p):
    'lasso : stem cycle'
    # print "parsed anant lasso"
    pass


def p_stem(p):
    'stem : statement_list'
    global isStem
    # print "parsed stem"
    isStem = False


def p_cycle(p):
    "cycle : WHILE boolexpression DO statement_list DONE ';' "
    for formula in p[2]:
        cycleRelation.append(formula)
    # print "parsed cycle"


def p_statement_list_base(p):
    'statement_list : statement'
    pass


def p_statement_list_recurrence(p):
    'statement_list : statement_list statement'
    pass


def p_statement_skip(p):
    "statement : SKIP ';'"
    pass


def p_statement_assign(p):
    "statement : NAME '=' expression ';' "
    global isStem
    if p[1] not in names:
        print("Parsing Error : Variable {} is Undeclared.".format(p[1]))
        sys.exit()

    name = p[1] + '_post'

    i = namesToZ3VarsIndex[name]
    p[0] = z3Vars[i] == p[3]
    if isStem:
        stemRelation.append(p[0])
    else:
        cycleRelation.append(p[0])


def p_statement_random(p):
    "statement : NAME '=' RANDOM ';' "
    global isStem
    # global cycleAssignmentsFound
    if p[1] not in names:
        print("Parsing Error : Variable {} is Undeclared.".format(p[1]))
        sys.exit()

    if isStem:
        pass
        # quantifiedVarsStem.append(z3Vars[i])
    else:
        pass
        # pdb.set_trace()
        # cycleAssignmentsFound = True
        # quantifiedVarsCycle.append(z3Vars[i])


def p_statement_assume(p):
    "statement : ASSUME boolexpression ';' "
    if isStem:
        for formula in p[2]:
            stemRelation.append(formula)
    else:
        for formula in p[2]:
            cycleRelation.append(formula)


def p_expression_binop(p):
    '''expression : expression '+' expression
                  | expression '-' expression'''

    # Our overapproximation must be linear.
    # Thus *,/ are not supported.
    # | expression '*' expression
    # | expression '/' expression'''
    if p[2] == '+':
        p[0] = p[1] + p[3]
    elif p[2] == '-':
        p[0] = p[1] - p[3]
    # elif p[2] == '*': p[0] = p[1] * p[3]
    # elif p[2] == '/': p[0] = p[1] / p[3]


def p_expression_uminus(p):
    "expression : '-' expression %prec UMINUS"
    p[0] = -p[2]


def p_expression_group(p):
    "expression : '(' expression ')'"
    p[0] = p[2]


def p_expression_coef_var(p):
    "expression : NUMBER '*' NAME"
    global isStem
    # global cycleAssignmentsFound
    if p[3] not in names:
        if p[3] not in namesPost:
            print("Parsing Error : Variable {} is Undeclared.".format(p[3]))
            sys.exit()
    name = p[3]
    z3VarIndex = namesToZ3VarsIndex[name]
    p[0] = p[1] * z3Vars[z3VarIndex]


def p_expression_number(p):
    "expression : NUMBER"
    p[0] = p[1]


def p_expression_name(p):
    "expression : NAME"
    global isStem
    # global cycleAssignmentsFound
    if p[1] not in names:
        if p[1] not in namesPost:
            print("Parsing Error : Variable {} is Undeclared.".format(p[1]))
            sys.exit()
    name = p[1]
    z3VarIndex = namesToZ3VarsIndex[name]
    p[0] = z3Vars[z3VarIndex]


def p_bool_expression_paranthesis(p):
    '''boolexpression : '(' boolexpression ')' '''
    p[0] = p[2]


def p_bool_expression_not(p):
    '''boolexpression : NOT boolexpression %prec UMINUS'''
    assert len(p[2]) == 1, 'NOT formula with a conjunct found'
    p[0] = [Not(p[2][0])]


def p_bool_expression_and(p):
    '''boolexpression : boolexpression AND boolexpression'''
    l1 = list(p[1])
    l1.extend(p[3])
    # p[0] = And(p[1], p[3])
    p[0] = l1


def p_bool_expression(p):
    '''boolexpression : constraint'''
    p[0] = [p[1]]


def p_constraint_1(p):
    '''constraint : expression EQ expression '''
    p[0] = p[1] == p[3]


def p_constraint_2(p):
    '''constraint : expression LE expression '''
    p[0] = p[1] <= p[3]


def p_constraint_3(p):
    '''constraint : expression GE expression '''
    p[0] = p[1] >= p[3]


def p_constraint_4(p):
    '''constraint : expression LT expression '''
    p[0] = p[1] < p[3]


def p_constraint_5(p):
    '''constraint : expression GT expression '''
    p[0] = p[1] > p[3]


def p_constraint_true(p):
    '''constraint : TRUE '''
    p[0] = BoolVal(True)


def p_error(p):
    print("Error while Parsing Anant Program")
    if p:
        print("Syntax error at '%s': line= %d" % (p.value, p.lineno))
    else:
        print("Syntax error at EOF")
    sys.exit()


def getAAPb(quantifierFreeFormula):
    A = []
    AP = []
    b = []
    fParser = parserFormula.parserFormula_main()
    for formula in quantifierFreeFormula:
        # pdb.set_trace()
        formula = simplify(formula)
        fStr = str(formula)  # formula.__str__().encode("ascii")
        fStr = fStr.replace('\n', '')
        cstrList = fParser.parse(fStr)
        for i in range(0, len(cstrList)):
            d = {}
            for var in names:
                if var in cstrList[i].keys():
                    coeff = cstrList[i][var]
                    d[var] = coeff
                else:
                    d[var] = 0
            A.append(d)

            d = {}
            for var in names:
                varpost = var + '_post'
                if varpost in cstrList[i].keys():
                    coeff = cstrList[i][varpost]
                    d[varpost] = coeff
                else:
                    d[varpost] = 0
            AP.append(d)

            if 'rhsConstant' in cstrList[i].keys():
                rhsConstant = cstrList[i]['rhsConstant']

            else:
                rhsConstant = 0

            b.append(rhsConstant)

    return A, AP, b


def main_func(filename, extractInv):
    # Build the lexer

    anantL = lexer_main()
    anantP = yacc.yacc(debug=0, write_tables=0)

    out = runInterproc(filename)
    # remove first line of output
    out = out[out.find('\n') + 1:]
    f_i = open(os.getcwd() + '/build/interproc.out', 'w')
    f_i.write(out)
    f_i.close()
    parserConcrete.getOverApprox = True
    if extractInv:
        parserConcrete.extractInv = True
    else:
        parserConcrete.extractInv = False

    out = parserConcrete.main_func(os.getcwd() + '/build/interproc.out')
    f_o = open(os.getcwd() + '/build/overapprox.out', 'w')
    f_o.write(out)
    f_o.close()

    f = open(os.getcwd() + '/build/overapprox.out', "r")
    anantP.parse(f.read(), lexer=anantL)
    f.close()

    z3Input = open(os.getcwd() + "/build/z3Input", "w")
    z3Input.write(str(len(names)) + '\n')
    for var in names:
        z3Input.write(var + ' ')
    z3Input.write('\n\n')

    A, AP, b = getAAPb(cycleRelation)
    z3Input.write(str(len(A)) + '\n')
    for d in A:
        for var in names:
            z3Input.write(str(d[var]) + ' ')
        z3Input.write('\n')
    for d in AP:
        for var in names:
            varpost = var + '_post'
            z3Input.write(str(d[varpost]) + ' ')
        z3Input.write('\n')
    for b_i in b:
        z3Input.write(str(b_i))
        z3Input.write('\n')
    z3Input.write('\n')

    A, AP, b = getAAPb(stemRelation)
    z3Input.write(str(len(A)) + '\n')
    for i in range(len(A)):
        d = AP[i]
        for var in names:
            varpost = var + '_post'
            z3Input.write(str(d[varpost]) + ' ')
        d = A[i]
        for var in names:
            z3Input.write(str(d[var]) + ' ')
        z3Input.write(str(b[i]) + '\n')
    z3Input.write('\n\n0\n')
    z3Input.write('0')
    z3Input.close()


def runInterproc(filename):
    # args = ["../util/interproc", "-display", "text", "-domain", "octagon", filename]
    exec_file = os.path.abspath(os.path.join(os.getcwd(), "build/interproc"))
    filename = os.path.abspath(filename)
    assert os.path.isfile(exec_file), "not a file: {}".format(exec_file)
    assert os.path.isfile(filename), "not a file: {}".format(filename)
    cmd = "{} -display text -domain octagon {}".format(exec_file, filename)
    # args = [exec_file, " -display", " text", " -domain", " octagon", " ",
    #         filename]
    try:
        out = subprocess.check_output(cmd, shell=True)
        return out.decode("utf-8")
    except Exception:
        logging.exception("Error in running Interproc")
        sys.exit()


if __name__ == "__main__":
    main_func()
