import os
import subprocess
import csv
import openai
import time

# log file
log_file = open('comparison_with_cirfix.log',"w")
# master file with all responses
master_filename = "all_repairs.txt"
fw_master = open( master_filename,'a' )

openai.api_key = os.getenv("OPENAI_API_KEY")

examples_generated=0


# import the csv file(s) for bug locations into a dictionary
with open ('cirfix_bugs_locations.csv', mode='r') as csv_file: 
    bugs_dict = csv.DictReader(csv_file) # each row of bugs_dict is a bug
    for i,bug in enumerate(bugs_dict):
        bug_id = bug["ID"] 

        # if os.path.exists(bug_id):
        #     continue
        if i!=10:
            continue

        filepath = bug["filepath"]
        loc_start = int(bug["line-start"])
        loc_end = int(bug["line-end"])

        log_file.write("working on " + bug_id + "," + filepath + "," +str(loc_start)+ "," +str(loc_end) + '\n')

        ''' produce the prompt to LLM'''
        prompt = ""
        # get pre-bug code
        max_num_pre_bug_lines = 50
        min_num_pre_bug_lines = 25
        buggy_file = open(filepath)
        buggy_file_lines = buggy_file.readlines() # list of lines
        if loc_start > max_num_pre_bug_lines:
            prompt += ''.join(buggy_file_lines[loc_start-max_num_pre_bug_lines:loc_start-1])
        else:
            prompt += ''.join(buggy_file_lines[:loc_start-1]) # pre-bug code
        buggy_file.close()

        # append instruction variation to the pre-bug code
        pre_bug_code = prompt
        # instruction_variations = ['-a','-b','-c','-d','-e']
        instruction_variations = ['-a', '-b']
        for instruction_variation in instruction_variations:
            prompt = pre_bug_code

            # append the bug instruction
            if instruction_variation == '-a':
                bug_instruction_file = open( os.path.join("instructions",'variation-a_bug-instruction.txt'),"r")
            else:
                bug_instruction_file = open( os.path.join("instructions",bug_id+instruction_variation+'_bug-instruction.txt'),"r")
            instruction = ''.join(bug_instruction_file.readlines())
            bug_instruction_file.close()
            prompt += instruction

            # append the bug in comments
            prompt_buggy_code = '\n//' + '//'.join(buggy_file_lines[loc_start-1:loc_end])
            prompt += prompt_buggy_code

            # append the fix instruction
            if instruction_variation == '-a' or instruction_variation == '-b':
                fix_instruction_file = open( os.path.join("instructions",'variation-a_fix-instruction.txt'),"r")
            else:
                fix_instruction_file = open( os.path.join("instructions",bug_id+instruction_variation+'_fix-instruction.txt'),"r")
            instruction = ''.join(fix_instruction_file.readlines())
            fix_instruction_file.close()
            prompt += instruction
            prompt += '\n'

            # write the prompt-to-LLM to a file
            with open( "prompt-to-LLM_" + bug_id + instruction_variation + '.v','w' ) as fw_prompt:
                fw_prompt.write(prompt)
            fw_prompt.close()

            '''get response from LLM'''
            # the parameters are manipulated according to the cwe and the buggy code
            stop = ["endmodule"] # default
            # temp = [0.1, 0.3, 0.5, 0.7, 0.9]
            temp = [0.1]
            completions_per_call = 1

            if(loc_start==loc_end):
                stop.append('\n')
            # if (cwe=="1234" or bug_id=='bug4' or bug_id=='bug5' or bug_id=='bug8'):
            #     stop.append('\n')
            # if (bug_id=='bug3' or bug_id=='bug6' or bug_id == 'bug7' or bug_id == 'bug9'):
            #     stop.append('end')
            # if (bug_id=='bug10'):
            #     stop.append('endcase')
            # # print(stop)

            for t in temp:
                # only needed for using token limited LLMs
                if( examples_generated >= 5 ):
                    print ("pausing " + str(60) + " seconds")
                    time.sleep(60)
                    examples_generated = 0

                # model="code-davinci-002",
                completion = openai.Completion.create( 
                    model="code-davinci-002", 
                    # engine="codegen",
                    prompt=prompt,
                    max_tokens= 200,
                    temperature= t,
                    top_p= 1,
                    n= completions_per_call,
                    stop=stop
                )

                examples_generated +=completions_per_call
                print ("generated ", str(completions_per_call), " bug= ", bug_id, " examples for i=", instruction_variation, " t= ", t)

                # '''output produced file'''
                # individual files
                for j,example in enumerate(completion.choices):
                    # print(example)
                    # create example file
                    filename = "example"+str(j)+ "_"+bug_id+'_i'+instruction_variation+'_t-'+str(t)+".v"
                    # if os.path.exists(os.path.join(bug_id,filename)):
                    #     continue
                    fw = open(os.path.join("repairs",filename),"w")
                    example_lines = example.text.split('\n')
                    example_lines_text = '\n'.join(example_lines)
                    write_to_file = ''.join(buggy_file_lines[:loc_start-1]) + example_lines_text + '\n' + ''.join(buggy_file_lines[loc_end:])
                    fw.write(write_to_file)

                    fw_master.write(filename + '\n')
                    fw_master.write('\n'.join(example_lines))
                    fw_master.write("\n========================================================\n")
                    fw_master.flush()

                    fw.close()
                
fw_master.close()
log_file.close()
# for design in designs:
#     with open ('../results/bug_locations/'+design+'.csv', mode='r') as csv_file: 
#         bugs_dict = csv.DictReader(csv_file) # each row of bugs_dict is a bug

#         # for each bug
#             # create a directory of the ID if not already present
#             # use the CWE to select the corresponding 'prompt' file
#             # produce the prompt
#             # use LLM to generate the repaired code
#             # fill in post-bug code
#             # write to file in relevant directory

#         # ID,filepath,module,cwe,line-start,line-end
#         for bug in bugs_dict:

#             bug_id = bug["ID"]            # if os.path.exists(bug_id):
#             #     continue
#             filepath = bug["filepath"][12:]
#             cwe = bug["cwe"]
#             loc_start = int(bug["line-start"])
#             if bug_id=='bug10':
#                 loc_start = int(bug["line-end"])
#             loc_end = int(bug["line-end"])

#             if(bug_id!= 'bug10'):
#                 continue

#             # if directory is already present, continue
#             # if os.path.exists(bug_id):
#             #     continue
#             # create directory
#             if ( not os.path.exists(bug_id) ):
#                 print("creating directory", bug_id)
#                 os.mkdir(bug_id)

#             ''' produce the prompt to LLM'''

#             prompt = ""
#             # get pre-bug code
#             max_num_pre_bug_lines = 50
#             min_num_pre_bug_lines = 25
#             # if cwe=='1245':
#             #     max_num_pre_bug_lines = 10
#             buggy_file = open(filepath)
#             buggy_file_lines = buggy_file.readlines() # list of lines
#             if loc_start > max_num_pre_bug_lines:
#                 prompt += ''.join(buggy_file_lines[loc_start-max_num_pre_bug_lines:loc_start-1])
#             else:
#                 prompt += ''.join(buggy_file_lines[:loc_start-1]) # pre-bug code
#             # print(prompt)
#             buggy_file.close()

#             # master file with all responses
#             master_filename = os.path.join(bug_id,"master.txt")
#             fw_master = open( master_filename,'a' )

#             # append instruction variation to the pre-bug code
#             pre_bug_code = prompt
#             # instruction_variations = ['-a','-b','-c','-d','-e']
#             instruction_variations = ['-a','-b','-c','-d']
#             for instruction_variation in instruction_variations:
#                 prompt = pre_bug_code

#                 # append the bug instruction
#                 bug_instruction_file = open( os.path.join("instructions",bug_id+instruction_variation+'_bug-instruction.txt'),"r")
#                 instruction = ''.join(bug_instruction_file.readlines())
#                 bug_instruction_file.close()
#                 prompt += instruction

#                 # append the bug in comments
#                 prompt_buggy_code = '\n//' + '//'.join(buggy_file_lines[loc_start-1:loc_end])
#                 prompt += prompt_buggy_code

#                 # append the fix instruction
#                 fix_instruction_file = open( os.path.join("instructions",bug_id+instruction_variation+'_fix-instruction.txt'),"r")
#                 instruction = ''.join(fix_instruction_file.readlines())
#                 fix_instruction_file.close()
#                 prompt += instruction
#                 prompt += '\n'

#                 # write the prompt-to-LLM to a file
#                 with open( os.path.join(bug_id,"prompt-to-LLM" + instruction_variation + '.v'),'w' ) as fw_prompt:
#                     fw_prompt.write(prompt)
#                 fw_prompt.close()

#                 '''get response from LLM'''
#                 # the parameters are manipulated according to the cwe and the buggy code
#                 stop = ["endmodule"] # default
#                 temp = [0.1, 0.3, 0.5, 0.7, 0.9]
#                 completions_per_call = 5

#                 if (cwe=="1234" or bug_id=='bug4' or bug_id=='bug5' or bug_id=='bug8'):
#                     stop.append('\n')
#                 if (bug_id=='bug3' or bug_id=='bug6' or bug_id == 'bug7' or bug_id == 'bug9'):
#                     stop.append('end')
#                 if (bug_id=='bug10'):
#                     stop.append('endcase')
#                 # print(stop)
#                 for t in temp:

#                     # only needed for using token limited LLMs
#                     if( examples_generated >= 25 ):
#                         print ("pausing " + str(sleep_duration) + " seconds")
#                         # time.sleep(sleep_duration)
#                         examples_generated = 0

#                     # model="code-davinci-002",
#                     completion = openai.Completion.create( 
#                         # model="code-davinci-001", 
#                         engine="codegen",
#                         prompt=prompt,
#                         max_tokens= 200,
#                         temperature= t,
#                         top_p= 1,
#                         n= completions_per_call,
#                         stop=stop
#                     )

#                     examples_generated +=completions_per_call
#                     print ("generated ", str(completions_per_call), " bug= ", bug_id, " examples for i=", instruction_variation, " t= ", t)

#                     # '''output produced file'''
#                     # individual files
#                     for j,example in enumerate(completion.choices):
#                         # print(example)
#                         # create example file
#                         filename = "example"+str(j+10)+ '_i'+instruction_variation+'_t-'+str(t)+".v"
#                         if os.path.exists(os.path.join(bug_id,filename)):
#                             continue
#                         with open( os.path.join(bug_id,filename),'w' ) as fw:

#                             # only need to evaluate the response of the LLM for trailing begins and excessive ends
#                             example_lines = example.text.split('\n')

#                             # if the response is only comments, uncomment
#                             only_comments = True
#                             for line in example_lines:
#                                 if not "//" in line:
#                                     only_comments = False
#                                     break
#                             if only_comments:
#                                 # print(example_lines)
#                                 for i,line in enumerate(example_lines):
#                                     example_lines[i] = line.replace("//","")
#                                     # print(line)
#                                 # print("after fix: \n",example_lines)
#                             if (bug_id == 'bug10'):
#                                 example_lines.append('endcase\n')

#                             if (bug_id != 'bug1' and bug_id != 'bug4' and bug_id != 'bug6' and bug_id != 'bug7' and bug_id != 'bug8'):
#                                 num_begins = 0
#                                 num_ends = 0
#                                 for line in example_lines:
#                                     if ("begin" in line and not("//" in line) ):
#                                         num_begins = num_begins +1
#                                     if ("end" in line and not ("//" in line) and not ("endcase" in line) ):  
#                                         num_ends = num_ends + 1        
#                                 # print("num_begins: "+str(num_begins)+" num_ends: ", str(num_ends))
#                                 end_additions = num_begins - num_ends
#                                 if (end_additions > 0):
#                                     for i in range(end_additions):
#                                         #add end at the end of the file
#                                         example_lines.append("end\n")

#                                 # remove excessive ends if needed
#                                 end_removals = num_ends - num_begins
#                                 # print(end_removals)
#                                 if (end_removals> 0):
#                                     example_lines.reverse()
#                                     for i in range(end_removals):
#                                         # remove an end
#                                         # print("i= "+str(i)+'\n')
#                                         for l_no,l in enumerate(example_lines):
#                                             if ("end" in l and not("//" in l) and not ("endcase" in line) ):
#                                                 # print("need to remove end in line "+ str(l_no))
#                                                 example_lines[l_no] = ' '
#                                                 break
#                                     example_lines.reverse()
                                            
#                             example_lines_text = '\n'.join(example_lines)
#                             write_to_file = ''.join(buggy_file_lines[:loc_start-1]) + example_lines_text + '\n' + ''.join(buggy_file_lines[loc_end:])
#                             fw.write(write_to_file)

#                             fw_master.write(filename + '\n')
#                             fw_master.write('\n'.join(example_lines))
#                             fw_master.write("\n========================================================\n")
#                             fw_master.flush()

#                         fw.close()
#             fw_master.close()

            