WRAG: Weight-Retrieval-Augmented Generation for Efficient and Updatable Domain Specialization
Description
What is WRAG?
WRAG (Weight-Retrieval-Augmented Generation) is a novel LLM serving architecture that applies the retrieval-augmented generation paradigm to model weights rather than text documents. Instead of loading a complete domain-specialized model, WRAG stores compact parameter-space delta "shards" in an external vector database and retrieves only the relevant ones at inference time, composing them onto a frozen base model on demand. VRAM usage stays constant regardless of how many domain shards are registered — a fundamental departure from both full fine-tuning and Mixture-of-Experts approaches.
Core Insight
RAG proved you don't need all knowledge compiled into the model — you retrieve it. WRAG extends this one level deeper: you don't need all computation compiled into the model either. Domain expertise becomes retrievable, composable, and updatable without touching the base model.
Key Findings
- A 3B instruction-tuned base model with WRAG outperforms BioMistral-7B (7B specialist) on all three medical benchmarks simultaneously — PubMedQA (66.0% vs 65.5%), MedQA (54.8% vs 45.4%), MedMCQA (51.4% vs 39.6%) — while using 57% less VRAM (6.08 GB vs 14.0 GB)
- SVD-Centered Merging, a novel weight composition strategy that separates domain-specific signal from cross-shard interference noise via centered task vector decomposition, achieves 46.0% on MedQA — surpassing BioMistral-7B using a non-instruct 3B base
- Alignment-aware shard training: domain shards must target feedforward layers exclusively (up_proj, down_proj, gate_proj) when applied to instruction-tuned models — targeting attention layers catastrophically degrades instruction-following capability by up to 34 percentage points
- Chain-of-thought reasoning shards achieve cross-task reasoning transfer, improving MedQA by +1.2pp over a strong instruction-tuned baseline through cognitive strategy encoding rather than factual knowledge injection
- WRAG supports continuous knowledge updates by registering new shards with zero modification to the base model or existing shards
Contributions
- A novel LLM serving architecture (WRAG) enabling on-demand domain specialization with constant VRAM and base-model inference speed
- SVD-Centered Merging — a new composition strategy for weight delta fusion
- The alignment-aware shard training principle — a practical constraint with implications for any system applying weight deltas to instruction-tuned models
- A complete reproducible implementation with trained shards, instruction-formatted corpora, and interactive chatbot interfaces
Resources
- Code and implementation: github.com/GyetiJulius/WRAG
- Trained shards and corpora: HuggingFace (see GitHub for links)
- Chatbot demo: included in repository
Keywords: large language models, retrieval-augmented generation, weight composition, domain specialization, model merging, parameter-efficient methods, medical NLP, instruction tuning, LLM serving
Files
Files
(23.9 kB)
| Name | Size | Download all |
|---|---|---|
|
md5:812bc283c808677415c73d180cb5317e
|
23.9 kB | Download |
Additional details
Software
- Repository URL
- https://github.com/GyetiJulius/WRAG
- Programming language
- Python , Shell
- Development Status
- Active