'''
Generate an interprocedural dominator graph for each bb target: list of (edge,level)
Merge all interprocedural dominator graphs and depth to generate dom_bits and dom_depth
Write the result to /root/instr-io/dom_bits_depth.txt
'''

import numpy as np

f1 = open('/root/instr-io/DominatorsOfTargetFunctions.txt','r')
f2 = open('/root/instr-io/BBTargets-inter.txt','r')
f3 = open('/root/instr-io/DominatorsOfTargets.txt','r')
f4 = open('/root/instr-io/BBtargets.txt','r')

'''
First, parse DominatorsOfTargets.txt generated by afl-pass
Obtain the mapping from src:line -> (edge,level)
'''
srcline2edges = {} # key: src_line, value: list of tuple (edge_num, level)
for line in f3:
    split_line = line.split(' | ')
    bb_id = int(split_line[0].split('Current bb:')[-1].strip())
    edge = split_line[1].split('Edge number:')[-1].strip()
    # let the level start from 1 instead of 0
    intra_level = int(split_line[2].split('Level:')[-1].strip()) + 1
    src_line = split_line[3].split('Target:')[-1].strip()
    next_dom_bb = int(split_line[4].split('Next bb dom:')[-1].strip())
    try:
        if not (edge,intra_level,bb_id,next_dom_bb) in srcline2edges[src_line]:
            srcline2edges[src_line].append((edge,intra_level,bb_id,next_dom_bb))
    except:
        srcline2edges[src_line] = [(edge,intra_level,bb_id,next_dom_bb)]
for src_line in srcline2edges:
    print('found edges for',src_line)

'''
Then, parse BBTargets-inter.txt generated by opt-pass: getCSAdditionalTargets
Obtain the mapping from function call -> src:line
'''
# key: function call, e.g. "main -> doit"
# value: list of callsites, e.g. ["bof.c:37"]
call2srcline = {} 
for line in f2:
    split_line = line.split(',')
    src_line = [split_line[-1].strip()]
    call = split_line[0] # call
    try: # maybe multiple callsites
        call2srcline[call].extend(src_line)
    except: # call first met
        call2srcline[call] = src_line
    call2srcline[call] = list(np.unique(np.asarray(call2srcline[call])))
    print(call,call2srcline[call])

'''
Finally, parse DominatorsOfTargetFunctions.txt
Generate an interprocedural dominator graph (IDG) for each bb target: list of list of (edge,level)
Finally, merge all IDGs to generate dom_bits and dom_depth
'''
dom_bits = [] 
dom_depth = [] 
dom_target = [] # High 16 bits store the target id corresponding to the key edge, low 16 bits store the corresponding target id
dom_bb_id = [] # BB id from which the edge is emitted
dom_next_dom = [] # The dom BB id that is reached
all_targets = [] # Targets in the order of BBtargets.txt
target_edge = [] # Key edge corresponding to the target, covering the target
target_id = [] # target_id corresponding to target_edge, id is the index of all_target list

for line in f4:
    src_line = (line.strip().split('/'))[-1]
    all_targets.append(src_line)
print(all_targets)

for idx in range(len(all_targets)): # individual targets which do not have inter are neglected above, add them here
    try:
        dom_edges = srcline2edges[all_targets[idx]]
        for edge in dom_edges:
            if edge[1] == 0:
                if edge[0] not in target_edge:
                    target_edge.append(edge[0])
                    target_id.append(1<<idx)
                else:
                    edge_0_idx = target_edge.index(edge[0])
                    target_id[edge_0_idx]|=1<<idx
    except:
        print("cannot find key edge for "+all_targets[idx]+" , maybe not instrumented.")

for line in f1: # iter each bb target, for each target build an interprocedural dom graph
    dom_graph_srcline = [] # the interprocedural dom graph, in src:line form
    split_line = line.split(',')
    for i in range(len(split_line)-1):
        # handle c++ function name? maybe not necessary because of c++ mangling
        if (':' in split_line[i]) and (not '::' in split_line[i]): # should only trigger once
            dom_graph_srcline.append([split_line[i]]) # the bb target provided by user
            # print(srcline2edges[split_line[i]])
            continue
        callee = split_line[i].strip()
        caller = split_line[i+1].strip()
        '''get the callsites' src_lines'''
        if caller+' -> '+callee in call2srcline.keys():
            dom_graph_srcline.append(call2srcline[caller+' -> '+callee]) 
    the_target = dom_graph_srcline[0]
    the_target_id = -1
    for idx in range(len(all_targets)):
        if the_target[0] == all_targets[idx]:
            the_target_id = idx
            break

    dom_graph_srcline.reverse() # for cg level calculation
    accumulation = 0
    max_depth = 0
    for calls_callsites in dom_graph_srcline: # iter function calls/bb target(only the first one)
        accumulation += max_depth
        max_depth = 0
        for src_line in calls_callsites: # iter callsites
            '''get intra-procedural dom edges'''
            edge_levels = [] # list of tuple (edge,level)
            try:
                for edge,level,bb_id,next_dom_bb in srcline2edges[src_line]:
                    if level>max_depth:
                        max_depth = level
                    level += accumulation # update inter-procedural level
                    edge_levels.append((int(edge),level))
                    if not edge in dom_bits: # merge all edges and depth
                        dom_bits.append(edge)
                        dom_depth.append(level)
                        dom_bb_id.append(bb_id)
                        dom_next_dom.append(next_dom_bb)
                        dom_target.append((1<<the_target_id)) # Low 16 bits store which target it belongs to
                    else:
                        idx_for_edge = dom_bits.index(edge)
                        if level>dom_depth[idx_for_edge]:
                            dom_depth[idx_for_edge] = level
                        dom_target[idx_for_edge]|=(1<<the_target_id)
                    if edge in target_edge: # Key edge, high 16 bits store which target it belongs to
                        idx_for_edge_in_target_edge = target_edge.index(edge)
                        idx_for_edge = dom_bits.index(edge)
                        dom_target[idx_for_edge]|=(target_id[idx_for_edge_in_target_edge]<<16)
                        dom_depth[idx_for_edge] = 0
            except:
                # continue
                print('edges not found for target ',src_line,', maybe optimized.')

for idx in range(len(all_targets)): # individual targets which do not have inter are neglected above, add them here
    try:
        dom_edges = srcline2edges[all_targets[idx]]
        for edge,level,bb_id,next_dom_bb in dom_edges:
            if not edge in dom_bits: # merge all edges and depth
                dom_bits.append(edge)
                dom_depth.append(level)
                dom_bb_id.append(bb_id)
                dom_next_dom.append(next_dom_bb)
                dom_target.append((1<<idx))
            else:
                idx_for_edge = dom_bits.index(edge)
                if level>dom_depth[idx_for_edge]:
                    # print(str(idx_for_edge)+":"+str(dom_depth[idx_for_edge])+"->"+str(level))
                    dom_depth[idx_for_edge] = level
                dom_target[idx_for_edge]|=(1<<idx)
            if edge in target_edge: #   is key edge
                idx_for_edge_in_target_edge = target_edge.index(edge)
                idx_for_edge = dom_bits.index(edge)
                dom_target[idx_for_edge]|=(target_id[idx_for_edge_in_target_edge]<<16)
                dom_depth[idx_for_edge] = 0
    except:
        continue

print(target_edge)
print(target_id)

rank_depth = {}
for i in range(len(dom_depth)):
    rank_depth[i] = dom_depth[i]
ranking = sorted(rank_depth.items(),key=lambda x:x[1],reverse=False) 
dom_bits_new = []
dom_depth_new = []
dom_bb_id_new = []
dom_next_dom_new = []
dom_target_new = []
max_depth_ = 0
for i in range(len(ranking)):
    idx = ranking[i][0]
    dom_bits_new.append(dom_bits[idx])
    if dom_depth[idx]>max_depth_:
        max_depth_ = dom_depth[idx]
    dom_depth_new.append(dom_depth[idx])
    dom_bb_id_new.append(dom_bb_id[idx])
    dom_next_dom_new.append(dom_next_dom[idx])
    dom_target_new.append(dom_target[idx])

target_idx = []

with open('/root/instr-io/dom_bits_depth.txt','w') as f:
    for i in range(len(dom_bits)):
        if dom_depth_new[i]==0:
            target_idx.append(i)
            dom_depth_new[i] = max_depth_+1
        else:
            f.write(dom_bits_new[i]+':'+str(dom_depth_new[i])+','+str(dom_target_new[i])+','+str(dom_bb_id_new[i])+','+str(dom_next_dom_new[i])+'\n')
            # edge_id:edge_depth,edge_target,bb_id,next_bb_id_dom
    for i in target_idx:
        f.write(dom_bits_new[i]+':'+str(dom_depth_new[i])+','+str(dom_target_new[i])+','+str(dom_bb_id_new[i])+','+str(dom_next_dom_new[i])+'\n')

f1.close()
f2.close()
f3.close()
f4.close()
