#!/usr/bin/env python3
"""Aggregate annotated discussion files into a single JSON for the visualizer.

Globs annotation/*_annotated.json, walks every post to compute statistics
(code frequency, co-occurrence, per-document counts), and outputs a combined
JSON structure to stdout.  Used by build_visualizer.sh to inject data into
annotation_explorer.html.
"""

import json, sys
from collections import Counter, defaultdict
from itertools import combinations
from pathlib import Path

BASE = Path(__file__).parent
ANNOTATION = BASE / "annotation"


def walk_posts(threads):
    """Yield every post object in the thread tree (depth-first)."""
    for post in threads:
        yield post
        yield from walk_posts(post.get("responses", []))


def count_corpus_posts():
    """Count total posts in data/*.json (pre-filtering baseline)."""
    data_dir = BASE / "data"
    if not data_dir.is_dir():
        return None

    def count_reddit_comments(comments):
        n = 0
        for c in comments:
            n += 1
            n += count_reddit_comments(c.get("replies", []))
        return n

    total = 0
    for p in sorted(data_dir.glob("*.json")):
        d = json.loads(p.read_text())
        if "comments" in d:
            # Reddit: OP + nested comment tree
            total += 1 + count_reddit_comments(d["comments"])
        elif "results" in d:
            # HN: list of {story, comments: [...]}
            for r in d["results"]:
                if isinstance(r, dict):
                    total += len(r.get("comments", []))
    return total if total > 0 else None


def main():
    codebook_path = Path(sys.argv[1]) if len(sys.argv) > 1 else BASE / "codebook/codebook_v8.json"
    codebook = json.loads(codebook_path.read_text())
    total_corpus_docs = len(codebook.get("sources", {}))

    files = sorted(ANNOTATION.glob("*_annotated.json"))
    if not files:
        # Output empty structure
        json.dump({"documents": [], "stats": {
            "total_posts": 0, "coded_posts": 0,
            "total_corpus_docs": total_corpus_docs, "annotated_docs": 0,
            "code_frequency": {}, "doc_code_counts": {},
            "cooccurrence": {}, "doc_post_counts": {},
        }}, sys.stdout)
        return

    documents = []
    code_freq = Counter()
    cooccurrence = Counter()
    doc_code_counts = {}
    doc_post_counts = {}
    total_posts = 0
    coded_posts = 0

    for path in files:
        doc = json.loads(path.read_text())
        documents.append(doc)
        doc_id = doc["doc_id"]

        doc_codes = Counter()
        doc_total = 0
        doc_coded = 0

        for post in walk_posts(doc.get("threads", [])):
            doc_total += 1
            codes = post.get("codes") or []
            if codes:
                doc_coded += 1
                for c in codes:
                    code_freq[c] += 1
                    doc_codes[c] += 1
                # Co-occurrence: count each pair once per post
                for a, b in combinations(sorted(set(codes)), 2):
                    cooccurrence[f"{a}|||{b}"] += 1

        total_posts += doc_total
        coded_posts += doc_coded
        doc_code_counts[doc_id] = dict(doc_codes)
        doc_post_counts[doc_id] = {"total": doc_total, "coded": doc_coded}

    # Sort documents by doc_id
    documents.sort(key=lambda d: d["doc_id"])

    # Count raw corpus posts for pre-filtering baseline
    corpus_total = count_corpus_posts()

    stats = {
        "total_posts": total_posts,
        "coded_posts": coded_posts,
        "total_corpus_docs": total_corpus_docs,
        "annotated_docs": len(documents),
        "code_frequency": dict(code_freq.most_common()),
        "doc_code_counts": doc_code_counts,
        "cooccurrence": dict(cooccurrence.most_common()),
        "doc_post_counts": doc_post_counts,
    }
    if corpus_total is not None:
        stats["corpus_total_posts"] = corpus_total

    result = {
        "documents": documents,
        "stats": stats,
    }

    json.dump(result, sys.stdout, ensure_ascii=False)


if __name__ == "__main__":
    main()
