from __future__ import annotations
import re
import unicodedata
import pandas as pd
from pathlib import Path
from .utils import ROOT, read_jsonl

# 日本語文字クラス（ひら/カタ/漢字/々/〆 等）
JP_CHAR = r"\u3040-\u30FF\u4E00-\u9FFF\u3005-\u3007"
# 記号・句読点（日本語も含む）
PUNCTS = ".,;:?!。、「」『』（）()［］[]｛｝{}・…—〜ー・！？”“''"

# 呼称アンカー（英日両対応）
ANCHORS = [
    r"\bEmina\b",
    r"\[NAME\]",
    r"エミナ",
    r"［NAME］",
]
ANCHOR_PATTERNS = [re.compile(p) for p in ANCHORS]
TOKEN_PATTERN = re.compile(r"[A-Za-z0-9_]+|[" + JP_CHAR + r"]")

def _normalize(text: str) -> str:
    # 全角/半角のゆらぎを統一
    return unicodedata.normalize("NFKC", text)

def _ja_safe_tokens(text: str):
    """形態素なしの簡易トークン化（英数=単語，日本語=1文字近似）"""
    return TOKEN_PATTERN.findall(text)

def anchor_density(text: str) -> float:
    t = _normalize(text)
    n_anchor = sum(len(p.findall(t)) for p in ANCHOR_PATTERNS)
    tokens = _ja_safe_tokens(t)
    return n_anchor / max(1, len(tokens))

def style_features(text: str) -> dict:
    t = _normalize(text)
    tokens = _ja_safe_tokens(t)
    punct = sum(1 for ch in t if ch in PUNCTS)
    punct_ratio = punct / max(1, len(t))
    avg_tok = (sum(len(tok) for tok in tokens) / max(1, len(tokens)))
    return {"punct_ratio": punct_ratio, "avg_token": avg_tok}

def extract(in_path: Path = ROOT / "data/logs.jsonl",
            out_path: Path = ROOT / "reports/features.parquet") -> None:
    rows = []
    for rec in read_jsonl(in_path):  # 1行=1発話
        feats = style_features(rec["text"])
        feats["anchor_density"] = anchor_density(rec["text"])
        feats.update({
            "session_id": rec["session_id"],
            "turn_id": rec["turn_id"],
            "role": rec["role"],
        })
        rows.append(feats)

    df = pd.DataFrame(rows)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    try:
        df.to_parquet(out_path, index=False)
        print(f"features -> {out_path} ({len(df)} rows)")
    except Exception as e:
        # Parquetが書けない環境向けフォールバック
        csv_fallback = out_path.with_suffix(".csv")
        df.to_csv(csv_fallback, index=False, encoding="utf-8")
        print(f"[WARN] parquet failed ({e.__class__.__name__}); fallback -> {csv_fallback}")

if __name__ == "__main__":
    extract()
