#%%
import sys
import os
import numpy
import re
from fuzzywuzzy import fuzz
from importlib import reload
import glob
from pybtex.database.input import bibtex
import pybtex.errors
import requests
from bs4 import BeautifulSoup
import networkx as nx
from networkx.algorithms import community
import json
import datetime
import time
from collections import OrderedDict
from progress.bar import ChargingBar


#%% some pre-settings
username = "marwan@pik-potsdam.de" # required for crossref requests (use your email address)
citationsFile = '../Data/citations.json'
cacheFile = '../Data/cache_crossref.json'

maxRetries = 20 # maximum number of API requests before fail

start_time = time.time()

dic = {"{\k c}":"c","{\k a}":"a","{":"","}":"","\\r":"","\r":"","\l":"","\H":"","\c":"","\o":"o","\k":"","\\O":"O","\\ae":"","\\c ":"","\\u ":"","\\.":"","\\. ":"","\\^":"","\\v ":"","\\v":"","\\` ":"","\\' ":"","\\`":"","\\\'":"","\\\"":"","\\~":"",
"\\\"a":"a", "\\\"u":"u", "\\\"o":"o", "\\ss":"ss", "\\'a":"a", "\\'o":"o", "\\'e":"e", "\\'i":"i", "\\i":"i", "\\l ":"l", "{":"", "}":"","\\\\":""}

def replace_all(text, dic):
    for i, j in dic.items():
        text = text.replace(i, j)
    return text


#%% define some functions

# get citations from crossref
def get_crossref_data(doi=None, title=None, author=None, journal=None):

    # if doi available, use this one for request
    if doi:
        uri = f'http://doi.crossref.org/servlet/query?usr={username}&format=unixsd&id={doi}'

    # if doi not available, use author + title + journal
    else:
        uri = f'https://doi.crossref.org/servlet/query?usr={username}&qdata=<?xml version = "1.0" encoding="UTF-8"?><query_batch xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" version="2.0" xmlns="http://www.crossref.org/qschema/2.0" xsi:schemaLocation="http://www.crossref.org/qschema/2.0 http://www.crossref.org/qschema/crossref_query_input2.0.xsd"><head><email_address>support@crossref.org</email_address><doi_batch_id>ABC_123_fff</doi_batch_id> </head> <body> <query enable-multiple-hits="false" secondary-query="author-title" key="key1">'

        if title:
           uri += '<article_title match="fuzzy">'
           uri += title
           uri += '</article_title>'

        uri += '<author match="fuzzy">'
        uri += author
        uri += '</author>'

        if journal:
           uri += '<journal_title match="fuzzy">'
           uri += journal
           uri += '</journal_title>'

        uri += '</query></body></query_batch>'
    #print(uri)
    print('>> crossref request')

    # execute the crossref request
    r = requests.get(uri)
    r.status_code
    #print(r.content.decode('ISO-8859-1'))

    # check citations, if in rp.bib
    obj = BeautifulSoup(r.text,features="xml").find_all("citation")


    # create a dictionary of citations
    citations = []
    bar = ChargingBar('Citations', max=len(obj))
    for i in range(len(obj)):
        bar.next()
        val = obj[i]
        author = None
        year = None
        doi = None
        title = None

        if val.find('author'):
            author = val.find('author').contents # get author
            if author != []:
                author = author[0].lower() # get author
            else:
                author = None
                
        if val.find('cyear'):
            y_ = val.find('cyear').contents
            if y_ == []:
                year = None
            else:
                year = y_[0] # get year
        if val.find('doi'):
            doi = val.find('doi').contents[0] # get DOI
        if val.find('article_title'):
            title = val.find('article_title').contents[0] # get title

        # no title found: get it from crossref    
        if doi != '' and title == '':
            # prepare crossref query
            uri = f'http://doi.crossref.org/servlet/query?usr={username}&format=unixsd&id={doi}'

            #get data from crossref
            r = requests.get(uri)
            if r.status_code != 200:
               continue

            try:
               title = BeautifulSoup(r.text.replace('<![CDATA[',''),features="xml").find('title').contents[0].replace('>','').replace(']','')
            except:
               title = None

        c_ = {'title' : title, 'author' : author, 'year' : year, 'doi' : doi}
        citations.append(c_)
    bar.finish()

    return citations


# get citations from semanticscholar
def get_semanticscholar_data(doi=None, title=None, year=None):

    lenCrossrefID = len(crossrefID) # remember length of crossrefID cache

    # base url
    base_url = f'https://api.semanticscholar.org/graph/v1/paper/'
    params = {'fields': 'title,authors,citationCount'}
    print('>> semanticscholar request')

    if doi:
        # if we have a DOI
        endpoint = f'{doi}/references'
    else:
        # if no DOI available, we try a search using title and year as an 
        # intermediate step
        endpoint = 'search'
        
        if title:
            params['query'] = title
        if year:
            params['year'] = year
            
        status = 429 # circumvent error due to too many API requests
        numRetries = 0
        while (status == 429 or status == 504) and numRetries < maxRetries:
            r = requests.get(base_url + endpoint, params=params)
            status = r.status_code
            numRetries += 1
            
            if status == 429 or status == 504:
                print('   .. try again')
                time.sleep(2)

        obj = r.json().get('data', [])
        
        if obj:
            # get the paper ID
            id = obj[0]['paperId']

            # now get the references
            endpoint = f'{id}/references'

    # get citations
    #print(base_url + endpoint)
    status = 429 # circumvent error due to too many API requests
    numRetries = 0
    while (status == 429 or status == 504) and numRetries < maxRetries:
        r = requests.get(base_url + endpoint, params=params)
        status = r.status_code
        numRetries += 1
        
        if status == 429 or status == 504:
            print('   .. try again')
            time.sleep(1)

    obj = r.json().get('data', [])
    
    # create a dictionary of citations
    citations = []
    bar = ChargingBar('Citations', max=len(obj))
    cnt = 0
    for val in obj:
       bar.next()
       if 'citedPaper' in val:
           val = val['citedPaper']
           author = None
           doi = None
           title = None
           year = None
           if val['paperId']:
               if val['paperId'] in crossrefID:
                   doi = crossrefID[val['paperId']]['doi']
                   year = crossrefID[val['paperId']]['year']
                   #print("    - found in cache")
               else:
                   cnt += 1 # count API requests
                   if cnt > 10:
                       time.sleep(2)
                       cnt = 0
                   doi, year = get_doi_from_paper_id(val['paperId'])
                   #print("    - new cache entry")
                   crossrefID[val['paperId']] = {'doi': doi, 'year': year}
               if val['authors'] and val['authors'][0]['name']:
                  author = val['authors'][0]['name']
               if val['title']:
                  title = val['title']
               c_ = {'title' : title, 'author' : author, 'year' : year, 'doi' : doi}
               citations.append(c_)
    
    bar.finish()

    # if crossrefID has new entry, store it
    if len(crossrefID) > lenCrossrefID:
        print("    - write cache")
        with open(cacheFile, 'w') as f:
            json.dump(crossrefID, f, indent=4)                   
    
    return citations


# get DOI for semanticscholar paper ID
def get_doi_from_paper_id(paper_id):
    base_url = 'https://api.semanticscholar.org/v1/paper/'
    
    # Construct the full URL for the paper metadata endpoint
    full_url = f'{base_url}{paper_id}'

    # Make a GET request to the API
    status = 429
    numRetries = 0
    while status == 429 and numRetries < maxRetries:
        r = requests.get(full_url)
        status = r.status_code
        numRetries += 1

    # Check if the request was successful (status code 200)
    if r.status_code == 200:
        # Parse the JSON response
        data = r.json()

        # Extract and return the DOI from the response
        doi = data.get('doi')
        year = data.get('year')
        return doi, year
    else:
        # Handle the case where the API request fails
        print(f"API request failed with status code {r.status_code}.")
        return None, None


# get number of citations from semanticscholar
def get_semanticscholar_numcitations(doi=None):

    r = requests.post(
        'https://api.semanticscholar.org/graph/v1/paper/batch',
        params={'fields': 'citationCount,title'},
        json={"ids": doi}
    )

    print(doi)
    print(r)
 
    # get citations
    obj = r.json()
    citation_counts = [entry['citationCount'] if entry and 'citationCount' in entry else 0 for entry in obj]

    return citation_counts


#%% import data

# import rp.bib
pybtex.errors.set_strict_mode(True)
parser = bibtex.Parser()
bibdata = parser.parse_file("../rp.bib")
labels = sorted(bibdata.entries.keys())

# import results
try:
    f = open(citationsFile,'r')
    citations = json.load(f)
    f.close()
except:
    citations = {}

# import crossref cache
try:
    # Try to open the existing file
    with open(cacheFile, 'r') as f:
        crossrefID = json.load(f)
except FileNotFoundError:
    # Handle the case where the file does not exist
    crossrefID = {}


#%% prepare list of all labels = node list
maxCnt = 3000 # can be used stop after this number of checks (for development purposes)
minCnt = 3386 # start entry (skip previous entries)
maxCnt = len(bibdata.entries)

nodeID = {}
paperyear = {}
#citations = {} # should be uncommented only if new citations file should be created

cnt = 0

for bib_id in labels:
    nodeID[bib_id] = cnt
    paperyear_ = bibdata.entries[bib_id].fields["year"].replace('in press','2022')
    
    paperyear[bib_id] = int(paperyear_)
    cnt += 1



#%% collect data from crossref/ semanticscholar
cnt = 0
for bib_id in labels:

    if cnt > maxCnt:
       break

    cnt += 1

    if cnt < minCnt:
       continue

    print("(" + str(cnt) + "/" + str(len(labels)) + ") ==========================================")
    print(f"\nCurrent paper: \033[1m" + bib_id + "\033[0m")

    if bib_id in citations:# and citations[bib_id] != []:
       if citations[bib_id]:
          print('paper already included')
          continue
    
    if cnt < 0: # or "marwan2007" != bib_id:
       continue

    # get citations from crossref
    # prepare crossref query
    doi = None
    title = None
    author = None
    journal = None
    year = None

    author = replace_all(str(bibdata.entries[bib_id].persons['author'][0].last_names[0]), dic)
    if "doi" in bibdata.entries[bib_id].fields:
        doi = bibdata.entries[bib_id].fields['doi']
    if "title" in bibdata.entries[bib_id].fields:
        title = replace_all(str(bibdata.entries[bib_id].fields["title"]), dic)
    if "journal" in bibdata.entries[bib_id].fields:
        journal = bibdata.entries[bib_id].fields["journal"]
    if "year" in bibdata.entries[bib_id].fields:
        year = bibdata.entries[bib_id].fields["year"]

    # get references via crossref
    c1 = get_semanticscholar_data(doi, title, year)
    c2 = get_crossref_data(doi, title, author, journal)
    c = c1 + c2
           

    # go through each cited paper and get its bibtex label
    citationslist = []
    found = [0] * len(c)
    for i in range(len(c)):
        doi = ''
        author = 'XXXXXXX'
        year = ''
        title = ''
        if c[i]['author']:
            author = c[i]['author'].lower() # get author
        if c[i]['year']:
            year = c[i]['year']
        if c[i]['doi']:
            doi = c[i]['doi'].lower() # get DOI
        if c[i]['title']:
            title = c[i]['title'] # get title

        # look for entry in rp.bib
        #matching = [s for s in labels if author[0:-2].lower() in s[0:-4]] # check only these entries (remove the year from the label)
        matching = labels

        for bib_id2 in matching:

            # found by DOI, then stop
            if 'doi' in bibdata.entries[bib_id2].fields and doi != '' and bibdata.entries[bib_id2].fields["doi"].lower() == doi:


               if (bib_id2 not in citationslist):
                   print("   DOI found citation: (" + str(i) + "/" + str(len(c)) + ") " + bib_id2)
                   citationslist.append(bib_id2)
               found[i] = 1
               break

        for bib_id2 in matching:
            fuzzyRatio = fuzz.ratio(replace_all(bibdata.entries[bib_id2].fields["title"].lower(),dic),title.lower())
            # if not found by DOI, then check title
            #if (bibdata.entries[bib_id2].fields["year"] == year) and (fuzz.ratio(bibdata.entries[bib_id2].fields["title"],title) > 95):
            if found[i] == 0 and fuzzyRatio > 95 and (bib_id2 not in citationslist):

               # skip duplicate publication
               if bib_id2 == "eckmann1995" and "eckmann87" in citationslist:
                  break

               print(f"   Fuzzy found citation [{fuzzyRatio}]: ({str(i)}/{str(len(c))}) {bib_id2}")
               
               if year != '' and bibdata.entries[bib_id2].fields["year"] != 'in press' and int(year) == int(bibdata.entries[bib_id2].fields["year"]):
                  citationslist.append(bib_id2)
                  print(f"   {year} is matching with {bibdata.entries[bib_id2].fields['year']}")
                  break
               elif year == '':
                  citationslist.append(bib_id2)
                  print("   no year available")
                  break
               else:
                  print(f"   {year} not matching with {bibdata.entries[bib_id2].fields['year']}")
                  break
       

    if 1: #citationslist != []:
       citations[bib_id] = citationslist        

    #%% save results
    if cnt % 2 == 0:
        sorted_dict = OrderedDict(sorted(citations.items()))
        # sort the citation lists for each key
        for key in sorted_dict:
            sorted_dict[key].sort()
        f = open(citationsFile,'w')
        json.dump(sorted_dict, f, indent=4)
        f.close()
        print("<< Write citationsFile")

sorted_dict = OrderedDict(sorted(citations.items()))
# sort the citation lists for each key
for key in sorted_dict:
    sorted_dict[key].sort()
f = open(citationsFile,'w')
json.dump(sorted_dict, f, indent=4)
f.close()

end_time = time.time()
script_time = end_time - start_time

print(f"Search/ download of citations in {script_time} sec.")



# #%% number of citations from semanticscholar 
# # (alternative to simply count from retrieved data)
# # NOT USED!
# numCitations = {}
# doilist = []
# biblist = []
# cnt = 0
# bar = ChargingBar('get number of citations', max=len(labels))
# for bib_id in labels:
#     bar.next()
#     if bib_id in numCitations:
#         continue
#     if "doi" in bibdata.entries[bib_id].fields:
#         doilist.append("DOI:{}".format(bibdata.entries[bib_id].fields['doi']))
#         biblist.append(bib_id)
#         cnt += 1
#     if cnt > 490:
#         c = get_semanticscholar_numcitations(doilist)
#         for bib, citation_count in zip(biblist, c):
#             numCitations[bib] = citation_count
#         cnt = 0
#         doilist = []
#         biblist = []
# bar.finish()


