#%%
"""Create citation network for the RP bibliography.
"""

#%%
import sys
import os
import numpy
import re
from fuzzywuzzy import fuzz
from importlib import reload
import glob
from pybtex.database.input import bibtex
import pybtex.errors
import requests
from bs4 import BeautifulSoup
import networkx as nx
from networkx.algorithms import community
import json
import datetime
import igraph as ig
from progress.bar import ChargingBar
from matplotlib import pyplot as plt


#%% some pre-settings
currentYear = datetime.datetime.now().year

dic = {"{\k c}":"c","{\k a}":"a","{":"","}":"","\\r":"","\r":"","\l":"","\H":"","\c":"","\o":"o","\k":"","\\O":"O","\\ae":"","\\c ":"","\\u ":"","\\.":"","\\. ":"","\\^":"","\\v ":"","\\v":"","\\` ":"","\\' ":"","\\`":"","\\\'":"","\\\"":"","\\~":"",
"\\\"a":"a", "\\\"u":"u", "\\\"o":"o", "\\ss":"ss", "\\'a":"a", "\\'o":"o", "\\'e":"e", "\\'i":"i", "\\i":"i", "\\l ":"l", "{":"", "}":"","\\\\":""}

def replace_all(text, dic):
    for i, j in dic.items():
        text = text.replace(i, j)
    return text


#%% import data

# import rp.bib
pybtex.errors.set_strict_mode(True)
parser = bibtex.Parser()
bibdata = parser.parse_file("../rp.bib")
labels = sorted(bibdata.entries.keys())

# import results
f = open('../Data/citations.json','r')
citations = json.load(f)
f.close()


#%% prepare list of all labels = node list
ignore = {}
nodeID = {}
paperyear = {}
labelsSelected = []

maxCnt = 1000
maxCnt = len(bibdata.entries)+1

# create a list of publication years
for idx, bib_id in enumerate(labels):
    
    ignore[bib_id] = 0
    paperyear_ = bibdata.entries[bib_id].fields["year"]

    # handle papers in press
    if paperyear_ == "in press":
        bibdata.entries[bib_id].fields["year"] = currentYear
        paperyear_ = currentYear
        
    # ignore some papers
    if int(paperyear_) < 1987 or bibdata.entries[bib_id].fields["annote"] == 'Software' or bibdata.entries[bib_id].fields["annote"] == 'Related':
        ignore[bib_id] = 1
        print(f'ignored {bib_id}')
    
    nodeID[bib_id] = idx
    paperyear[bib_id] = int(paperyear_)
    


#%% sanity check: 
# find keys in citations that are not in variable labels (i.e. are not in rp.bib)
difference_keys1 = set(citations.keys()) - set(labels)
difference_keys2 = set(labels) - set(citations.keys())
citationsN = len(citations)
rpbibN = len(labels)
print("Sanity check:")
print(f"  Total references in rp.bib: {rpbibN}.")
print(f"  Total references in citations: {citationsN}.")
print(f"  Keys: {difference_keys1} are not in rp.bib.")
print(f"  Keys: {difference_keys2} are not in citations.")


#%% prepare citation network
unique_years = sorted(set(paperyear.values()))

mat = numpy.zeros((len(labels),len(labels)),int)
mat2 = numpy.zeros((len(labels),len(labels)),int)
matTnet = numpy.zeros((len(labels),len(labels), len(unique_years)),int)

citationTime = {}

# find citations of papers and create network
        
# alternative
for idx, bib_id2 in enumerate(labels):
    index_of_citation_year = unique_years.index(paperyear[bib_id2])
    if not bib_id2 in citations:
        print(f"{bib_id2} is missing in citations file")
        continue
    for bib_id in citations[bib_id2]:
        if int(paperyear[bib_id]) >= 1987 and not ignore[bib_id]:
           
           # check if cited paper is younger than citing one
           # (allow 1 year publication in advance)
           if paperyear[bib_id2]-paperyear[bib_id] < -1:
               print(f"paper {bib_id2} cites {bib_id}")
               continue

           mat[nodeID[bib_id], nodeID[bib_id2]] = 1
           matTnet[nodeID[bib_id], nodeID[bib_id2], index_of_citation_year] = 1
           if not bib_id in citationTime:
               citationTime[bib_id] = []
           citationTime[bib_id].append(paperyear[bib_id2]-paperyear[bib_id])

# write time between publication and citation in file
with open("../Data/citations_citationTime.txt", "w") as f:
    f.write(f"label; duration\n")
    for bib_id in citationTime:
        f.write(f"{bib_id};  {sorted(citationTime[bib_id])}\n")




#%% some statistics

# most cited papers
citationsPerPaper = numpy.sum(mat, axis=(1))

# Get the sorted version and the sorting indices
sorting_indices = numpy.argsort(citationsPerPaper)[::-1]  # [::-1] for descending order
citationsPerPaperSorted = citationsPerPaper[sorting_indices]

# Access the top N elements in the labels list using sorting indices
topN = 20
top_labels = [labels[i] for i in sorting_indices]


# paper citations and publication year
with open('../Data/citations_years.txt', "w") as f:
    for author, numCitations in zip(top_labels, citationsPerPaperSorted):
        if ignore[author]:
            continue
        f.write(f"{author}, {numCitations}, {paperyear[author]}\n")
        #f.write(f"{author}, {numCitations}, {paperyear[author]}, {weight[nodeID[author]]}\n") # this works only when the weight is calculated in the code below



# citations over year over all papers
citationsPerYear = numpy.sum(matTnet, axis=(0, 1))

# plt.plot(unique_years,citationsPerYear)
# plt.show()

# citations over year for each individual paper
citationsPerYearPerPaper = numpy.sum(matTnet, axis=(1))
max_indices = numpy.argsort(citationsPerYearPerPaper, axis=0)[::-1]
topN = 5

top_cit = [[0 for _ in range(topN)] for _ in range(len(unique_years))]
top_labels = [[0 for _ in range(topN)] for _ in range(len(unique_years))]

# prepare a ranking of papers based on the number of citations for each year
for j in range(0, len(unique_years)):
   counter = 0
   for i in max_indices[:topN, j]:
      if i < len(labels)-1:
         top_labels[j][counter] = labels[i+1]
         top_cit[j][counter] = citationsPerYearPerPaper[i, j]
         counter += 1





# plt.plot(unique_years,top_cit)
# plt.show()
# 
# 
# plt.plot(unique_years,citationsPerYearPerPaper[329,:])
# plt.show()




#%% disruptive measure
CD = {}
bar = ChargingBar('Calculate disruptive measure', max=len(labels), suffix='%(percent).0f%% - %(eta)ds')
for bib_id in labels:
    bar.next()
    
    if ignore[bib_id]:
        continue
    B = mat[:,nodeID[bib_id]] # predecessors
    idx = np.where(B == 1)
    if not bib_id in CD:
        CD[bib_id] = 0
    else:
        continue
    CD[bib_id] = 0
    cnt = 0
    for bib_id2 in labels:
        if paperyear[bib_id] < paperyear[bib_id2]-5 and not ignore[bib_id2]:
            b = int(any(mat[idx, nodeID[bib_id2]])) # cites predecessor
            f = int(mat[nodeID[bib_id], nodeID[bib_id2]]) # cites focal paper
            #print(f"{bib_id2} - {b} - {f}")
            CD[bib_id] += -2*b*f + f
            cnt += 1
    if cnt:
        CD[bib_id] /= cnt



sorted_CD = dict(sorted(CD.items(), key=lambda item: item[1], reverse=True))
with open('../Data/citations_CD.txt', "w") as f:
    for author in sorted_CD:
        f.write(f"{author}, {sorted_CD[author]}\n")


 
#%% create weighted links
startY = dict((v,paperyear[labels[v]]) for k,v in nodeID.items())
y = list(startY.values())
indegree = {}
weight = {}

for bib_id in labels:
   # skip if it is related or software
   if ignore[bib_id]:
      continue
   y = paperyear[bib_id]
   if y >= currentYear:
      y = currentYear
   w = exp(-0.07 * (currentYear - y))
   #indegree[nodeID[bib_id]] = sum(mat[nodeID[bib_id],:]) * w
   indegree[nodeID[bib_id]] = sum(mat[nodeID[bib_id],:]) 
   weight[nodeID[bib_id]] = sum(mat[nodeID[bib_id],:]) * w
   mat2[:,nodeID[bib_id]] = mat[:,nodeID[bib_id]] * sum(mat[nodeID[bib_id],:])

   


#%% create network
G = ig.Graph.Adjacency(mat2.tolist(), mode="directed")

# Add nodes and set attributes with igraph
bar = ChargingBar('Node attributes', max=len(paperyear), suffix='%(percent).0f%% - %(eta)ds')



for node, vertex in enumerate(G.vs):
    bar.next()
    if ignore[labels[node]]:
       continue
    vertex['year'] = paperyear[labels[node]]
    #vertex['position'] = 10*randn(1) * (max(indegree.values()) - indegree[node])/max(indegree.values())
    #vertex['position'] = 10*randn(1) * (max(weight.values()) - weight[node])/max(weight.values())
    vertex['position'] = 10*randn(1) * (1. - 1. / (1 + exp(-.8 * (weight[node] - 80))))
    vertex['name'] = labels[node]
    vertex['wdeg'] = weight[node]
    vertex['indegree'] = indegree[node]
bar.finish()

# remove nodes which are not linked
remove = [node for node, degree in enumerate(G.degree()) if degree == 0]

G.delete_vertices(remove)

#%% save network
print ("save network as graphml (for loading with Gephy 0.9.2)")

G.write_graphmlz("../Data/citations.graphml.gz") # using igraph



#from scipy import sparse
#print(sparse.csr_matrix(mat))

