### This python script plots two heatmaps in Figure 2 ###


# Install required packages
# pip install pandas seaborn matplotlib scikit-learn unidecode

# Import libraries
import os
import sys
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from unidecode import unidecode
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors

# Set working dir 
DATA_PATH = None

def initialize_global_path():
    global DATA_PATH
    if len(sys.argv) < 2:
        print("Usage: python script.py <absolute_data_path>")
        sys.exit(1)
    DATA_PATH = sys.argv[1]
    if not os.path.isdir(DATA_PATH):
        print(f"Error: Directory '{DATA_PATH}' does not exist.")
        sys.exit(1)

# Call the function to initialize the global path
initialize_global_path()

# Read and preprocess data
def load_and_preprocess_data(filepath):
    df = pd.read_stata(filepath)
    df = df[df['theme'] == 'event'][['firm', 'word']]
    df['firm'] = df['firm'].replace({'yy': 'YY', '9158': '9158', 'sinashow': 'Sina Show'})
    df['word'] = df['word'].str.strip().apply(unidecode)
    return df

# Extract word lists
def get_word_lists(df, categories):
    word_lists = {category: df.loc[df['firm'] == category, 'word'].tolist() for category in categories}
    return word_lists

# Calculate similarity matrix
def calculate_similarity_matrix(word_lists, similarity_measure='jaccard'):
    categories = list(word_lists.keys())
    n = len(categories)
    similarity_matrix = np.zeros((n, n))

    for i, cat_i in enumerate(categories):
        for j, cat_j in enumerate(categories):
            if i == j:
                similarity_matrix[i, j] = 1.0
            else:
                if similarity_measure == 'jaccard':
                    set_i, set_j = set(word_lists[cat_i]), set(word_lists[cat_j])
                    similarity_matrix[i, j] = len(set_i & set_j) / len(set_i | set_j)
                elif similarity_measure == 'max':
                    set_i, set_j = set(word_lists[cat_i]), set(word_lists[cat_j])
                    intersection = len(set_i & set_j)
                    similarity_matrix[i, j] = max(intersection / len(set_i), intersection / len(set_j))
                elif similarity_measure == 'cosine':
                    vectorizer = TfidfVectorizer()
                    tfidf_matrix = vectorizer.fit_transform(word_lists[cat_i] + word_lists[cat_j])
                    cosine_sim = cosine_similarity(tfidf_matrix[:len(word_lists[cat_i])], tfidf_matrix[len(word_lists[cat_i]):])
                    similarity_matrix[i, j] = cosine_sim.mean()
    return similarity_matrix

# Plot and save heatmap
def plot_and_save_heatmap(matrix, categories, title, save_path):
    plt.figure(figsize=(8, 6))
    sns.set(font_scale=1.2)
    sns.heatmap(matrix, annot=True, cmap='Blues', linewidths=0.5, xticklabels=categories,
                yticklabels=categories)
    plt.title('Similarity Matrix ({})'.format(measure))
    output_file = f"{save_path}/Figure2_heatmap_{measure}.png"
    plt.savefig(output_file)
    #plt.show()
    plt.close() 


# Main execution
filepath = os.path.join(DATA_PATH,'data/SVP_keywords.dta')
output_path = os.path.join(DATA_PATH,'output')  # Path to save heatmaps
os.makedirs(output_path, exist_ok=True)  # Create directory if it doesn't exist

firm_categories = ['YY', '9158', 'Sina Show']

# Load and preprocess data
df = load_and_preprocess_data(filepath)

# Generate word lists
word_lists = get_word_lists(df, firm_categories)

# Calculate similarity matrices and save plots
similarity_measures = ['jaccard', 'max']
for measure in similarity_measures:
    similarity_matrix = calculate_similarity_matrix(word_lists, similarity_measure=measure)
    plot_and_save_heatmap(similarity_matrix, firm_categories, measure.capitalize(), output_path)