# -*- coding: utf-8 -*-
"""DAG_Data

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1Jzv-TwZ9NYjYQ0df__lk4curwh81irii

# MULTI-PROVIDER BENCHMARK: LATENCY & TOKEN OPTIMIZATION IN MULTI-AGENT SYSTEMS
# Sequential Baseline vs. Optimized DAG (LangGraph + Prompt Caching + MCP Filter)
# Providers: Anthropic · OpenAI  (Google/Gemini arm implemented but not run in this study)
# Target: Google Colab (Python 3.10+)

## WHAT THIS DOES
Runs the IDENTICAL "Sequential Baseline vs Optimized DAG" workflow across a 4-model
matrix, holding the model constant within each run so the measured delta reflects the
ARCHITECTURE (parallel DAG + prompt caching + MCP payload filtering), not model tiering.

    15 prompts × 4 models × 2 systems = 120 rows → multi_provider_benchmark.csv
    4 repeated runs → 480 measured rows total (used for median-based robust estimation)

Each baseline row = 3 LLM calls; each optimized row = 4 LLM calls (3 workers + reducer).
A full live run issues 15 × 4 × (3 + 4) = 420 API calls per run. PLAN COST ACCORDINGLY.
Use DRY_RUN=True to validate the full pipeline and CSV shape for $0 first.

## PER-PROVIDER FACTS (verified against current docs)

TOKEN METADATA — LangChain normalizes providers into the standardized
AIMessage.usage_metadata: input_tokens / output_tokens / total_tokens, plus
input_token_details.{cache_read, cache_creation}. In this field, input_tokens
INCLUDES cached tokens (unlike the raw Anthropic SDK, which excludes them).
Provider-specific raw parse of response_metadata is used as a documented fallback:
  - OpenAI    : response_metadata['token_usage'] → prompt_tokens / completion_tokens;
                cached at token_usage['prompt_tokens_details']['cached_tokens']
  - Anthropic : response_metadata['usage'] → input_tokens / output_tokens;
                cache_read_input_tokens / cache_creation_input_tokens
                (raw Anthropic input_tokens EXCLUDES cache → normalized to inclusive)

CACHING MECHANICS — Anthropic requires an explicit flag; OpenAI is automatic:
  - Anthropic : explicit cache_control={"type":"ephemeral"} on the system block.
                Min cacheable prefix: claude-haiku-4-5 = 4096 tok, claude-sonnet-4-6 = 1024 tok.
  - OpenAI    : AUTOMATIC for prompts ≥1024 tok. No flag exists or is needed.
                We rely on automatic prefix caching (longest previously computed prefix reused).

The shared system block is sized at ~4,621 tokens, clearing every provider's minimum
so caching engages on all models in the optimized arm.

LANGGRAPH FAN-IN — parallel workers return DELTAS ONLY; the accumulated channels
(worker_outputs, all_turn_metrics) use Annotated[..., operator.add] reducers,
so the parallel fan-in never raises InvalidUpdateError.

NOTE: Google/Gemini arm is implemented in code but was not executed in this study.
All reported results cover Anthropic (claude-haiku-4-5, claude-sonnet-4-6) and
OpenAI (gpt-4o-mini, gpt-4o) only. The Google rows remain commented out for
future reproducibility.
"""

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 0: SELF-INSTALLING BOOTSTRAP
# ═══════════════════════════════════════════════════════════════════════════════
import sys
import subprocess
import importlib.util

_PACKAGES_TO_INSTALL = [
    "anthropic>=0.49.0",
    "openai>=1.55.0",
    "langchain-core>=0.3.40",
    "langchain-anthropic>=0.3.0",
    "langchain-openai>=0.3.0",
    "langchain-google-genai>=2.0.0",
    "langgraph>=0.2.0",
    "tenacity>=8.2.0",
    "pandas>=2.0.0",
    "numpy>=1.24.0",
    "matplotlib>=3.7.0",
]

_PROBE_MAP = {
    "anthropic":              "anthropic",
    "openai":                 "openai",
    "langchain-core":         "langchain_core",
    "langchain-anthropic":    "langchain_anthropic",
    "langchain-openai":       "langchain_openai",
    "langchain-google-genai": "langchain_google_genai",
    "langgraph":              "langgraph",
    "tenacity":               "tenacity",
    "pandas":                 "pandas",
    "numpy":                  "numpy",
    "matplotlib":             "matplotlib",
}


def _bootstrap_install() -> None:
    missing = [pkg for pkg, mod in _PROBE_MAP.items()
               if importlib.util.find_spec(mod) is None]
    if not missing:
        print("[BOOTSTRAP] All packages already present — skipping install.")
        return
    print(f"[BOOTSTRAP] Installing {len(missing)} missing package group(s): {missing}")
    print("[BOOTSTRAP] This runs once (~60–90 s in Colab) …\n")
    result = subprocess.run(
        [sys.executable, "-m", "pip", "install", "-q"] + _PACKAGES_TO_INSTALL,
        capture_output=True, text=True,
    )
    if result.returncode != 0:
        print("[BOOTSTRAP] pip stderr (last 2500 chars):\n", result.stderr[-2500:])
        raise RuntimeError(
            f"pip install failed (exit {result.returncode}). "
            "Check connectivity and re-run the cell."
        )
    importlib.invalidate_caches()
    print("[BOOTSTRAP] Installation complete.\n")


_bootstrap_install()

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 1: IMPORTS & CONFIGURATION
# ═══════════════════════════════════════════════════════════════════════════════
import os
import json
import time
import uuid
import random
import logging
import warnings
import textwrap
import operator
import traceback
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Annotated

import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)

from langgraph.graph import StateGraph, START, END
from langchain_core.messages import HumanMessage, SystemMessage
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

# Provider chat-model classes
from langchain_anthropic import ChatAnthropic
from langchain_openai import ChatOpenAI
from langchain_google_genai import ChatGoogleGenerativeAI


# ─── RUN KNOBS ────────────────────────────────────────────────────────────────
# DRY_RUN=True  → no network, deterministic MOCK metrics. Validates the full DAG,
#                 the 180-row matrix, token accounting, caching plumbing, and the
#                 CSV/chart pipeline for $0. Flip to False for the real benchmark.
DRY_RUN = False

MAX_OUTPUT_TOKENS = 1024          # cap on generated tokens per call
RATE_LIMIT_DELAY   = 1.0           # seconds between sequential API calls (be polite)
REQUEST_TIMEOUT   = 120           # seconds per call

# Optional subsetting for cheaper partial runs (None = use the full lists)
PROMPT_LIMIT      = None          # e.g. 3 to run only the first 3 prompts
PROVIDER_FILTER   = None          # e.g. {"anthropic"} to run a single provider

OUTPUT_CSV   = "multi_provider_benchmark_2.csv"
SUMMARY_CSV  = "multi_provider_summary_2.csv"
CHART_PNG    = "multi_provider_benchmark_2.png"
MAKE_CHART   = True

# ─── MODEL MATRIX ─────────────────────────────────────────────────────────────
# (provider, model_id, role_label). Each model runs its OWN full baseline + DAG.
MODEL_MATRIX: List[Tuple[str, str, str]] = [
    ("anthropic", "claude-haiku-4-5",   "fast"),
    ("anthropic", "claude-sonnet-4-6",  "strong"),
    ("openai",    "gpt-4o-mini",        "fast"),
    ("openai",    "gpt-4o",             "strong"),
    # ("google",    "gemini-2.5-flash",   "fast"),
    # ("google",    "gemini-2.5-pro",     "strong"),
]

# Minimum cacheable prefix (tokens) per (provider, model). Our shared system block
# is sized above the largest of these so caching engages everywhere in the opt arm.
CACHE_MIN_TOKENS: Dict[str, int] = {
    "claude-haiku-4-5":  4096,
    "claude-sonnet-4-6": 1024,
    "gpt-4o-mini":       1024,
    "gpt-4o":            1024,
    # "gemini-2.5-flash":  1024,
    # "gemini-2.5-pro":    2048,
}

# Env var name holding each provider's API key
PROVIDER_ENV: Dict[str, str] = {
    "anthropic": "ANTHROPIC_API_KEY",
    "openai":    "OPENAI_API_KEY",
    # "google":    "GOOGLE_API_KEY",
}

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 2: API KEY SETUP
# ═══════════════════════════════════════════════════════════════════════════════
def get_api_keys(required_providers: set) -> Dict[str, str]:
    """Load each required provider's key from env, Colab secrets, or prompt."""
    keys: Dict[str, str] = {}
    if DRY_RUN:
        return {p: "dry-run-no-key" for p in required_providers}

    try:
        from google.colab import userdata  # type: ignore
        _has_colab = True
    except Exception:
        userdata = None
        _has_colab = False

    for provider in sorted(required_providers):
        env_name = PROVIDER_ENV[provider]
        key = os.environ.get(env_name, "")
        if not key and _has_colab:
            try:
                key = userdata.get(env_name) or ""
            except Exception:
                key = ""
        if not key:
            key = input(f"Enter {env_name} ({provider}): ").strip()
        if not key:
            raise ValueError(f"{env_name} is required for provider '{provider}'.")
        os.environ[env_name] = key
        keys[provider] = key
    return keys

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 3: 15 COMPLEX AI INFRASTRUCTURE RESEARCH PROMPTS
# ═══════════════════════════════════════════════════════════════════════════════
EVAL_PROMPTS: List[Dict[str, str]] = [
    {"id": "P01", "topic": "Speculative Decoding",
     "prompt": ("Derive the expected throughput gain of speculative decoding pairing a 7B draft "
                "model with a 70B target. Model the token acceptance rate alpha as a function of "
                "task entropy, quantify KV-cache overhead per speculative step, and propose a "
                "controller that adapts speculation depth k to live rejection rates. Include the "
                "math and the cost-ratio break-even condition.")},
    {"id": "P02", "topic": "MoE Routing Efficiency",
     "prompt": ("Formalize the expert load-balancing problem for a Mixture-of-Experts transformer "
                "served across 128 GPUs. Define the expert capacity factor, derive when token "
                "dropping occurs, and propose a two-phase (coarse+fine) router minimizing cross-node "
                "traffic while keeping expert utilization above 85%. Compare against Switch and GLaM "
                "with complexity analysis.")},
    {"id": "P03", "topic": "KV-Cache Compression",
     "prompt": ("Design a hierarchical KV-cache eviction policy for 128K-token inference combining "
                "recency-weighted attention scores with semantic clustering of key vectors. Derive "
                "the memory–accuracy Pareto frontier for 2x–16x compression, contrast perplexity "
                "impact on code-gen vs summarization, and compare StreamingLLM, ScissorHands, and "
                "your method.")},
    {"id": "P04", "topic": "Agentic Scaling Laws",
     "prompt": ("Characterize scaling laws for N-agent LLM systems on decomposed subtasks. Relate "
                "agent count, task complexity (minimum description length), coordination overhead "
                "(inter-agent tokens), and output quality. Derive the optimal agent count N* as a "
                "function of task entropy and give a topology-selection criterion (star/ring/all-to-"
                "all) with diminishing-returns analysis.")},
    {"id": "P05", "topic": "Continuous Batching",
     "prompt": ("Analyze continuous (iteration-level) batching versus static batching in a vLLM-style "
                "server. Derive GPU-utilization improvement as a function of Poisson arrival rate "
                "lambda and sequence-length distribution, quantify PagedAttention's fragmentation "
                "reduction, estimate max tokens/sec for an A100-80GB serving a 13B model at fp16, and "
                "propose an SLO-aware preemption policy.")},
    {"id": "P06", "topic": "Prompt-Caching Economics",
     "prompt": ("Build a formal cost model for prompt caching in a multi-tenant deployment. Define "
                "hit rate H as a function of prefix-repetition frequency, 5-minute TTL, and "
                "concurrency; derive the break-even reuse count and prefix length; and model monthly "
                "savings for 10M req/day at 60% prefix overlap. Discuss cache-invalidation strategies "
                "for dynamic context.")},
    {"id": "P07", "topic": "Tensor vs Pipeline Parallelism",
     "prompt": ("Compare tensor parallelism (Megatron) and pipeline parallelism (GPipe/PipeDream) for "
                "serving a 405B model on 64 H100s. Derive communication volume per strategy in terms "
                "of depth D, hidden dim H, and sequence length S, analyze pipeline-bubble overhead "
                "with micro-batch scheduling, and propose a hybrid 3D-parallel config minimizing "
                "batch-size-1 latency.")},
    {"id": "P08", "topic": "RLHF Alignment Tax",
     "prompt": ("Quantify the inference-time 'alignment tax' of RLHF / Constitutional training. "
                "Analyze how safety tuning shifts logit distributions, lengthens refusals, and lowers "
                "speculative-decode acceptance. Propose a sub-5ms inference-time verification layer "
                "external to the main model and formalize its precision–recall trade-off at "
                "production scale.")},
    {"id": "P09", "topic": "Multi-Agent Memory",
     "prompt": ("Design distributed shared memory for a 50-agent system needing short-term working "
                "memory and long-term episodic memory. Formalize the eventual- vs strong-consistency "
                "trade-off, propose a vector-indexed semantic store with ANN retrieval, derive the "
                "read/write latency budget per turn under a 2-second SLO, and model contention with an "
                "M/M/c queue.")},
    {"id": "P10", "topic": "Inference-Time Compute Scaling",
     "prompt": ("Analyze whether inference-time compute (CoT, ToT, majority voting, best-of-N) can "
                "substitute for pre-training scale. Derive the compute-optimal strategy as a function "
                "of difficulty (pass@k), compare token efficiency across CoT/ToT/MCTS, formalize the "
                "quality–latency Pareto frontier under a 10s SLO, and propose an adaptive inference "
                "budget allocator.")},
    {"id": "P11", "topic": "Quantization Trade-offs",
     "prompt": ("Compare INT8, INT4, and FP8 post-training quantization for a 70B model. Derive the "
                "accuracy-vs-memory frontier, analyze outlier-aware schemes (SmoothQuant, AWQ, GPTQ) "
                "and per-channel vs per-tensor scaling, quantify the decode-bandwidth win from smaller "
                "weights, and recommend a scheme for a latency-critical fp8-capable H100 deployment.")},
    {"id": "P12", "topic": "Disaggregated Prefill/Decode",
     "prompt": ("Evaluate prefill/decode disaggregation across separate GPU pools. Given prefill is "
                "compute-bound and decode is memory-bandwidth-bound, derive the optimal pool ratio as "
                "a function of input/output length distribution, model the KV-cache transfer cost over "
                "NVLink/InfiniBand between pools, and analyze the SLO impact on TTFT vs TBT.")},
    {"id": "P13", "topic": "Long-Context Attention",
     "prompt": ("Compare full O(n^2) attention, sliding-window, and linear-attention variants for "
                "1M-token context. Derive memory and compute scaling for each, analyze quality "
                "degradation on retrieval-style 'needle-in-a-haystack' tasks, quantify the Flash-"
                "Attention bandwidth advantage, and propose a hybrid global+local scheme with a "
                "complexity budget.")},
    {"id": "P14", "topic": "Multi-Agent Failure Modes",
     "prompt": ("Construct a failure-mode taxonomy for 10-step agent chains (hallucination, tool "
                "misuse, context overflow, infinite loops, cascading errors). Derive end-to-end "
                "task-completion probability as a function of per-step reliability and chain length, "
                "model the effect of validation/retry layers on the reliability–latency trade-off, "
                "and target 99.9% completion.")},
    {"id": "P15", "topic": "Cross-Provider Routing",
     "prompt": ("Design a cost-aware router across heterogeneous LLM providers with differing price, "
                "latency, caching semantics, and quality. Formalize routing as constrained "
                "optimization over an expected-quality-per-dollar objective with an SLO constraint, "
                "derive when to escalate a low-confidence cheap-tier answer to a strong tier, and "
                "analyze how divergent caching models change the routing calculus.")},
]

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 4: MOCK MCP TOOL SERVER (bloated JSON) + MIDDLEWARE (filtering)
# ═══════════════════════════════════════════════════════════════════════════════
class MockMCPToolServer:
    @staticmethod
    def search(query: str, num_results: int = 4) -> str:
        results = []
        for i in range(num_results):
            results.append({
                "doc_id": str(uuid.uuid4()),
                "request_id": f"mcp-req-{uuid.uuid4().hex[:12]}",
                "shard_key": f"shard-{random.randint(1, 256):03d}",
                "replica_set": f"rs-{random.choice(['primary', 'secondary-1', 'secondary-2'])}",
                "retrieval_timestamp_utc": datetime.utcnow().isoformat(),
                "retrieval_latency_ms": round(random.uniform(12.0, 180.0), 2),
                "index_version": f"v{random.randint(1,5)}.{random.randint(0,9)}.{random.randint(0,9)}",
                "embedding_model": "text-embedding-3-large",
                "embedding_dimensions": 3072,
                "cosine_similarity_score": round(random.uniform(0.62, 0.99), 6),
                "bm25_score": round(random.uniform(0.1, 1.0), 6),
                "rrf_score": round(random.uniform(0.01, 0.1), 8),
                "title": f"Research Paper {i+1}: {query[:40]}... [Section {chr(65+i)}]",
                "url": f"https://arxiv.org/abs/2{random.randint(300,500)}.{random.randint(10000,99999)}",
                "doi": f"10.{random.randint(1000,9999)}/arxiv.{random.randint(2300,2500)}.{random.randint(10000,99999)}",
                "abstract": (f"This paper presents a comprehensive analysis of {query[:30]}. We "
                             f"demonstrate improvements over baselines across benchmarks with reduced "
                             f"compute overhead, validated empirically on standard suites."),
                "authors": [f"Author_{j} Lastname_{j}" for j in range(random.randint(2, 6))],
                "institution": random.choice(["MIT CSAIL", "Stanford AI Lab", "DeepMind", "Google Brain", "CMU LTI"]),
                "year": random.randint(2022, 2025),
                "citations": random.randint(0, 2400),
                "venue": random.choice(["NeurIPS", "ICML", "ICLR", "ACL", "EMNLP"]),
                "tags": [f"tag_{t}" for t in random.sample(range(50), 5)],
                "access_control": {"level": "public", "license": "CC-BY-4.0",
                                   "embargo_until": None, "region_restrictions": []},
                "raw_payload_bytes": [random.randint(0, 255) for _ in range(32)],
                "tracking": {"experiment_id": uuid.uuid4().hex,
                             "ab_variant": random.choice(["control", "treatment_a", "treatment_b"]),
                             "user_session": uuid.uuid4().hex,
                             "cdn_edge_node": f"edge-{random.choice(['us-east','eu-west','ap-south'])}-{random.randint(1,5)}"},
            })
        envelope = {
            "api_version": "2024-11-01", "service": "mcp-search-v3",
            "request_id": str(uuid.uuid4()), "status": "success", "http_status": 200,
            "query": query, "total_hits": random.randint(1200, 48000),
            "returned_hits": num_results, "processing_time_ms": round(random.uniform(45.0, 320.0), 2),
            "quota_remaining": random.randint(500, 10000),
            "quota_reset_epoch": int(time.time()) + 3600,
            "cache_hit": random.choice([True, False]), "server_region": "us-east-1",
            "load_balancer_id": f"lb-{uuid.uuid4().hex[:8]}", "results": results,
        }
        return json.dumps(envelope, separators=(",", ":"))


class MCPMiddleware:
    RESULT_ALLOWLIST = {"title", "abstract", "authors", "year", "venue", "cosine_similarity_score"}

    @staticmethod
    def filter_search_payload(raw_json: str) -> str:
        try:
            data = json.loads(raw_json)
        except json.JSONDecodeError as e:
            logger.warning(f"MCPMiddleware parse error: {e}")
            return "ERROR: could not parse MCP payload"
        status = data.get("status", "unknown")
        total_hits = data.get("total_hits", 0)
        lines = [f"[MCP Search | status={status} | total_hits={total_hits}]"]
        for idx, r in enumerate(data.get("results", []), 1):
            r = {k: v for k, v in r.items() if k in MCPMiddleware.RESULT_ALLOWLIST}
            authors = r.get("authors", [])
            author_str = ", ".join(authors[:2]) + (" et al." if len(authors) > 2 else "")
            lines.append(
                f"[{idx}] {r.get('title','N/A')} | {author_str} | "
                f"{r.get('venue','N/A')} {r.get('year','N/A')} | "
                f"sim={r.get('cosine_similarity_score',0.0):.3f}\n"
                f"    Abstract: {r.get('abstract','')[:220]}..."
            )
        return "\n".join(lines)

    @staticmethod
    def measure_compression(raw_json: str, filtered_str: str) -> Dict[str, int]:
        raw_t = len(raw_json) // 4
        filt_t = len(filtered_str) // 4
        return {"raw_estimated_tokens": raw_t, "filtered_estimated_tokens": filt_t,
                "tokens_saved": raw_t - filt_t,
                "compression_ratio_pct": int((1 - filt_t / max(raw_t, 1)) * 100)}

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 5: SHARED SYSTEM CONTEXT BLOCK (sized > every provider's cache minimum)
# ═══════════════════════════════════════════════════════════════════════════════
def build_system_context_block() -> str:
    """Large shared instruction block (~5k tokens) — identical across optimized-arm
    turns so the cache prefix hits on every provider (Anthropic explicit; OpenAI &
    Gemini implicit)."""
    return textwrap.dedent("""
    ===============================================================================
    SYSTEM: AI INFRASTRUCTURE RESEARCH ANALYSIS FRAMEWORK v3.0
    Classification: RESEARCH ASSISTANT | Mode: EXPERT TECHNICAL ANALYSIS
    ===============================================================================

    ROLE DEFINITION
    You are an expert AI Systems Researcher with deep expertise spanning:
    - Large Language Model (LLM) inference optimization and serving systems
    - Distributed systems architecture for ML workloads at scale
    - Hardware-software co-design for AI accelerators (GPUs, TPUs, custom ASICs)
    - Theoretical computer science: complexity analysis, information theory, algorithms
    - Empirical machine learning: experimental design, statistical analysis, ablations
    - Academic writing: peer-reviewed publications, technical reports, grant proposals

    ANALYTICAL FRAMEWORK
    Apply this systematic framework to every AI-infrastructure problem:
    1. FORMAL PROBLEM SPECIFICATION — define the objective with mathematical precision,
       enumerate constraints (memory, compute, latency, throughput, cost), state and
       justify assumptions, and lay out the design-space dimensions.
    2. COMPLEXITY ANALYSIS — give time and space complexity O(.) for key algorithms,
       analyze communication complexity for distributed settings, and locate bottlenecks
       via roofline-model reasoning.
    3. MATHEMATICAL DERIVATION — use standard ML/systems notation, derive closed forms
       where possible, propose principled approximations otherwise, validate by units.
    4. EMPIRICAL GROUNDING — reference relevant benchmarks, calibrate against the
       hardware constants below, and acknowledge measurement uncertainty.
    5. PRACTICAL ENGINEERING — assess implementation complexity, failure modes,
       production readiness, and cost-efficiency at scale ($/token, $/FLOP).

    TECHNICAL VOCABULARY STANDARDS
    INFERENCE SYSTEMS: TTFT (time-to-first-token), TBT (time-between-tokens), ITL
    (inter-token latency), throughput (tokens/sec), SLO, P50/P95/P99 latency, goodput,
    head-of-line blocking, request coalescing, admission control, backpressure.
    PARALLELISM: TP, PP, DP, EP, sequence parallelism, 3D parallelism, micro-batch,
    bubble overhead, activation checkpointing, gradient accumulation, ZeRO/FSDP sharding.
    MEMORY: KV-cache, PagedAttention, memory bandwidth (GB/s), HBM, SRAM, working set,
    memory pressure, eviction policy, recomputation vs. rematerialization.
    TRANSFORMERS: attention complexity O(n^2 d), Flash Attention, linear attention, MHA,
    MQA, GQA, RoPE, ALiBi, sliding-window attention, chunked prefill.
    OPTIMIZATION: speculative decoding, draft model, acceptance rate alpha, speculation
    depth k, Medusa, EAGLE, lookahead decoding, continuous batching, iteration-level
    scheduling, quantization (PTQ/QAT, INT8/INT4/FP8), SmoothQuant, AWQ, GPTQ.
    RETRIEVAL & MEMORY: dense/sparse retrieval, BM25, hybrid search, reciprocal rank
    fusion, HNSW, IVF-PQ, ScaNN, FAISS, recall@k, RAG, semantic deduplication.
    AGENTS: tool use, ReAct, plan-and-execute, reflexion, map-reduce decomposition,
    orchestrator-worker, blackboard, handoff protocols, context compaction, guardrails.

    NOTATION CONVENTIONS (use consistently across derivations)
    - N: number of agents or GPUs (disambiguate per context)
    - D: model depth (layers); H: hidden dim; H_kv: key/value head dim
    - L: sequence length; L_p: prompt length; L_g: generated length
    - B: batch size; b: micro-batch size
    - alpha: speculative acceptance rate; k: speculation depth
    - lambda: arrival rate; mu: service rate; rho = lambda/mu: utilization
    - C: per-token compute (FLOPs); M: memory footprint (bytes)
    - tau: latency budget / SLO; T: throughput (tokens/sec)
    - H(X): Shannon entropy; MDL: minimum description length

    ANALYTICAL PROTOCOLS (apply in order; do not skip)
    P-1 STATE THE OBJECTIVE as argmin/argmax over decision variables subject to explicit
        constraints before any numeric estimate.
    P-2 DIMENSIONAL CHECK: every derived expression must be dimensionally consistent;
        report units on intermediate quantities.
    P-3 BOUND BEFORE BUILDING: give back-of-envelope upper and lower bounds first.
    P-4 BASELINE FIRST: characterize the naive approach and its asymptotic cost so any
        proposed method is quantified relative to it.
    P-5 SENSITIVITY: identify the one or two dominant parameters and report how the
        conclusion changes across their plausible ranges.
    P-6 FALSIFIABILITY: state at least one measurable prediction that would, if violated,
        invalidate the analysis.

    EVALUATION RUBRIC (self-assess each response)
    R-1 Correctness: derivations sound and dimensionally valid.
    R-2 Specificity: estimates grounded in the constants below, not vague.
    R-3 Comparativeness: at least two named baselines quantified, not merely mentioned.
    R-4 Actionability: the roadmap is concrete enough to start engineering.
    R-5 Honesty: assumptions, approximations, and uncertainty stated, not hidden.

    COMMON PITFALLS TO AVOID
    - Confusing throughput (aggregate tokens/sec) with per-request latency.
    - Ignoring the prefill (compute-bound) vs decode (bandwidth-bound) asymmetry.
    - Treating KV-cache as free; it scales with batch size and sequence length.
    - Assuming perfect parallel speedup; communication and synchronization erode it.
    - Reporting mean latency without tail percentiles for SLO-governed systems.
    - Double-counting cached tokens when reasoning about billable input tokens.

    CALIBRATION CONSTANTS (illustrative, for grounding estimates)
    - H100 SXM: 3.35 TB/s HBM3, 989 TFLOP/s FP16, 80GB. NVLink 4.0: 900 GB/s intra-node.
    - A100 SXM: 2.0 TB/s HBM2e, 312 TFLOP/s FP16, 80GB. InfiniBand NDR: ~50 GB/s/port.
    - Flash Attention 2 speedup vs standard: 5-9x for sequence > 2K tokens.
    - Speculative decoding typical acceptance rate: 0.6-0.85 for aligned models.
    - PagedAttention reduces KV fragmentation waste from ~60-80% to <4% in practice.
    - A cached prefix of P tokens reused R times amortizes a 1.25x write cost once
      R * (0.9 * P) > 0.25 * P, i.e. essentially after the second reuse for any P.

    ===============================================================================
    APPENDIX A — ROOFLINE & ARITHMETIC INTENSITY PRIMER (shared reference)
    ===============================================================================
    The roofline model bounds achievable performance by the minimum of two ceilings: the
    compute ceiling (peak FLOP/s) and the memory ceiling (peak bandwidth x arithmetic
    intensity). Arithmetic intensity I = FLOPs / bytes. A kernel is compute-bound when I
    exceeds the machine balance point B = peak_FLOPs / peak_bandwidth, memory-bound below
    it. For transformer inference, prefill and decode sit on opposite sides of this ridge.
    Prefill processes the whole prompt in one pass: each loaded weight is reused across all
    prompt positions, so intensity is high and it is compute-bound. Decode emits one token
    at a time, reloading full weight matrices and the entire KV-cache per token, so it is
    memory-bandwidth-bound. This asymmetry is the single most important fact in inference
    optimization: batching amortizes weight reloads (helping decode throughput), KV-cache
    size directly governs decode latency, and FLOP-reduction tricks that help prefill often
    do nothing for decode. To estimate batch-1 decode tokens/sec for P parameters at p
    bytes/param on bandwidth W: a first-order bound is T ~ W / (P * p), times a realized-
    bandwidth factor of 0.6-0.8, before KV-cache traffic.

    ===============================================================================
    APPENDIX B — QUEUING THEORY CHEAT SHEET (for SLO / concurrency analysis)
    ===============================================================================
    For an M/M/1 queue with arrival rate lambda and service rate mu, utilization rho =
    lambda/mu, mean number in system L = rho/(1-rho), mean time in system W = 1/(mu-lambda).
    Latency diverges hyperbolically as rho approaches one, so SLO-governed systems target
    utilization well below saturation (often 0.6-0.8) and tail latency degrades faster than
    the mean near the knee. For an M/M/c queue (c parallel servers), the Erlang-C formula
    gives the wait probability; expected queue wait is that probability divided by
    (c*mu - lambda). For agent ensembles sharing a memory store, model the store as the
    constrained resource: per-turn budget must cover queue wait plus service, and contention
    rises sharply once aggregate rate approaches c/s. Little's Law (L = lambda * W) holds for
    any stable system and is the fastest cross-check on a latency-throughput-concurrency triple.

    ===============================================================================
    APPENDIX C — SPECULATIVE DECODING ALGEBRA (shared derivation reference)
    ===============================================================================
    With a draft proposing k tokens per step and acceptance rate alpha, the expected accepted
    length per verification step is approximately (1 - alpha^(k+1)) / (1 - alpha), capped by k,
    plus the bonus token the target always contributes. Realized speedup is expected accepted
    length divided by the per-step cost ratio (draft forward passes plus one target verify,
    relative to a single target decode). Speedup is maximized when the draft is cheap and
    alpha is high; it collapses toward one as alpha falls because rejections waste draft
    compute. Acceptance falls as task entropy rises, so a principled controller lowers k when
    measured rejection climbs and raises it when the draft tracks the target well.

    ===============================================================================
    APPENDIX D — MODEL-ROUTING DECISION REFERENCE (shared reference)
    ===============================================================================
    Route by marginal value of capability per unit cost: send a task to the strong tier only
    when the expected quality lift over the fast tier exceeds the cost premium weighted by
    downstream impact. For decomposable work, map-reduce routes high-volume low-complexity
    extraction to the fast tier and reserves the strong tier for the single synthesis step.
    Expected blended cost is the sum over stages of (stage tokens x tier price), so pushing
    token volume to the cheaper tier and minimizing strong-tier tokens is the dominant lever,
    ahead of caching and payload filtering in most realistic mixes. Quality guardrails
    (structured-output validation, confidence-thresholded escalation) keep the routed slice
    from degrading aggregate quality. Always measure the routed slice against an all-strong
    control before trusting the savings.

    ===============================================================================
    APPENDIX E — CONTINUOUS BATCHING & SCHEDULING REFERENCE (shared reference)
    ===============================================================================
    Static batching groups a fixed set of requests and runs them to completion together, so
    the batch is held hostage by its longest sequence and GPU utilization collapses when
    sequence lengths are skewed. Continuous (iteration-level) batching instead schedules at
    the granularity of a single decode step: finished sequences leave the batch and waiting
    requests join immediately, keeping the accelerator saturated. With Poisson arrivals at
    rate lambda and a heavy-tailed length distribution, the utilization gain over static
    batching grows with both the arrival rate and the length variance, because static
    batching wastes proportionally more cycles waiting on stragglers as variance rises.
    PagedAttention complements this by allocating KV-cache in fixed-size blocks (analogous to
    OS virtual-memory paging), which removes the contiguous-allocation requirement, cuts
    fragmentation waste from the 60-80% range down to the low single digits, and lets the
    scheduler pack many more concurrent sequences into the same HBM. A production scheduler
    must also enforce SLOs: chunked prefill interleaves long prompt prefills with ongoing
    decodes so a single large prompt does not stall everyone behind it (head-of-line
    blocking), and a preemption policy decides which low-priority sequences to evict (swapping
    their KV-cache to host memory or recomputing it) when a high-priority request arrives.
    The recompute-vs-swap choice is itself a roofline question: recomputation trades extra
    FLOPs for avoided PCIe traffic, and is preferable when the prompt is short relative to
    interconnect bandwidth, while swapping wins for long prefixes on fast NVLink paths.

    ===============================================================================
    APPENDIX F — QUANTIZATION & MEMORY-FORMAT REFERENCE (shared reference)
    ===============================================================================
    Quantization reduces the bytes-per-parameter term that dominates memory-bound decode, so
    it directly raises decode tokens/sec roughly in proportion to the bit-width reduction
    until a compute or KV-cache ceiling binds. Post-training quantization (PTQ) maps trained
    fp16 weights to lower precision without retraining; quantization-aware training (QAT)
    simulates quantization during training for higher fidelity at greater cost. The central
    difficulty is activation outliers: a few channels carry disproportionately large values
    that, if naively scaled, destroy accuracy. SmoothQuant migrates outlier scale from
    activations into weights so both quantize cleanly; AWQ preserves the most salient weight
    channels at higher precision based on activation statistics; GPTQ performs layer-wise
    second-order error compensation. Per-tensor scaling uses one scale factor per tensor and
    is cheapest but least accurate; per-channel scaling assigns a scale per output channel and
    recovers most of the lost accuracy at modest overhead. Weight-only schemes (e.g. INT4
    weights with fp16 activations) target the decode bandwidth bottleneck specifically and
    often preserve quality better than quantizing activations too. FP8 on Hopper-class
    accelerators is attractive because hardware tensor cores execute it natively, giving both
    a memory and a compute win, whereas INT4 usually requires dequantization into a wider
    type before the matmul, capturing the bandwidth saving but not always a compute saving.

    ===============================================================================
    APPENDIX G — PARALLELISM COMMUNICATION-VOLUME REFERENCE (shared reference)
    ===============================================================================
    Tensor parallelism (TP) shards each weight matrix across devices and inserts an all-reduce
    per transformer block to recombine partial results, so its communication volume scales with
    hidden dimension, sequence length, and the number of layers, and it is therefore most
    viable inside a single high-bandwidth NVLink domain. Pipeline parallelism (PP) splits the
    model by layer across stages and passes only the activation tensor at each stage boundary,
    so its per-boundary communication is small but it introduces a pipeline bubble: at batch
    size one the bubble fraction approaches (stages - 1) / (stages - 1 + microbatches), which
    is why latency-critical batch-1 serving punishes deep pipelines and favors TP within a node.
    Data parallelism (DP) replicates the model and synchronizes gradients (training-time, via
    ring all-reduce whose volume is independent of worker count but whose latency grows with
    it). Expert parallelism (EP) places different MoE experts on different devices and incurs
    an all-to-all token shuffle proportional to tokens routed across the device boundary, which
    is why MoE routing quality and cross-node placement dominate its efficiency. A 3D parallel
    configuration composes TP within a node, PP across a few nodes, and DP across replicas;
    the optimal degree assignment is a joint optimization minimizing the sum of bubble overhead,
    collective-communication time, and memory pressure subject to fitting parameters, KV-cache,
    and activations into aggregate HBM. The dominant heuristic: keep the chatty TP collectives
    on the fastest interconnect, spread PP only as far as the bubble budget allows, and use DP
    for the remaining scale-out where synchronization is least frequent.

    ===============================================================================
    APPENDIX H — RETRIEVAL & VECTOR-INDEX REFERENCE (shared reference)
    ===============================================================================
    Retrieval-augmented generation grounds outputs in an external corpus by embedding queries
    and documents into a shared vector space and fetching nearest neighbors. Dense retrieval
    captures semantic similarity but can miss exact lexical matches; sparse retrieval (BM25)
    captures lexical overlap; hybrid search fuses both, commonly via reciprocal rank fusion,
    which sums reciprocal ranks across retrievers and is robust to score-scale differences.
    Exact nearest-neighbor search is linear in corpus size and infeasible at scale, so
    approximate indexes trade a small recall loss for large speedups: HNSW builds a navigable
    small-world graph giving logarithmic-ish query time at higher memory cost; IVF-PQ clusters
    vectors and product-quantizes residuals for a compact, disk-friendly index; ScaNN combines
    anisotropic quantization with hardware-aware scoring. The operative quality metric is
    recall@k against an exact baseline, and the operative cost metrics are query latency,
    index memory footprint, and build time. Chunking strategy materially affects quality:
    chunks too large dilute the embedding and waste context-window budget, chunks too small
    fragment meaning and inflate the index; overlap mitigates boundary effects at a storage
    cost. Embedding drift across model versions silently degrades recall, so re-embedding on
    model upgrades and monitoring recall@k over time are standard operational hygiene.

    AGENT COORDINATION PROTOCOL
    All agents share this workspace via the cache-prefilled system block. Each worker handles
    one specialized sub-task and writes a structured summary; the reducer synthesizes them.
    Workers must: (1) respect their assigned sub-task slice without duplicating scope,
    (2) surface uncertainty rather than fabricating precise figures, (3) keep each sub-task
    response under ~150 tokens. The reducer reconciles conflicting claims explicitly. Any
    figure not grounded in the calibration constants above must be flagged as an estimate, and
    the reducer must reconcile conflicting worker claims rather than silently averaging them.

    ===============================================================================
    END SYSTEM CONTEXT — BEGIN TASK PROCESSING
    ===============================================================================
    """).strip()

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 6: METRICS DATA STRUCTURES
# ═══════════════════════════════════════════════════════════════════════════════
class TurnMetrics(TypedDict):
    model:              str
    wall_clock_s:       float
    input_tokens:       int     # standardized: INCLUSIVE of cached tokens
    output_tokens:      int
    cache_read_tokens:  int
    cache_write_tokens: int


def empty_turn_metrics(model: str) -> TurnMetrics:
    return {"model": model, "wall_clock_s": 0.0, "input_tokens": 0,
            "output_tokens": 0, "cache_read_tokens": 0, "cache_write_tokens": 0}


class RunMetrics(TypedDict):
    provider:           str
    model:              str
    prompt_id:          str
    system_type:        str
    total_latency_s:    float
    input_tokens:       int
    output_tokens:      int
    cache_read_tokens:  int
    cache_write_tokens: int
    total_tokens:       int
    mcp_tokens_saved:   int
    num_llm_calls:      int
    error:              Optional[str]


def empty_run_metrics(provider, model, prompt_id, system_type, error="") -> RunMetrics:
    return RunMetrics(provider=provider, model=model, prompt_id=prompt_id,
                      system_type=system_type, total_latency_s=0.0, input_tokens=0,
                      output_tokens=0, cache_read_tokens=0, cache_write_tokens=0,
                      total_tokens=0, mcp_tokens_saved=0, num_llm_calls=0, error=error)

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 7: PROVIDER-AGNOSTIC LLM CLIENT (LangChain wrappers + token extraction)
# ═══════════════════════════════════════════════════════════════════════════════
def _is_transient(exc: Exception) -> bool:
    s = (str(exc) + " " + type(exc).__name__).lower()
    return any(t in s for t in ("rate", "timeout", "overloaded", "429", "503",
                                "502", "unavailable", "connection", "resourceexhausted"))


class TransientError(Exception):
    """Wrapper so tenacity retries only on classified-transient provider errors."""


def extract_usage(provider: str, msg: Any) -> Tuple[int, int, int, int]:
    """
    Return (input_tokens, output_tokens, cache_read, cache_write).

    PRIMARY: LangChain standardized AIMessage.usage_metadata, which normalizes all
    three providers. In this representation input_tokens is INCLUSIVE of cached tokens.

    FALLBACK (requirement #3): provider-specific raw parse of response_metadata,
    normalized to the same inclusive convention.
    """
    um = getattr(msg, "usage_metadata", None)
    if um:
        it = int(um.get("input_tokens", 0) or 0)
        ot = int(um.get("output_tokens", 0) or 0)
        d = um.get("input_token_details", {}) or {}
        cr = int(d.get("cache_read", 0) or 0)
        cw = int(d.get("cache_creation", 0) or 0)
        if it or ot:
            return it, ot, cr, cw

    rm = getattr(msg, "response_metadata", {}) or {}
    if provider == "openai":
        tu = rm.get("token_usage") or rm.get("usage") or {}
        it = int(tu.get("prompt_tokens", 0) or 0)         # already inclusive of cache
        ot = int(tu.get("completion_tokens", 0) or 0)
        cr = int((tu.get("prompt_tokens_details") or {}).get("cached_tokens", 0) or 0)
        return it, ot, cr, 0
    if provider == "anthropic":
        u = rm.get("usage") or {}
        base = int(u.get("input_tokens", 0) or 0)          # EXCLUDES cache in raw API
        cr = int(u.get("cache_read_input_tokens", 0) or 0)
        cw = int(u.get("cache_creation_input_tokens", 0) or 0)
        ot = int(u.get("output_tokens", 0) or 0)
        return base + cr + cw, ot, cr, cw                  # normalize to inclusive
    # if provider == "google":
    #     g = rm.get("usage_metadata") or rm.get("usage") or rm
    #     it = int(g.get("prompt_token_count", 0) or 0)      # already inclusive of cache
    #     ot = int(g.get("candidates_token_count", 0) or 0)
    #     cr = int(g.get("cached_content_token_count", 0) or 0)
    #     return it, ot, cr, 0
    return 0, 0, 0, 0


class LLMClient:
    """One client per (provider, model). Builds the right LangChain chat model,
    applies provider-correct caching, and extracts normalized token usage."""

    def __init__(self, provider: str, model: str, api_key: str):
        self.provider = provider
        self.model = model
        self.api_key = api_key
        self._llm = None if DRY_RUN else self._construct()

    def _construct(self):
        common = dict(temperature=0)
        kwargs_variants = []
        if self.provider == "anthropic":
            kwargs_variants = [
                dict(model=self.model, max_tokens=MAX_OUTPUT_TOKENS,
                     timeout=REQUEST_TIMEOUT, api_key=self.api_key, **common),
                dict(model=self.model, max_tokens=MAX_OUTPUT_TOKENS,
                     api_key=self.api_key, **common),
            ]
            cls = ChatAnthropic
        elif self.provider == "openai":
            kwargs_variants = [
                dict(model=self.model, max_tokens=MAX_OUTPUT_TOKENS,
                     timeout=REQUEST_TIMEOUT, api_key=self.api_key, **common),
                dict(model=self.model, max_tokens=MAX_OUTPUT_TOKENS,
                     api_key=self.api_key, **common),
            ]
            cls = ChatOpenAI
        elif self.provider == "google":
            kwargs_variants = [
                dict(model=self.model, max_output_tokens=MAX_OUTPUT_TOKENS,
                     timeout=REQUEST_TIMEOUT, google_api_key=self.api_key, **common),
                dict(model=self.model, max_output_tokens=MAX_OUTPUT_TOKENS,
                     google_api_key=self.api_key, **common),
            ]
            cls = ChatGoogleGenerativeAI
        else:
            raise ValueError(f"Unknown provider: {self.provider}")

        last_err = None
        for kw in kwargs_variants:          # drop unsupported kwargs (e.g. timeout) on TypeError
            try:
                return cls(**kw)
            except TypeError as e:
                last_err = e
                continue
        raise last_err

    def _build_messages(self, system_text: str, user_text: str, use_cache: bool):
        """Anthropic caching is EXPLICIT (cache_control block). OpenAI & Gemini cache
        automatically, so their system content stays a plain string."""
        if use_cache and self.provider == "anthropic":
            system_msg = SystemMessage(content=[{
                "type": "text", "text": system_text,
                "cache_control": {"type": "ephemeral"},
            }])
        else:
            system_msg = SystemMessage(content=system_text)
        return [system_msg, HumanMessage(content=user_text)]

    @retry(stop=stop_after_attempt(3),
           wait=wait_exponential(multiplier=1, min=2, max=30),
           retry=retry_if_exception_type(TransientError), reraise=True)
    def _invoke(self, messages):
        try:
            return self._llm.invoke(messages)
        except Exception as e:
            if _is_transient(e):
                raise TransientError(str(e)) from e
            raise

    def call(self, system_text: str, user_text: str,
             use_cache: bool = False, max_tokens: int = MAX_OUTPUT_TOKENS
             ) -> Tuple[str, TurnMetrics]:
        t0 = time.perf_counter()

        if DRY_RUN:
            return self._mock_call(system_text, user_text, use_cache, t0)

        messages = self._build_messages(system_text, user_text, use_cache)
        msg = self._invoke(messages)
        wall = time.perf_counter() - t0

        it, ot, cr, cw = extract_usage(self.provider, msg)
        text = msg.content if isinstance(msg.content, str) else str(msg.content)
        return text, {"model": self.model, "wall_clock_s": wall, "input_tokens": it,
                      "output_tokens": ot, "cache_read_tokens": cr, "cache_write_tokens": cw}

    # ── deterministic mock for DRY_RUN (validates plumbing without spending) ──
    _CACHE_STORE: Dict[str, bool] = {}

    def _mock_call(self, system_text, user_text, use_cache, t0):
        sys_tok = max(len(system_text) // 4, 1)
        usr_tok = max(len(user_text) // 4, 1)
        out_tok = random.randint(120, MAX_OUTPUT_TOKENS)
        time.sleep(0.002)
        wall = time.perf_counter() - t0
        cache_key = f"{self.provider}:{self.model}:{hash(system_text) & 0xffffffff}"
        cr = cw = 0
        eligible = use_cache and sys_tok >= CACHE_MIN_TOKENS.get(self.model, 1024)
        if eligible:
            if LLMClient._CACHE_STORE.get(cache_key):
                cr = sys_tok                       # warm: read the cached prefix
            else:
                cw = sys_tok                       # cold: write it once
                LLMClient._CACHE_STORE[cache_key] = True
        input_tokens = sys_tok + usr_tok           # standardized inclusive convention
        return (f"[DRY_RUN {self.provider}/{self.model}] mock synthesis output.",
                {"model": self.model, "wall_clock_s": wall, "input_tokens": input_tokens,
                 "output_tokens": out_tok, "cache_read_tokens": cr, "cache_write_tokens": cw})

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 8: SEQUENTIAL BASELINE (unoptimized: 3 serial calls, no cache, raw payloads)
# ═══════════════════════════════════════════════════════════════════════════════
class SequentialBaselineSystem:
    SIMPLE_SYSTEM = ("You are an expert AI systems researcher. Answer technical questions "
                     "about AI infrastructure with precision and depth.")

    def __init__(self, client: LLMClient):
        self.client = client
        self.mcp = MockMCPToolServer()

    def run(self, prompt: str, prompt_id: str) -> RunMetrics:
        t_start = time.perf_counter()
        turns: List[TurnMetrics] = []
        try:
            raw = self.mcp.search(prompt[:80], num_results=4)
            ctx = f"Retrieved research context:\n{raw[:2000]}"

            _, m1 = self.client.call(
                self.SIMPLE_SYSTEM,
                f"{ctx}\n\nTASK: Extract the key technical claims and data points from the "
                f"above search results relevant to:\n{prompt}",
                use_cache=False, max_tokens=512)
            turns.append(m1); time.sleep(RATE_LIMIT_DELAY)

            reasoning_text, m2 = self.client.call(
                self.SIMPLE_SYSTEM,
                f"Research Question: {prompt}\n\nProvide technical analysis with mathematical "
                f"formulations, complexity analysis, and comparative evaluation.",
                use_cache=False, max_tokens=MAX_OUTPUT_TOKENS)
            turns.append(m2); time.sleep(RATE_LIMIT_DELAY)

            _, m3 = self.client.call(
                self.SIMPLE_SYSTEM,
                f"Synthesize into a final structured research summary with executive summary, "
                f"key findings, and open problems:\n\n{reasoning_text}",
                use_cache=False, max_tokens=512)
            turns.append(m3)

            return RunMetrics(
                provider=self.client.provider, model=self.client.model,
                prompt_id=prompt_id, system_type="baseline",
                total_latency_s=time.perf_counter() - t_start,
                input_tokens=sum(m["input_tokens"] for m in turns),
                output_tokens=sum(m["output_tokens"] for m in turns),
                cache_read_tokens=sum(m["cache_read_tokens"] for m in turns),
                cache_write_tokens=sum(m["cache_write_tokens"] for m in turns),
                total_tokens=sum(m["input_tokens"] + m["output_tokens"] for m in turns),
                mcp_tokens_saved=0, num_llm_calls=len(turns), error=None)
        except Exception as e:
            logger.error(f"[BASELINE] {self.client.model} {prompt_id}: {e}")
            return empty_run_metrics(self.client.provider, self.client.model,
                                     prompt_id, "baseline", str(e))

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 9: OPTIMIZED DAG (LangGraph map-reduce + caching + MCP filtering)
# Parallel workers return DELTAS ONLY; accumulator channels use operator.add reducers.
# All nodes use the SAME benchmarked model (isolates the architecture's effect).
# ═══════════════════════════════════════════════════════════════════════════════
class DAGState(TypedDict):
    prompt:           str
    prompt_id:        str
    system_context:   str
    mcp_filtered:     str
    mcp_tokens_saved: int
    sub_tasks:        List[str]                                    # set once by orchestrator
    worker_outputs:   Annotated[List[str], operator.add]          # parallel-accumulated
    synthesis:        str
    all_turn_metrics: Annotated[List[TurnMetrics], operator.add]  # accumulated everywhere
    error:            Optional[str]


class OptimizedDAGSystem:
    def __init__(self, client: LLMClient, system_context: str):
        self.client = client
        self.system_context = system_context
        self.mcp = MockMCPToolServer()
        self.mw = MCPMiddleware()
        self.graph = self._build_dag()

    def node_mcp_fetch_and_filter(self, state: DAGState) -> Dict[str, Any]:
        raw = self.mcp.search(state["prompt"][:80], num_results=4)
        filtered = self.mw.filter_search_payload(raw)
        comp = self.mw.measure_compression(raw, filtered)
        return {"mcp_filtered": filtered, "mcp_tokens_saved": comp["tokens_saved"]}

    def node_orchestrator(self, state: DAGState) -> Dict[str, Any]:
        p = state["prompt"]; ctx = state["mcp_filtered"]
        return {"sub_tasks": [
            f"EXTRACTION TASK A — Formalization: extract and formalize the core mathematical "
            f"problem statement, constraints, and variables for: '{p[:120]}'\n\nContext:\n{ctx}",
            f"EXTRACTION TASK B — Literature Mapping: identify key prior work, baseline methods, "
            f"and benchmark comparisons for: '{p[:120]}'\n\nContext:\n{ctx}",
            f"EXTRACTION TASK C — Engineering Constraints: extract practical implementation "
            f"constraints, hardware specs, and cost considerations for: '{p[:120]}'\n\nContext:\n{ctx}",
        ]}

    def _make_worker_node(self, worker_id: int):
        def worker_node(state: DAGState) -> Dict[str, Any]:
            tasks = state.get("sub_tasks", [])
            if worker_id >= len(tasks):
                return {}
            try:
                text, metrics = self.client.call(
                    state["system_context"], tasks[worker_id],
                    use_cache=True, max_tokens=256)
                time.sleep(RATE_LIMIT_DELAY * 0.5)
            except Exception as e:
                text, metrics = f"[Worker {worker_id} error: {e}]", empty_turn_metrics(self.client.model)
            return {"worker_outputs": [f"[Worker {worker_id+1} Output]\n{text}"],
                    "all_turn_metrics": [metrics]}
        worker_node.__name__ = f"worker_node_{worker_id}"
        return worker_node

    def node_reducer(self, state: DAGState) -> Dict[str, Any]:
        joined = "\n\n".join(state.get("worker_outputs", []))
        prompt = (
            f"SYNTHESIS TASK — You are the reducer in a map-reduce research pipeline.\n\n"
            f"ORIGINAL QUESTION:\n{state['prompt']}\n\n"
            f"PARALLEL EXTRACTION RESULTS:\n{joined}\n\n"
            f"Integrate into: (1) Executive Summary, (2) Mathematical formulations, "
            f"(3) Comparative evaluation vs baselines, (4) Implementation roadmap, "
            f"(5) Open problems. Be precise and concise.")
        try:
            text, metrics = self.client.call(state["system_context"], prompt,
                                              use_cache=True, max_tokens=MAX_OUTPUT_TOKENS)
        except Exception as e:
            text, metrics = f"[Synthesis error: {e}]", empty_turn_metrics(self.client.model)
        return {"synthesis": text, "all_turn_metrics": [metrics]}

    def _build_dag(self):
        b = StateGraph(DAGState)
        b.add_node("mcp_fetch", self.node_mcp_fetch_and_filter)
        b.add_node("orchestrator", self.node_orchestrator)
        b.add_node("worker_0", self._make_worker_node(0))
        b.add_node("worker_1", self._make_worker_node(1))
        b.add_node("worker_2", self._make_worker_node(2))
        b.add_node("reducer", self.node_reducer)
        b.add_edge(START, "mcp_fetch")
        b.add_edge("mcp_fetch", "orchestrator")
        for w in ("worker_0", "worker_1", "worker_2"):
            b.add_edge("orchestrator", w)
            b.add_edge(w, "reducer")
        b.add_edge("reducer", END)
        return b.compile()

    def run(self, prompt: str, prompt_id: str) -> RunMetrics:
        t_start = time.perf_counter()
        init: DAGState = {"prompt": prompt, "prompt_id": prompt_id,
                          "system_context": self.system_context, "mcp_filtered": "",
                          "mcp_tokens_saved": 0, "sub_tasks": [], "worker_outputs": [],
                          "synthesis": "", "all_turn_metrics": [], "error": None}
        try:
            final = self.graph.invoke(init)
            turns: List[TurnMetrics] = final.get("all_turn_metrics", [])
            return RunMetrics(
                provider=self.client.provider, model=self.client.model,
                prompt_id=prompt_id, system_type="optimized",
                total_latency_s=time.perf_counter() - t_start,
                input_tokens=sum(m["input_tokens"] for m in turns),
                output_tokens=sum(m["output_tokens"] for m in turns),
                cache_read_tokens=sum(m["cache_read_tokens"] for m in turns),
                cache_write_tokens=sum(m["cache_write_tokens"] for m in turns),
                total_tokens=sum(m["input_tokens"] + m["output_tokens"] for m in turns),
                mcp_tokens_saved=final.get("mcp_tokens_saved", 0),
                num_llm_calls=len(turns), error=None)
        except Exception as e:
            logger.error(f"[OPTIMIZED] {self.client.model} {prompt_id}: {traceback.format_exc()}")
            return empty_run_metrics(self.client.provider, self.client.model,
                                     prompt_id, "optimized", str(e))

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 10: BENCHMARK HARNESS (6 models × 15 prompts × 2 systems = 180 rows)
# ═══════════════════════════════════════════════════════════════════════════════
def run_benchmark(api_keys: Dict[str, str]) -> pd.DataFrame:
    system_ctx = build_system_context_block()
    ctx_tokens = len(system_ctx) // 4

    matrix = [m for m in MODEL_MATRIX if (PROVIDER_FILTER is None or m[0] in PROVIDER_FILTER)]
    prompts = EVAL_PROMPTS if PROMPT_LIMIT is None else EVAL_PROMPTS[:PROMPT_LIMIT]
    total_rows = len(matrix) * len(prompts) * 2
    total_calls = len(matrix) * len(prompts) * (3 + 4)

    print("\n" + "=" * 80)
    print("  MULTI-PROVIDER BENCHMARK — Sequential Baseline vs Optimized DAG")
    print(f"  Mode: {'DRY_RUN (mock, $0)' if DRY_RUN else 'LIVE'}  |  "
          f"Models: {len(matrix)}  |  Prompts: {len(prompts)}")
    print(f"  Expected rows: {total_rows}  |  Expected LLM calls: {total_calls}")
    print(f"  Shared system block: ~{ctx_tokens:,} tokens "
          f"(cache mins → Haiku 4096 / Pro 2048 / others 1024)")
    print("=" * 80 + "\n")

    rows: List[RunMetrics] = []
    for mi, (provider, model, role) in enumerate(matrix, 1):
        print(f"┌─ [{mi}/{len(matrix)}] {provider.upper()} · {model} ({role})")
        client = LLMClient(provider, model, api_keys.get(provider, ""))
        baseline = SequentialBaselineSystem(client)
        optimized = OptimizedDAGSystem(client, system_ctx)

        for pi, info in enumerate(prompts, 1):
            pid = f"{info['id']}"
            print(f"│   [{pi:2d}/{len(prompts)}] {pid} — {info['topic']}")

            br = baseline.run(info["prompt"], pid)
            rows.append(br)
            tag = f"ERR {br['error'][:40]}" if br["error"] else \
                  f"lat={br['total_latency_s']:.1f}s tok={br['total_tokens']:,}"
            print(f"│        baseline : {tag}")

            if not DRY_RUN:
                time.sleep(RATE_LIMIT_DELAY)

            orr = optimized.run(info["prompt"], pid)
            rows.append(orr)
            if orr["error"]:
                tag = f"ERR {orr['error'][:40]}"
            else:
                denom = max(orr["input_tokens"], 1)
                cache_pct = orr["cache_read_tokens"] / denom * 100
                tag = (f"lat={orr['total_latency_s']:.1f}s tok={orr['total_tokens']:,} "
                       f"cache_read={orr['cache_read_tokens']:,} ({cache_pct:.0f}% of input) "
                       f"mcp_saved={orr['mcp_tokens_saved']}")
            print(f"│        optimized: {tag}")

            if not DRY_RUN:
                time.sleep(RATE_LIMIT_DELAY)
        print("└─ done.\n")

    return pd.DataFrame(rows)

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 11: SUMMARY + CSV EXPORT
# ═══════════════════════════════════════════════════════════════════════════════
def summarize(df: pd.DataFrame) -> pd.DataFrame:
    ok = df[df["error"].isna() | (df["error"] == "")].copy()
    if ok.empty:
        print("[WARN] No successful runs to summarize.")
        return pd.DataFrame()

    g = (ok.groupby(["provider", "model", "system_type"])
           .agg(n=("prompt_id", "count"),
                latency_mean=("total_latency_s", "mean"),
                input_mean=("input_tokens", "mean"),
                output_mean=("output_tokens", "mean"),
                cache_read_mean=("cache_read_tokens", "mean"),
                total_mean=("total_tokens", "mean"))
           .reset_index())

    print("\n" + "=" * 96)
    print("  SUMMARY — mean per (provider, model, system_type)")
    print("=" * 96)
    print(f"  {'provider':<10}{'model':<20}{'system':<11}{'n':>4}"
          f"{'lat(s)':>9}{'in_tok':>10}{'out_tok':>9}{'cache_rd':>10}{'total':>9}")
    print("  " + "-" * 92)
    for _, r in g.iterrows():
        print(f"  {r['provider']:<10}{r['model']:<20}{r['system_type']:<11}{int(r['n']):>4}"
              f"{r['latency_mean']:>9.1f}{r['input_mean']:>10.0f}{r['output_mean']:>9.0f}"
              f"{r['cache_read_mean']:>10.0f}{r['total_mean']:>9.0f}")

    # Per-model speedup / token-efficiency (baseline ÷ optimized)
    print("\n  OPTIMIZATION FACTORS (baseline ÷ optimized; >1 = optimized better)")
    print("  " + "-" * 92)
    print(f"  {'provider':<10}{'model':<20}{'latency_x':>11}{'token_x':>10}{'opt_cache_rd':>14}")
    print("  " + "-" * 92)
    for (prov, model), sub in g.groupby(["provider", "model"]):
        b = sub[sub["system_type"] == "baseline"]
        o = sub[sub["system_type"] == "optimized"]
        if b.empty or o.empty:
            continue
        b, o = b.iloc[0], o.iloc[0]
        lat_x = b["latency_mean"] / o["latency_mean"] if o["latency_mean"] else float("nan")
        tok_x = b["total_mean"] / o["total_mean"] if o["total_mean"] else float("nan")
        print(f"  {prov:<10}{model:<20}{lat_x:>11.2f}{tok_x:>10.2f}{o['cache_read_mean']:>14.0f}")
    print("=" * 96 + "\n")
    return g


def export_csv(df: pd.DataFrame, summary: pd.DataFrame):
    # Required columns, cleanly named, in the requested order (plus useful extras).
    out = pd.DataFrame({
        "Provider":            df["provider"],
        "Model":               df["model"],
        "TaskID":              df["prompt_id"],
        "SystemType":          df["system_type"],
        "LatencySeconds":      df["total_latency_s"].round(3),
        "InputTokens":         df["input_tokens"],
        "OutputTokens":        df["output_tokens"],
        "CachedTokens":        df["cache_read_tokens"],
        "CacheCreationTokens": df["cache_write_tokens"],
        "TotalTokens":         df["total_tokens"],
        "NumLLMCalls":         df["num_llm_calls"],
        "MCPTokensSaved":      df["mcp_tokens_saved"],
        "Error":               df["error"],
    })
    out.to_csv(OUTPUT_CSV, index=False)
    if not summary.empty:
        summary.to_csv(SUMMARY_CSV, index=False)
    print(f"[INFO] Wrote {len(out)} rows → {OUTPUT_CSV}")
    if not summary.empty:
        print(f"[INFO] Summary → {SUMMARY_CSV}")

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 12: OPTIONAL CHART (grouped latency by model: baseline vs optimized)
# ═══════════════════════════════════════════════════════════════════════════════
def make_chart(df: pd.DataFrame):
    ok = df[df["error"].isna() | (df["error"] == "")].copy()
    if ok.empty:
        print("[WARN] No successful runs to chart."); return
    piv = (ok.groupby(["model", "system_type"])["total_latency_s"].mean()
             .unstack("system_type"))
    for col in ("baseline", "optimized"):
        if col not in piv.columns:
            piv[col] = 0.0
    piv = piv.reindex([m for _, m, _ in MODEL_MATRIX if m in piv.index])

    x = np.arange(len(piv)); w = 0.38
    fig, ax = plt.subplots(figsize=(max(10, len(piv) * 1.4), 6), facecolor="#F8F9FA")
    ax.bar(x - w/2, piv["baseline"], w, color="#E74C3C", label="Baseline (sequential)")
    ax.bar(x + w/2, piv["optimized"], w, color="#2ECC71", label="Optimized (DAG + cache)")
    for i, (bl, op) in enumerate(zip(piv["baseline"], piv["optimized"])):
        if op > 0 and bl > 0:
            ax.text(x[i], max(bl, op) * 1.02, f"{bl/op:.1f}x", ha="center",
                    va="bottom", fontsize=8, fontweight="bold", color="#1A5276")
    ax.set_xticks(x); ax.set_xticklabels(list(piv.index), rotation=20, ha="right", fontsize=8)
    ax.set_ylabel("Mean wall-clock latency per run (s)")
    ax.set_title("Multi-Provider Benchmark — Baseline vs Optimized DAG latency by model",
                 fontweight="bold")
    ax.legend(); ax.grid(axis="y", alpha=0.3, linestyle="--")
    plt.tight_layout()
    plt.savefig(CHART_PNG, dpi=200, bbox_inches="tight", facecolor="#F8F9FA")
    plt.close()
    print(f"[INFO] Chart → {CHART_PNG} ({os.path.getsize(CHART_PNG)//1024} KB)")

# ═══════════════════════════════════════════════════════════════════════════════
# SECTION 13: MAIN
# ═══════════════════════════════════════════════════════════════════════════════
def main():
    print("\n" + "#" * 80)
    print("  MULTI-PROVIDER MULTI-AGENT BENCHMARK")
    print("  Anthropic · OpenAI · Google  |  Baseline vs Optimized DAG")
    print("#" * 80)

    required = {p for p, _, _ in MODEL_MATRIX
                if (PROVIDER_FILTER is None or p in PROVIDER_FILTER)}
    try:
        api_keys = get_api_keys(required)
    except Exception as e:
        print(f"[FATAL] API key setup failed: {e}"); return

    if not DRY_RUN:
        n_models = len([m for m in MODEL_MATRIX
                        if (PROVIDER_FILTER is None or m[0] in PROVIDER_FILTER)])
        n_prompts = len(EVAL_PROMPTS) if PROMPT_LIMIT is None else PROMPT_LIMIT
        print(f"\n[WARNING] LIVE run will issue ~{n_models * n_prompts * 7} paid API calls "
              f"across {sorted(required)}.\n          Set DRY_RUN=True to validate for free first.\n")

    try:
        df = run_benchmark(api_keys)
    except KeyboardInterrupt:
        print("\n[INTERRUPTED] stopped by user."); return
    except Exception:
        print(f"\n[FATAL] harness error:\n{traceback.format_exc()}"); return

    if df.empty:
        print("[ERROR] No results collected."); return

    n_ok = int((df["error"].isna() | (df["error"] == "")).sum())
    print(f"[INFO] Collected {len(df)} rows ({n_ok} successful, {len(df) - n_ok} errored).")

    summary = summarize(df)
    export_csv(df, summary)
    if MAKE_CHART:
        try:
            make_chart(df)
        except Exception as e:
            print(f"[WARN] chart failed: {e}")

    print("\n[PREVIEW] first 8 rows:")
    cols = ["provider", "model", "prompt_id", "system_type", "total_latency_s",
            "input_tokens", "output_tokens", "cache_read_tokens", "total_tokens"]
    pd.set_option("display.width", 140)
    pd.set_option("display.max_columns", 12)
    pd.set_option("display.float_format", "{:.2f}".format)
    print(df[cols].head(8).to_string(index=False))

    print("\n[DONE] Outputs:")
    print(f"  {OUTPUT_CSV}   — 180-row benchmark (Provider, Model, TaskID, SystemType, "
          f"Latency, Input/Output/Cached tokens, …)")
    print(f"  {SUMMARY_CSV}  — per (provider, model, system) means")
    if MAKE_CHART:
        print(f"  {CHART_PNG}   — latency comparison chart")
    if DRY_RUN:
        print("\n  NOTE: DRY_RUN=True produced MOCK numbers to validate the pipeline. "
              "Set DRY_RUN=False for the real benchmark.")
    print("#" * 80 + "\n")


if __name__ == "__main__":
  main()