"""
Generate the emotion-history correlation network figure.

Project structure after decompression:
    project_root/
    ├── code/
    ├── data/
    ├── figures/        # created automatically when scripts are run
    ├── README.md
    └── requirements.txt

This script reads:
    data/correlation_network.xlsx

The input file should be the combined correlation table generated by:
    build_correlation_network.py

This script applies threshold filtering for network visualization:
    - Basic emotion to complex emotion / historical-understanding edges:
      r > 0.30 and p < 0.05
    - Complex emotion to historical-understanding edges:
      r > 0.50 and p < 0.05

This script writes:
    figures/emotion_history_correlation_network.png
"""

from pathlib import Path

import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D




# =========================
# 1. Path settings
# =========================
BASE_DIR = Path(__file__).resolve().parents[1]
DATA_DIR = BASE_DIR / "data"
FIGURES_DIR = BASE_DIR / "figures"

FIGURES_DIR.mkdir(exist_ok=True)

INPUT_FILE = DATA_DIR / "correlation_network.xlsx"
OUTPUT_FILE = FIGURES_DIR / "emotion_history_correlation_network.png"


# =========================
# 2. Node categories
# =========================
basic_emotions = {"joy", "anger", "surprise", "sadness", "hate"}

complex_emotions = {
    "nostalgia",
    "frustration_and_accomplishment",
    "educated",
    "negative_emotion",
    "awe_and_reflection",
}

historical_dimensions = {
    "historical_authenticity",
    "perspective_through_player_choices",
    "multiplicity_and_complexity_of_perspective",
    "awareness_of_cultural_heritage_protection",
    "awareness_of_difference_between_games_and_archaeology",
}


# =========================
# 3. Mapping from subcategories to network nodes
# =========================
complex_mapping = {
    "nostalgia": [
        "nostalgia",
        "cultural_heritage_nostalgia",
        "personal_memory_nostalgia",
        "collective_identity_nostalgia",
        "golden_age_nostalgia",
        "romanticized_nostalgia",
        "lost_culture_nostalgia",
    ],
    "frustration_and_accomplishment": [
        "frustration_and_accomplishment",
        "failure_related_emotions",
        "unfair_difficulty_frustration",
        "exploration_restriction_frustration",
        "level_challenge_achievement",
        "exploration_achievement",
        "character_progression_achievement",
        "historical_constraint_frustration",
    ],
    "educated": [
        "educated",
        "satisfaction_from_learning",
        "aha_moment_or_epiphany",
        "immersed_in_historical_learning",
        "emotionally_moved_by_learning",
    ],
    "negative_emotion": [
        "negative_emotion",
        "fear_of_violence_or_war",
        "rejection_of_imperialism_or_colonialism",
        "guilt_over_historical_injustice",
        "frustration_with_narrative_injustice",
        "sorrow_for_historical_tragedies",
    ],
    "awe_and_reflection": [
        "awe_and_reflection",
        "awe_at_the_grandeur_of_civilization",
        "reflection_on_the_passage_of_time",
        "reevaluating_personal_historical_perspectives",
        "respect_for_cultural_diversity",
        "moral_reflection_on_history",
    ],
}

complex_lookup = {}
for category, subcols in complex_mapping.items():
    for subcol in subcols:
        complex_lookup[subcol] = category


def map_node(raw_node: str):
    """
    Map raw variable names or subcategory names to the aggregated network node names.
    """
    node = str(raw_node).strip()

    # Basic emotions
    if node in basic_emotions:
        return node

    # Complex emotions and subcategories
    if node in complex_lookup:
        return complex_lookup[node]

    # Historical-understanding dimensions and subcategories
    if node == "historical_authenticity" or node.startswith("Historical_authenticity_"):
        return "historical_authenticity"

    if node == "perspective_through_player_choices" or node.startswith(
        "Historical_Perspective_through_Player_Choices_"
    ):
        return "perspective_through_player_choices"

    if node == "multiplicity_and_complexity_of_perspective" or node.startswith(
        "Multiplicity_and_Complexity_of_Historical_Perspective_"
    ):
        return "multiplicity_and_complexity_of_perspective"

    if node == "awareness_of_cultural_heritage_protection" or (
        node.startswith("Awareness_of_protection_of_culture_heritage_")
        and node != "Awareness_of_protection_of_culture_heritage_Archaeological_Perspective"
    ):
        return "awareness_of_cultural_heritage_protection"

    if node == "awareness_of_difference_between_games_and_archaeology" or node.startswith(
        "Awareness_the_difference_of_games_and_archaeology_"
    ) or node == "Awareness_of_protection_of_culture_heritage_Archaeological_Perspective":
        return "awareness_of_difference_between_games_and_archaeology"

    return None


def split_comparison(comparison: str):
    """
    Split a comparison string such as 'joy ~ nostalgia' into two raw node names.
    """
    parts = str(comparison).split("~")

    if len(parts) != 2:
        return None, None

    node1 = parts[0].strip()
    node2 = parts[1].strip()

    return node1, node2


def classify_edge(node1: str, node2: str):
    """
    Classify edge type according to the aggregated node categories.
    """
    pair = {node1, node2}

    if pair & basic_emotions and pair & complex_emotions:
        return "basic_complex"

    if pair & basic_emotions and pair & historical_dimensions:
        return "basic_history"

    if pair & complex_emotions and pair & historical_dimensions:
        return "complex_history"

    return "other"


# =========================
# 4. Read and filter correlation table
# =========================
df = pd.read_excel(INPUT_FILE)
df.columns = [str(col).strip() for col in df.columns]

required_columns = ["Comparison", "Correlation Coefficient", "p-value"]
missing_columns = [col for col in required_columns if col not in df.columns]

if missing_columns:
    raise ValueError(f"Missing required columns in {INPUT_FILE.name}: {missing_columns}")

filtered_edges = []

for _, row in df.iterrows():
    raw_node1, raw_node2 = split_comparison(row["Comparison"])

    if raw_node1 is None or raw_node2 is None:
        continue

    node1 = map_node(raw_node1)
    node2 = map_node(raw_node2)

    if node1 is None or node2 is None:
        continue

    if node1 == node2:
        continue

    r = pd.to_numeric(row["Correlation Coefficient"], errors="coerce")
    p = pd.to_numeric(row["p-value"], errors="coerce")

    if pd.isna(r) or pd.isna(p):
        continue

    edge_type = classify_edge(node1, node2)

    if edge_type in {"basic_complex", "basic_history"}:
        threshold = 0.30
    elif edge_type == "complex_history":
        threshold = 0.50
    else:
        continue

    # Positive correlations only, consistent with the interpretation in the manuscript.
    if r > threshold and p < 0.05:
        filtered_edges.append(
            {
                "node1": node1,
                "node2": node2,
                "Correlation Coefficient": r,
                "p-value": p,
                "edge_type": edge_type,
                "game": row["game"] if "game" in df.columns else "combined",
            }
        )


edge_df = pd.DataFrame(
    filtered_edges,
    columns=[
        "node1",
        "node2",
        "Correlation Coefficient",
        "p-value",
        "edge_type",
        "game",
    ],
)

if edge_df.empty:
    raise ValueError(
        "No edges passed the filtering thresholds. "
        "Please check the correlation table, node mappings, or threshold values."
    )

# If multiple subcategory-level correlations map to the same aggregated edge,
# keep the strongest positive correlation.
edge_df = edge_df.sort_values("Correlation Coefficient", ascending=False)
edge_df = edge_df.drop_duplicates(subset=["node1", "node2"], keep="first")

print(f"Number of retained network edges: {len(edge_df)}")


# =========================
# 5. Build network
# =========================
G = nx.Graph()

for _, row in edge_df.iterrows():
    G.add_edge(
        row["node1"],
        row["node2"],
        weight=row["Correlation Coefficient"],
        p_value=row["p-value"],
        edge_type=row["edge_type"],
        game=row["game"],
    )


# =========================
# 6. Label wrapping
# =========================
label_map = {
    "awareness_of_difference_between_games_and_archaeology":
        "awareness_of_difference\nbetween_games_and_archaeology",
    "awareness_of_cultural_heritage_protection":
        "awareness_of_cultural\nheritage_protection",
    "multiplicity_and_complexity_of_perspective":
        "multiplicity_and_complexity\nof_perspective",
    "perspective_through_player_choices":
        "perspective_through\nplayer_choices",
    "historical_authenticity":
        "historical_authenticity",
    "frustration_and_accomplishment":
        "frustration_and\naccomplishment",
    "negative_emotion":
        "negative_emotion",
    "awe_and_reflection":
        "awe_and_reflection",
}

labels = {node: label_map.get(node, node) for node in G.nodes()}


# =========================
# 7. Colors
# =========================
color_base = "#FDB462"       # Basic emotions
color_complex = "#66C2A5"    # Complex emotions
color_hist = "#8E6CAB"       # Historical understanding
color_default = "#BDBDBD"

node_colors = []

for node in G.nodes():
    if node in basic_emotions:
        node_colors.append(color_base)
    elif node in complex_emotions:
        node_colors.append(color_complex)
    elif node in historical_dimensions:
        node_colors.append(color_hist)
    else:
        node_colors.append(color_default)


# =========================
# 8. Node size by weighted degree
# =========================
weighted_degrees = dict(G.degree(weight="weight"))
node_sizes = [1200 + weighted_degrees[node] * 900 for node in G.nodes()]


# =========================
# 9. Edge width by correlation coefficient
# =========================
edges = list(G.edges(data=True))
edge_widths = [1 + d["weight"] * 6 for _, _, d in edges]


# =========================
# 10. Layout
# =========================
pos = nx.kamada_kawai_layout(G, weight="weight")

manual_shift = {
    "hate": (-0.18, 0.00),
    "joy": (0.14, -0.10),
    "frustration_and_accomplishment": (0.10, -0.06),
    "historical_authenticity": (0.12, -0.01),
    "awareness_of_difference_between_games_and_archaeology": (-0.05, 0.02),
    "perspective_through_player_choices": (-0.02, 0.06),
    "multiplicity_and_complexity_of_perspective": (0.08, 0.03),
}

for node, (dx, dy) in manual_shift.items():
    if node in pos:
        pos[node][0] += dx
        pos[node][1] += dy


# =========================
# 11. Label positions
# =========================
label_pos = {}

for node, (x, y) in pos.items():
    if node in historical_dimensions:
        label_pos[node] = (x, y + 0.03)
    elif node in complex_emotions:
        label_pos[node] = (x, y - 0.03)
    else:
        label_pos[node] = (x, y + 0.025)


# =========================
# 12. Plot
# =========================
plt.figure(figsize=(16, 10), facecolor="white")
ax = plt.gca()
ax.set_facecolor("white")

nx.draw_networkx_edges(
    G,
    pos,
    width=edge_widths,
    edge_color="#8F8F8F",
    alpha=0.60,
)

nx.draw_networkx_nodes(
    G,
    pos,
    node_color=node_colors,
    node_size=node_sizes,
    alpha=0.95,
    edgecolors="white",
    linewidths=1.8,
)

nx.draw_networkx_labels(
    G,
    label_pos,
    labels=labels,
    font_size=11,
    font_color="black",
    font_family="sans-serif",
    bbox=dict(
        facecolor="white",
        edgecolor="none",
        alpha=0.75,
        boxstyle="round,pad=0.15",
    ),
)


# =========================
# 13. Legend
# =========================
legend_elements = [
    Line2D(
        [0], [0],
        marker="o",
        color="w",
        label="Basic emotions",
        markerfacecolor=color_base,
        markersize=12,
    ),
    Line2D(
        [0], [0],
        marker="o",
        color="w",
        label="Complex emotions",
        markerfacecolor=color_complex,
        markersize=12,
    ),
    Line2D(
        [0], [0],
        marker="o",
        color="w",
        label="Historical understanding",
        markerfacecolor=color_hist,
        markersize=12,
    ),
]

plt.legend(
    handles=legend_elements,
    loc="upper left",
    frameon=False,
    fontsize=11,
)


# =========================
# 14. Title and save
# =========================
plt.title(
    "Emotion-History Understanding Correlation Network",
    fontsize=18,
    pad=18,
)

plt.axis("off")
plt.margins(0.20)
plt.tight_layout()

plt.savefig(OUTPUT_FILE, dpi=600, bbox_inches="tight")
plt.close()

print(f"Figure saved to: {OUTPUT_FILE}")