first commit
Browse files- .gitignore +3 -0
- README.md +213 -1
- eval.py +124 -0
- inference.py +78 -0
- requirements.txt +11 -0
- src/__init__.py +0 -0
- src/data.py +60 -0
- src/telern/__init__.py +1 -0
- src/telern/concept_graph.py +255 -0
- src/telern/config.py +72 -0
- src/telern/evaluate.py +454 -0
- src/telern/hypernetwork.py +176 -0
- src/telern/model.py +133 -0
- train.py +293 -0
- train_ce.py +123 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/data
|
| 2 |
+
**/__pycache__
|
| 3 |
+
/dataset
|
README.md
CHANGED
|
@@ -1,3 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TELEN: Temporal Evolving Legal Embedding Network
|
| 2 |
+
|
| 3 |
+
> **Vietnamese legal text embedding with meta-learning for continuous adaptation to new laws.**
|
| 4 |
+
|
| 5 |
+
[](https://www.python.org/)
|
| 6 |
+
[](https://pytorch.org/)
|
| 7 |
+
[](LICENSE)
|
| 8 |
+
|
| 9 |
---
|
| 10 |
+
|
| 11 |
+
## Overview
|
| 12 |
+
|
| 13 |
+
TELEN introduces a **novel embedding architecture** designed specifically for Vietnamese legal text retrieval in RAG (Retrieval-Augmented Generation) systems. Unlike conventional static embedding models, TELEN generates embeddings that **adapt dynamically** to the current state of the legal corpus β enabling seamless integration of new laws without retraining.
|
| 14 |
+
|
| 15 |
+
### Key Innovations
|
| 16 |
+
|
| 17 |
+
1. **HyperNetwork-Driven Projection** β Instead of fixed projection weights, a HyperNetwork generates the embedding projection function from the current legal corpus state. When new laws are published, the embedding space adapts automatically.
|
| 18 |
+
|
| 19 |
+
2. **Legal Concept Graph (LCG)** β An evolving knowledge graph where nodes represent legal entities (laws, key terms) and edges encode cross-references, agency hierarchy, temporal sequences, and semantic similarity.
|
| 20 |
+
|
| 21 |
+
3. **State-Adaptive Embeddings** β Embeddings are not static vectors but are modulated by a learned "legal state vector" that summarizes the entire legal landscape at any point in time.
|
| 22 |
+
|
| 23 |
---
|
| 24 |
+
|
| 25 |
+
## Architecture
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
Legal Text
|
| 29 |
+
β
|
| 30 |
+
Bi-Encoder (bkai-foundation-models/vietnamese-bi-encoder)
|
| 31 |
+
β
|
| 32 |
+
Raw Representation [768-dim]
|
| 33 |
+
β
|
| 34 |
+
βββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
β HyperNetwork(state_vector) β ΞW, Ξb β β Generated, not learned!
|
| 36 |
+
β Adapted Projection = Base + ΞWΒ·x + Ξb β
|
| 37 |
+
βββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
β
|
| 39 |
+
Legal Concept Graph (GNN)
|
| 40 |
+
β state_vector
|
| 41 |
+
State Encoder β current legal corpus
|
| 42 |
+
β
|
| 43 |
+
L2-Normalized Embedding [768-dim]
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
## Benchmark Results
|
| 47 |
+
|
| 48 |
+
**Test set**: 1,406 Vietnamese legal articles from 2021 (held-out, unseen during training)
|
| 49 |
+
|
| 50 |
+
| Model | NDCG@3 | NDCG@5 | NDCG@10 | MRR@3 | MRR@5 | MRR@10 |
|
| 51 |
+
|---|---|---|---|---|---|---|
|
| 52 |
+
| **BM25** (bm25) | 0.5164 | 0.5628 | 0.5718 | 0.5016 | 0.5290 | 0.5354 |
|
| 53 |
+
| **PhoBERT-base-v2** (dense) | 0.4803 | 0.5305 | 0.5738 | 0.4503 | 0.4792 | 0.4961 |
|
| 54 |
+
| **DEk21** (dense) | 0.6651 | 0.6907 | 0.7286 | 0.6394 | 0.6553 | 0.6734 |
|
| 55 |
+
| **TELEN** (dense) | **0.8878** | **0.9097** | **0.9132** | **0.8686** | **0.8782** | **0.8782** |
|
| 56 |
+
|
| 57 |
+
### Relative Improvement
|
| 58 |
+
|
| 59 |
+
| Baseline | NDCG@3 | NDCG@10 | MRR@10 |
|
| 60 |
+
|---|---|---|---|
|
| 61 |
+
| vs PhoBERT (dense) | **+84.9%** | **+59.2%** | **+77.1%** |
|
| 62 |
+
| vs DEk21 (dense) | **+33.5%** | **+25.3%** | **+30.4%** |
|
| 63 |
+
|
| 64 |
+
---
|
| 65 |
+
|
| 66 |
+
## Quick Start
|
| 67 |
+
|
| 68 |
+
### Installation
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
pip install -r requirements.txt
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
### Inference
|
| 75 |
+
|
| 76 |
+
```python
|
| 77 |
+
from inference import TELENInference
|
| 78 |
+
|
| 79 |
+
# Load model
|
| 80 |
+
model = TELENInference()
|
| 81 |
+
|
| 82 |
+
# Encode legal texts
|
| 83 |
+
texts = [
|
| 84 |
+
"Δiα»u 1: ThΓ΄ng tΖ° nΓ y quy Δα»nh vα» quαΊ£n lΓ½ thuαΊΏ giΓ‘ trα» gia tΔng...",
|
| 85 |
+
"Δiα»u 2: Δα»i tượng Γ‘p dα»₯ng lΓ cΓ‘c tα» chα»©c, cΓ‘ nhΓ’n kinh doanh...",
|
| 86 |
+
]
|
| 87 |
+
embeddings = model.encode(texts) # β [2, 768] normalized vectors
|
| 88 |
+
|
| 89 |
+
# Compute similarity
|
| 90 |
+
similarity = model.similarity(texts[0], texts[1])
|
| 91 |
+
print(f"Cosine similarity: {similarity:.4f}")
|
| 92 |
+
|
| 93 |
+
# Retrieve similar documents
|
| 94 |
+
results = model.retrieve(texts[0], corpus, top_k=10)
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### Training
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
# Train TELEN from scratch
|
| 101 |
+
python train.py
|
| 102 |
+
|
| 103 |
+
# Train cross-encoder re-ranker (optional, for extra +2-3% gain)
|
| 104 |
+
python train_ce.py
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
### Evaluation
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
python eval.py
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
## Training Details
|
| 116 |
+
|
| 117 |
+
### Dataset
|
| 118 |
+
- **Source**: [another-symato/VMTEB-Zalo-legel-retrieval-wseg](https://huggingface.co/datasets/another-symato/VMTEB-Zalo-legel-retrieval-wseg) on HuggingFace
|
| 119 |
+
- **Content**: 61,425 Vietnamese legal articles (ThΓ΄ng tΖ°, Nghα» Δα»nh, LuαΊt, PhΓ‘p lα»nh)
|
| 120 |
+
- **Period**: 1999β2021
|
| 121 |
+
- **Format**: Word-segmented Vietnamese text (underscore-separated compound words)
|
| 122 |
+
|
| 123 |
+
### Training Pipeline
|
| 124 |
+
|
| 125 |
+
| Stage | Description | Epochs | Trainable Params |
|
| 126 |
+
|---|---|---|---|
|
| 127 |
+
| 1. Contrastive Pretraining | Triplet + InfoNCE loss on same-law article pairs | 5 | ~1M (projection head) |
|
| 128 |
+
| 2. Meta-Training | HyperNetwork learns to adapt embedding space for future laws | 50 (early stop) | ~4M (HyperNetwork + State Encoder) |
|
| 129 |
+
|
| 130 |
+
### Hyperparameters
|
| 131 |
+
|
| 132 |
+
| Parameter | Value |
|
| 133 |
+
|---|---|
|
| 134 |
+
| Backbone | `bkai-foundation-models/vietnamese-bi-encoder` |
|
| 135 |
+
| Embedding dimension | 768 |
|
| 136 |
+
| Adaptation rank | 64 |
|
| 137 |
+
| GNN layers | 3 |
|
| 138 |
+
| Meta N-way, K-shot | 16-way, 5-shot |
|
| 139 |
+
| Negatives per query | 256 (50% hard + 50% random) |
|
| 140 |
+
| Temperature | 0.05 |
|
| 141 |
+
| Optimizer | AdamW + CosineAnnealingWarmRestarts |
|
| 142 |
+
|
| 143 |
+
### Hardware
|
| 144 |
+
- GPU: NVIDIA RTX 5070 Ti (16GB VRAM)
|
| 145 |
+
- Training time: ~8 hours (5 contrastive + 50 meta epochs)
|
| 146 |
+
|
| 147 |
+
---
|
| 148 |
+
|
| 149 |
+
## Continuous Adaptation
|
| 150 |
+
|
| 151 |
+
When a new law is published, TELEN adapts without retraining:
|
| 152 |
+
|
| 153 |
+
```python
|
| 154 |
+
# New law arrives
|
| 155 |
+
new_articles = [
|
| 156 |
+
"Δiα»u 1: LuαΊt mα»i vα» trΓ tuα» nhΓ’n tαΊ‘o...",
|
| 157 |
+
"Δiα»u 2: CΓ‘c nguyΓͺn tαΊ―c Γ‘p dα»₯ng AI trong xΓ©t xα»...",
|
| 158 |
+
]
|
| 159 |
+
|
| 160 |
+
# Update concept graph (milliseconds)
|
| 161 |
+
model.add_new_law("123/2025/l-ai", new_articles)
|
| 162 |
+
|
| 163 |
+
# Embedding space automatically adapts via HyperNetwork
|
| 164 |
+
# All subsequent query embeddings reflect the new legal landscape
|
| 165 |
+
embeddings = model.encode(["Δiα»u 1: ..."])
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
---
|
| 169 |
+
|
| 170 |
+
## Project Structure
|
| 171 |
+
|
| 172 |
+
```
|
| 173 |
+
law-embedding/
|
| 174 |
+
βββ dataset/
|
| 175 |
+
β βββ train-00000-of-00001.parquet # Training data (61K legal articles)
|
| 176 |
+
βββ src/
|
| 177 |
+
β βββ data.py # Data loading utilities
|
| 178 |
+
β βββ telern/
|
| 179 |
+
β βββ config.py # Configuration
|
| 180 |
+
β βββ model.py # TELEN architecture
|
| 181 |
+
β βββ concept_graph.py # Legal Concept Graph + GNN
|
| 182 |
+
β βββ hypernetwork.py # HyperNetwork + StateEncoder
|
| 183 |
+
β βββ evaluate.py # Evaluation metrics & baselines
|
| 184 |
+
βββ data/checkpoints/telen/
|
| 185 |
+
β βββ telen_best.pt # Pretrained model weights
|
| 186 |
+
βββ train.py # Training script
|
| 187 |
+
βββ train_ce.py # Cross-encoder training (optional)
|
| 188 |
+
βββ eval.py # Evaluation script
|
| 189 |
+
βββ inference.py # Inference API
|
| 190 |
+
βββ requirements.txt
|
| 191 |
+
βββ README.md
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
---
|
| 195 |
+
|
| 196 |
+
## Citation
|
| 197 |
+
|
| 198 |
+
```bibtex
|
| 199 |
+
@misc{telen2025,
|
| 200 |
+
title={TELEN: Temporal Evolving Legal Embedding Network for Vietnamese Law},
|
| 201 |
+
author={dangdinh},
|
| 202 |
+
year={2026},
|
| 203 |
+
publisher={Huggingface},
|
| 204 |
+
}
|
| 205 |
+
```
|
| 206 |
+
|
| 207 |
+
## License
|
| 208 |
+
|
| 209 |
+
MIT License β see [LICENSE](LICENSE) file for details.
|
| 210 |
+
|
| 211 |
+
## Acknowledgments
|
| 212 |
+
|
| 213 |
+
- `bkai-foundation-models/vietnamese-bi-encoder` β backbone bi-encoder
|
| 214 |
+
- `huyydangg/DEk21_hcmute_embedding` β baseline comparison (previous SOTA)
|
| 215 |
+
- `vinai/phobert-base-v2` β used in cross-encoder re-ranker
|
eval.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluate TELEN with full benchmarks.
|
| 3 |
+
|
| 4 |
+
Metrics: NDCG@3, NDCG@5, NDCG@10, MRR@3, MRR@5, MRR@10
|
| 5 |
+
|
| 6 |
+
Baselines:
|
| 7 |
+
- BM25 (lexical retrieval)
|
| 8 |
+
- Frozen PhoBERT (vinai/phobert-base-v2)
|
| 9 |
+
- DEk21 (huyydangg/DEk21_hcmute_embedding)
|
| 10 |
+
- TELEN (ours)
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
python eval.py
|
| 14 |
+
"""
|
| 15 |
+
import sys; sys.path.insert(0, ".")
|
| 16 |
+
sys.stdout.reconfigure(encoding='utf-8')
|
| 17 |
+
import warnings; warnings.filterwarnings("ignore")
|
| 18 |
+
import random, numpy as np, torch, torch.nn.functional as F
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
from sentence_transformers import SentenceTransformer
|
| 22 |
+
from pyvi import ViTokenizer
|
| 23 |
+
|
| 24 |
+
from src.telern.config import TELENConfig
|
| 25 |
+
from src.telern.model import create_model
|
| 26 |
+
from src.telern.evaluate import (
|
| 27 |
+
BM25Baseline, FrozenPhoBERT, prepare_test_data,
|
| 28 |
+
build_test_queries, build_test_corpus, compute_metrics, evaluate_bm25,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
SEED = 42; random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
|
| 32 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 33 |
+
config = TELENConfig()
|
| 34 |
+
|
| 35 |
+
def wseg(text):
|
| 36 |
+
return ViTokenizer.tokenize(text.replace("_", " "))
|
| 37 |
+
|
| 38 |
+
def evaluate_model(name, encode_fn, queries, corpus, corpus_ids, corpus_law_ids):
|
| 39 |
+
"""Generic evaluation for any embedding model."""
|
| 40 |
+
print(f"\n [{name}] Encoding corpus ({len(corpus)} docs)...")
|
| 41 |
+
c_embs = []
|
| 42 |
+
for i in range(0, len(corpus), 64):
|
| 43 |
+
batch = [d["text"] for d in corpus[i:i+64]]
|
| 44 |
+
embs = encode_fn(batch)
|
| 45 |
+
if isinstance(embs, np.ndarray): embs = torch.tensor(embs)
|
| 46 |
+
c_embs.append(embs.cpu())
|
| 47 |
+
c_embs = torch.cat(c_embs, dim=0)
|
| 48 |
+
|
| 49 |
+
print(f" [{name}] Evaluating {len(queries)} queries...")
|
| 50 |
+
all_m = defaultdict(list)
|
| 51 |
+
for q in tqdm(queries, desc=f" {name}"):
|
| 52 |
+
q_emb = encode_fn([q["query_text"]])
|
| 53 |
+
if isinstance(q_emb, np.ndarray): q_emb = torch.tensor(q_emb)
|
| 54 |
+
sim = F.cosine_similarity(q_emb.cpu(), c_embs).numpy()
|
| 55 |
+
rel = np.array([1.0 if corpus_law_ids[j]==q["law_id"] else 0.0 for j in range(len(corpus))])
|
| 56 |
+
si = sim.argsort()[::-1]; sr = rel[si]
|
| 57 |
+
for j,cid in enumerate(corpus_ids):
|
| 58 |
+
if cid==q["query_id"]:
|
| 59 |
+
p=np.where(si==j)[0]; sr=np.delete(sr,p[0]) if len(p)>0 else None; break
|
| 60 |
+
for k in [3,5,10]:
|
| 61 |
+
for mn,mv in compute_metrics(sr[:k],[k]).items(): all_m[mn].append(mv)
|
| 62 |
+
return {n: np.mean(v) for n,v in all_m.items()}
|
| 63 |
+
|
| 64 |
+
# ββ Data ββ
|
| 65 |
+
test_df = prepare_test_data(config)
|
| 66 |
+
queries = build_test_queries(test_df, max_queries=300)
|
| 67 |
+
corpus = build_test_corpus(test_df)
|
| 68 |
+
corpus_ids = [d["article_id"] for d in corpus]
|
| 69 |
+
corpus_law_ids = [d["law_id"] for d in corpus]
|
| 70 |
+
train_df = test_df[test_df["year"] <= config.meta.train_split_year]
|
| 71 |
+
print(f"Test: {len(queries)} queries, {len(corpus)} docs, {test_df['law_id'].nunique()} laws")
|
| 72 |
+
|
| 73 |
+
results = {}
|
| 74 |
+
|
| 75 |
+
# ββ BM25 ββ
|
| 76 |
+
print("\n[1/4] BM25")
|
| 77 |
+
results["BM25"] = evaluate_bm25(queries, corpus)
|
| 78 |
+
|
| 79 |
+
# ββ PhoBERT ββ
|
| 80 |
+
print("\n[2/4] Frozen PhoBERT")
|
| 81 |
+
phobert = FrozenPhoBERT()
|
| 82 |
+
results["PhoBERT"] = evaluate_model("PhoBERT", lambda texts: phobert.encode(texts, batch_size=64), queries, corpus, corpus_ids, corpus_law_ids)
|
| 83 |
+
|
| 84 |
+
# ββ DEk21 ββ
|
| 85 |
+
print("\n[3/4] DEk21 (SOTA)")
|
| 86 |
+
dek21 = SentenceTransformer("huyydangg/DEk21_hcmute_embedding", device=device)
|
| 87 |
+
results["DEk21"] = evaluate_model("DEk21", lambda texts: dek21.encode([wseg(t) for t in texts], batch_size=64, show_progress_bar=False, normalize_embeddings=True, convert_to_tensor=True), queries, corpus, corpus_ids, corpus_law_ids)
|
| 88 |
+
|
| 89 |
+
# ββ TELEN ββ
|
| 90 |
+
print("\n[4/4] TELEN (Ours)")
|
| 91 |
+
telen = create_model(config).to(device)
|
| 92 |
+
ckpt = torch.load(config.output_dir + "/telen_best.pt", map_location=device, weights_only=False)
|
| 93 |
+
telen.hypernetwork.load_state_dict(ckpt["hypernetwork"])
|
| 94 |
+
telen.state_encoder.load_state_dict(ckpt["state_encoder"])
|
| 95 |
+
telen.base_projection.load_state_dict(ckpt["base_projection"])
|
| 96 |
+
telen.attn_query.data.copy_(ckpt["attn_query"])
|
| 97 |
+
if len(train_df) > 0: telen.build_graph(train_df)
|
| 98 |
+
|
| 99 |
+
def telen_encode(texts):
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
return telen(texts, use_stochastic=False)["embeddings"].cpu()
|
| 102 |
+
|
| 103 |
+
results["TELEN"] = evaluate_model("TELEN", telen_encode, queries, corpus, corpus_ids, corpus_law_ids)
|
| 104 |
+
|
| 105 |
+
# ββ Summary ββ
|
| 106 |
+
print("\n" + "=" * 75)
|
| 107 |
+
print("BENCHMARK RESULTS")
|
| 108 |
+
print("=" * 75)
|
| 109 |
+
h = f"{'Method':<15}"
|
| 110 |
+
for m in [3,5,10]: h += f" {'NDCG@'+str(m):>10} {'MRR@'+str(m):>10}"
|
| 111 |
+
print(h); print("-"*len(h))
|
| 112 |
+
for name in ["BM25", "PhoBERT", "DEk21", "TELEN"]:
|
| 113 |
+
r = f"{name:<15}"
|
| 114 |
+
for m in [3,5,10]: r += f" {results[name][f'ndcg@{m}']:>10.4f} {results[name][f'mrr@{m}']:>10.4f}"
|
| 115 |
+
print(r)
|
| 116 |
+
|
| 117 |
+
print("\n--- Relative Improvement over Baselines ---")
|
| 118 |
+
for baseline in ["PhoBERT", "DEk21"]:
|
| 119 |
+
print(f" TELEN vs {baseline}:")
|
| 120 |
+
for m in [3,5,10]:
|
| 121 |
+
ni = (results["TELEN"][f"ndcg@{m}"] / max(results[baseline][f"ndcg@{m}"], 1e-6) - 1) * 100
|
| 122 |
+
mi = (results["TELEN"][f"mrr@{m}"] / max(results[baseline][f"mrr@{m}"], 1e-6) - 1) * 100
|
| 123 |
+
print(f" NDCG@{m}: {ni:+.1f}% MRR@{m}: {mi:+.1f}%")
|
| 124 |
+
print("Done!")
|
inference.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TELEN Inference β encode legal texts to 768-dim embeddings.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
from inference import TELENInference
|
| 6 |
+
model = TELENInference()
|
| 7 |
+
embeddings = model.encode(["Δiα»u 1: ThΓ΄ng tΖ° nΓ y quy Δα»nh vα»..."])
|
| 8 |
+
similarity = model.similarity(text1, text2)
|
| 9 |
+
"""
|
| 10 |
+
import sys; sys.path.insert(0, ".")
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from pyvi import ViTokenizer
|
| 14 |
+
|
| 15 |
+
from src.telern.config import TELENConfig
|
| 16 |
+
from src.telern.model import create_model
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TELENInference:
|
| 20 |
+
def __init__(self, checkpoint_path: str = None):
|
| 21 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
self.config = TELENConfig()
|
| 23 |
+
self.model = create_model(self.config).to(self.device)
|
| 24 |
+
|
| 25 |
+
if checkpoint_path is None:
|
| 26 |
+
checkpoint_path = self.config.output_dir + "/telen_best.pt"
|
| 27 |
+
|
| 28 |
+
ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
|
| 29 |
+
self.model.hypernetwork.load_state_dict(ckpt["hypernetwork"])
|
| 30 |
+
self.model.state_encoder.load_state_dict(ckpt["state_encoder"])
|
| 31 |
+
self.model.base_projection.load_state_dict(ckpt["base_projection"])
|
| 32 |
+
self.model.attn_query.data.copy_(ckpt["attn_query"])
|
| 33 |
+
self.model.eval()
|
| 34 |
+
|
| 35 |
+
print(f"TELEN loaded on {self.device}")
|
| 36 |
+
print(f" HyperNetwork: {sum(p.numel() for p in self.model.hypernetwork.parameters()):,} params")
|
| 37 |
+
print(f" Ready for inference.")
|
| 38 |
+
|
| 39 |
+
def build_graph(self, df):
|
| 40 |
+
"""Build concept graph from a DataFrame with [id, title, text, law_id, law_type, year] columns."""
|
| 41 |
+
self.model.build_graph(df)
|
| 42 |
+
|
| 43 |
+
def encode(self, texts: list, batch_size: int = 64) -> torch.Tensor:
|
| 44 |
+
"""Encode a list of legal texts to 768-dim normalized embeddings."""
|
| 45 |
+
embeddings = []
|
| 46 |
+
for i in range(0, len(texts), batch_size):
|
| 47 |
+
batch = texts[i:i + batch_size]
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
result = self.model(batch, use_stochastic=False)
|
| 50 |
+
embeddings.append(result["embeddings"].cpu())
|
| 51 |
+
return torch.cat(embeddings, dim=0)
|
| 52 |
+
|
| 53 |
+
def similarity(self, text1: str, text2: str) -> float:
|
| 54 |
+
"""Compute cosine similarity between two texts."""
|
| 55 |
+
emb = self.encode([text1, text2])
|
| 56 |
+
return F.cosine_similarity(emb[0:1], emb[1:2]).item()
|
| 57 |
+
|
| 58 |
+
def retrieve(self, query: str, corpus: list, top_k: int = 10) -> list:
|
| 59 |
+
"""Retrieve top-k most similar documents from a corpus."""
|
| 60 |
+
query_emb = self.encode([query])
|
| 61 |
+
corpus_embs = self.encode(corpus)
|
| 62 |
+
sim = F.cosine_similarity(query_emb, corpus_embs).numpy()
|
| 63 |
+
top_indices = sim.argsort()[::-1][:top_k]
|
| 64 |
+
return [(int(i), float(sim[i])) for i in top_indices]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ββ Demo ββ
|
| 68 |
+
if __name__ == "__main__":
|
| 69 |
+
model = TELENInference()
|
| 70 |
+
|
| 71 |
+
# Example queries
|
| 72 |
+
q1 = "Δiα»u 1: ThΓ΄ng tΖ° nΓ y quy Δα»nh vα» quαΊ£n lΓ½ thuαΊΏ giΓ‘ trα» gia tΔng Δα»i vα»i hΓ ng hΓ³a nhαΊp khαΊ©u"
|
| 73 |
+
q2 = "Δiα»u 2: Δα»i tượng Γ‘p dα»₯ng lΓ cΓ‘c tα» chα»©c, cΓ‘ nhΓ’n kinh doanh hΓ ng hΓ³a nhαΊp khαΊ©u"
|
| 74 |
+
q3 = "Δiα»u 1: Nghα» Δα»nh nΓ y quy Δα»nh vα» xα» phαΊ‘t vi phαΊ‘m hΓ nh chΓnh trong lΔ©nh vα»±c giao thΓ΄ng"
|
| 75 |
+
|
| 76 |
+
print(f"\nSimilarity test:")
|
| 77 |
+
print(f" q1 vs q2 (same law): {model.similarity(q1, q2):.4f}")
|
| 78 |
+
print(f" q1 vs q3 (diff law): {model.similarity(q1, q3):.4f}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0.0
|
| 2 |
+
transformers>=4.40.0
|
| 3 |
+
sentence-transformers>=3.0.0
|
| 4 |
+
peft>=0.10.0
|
| 5 |
+
pandas>=2.0.0
|
| 6 |
+
pyarrow>=14.0.0
|
| 7 |
+
scikit-learn>=1.3.0
|
| 8 |
+
tqdm>=4.65.0
|
| 9 |
+
numpy>=1.24.0
|
| 10 |
+
pyvi>=0.1.0
|
| 11 |
+
accelerate>=0.24.0
|
src/__init__.py
ADDED
|
File without changes
|
src/data.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared data utilities used by TELEN modules."""
|
| 2 |
+
import unicodedata
|
| 3 |
+
import pandas as pd
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def load_raw_data(parquet_path: str) -> pd.DataFrame:
|
| 7 |
+
"""Load the raw parquet file."""
|
| 8 |
+
return pd.read_parquet(parquet_path)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def extract_metadata(df: pd.DataFrame) -> pd.DataFrame:
|
| 12 |
+
"""Extract law_id, article_num, law_type, year from id column."""
|
| 13 |
+
df = df.copy()
|
| 14 |
+
|
| 15 |
+
def parse_id(id_str):
|
| 16 |
+
if "#" in id_str:
|
| 17 |
+
parts = id_str.split("#")
|
| 18 |
+
law_id = parts[0]
|
| 19 |
+
article_part = parts[1]
|
| 20 |
+
article_num = int(article_part.split("-")[0])
|
| 21 |
+
else:
|
| 22 |
+
law_id = id_str
|
| 23 |
+
article_num = 0
|
| 24 |
+
return law_id, article_num
|
| 25 |
+
|
| 26 |
+
parsed = df["id"].apply(parse_id)
|
| 27 |
+
df["law_id"] = parsed.apply(lambda x: x[0])
|
| 28 |
+
df["article_num"] = parsed.apply(lambda x: x[1])
|
| 29 |
+
|
| 30 |
+
def extract_law_type(law_id):
|
| 31 |
+
parts = law_id.split("/")
|
| 32 |
+
if len(parts) >= 3:
|
| 33 |
+
return parts[2].split("-")[-1] if "-" in parts[2] else parts[2]
|
| 34 |
+
return "unknown"
|
| 35 |
+
|
| 36 |
+
df["law_type"] = df["law_id"].apply(extract_law_type)
|
| 37 |
+
|
| 38 |
+
def extract_year(law_id):
|
| 39 |
+
parts = law_id.split("/")
|
| 40 |
+
if len(parts) >= 2:
|
| 41 |
+
year_str = parts[1]
|
| 42 |
+
try:
|
| 43 |
+
year = int(year_str)
|
| 44 |
+
return year if year >= 100 else year + 1900
|
| 45 |
+
except ValueError:
|
| 46 |
+
pass
|
| 47 |
+
return 1999
|
| 48 |
+
|
| 49 |
+
df["year"] = df["law_id"].apply(extract_year)
|
| 50 |
+
return df
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def clean_data(df: pd.DataFrame, min_text_len: int = 10) -> pd.DataFrame:
|
| 54 |
+
"""Remove short/empty texts and duplicates."""
|
| 55 |
+
df = df.copy()
|
| 56 |
+
df = df[df["text"].str.len() >= min_text_len].reset_index(drop=True)
|
| 57 |
+
df["title"] = df["title"].apply(lambda x: unicodedata.normalize("NFC", str(x)))
|
| 58 |
+
df["text"] = df["text"].apply(lambda x: unicodedata.normalize("NFC", str(x)))
|
| 59 |
+
df = df.drop_duplicates(subset=["text"], keep="first").reset_index(drop=True)
|
| 60 |
+
return df
|
src/telern/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""TELEN: Temporal Evolving Legal Embedding Network."""
|
src/telern/concept_graph.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Legal Concept Graph β evolving knowledge backbone of TELEN.
|
| 3 |
+
|
| 4 |
+
Nodes: law entities + key terms extracted via TF-IDF
|
| 5 |
+
Edges: agency, temporal, semantic, cross-reference, term-document
|
| 6 |
+
GNN: Multi-layer sparse graph convolution
|
| 7 |
+
"""
|
| 8 |
+
import re
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
+
# GNN Layers
|
| 17 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 18 |
+
class GCNLayer(nn.Module):
|
| 19 |
+
def __init__(self, in_dim, out_dim, dropout=0.1):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.linear = nn.Linear(in_dim, out_dim)
|
| 22 |
+
self.dropout = nn.Dropout(dropout)
|
| 23 |
+
self.norm = nn.LayerNorm(out_dim)
|
| 24 |
+
|
| 25 |
+
def forward(self, x, adj):
|
| 26 |
+
deg = adj.sum(dim=1).clamp(min=1)
|
| 27 |
+
deg_inv_sqrt = deg.pow(-0.5)
|
| 28 |
+
norm_adj = deg_inv_sqrt.unsqueeze(1) * adj * deg_inv_sqrt.unsqueeze(0)
|
| 29 |
+
x = norm_adj @ x
|
| 30 |
+
x = self.linear(x)
|
| 31 |
+
x = F.relu(x)
|
| 32 |
+
x = self.dropout(x)
|
| 33 |
+
x = self.norm(x)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class GNNEncoder(nn.Module):
|
| 38 |
+
def __init__(self, dim, n_layers=3, dropout=0.1):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.layers = nn.ModuleList([GCNLayer(dim, dim, dropout) for _ in range(n_layers)])
|
| 41 |
+
|
| 42 |
+
def forward(self, x, adj):
|
| 43 |
+
for layer in self.layers:
|
| 44 |
+
x = layer(x, adj) + x # residual
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 49 |
+
# Legal Concept Graph
|
| 50 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 51 |
+
class LegalConceptGraph(nn.Module):
|
| 52 |
+
def __init__(self, config):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.config = config
|
| 55 |
+
self.hidden_dim = config.graph.hidden_dim
|
| 56 |
+
|
| 57 |
+
self.node_ids = []
|
| 58 |
+
self.node_embeddings = None
|
| 59 |
+
self.edges = {"cross_ref": [], "agency": [], "temporal": [], "semantic": []}
|
| 60 |
+
self._adj_cached = None
|
| 61 |
+
self._adj_dirty = True
|
| 62 |
+
|
| 63 |
+
self.gnn = GNNEncoder(config.graph.hidden_dim, config.graph.gnn_layers, config.graph.gnn_dropout)
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def num_nodes(self):
|
| 67 |
+
return len(self.node_ids)
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def device(self):
|
| 71 |
+
return self.gnn.layers[0].linear.weight.device
|
| 72 |
+
|
| 73 |
+
def add_nodes(self, node_ids, embeddings):
|
| 74 |
+
if self.node_embeddings is None:
|
| 75 |
+
self.node_embeddings = embeddings
|
| 76 |
+
else:
|
| 77 |
+
self.node_embeddings = torch.cat([self.node_embeddings, embeddings], dim=0)
|
| 78 |
+
self.node_ids.extend(node_ids)
|
| 79 |
+
self._adj_dirty = True
|
| 80 |
+
|
| 81 |
+
def add_edges(self, edge_type, edges):
|
| 82 |
+
self.edges[edge_type].extend(edges)
|
| 83 |
+
self._adj_dirty = True
|
| 84 |
+
|
| 85 |
+
def build_adjacency(self):
|
| 86 |
+
if not self._adj_dirty and self._adj_cached is not None:
|
| 87 |
+
return self._adj_cached
|
| 88 |
+
N = self.num_nodes
|
| 89 |
+
adj = torch.zeros(N, N, device=self.device)
|
| 90 |
+
|
| 91 |
+
for edge_type, use in [("cross_ref", self.config.graph.use_cross_ref_edges),
|
| 92 |
+
("agency", self.config.graph.use_agency_edges),
|
| 93 |
+
("temporal", self.config.graph.use_temporal_edges),
|
| 94 |
+
("semantic", self.config.graph.use_semantic_edges)]:
|
| 95 |
+
if not use or not self.edges[edge_type]:
|
| 96 |
+
continue
|
| 97 |
+
valid = [(s, d, w) for s, d, w in self.edges[edge_type] if s < N and d < N]
|
| 98 |
+
if not valid:
|
| 99 |
+
continue
|
| 100 |
+
src = torch.tensor([e[0] for e in valid], device=self.device, dtype=torch.long)
|
| 101 |
+
dst = torch.tensor([e[1] for e in valid], device=self.device, dtype=torch.long)
|
| 102 |
+
wgt = torch.tensor([e[2] for e in valid], device=self.device, dtype=torch.float)
|
| 103 |
+
adj.index_put_((src, dst), wgt, accumulate=True)
|
| 104 |
+
adj.index_put_((dst, src), wgt, accumulate=True)
|
| 105 |
+
|
| 106 |
+
adj = adj + torch.eye(N, device=self.device)
|
| 107 |
+
self._adj_cached = adj
|
| 108 |
+
self._adj_dirty = False
|
| 109 |
+
return adj
|
| 110 |
+
|
| 111 |
+
def forward(self):
|
| 112 |
+
dev = self.device
|
| 113 |
+
if self.node_embeddings.device != dev:
|
| 114 |
+
self.node_embeddings = self.node_embeddings.to(dev)
|
| 115 |
+
adj = self.build_adjacency()
|
| 116 |
+
return self.gnn(self.node_embeddings, adj)
|
| 117 |
+
|
| 118 |
+
def to(self, device):
|
| 119 |
+
super().to(device)
|
| 120 |
+
if self.node_embeddings is not None:
|
| 121 |
+
self.node_embeddings = self.node_embeddings.to(device)
|
| 122 |
+
return self
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 126 |
+
# Cross-reference extraction
|
| 127 |
+
# ββββββββββββββββββββββββββββοΏ½οΏ½ββββββββββββββββββ
|
| 128 |
+
CROSS_REF_PATTERNS = [
|
| 129 |
+
(re.compile(r"(?:theo|theo quy Δα»nh tαΊ‘i|cΔn cα»© vΓ o|cΔn cα»©)\s+Δiα»u\s+(\d+)\s+(?:cα»§a\s+)?(LuαΊt|Bα» luαΊt|Nghα» Δα»nh|ThΓ΄ng tΖ°|PhΓ‘p lα»nh)\s+([^,.;]+)"), "citation"),
|
| 130 |
+
(re.compile(r"(LuαΊt|Bα» luαΊt|Nghα» Δα»nh|ThΓ΄ng tΖ°|PhΓ‘p lα»nh|QuyαΊΏt Δα»nh)\s+(?:sα»\s+)?([\d]+/[\d]+/[\w-]+)"), "reference"),
|
| 131 |
+
(re.compile(r"sα»a Δα»i[,οΌ]\s*bα» sung\s+(?:mα»t sα» Δiα»u cα»§a\s+)?(LuαΊt|Nghα» Δα»nh|ThΓ΄ng tΖ°)\s+([^,.;]+)"), "amendment"),
|
| 132 |
+
(re.compile(r"(?:thay thαΊΏ|bΓ£i bα»)\s+(?:Δiα»u\s+(\d+)\s+(?:cα»§a\s+)?)?(LuαΊt|Nghα» Δα»nh|ThΓ΄ng tΖ°)\s+([^,.;]+)"), "replacement"),
|
| 133 |
+
]
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def extract_key_terms(df, max_terms=200):
|
| 137 |
+
texts = [(f"{row['title']} {row['text'][:500]}").replace("_", " ")
|
| 138 |
+
for _, row in df.iterrows()]
|
| 139 |
+
vectorizer = TfidfVectorizer(max_features=max_terms, ngram_range=(1, 2),
|
| 140 |
+
min_df=3, max_df=0.8, token_pattern=r'(?u)\b\w+\b')
|
| 141 |
+
tfidf = vectorizer.fit_transform(texts)
|
| 142 |
+
scores = tfidf.max(axis=0).toarray().flatten()
|
| 143 |
+
return list(vectorizer.get_feature_names_out()[scores.argsort()[::-1][:max_terms]])
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def _law_matches_ref(law_id, ref_text):
|
| 147 |
+
law_lower = law_id.lower().replace("_", " ").replace("-", " ")
|
| 148 |
+
ref_lower = ref_text.lower().replace("_", " ").replace("-", " ")
|
| 149 |
+
parts = law_id.split("/")
|
| 150 |
+
if len(parts) >= 3:
|
| 151 |
+
if parts[2].replace("_", " ") in ref_lower: return True
|
| 152 |
+
if len(parts) >= 2 and parts[1] in ref_lower: return True
|
| 153 |
+
return False
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def build_concept_graph(df, encode_fn, config):
|
| 157 |
+
"""Build enhanced concept graph from training data."""
|
| 158 |
+
graph = LegalConceptGraph(config)
|
| 159 |
+
law_groups = df.groupby("law_id")
|
| 160 |
+
law_ids = sorted(law_groups.groups.keys())
|
| 161 |
+
N_laws = len(law_ids)
|
| 162 |
+
print(f" Building graph: {N_laws} law nodes...")
|
| 163 |
+
|
| 164 |
+
# Law embeddings
|
| 165 |
+
embs = []
|
| 166 |
+
for lid in law_ids:
|
| 167 |
+
group = law_groups.get_group(lid)
|
| 168 |
+
texts = [f"{t}: {txt[:300]}" for t, txt in zip(group["title"], group["text"])]
|
| 169 |
+
embs.append(torch.stack([encode_fn(t) for t in texts[:5]]).mean(dim=0))
|
| 170 |
+
law_embs = torch.stack(embs)
|
| 171 |
+
graph.add_nodes(law_ids, law_embs)
|
| 172 |
+
law_id_to_idx = {lid: i for i, lid in enumerate(law_ids)}
|
| 173 |
+
|
| 174 |
+
# Key term nodes
|
| 175 |
+
print(" Extracting key terms...")
|
| 176 |
+
key_terms = extract_key_terms(df, max_terms=200)
|
| 177 |
+
term_embs = torch.stack([encode_fn(t) for t in key_terms])
|
| 178 |
+
graph.add_nodes([f"TERM:{t}" for t in key_terms], term_embs)
|
| 179 |
+
print(f" {len(key_terms)} key terms")
|
| 180 |
+
|
| 181 |
+
# Agency edges
|
| 182 |
+
agency_edges = []
|
| 183 |
+
for _, group in df.groupby("law_type"):
|
| 184 |
+
same = group["law_id"].unique()
|
| 185 |
+
for i in range(len(same)):
|
| 186 |
+
for j in range(i + 1, len(same)):
|
| 187 |
+
if same[i] in law_id_to_idx and same[j] in law_id_to_idx:
|
| 188 |
+
agency_edges.append((law_id_to_idx[same[i]], law_id_to_idx[same[j]], 0.3))
|
| 189 |
+
graph.add_edges("agency", agency_edges)
|
| 190 |
+
print(f" Agency edges: {len(agency_edges)}")
|
| 191 |
+
|
| 192 |
+
# Temporal edges
|
| 193 |
+
temporal_edges = []
|
| 194 |
+
for _, group in df.groupby("law_type"):
|
| 195 |
+
yl = group.groupby("year")["law_id"].unique()
|
| 196 |
+
for y1, y2 in zip(sorted(yl.keys()), sorted(yl.keys())[1:]):
|
| 197 |
+
for l1 in yl[y1]:
|
| 198 |
+
for l2 in yl[y2]:
|
| 199 |
+
if l1 in law_id_to_idx and l2 in law_id_to_idx:
|
| 200 |
+
temporal_edges.append((law_id_to_idx[l1], law_id_to_idx[l2], 0.2))
|
| 201 |
+
graph.add_edges("temporal", temporal_edges)
|
| 202 |
+
print(f" Temporal edges: {len(temporal_edges)}")
|
| 203 |
+
|
| 204 |
+
# Semantic edges (chunked k-NN)
|
| 205 |
+
semantic_k = min(config.graph.semantic_knn, N_laws - 1)
|
| 206 |
+
semantic_edges = []
|
| 207 |
+
if N_laws > 1:
|
| 208 |
+
chunk = 64
|
| 209 |
+
for i in range(0, N_laws, chunk):
|
| 210 |
+
end = min(i + chunk, N_laws)
|
| 211 |
+
sim = F.cosine_similarity(law_embs[i:end].unsqueeze(1), law_embs.unsqueeze(0), dim=2)
|
| 212 |
+
for j in range(sim.shape[0]):
|
| 213 |
+
sim[j, i + j] = float("-inf")
|
| 214 |
+
vals, idx = sim.topk(k=semantic_k, dim=1)
|
| 215 |
+
for j in range(sim.shape[0]):
|
| 216 |
+
for kk in range(semantic_k):
|
| 217 |
+
semantic_edges.append((i + j, idx[j, kk].item(), vals[j, kk].item()))
|
| 218 |
+
graph.add_edges("semantic", semantic_edges)
|
| 219 |
+
print(f" Semantic edges: {len(semantic_edges)}")
|
| 220 |
+
|
| 221 |
+
# Cross-reference edges
|
| 222 |
+
cross_ref_edges = []
|
| 223 |
+
for _, row in df.iterrows():
|
| 224 |
+
src = row["law_id"]
|
| 225 |
+
if src not in law_id_to_idx: continue
|
| 226 |
+
for pattern, etype in CROSS_REF_PATTERNS:
|
| 227 |
+
for match in pattern.findall(row["text"]):
|
| 228 |
+
match_str = " ".join(match).lower() if isinstance(match, tuple) else str(match).lower()
|
| 229 |
+
for tgt in law_ids:
|
| 230 |
+
if tgt != src and _law_matches_ref(tgt, match_str):
|
| 231 |
+
cross_ref_edges.append((law_id_to_idx[src], law_id_to_idx[tgt], 0.5))
|
| 232 |
+
break
|
| 233 |
+
graph.add_edges("cross_ref", cross_ref_edges)
|
| 234 |
+
print(f" Cross-ref edges: {len(cross_ref_edges)}")
|
| 235 |
+
|
| 236 |
+
# Term-document edges
|
| 237 |
+
term_doc_edges = []
|
| 238 |
+
law_texts = [(f"{row['title']} {row['text'][:300]}").replace("_", " ")
|
| 239 |
+
for _, row in df.iterrows()]
|
| 240 |
+
vec = TfidfVectorizer(vocabulary=key_terms if key_terms else None)
|
| 241 |
+
try:
|
| 242 |
+
tfidf = vec.fit_transform(law_texts)
|
| 243 |
+
for ti, term in enumerate(key_terms):
|
| 244 |
+
if ti < tfidf.shape[1]:
|
| 245 |
+
col = tfidf[:, ti].toarray().flatten()
|
| 246 |
+
for lp in col.argsort()[::-1][:10]:
|
| 247 |
+
if col[lp] > 0.1 and lp < N_laws:
|
| 248 |
+
term_doc_edges.append((N_laws + ti, lp, float(col[lp])))
|
| 249 |
+
except ValueError:
|
| 250 |
+
pass
|
| 251 |
+
graph.add_edges("semantic", term_doc_edges)
|
| 252 |
+
print(f" Term-doc edges: {len(term_doc_edges)}")
|
| 253 |
+
print(f" Total: {graph.num_nodes} nodes ({N_laws} laws + {len(key_terms)} terms)")
|
| 254 |
+
|
| 255 |
+
return graph, law_id_to_idx
|
src/telern/config.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TELEN configuration."""
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
ROOT = Path("E:/law-embedding")
|
| 7 |
+
DATA_DIR = ROOT / "dataset"
|
| 8 |
+
CHECKPOINT_DIR = ROOT / "data" / "checkpoints" / "telen"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class GraphConfig:
|
| 13 |
+
"""Legal Concept Graph configuration."""
|
| 14 |
+
hidden_dim: int = 768
|
| 15 |
+
gnn_layers: int = 3
|
| 16 |
+
gnn_dropout: float = 0.1
|
| 17 |
+
# Edge types
|
| 18 |
+
use_cross_ref_edges: bool = True
|
| 19 |
+
use_agency_edges: bool = True
|
| 20 |
+
use_temporal_edges: bool = True
|
| 21 |
+
use_semantic_edges: bool = True
|
| 22 |
+
semantic_knn: int = 10
|
| 23 |
+
# Concept extraction
|
| 24 |
+
max_concepts_per_article: int = 8
|
| 25 |
+
min_tfidf_score: float = 0.05
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class HyperNetworkConfig:
|
| 30 |
+
"""HyperNetwork that generates projection weights from legal state."""
|
| 31 |
+
adaptation_rank: int = 64 # Low-rank adaptation
|
| 32 |
+
hn_hidden_dim: int = 512
|
| 33 |
+
hn_layers: int = 3
|
| 34 |
+
dropout: float = 0.1
|
| 35 |
+
# What the HyperNetwork outputs
|
| 36 |
+
output_shift: bool = True # ΞW for projection
|
| 37 |
+
output_bias: bool = True # Ξb for projection
|
| 38 |
+
output_variance: bool = True # log ΟΒ² for stochastic embedding
|
| 39 |
+
min_variance: float = 0.01 # minimum variance
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class MetaTrainingConfig:
|
| 44 |
+
"""Meta-learning training configuration."""
|
| 45 |
+
meta_lr: float = 3e-4
|
| 46 |
+
inner_lr: float = 5e-3
|
| 47 |
+
meta_batch_size: int = 4 # episodes per meta-update
|
| 48 |
+
n_query: int = 32 # query articles per episode
|
| 49 |
+
n_negatives: int = 256 # negative articles per query
|
| 50 |
+
meta_epochs: int = 50
|
| 51 |
+
temperature: float = 0.05
|
| 52 |
+
# Temporal splits for meta-training
|
| 53 |
+
train_split_year: int = 2018
|
| 54 |
+
val_split_year: int = 2020
|
| 55 |
+
# State construction
|
| 56 |
+
max_state_articles: int = 500 # max articles to include in state
|
| 57 |
+
# Stochastic embedding
|
| 58 |
+
kl_weight: float = 0.001 # weight for KL regularization
|
| 59 |
+
n_mc_samples: int = 1 # Monte Carlo samples during training
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class TELENConfig:
|
| 64 |
+
"""Full TELEN configuration."""
|
| 65 |
+
backbone: str = "vinai/phobert-base-v2"
|
| 66 |
+
hidden_dim: int = 768
|
| 67 |
+
max_seq_length: int = 480
|
| 68 |
+
graph: GraphConfig = field(default_factory=GraphConfig)
|
| 69 |
+
hypernetwork: HyperNetworkConfig = field(default_factory=HyperNetworkConfig)
|
| 70 |
+
meta: MetaTrainingConfig = field(default_factory=MetaTrainingConfig)
|
| 71 |
+
output_dir: str = str(CHECKPOINT_DIR)
|
| 72 |
+
seed: int = 42
|
src/telern/evaluate.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation for TELEN: NDCG@k and MRR@k.
|
| 3 |
+
|
| 4 |
+
Metrics:
|
| 5 |
+
- NDCG@3, NDCG@5, NDCG@10
|
| 6 |
+
- MRR@3, MRR@5, MRR@10
|
| 7 |
+
|
| 8 |
+
Baselines:
|
| 9 |
+
- BM25 (lexical)
|
| 10 |
+
- Frozen PhoBERT + mean pooling
|
| 11 |
+
- TELEN (ours)
|
| 12 |
+
|
| 13 |
+
Evaluation setup:
|
| 14 |
+
- Query = article title + first 100 chars of text
|
| 15 |
+
- Relevant = other articles from the SAME law
|
| 16 |
+
- Corpus = all articles from test years (held-out)
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
import random
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import List, Dict, Tuple
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
import pandas as pd
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn.functional as F
|
| 29 |
+
from tqdm import tqdm
|
| 30 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 31 |
+
from transformers import AutoModel, AutoTokenizer
|
| 32 |
+
|
| 33 |
+
from .config import TELENConfig, DATA_DIR
|
| 34 |
+
from .model import TELEN, create_telen
|
| 35 |
+
from ..data import load_raw_data, extract_metadata, clean_data
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 39 |
+
# Metrics
|
| 40 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
|
| 42 |
+
def dcg_at_k(scores: np.ndarray, k: int) -> float:
|
| 43 |
+
"""Discounted Cumulative Gain at k."""
|
| 44 |
+
scores = np.asarray(scores)[:k]
|
| 45 |
+
if len(scores) == 0:
|
| 46 |
+
return 0.0
|
| 47 |
+
discounts = np.log2(np.arange(2, len(scores) + 2))
|
| 48 |
+
return np.sum((2.0**scores - 1) / discounts)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def ndcg_at_k(scores: np.ndarray, k: int) -> float:
|
| 52 |
+
"""Normalized DCG at k."""
|
| 53 |
+
ideal = np.sort(scores)[::-1]
|
| 54 |
+
dcg_val = dcg_at_k(scores, k)
|
| 55 |
+
idcg_val = dcg_at_k(ideal, k)
|
| 56 |
+
return dcg_val / idcg_val if idcg_val > 0 else 0.0
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def mrr_at_k(scores: np.ndarray, k: int) -> float:
|
| 60 |
+
"""Mean Reciprocal Rank at k."""
|
| 61 |
+
scores = np.asarray(scores)[:k]
|
| 62 |
+
for rank, s in enumerate(scores, start=1):
|
| 63 |
+
if s > 0:
|
| 64 |
+
return 1.0 / rank
|
| 65 |
+
return 0.0
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def compute_metrics(
|
| 69 |
+
relevance_scores: np.ndarray, k_values: List[int] = [3, 5, 10]
|
| 70 |
+
) -> Dict[str, float]:
|
| 71 |
+
"""Compute NDCG@k and MRR@k from relevance scores."""
|
| 72 |
+
metrics = {}
|
| 73 |
+
for k in k_values:
|
| 74 |
+
metrics[f"ndcg@{k}"] = ndcg_at_k(relevance_scores, k)
|
| 75 |
+
metrics[f"mrr@{k}"] = mrr_at_k(relevance_scores, k)
|
| 76 |
+
return metrics
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 80 |
+
# Evaluation
|
| 81 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
|
| 83 |
+
def prepare_test_data(config: TELENConfig):
|
| 84 |
+
"""Prepare test data from held-out years."""
|
| 85 |
+
print("Loading data...")
|
| 86 |
+
df = load_raw_data(str(DATA_DIR / "train-00000-of-00001.parquet"))
|
| 87 |
+
df = extract_metadata(df)
|
| 88 |
+
df = clean_data(df, min_text_len=10)
|
| 89 |
+
|
| 90 |
+
# Test split: articles from test years
|
| 91 |
+
test_years = range(config.meta.val_split_year + 1, 2025)
|
| 92 |
+
test_df = df[df["year"].isin(test_years)].reset_index(drop=True)
|
| 93 |
+
|
| 94 |
+
print(f" Test set: {len(test_df)} articles from {test_df['law_id'].nunique()} laws")
|
| 95 |
+
return test_df
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def build_test_queries(test_df: pd.DataFrame, max_queries: int = 500) -> List[Dict]:
|
| 99 |
+
"""Build query set from test articles."""
|
| 100 |
+
# Group by law_id
|
| 101 |
+
law_groups = test_df.groupby("law_id")
|
| 102 |
+
|
| 103 |
+
queries = []
|
| 104 |
+
for law_id, group in law_groups:
|
| 105 |
+
articles = group.to_dict("records")
|
| 106 |
+
if len(articles) < 3: # Need at least 1 query + 2 relevant
|
| 107 |
+
continue
|
| 108 |
+
# Use each article as a potential query
|
| 109 |
+
for article in articles[:2]: # Max 2 queries per law
|
| 110 |
+
queries.append({
|
| 111 |
+
"query_id": article["id"],
|
| 112 |
+
"query_text": f"{article['title']}: {article['text'][:500]}",
|
| 113 |
+
"query_full": article["text"],
|
| 114 |
+
"law_id": law_id,
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
if len(queries) > max_queries:
|
| 118 |
+
queries = random.sample(queries, max_queries)
|
| 119 |
+
|
| 120 |
+
print(f" Queries: {len(queries)}")
|
| 121 |
+
return queries
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def build_test_corpus(test_df: pd.DataFrame) -> List[Dict]:
|
| 125 |
+
"""Build corpus of all test articles for retrieval."""
|
| 126 |
+
corpus = []
|
| 127 |
+
for _, row in test_df.iterrows():
|
| 128 |
+
corpus.append({
|
| 129 |
+
"article_id": row["id"],
|
| 130 |
+
"text": f"{row['title']}: {row['text'][:500]}",
|
| 131 |
+
"law_id": row["law_id"],
|
| 132 |
+
})
|
| 133 |
+
print(f" Corpus: {len(corpus)} documents")
|
| 134 |
+
return corpus
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def evaluate_telen(
|
| 138 |
+
model: TELEN,
|
| 139 |
+
queries: List[Dict],
|
| 140 |
+
corpus: List[Dict],
|
| 141 |
+
batch_size: int = 64,
|
| 142 |
+
) -> Dict[str, float]:
|
| 143 |
+
"""
|
| 144 |
+
Evaluate TELEN on retrieval metrics.
|
| 145 |
+
|
| 146 |
+
For each query, rank all corpus documents by cosine similarity.
|
| 147 |
+
Relevance = article is from the same law.
|
| 148 |
+
"""
|
| 149 |
+
device = next(model.parameters()).device
|
| 150 |
+
model.eval()
|
| 151 |
+
|
| 152 |
+
# Encode corpus
|
| 153 |
+
print(" Encoding corpus...")
|
| 154 |
+
corpus_embeddings = []
|
| 155 |
+
corpus_ids = [doc["article_id"] for doc in corpus]
|
| 156 |
+
corpus_law_ids = [doc["law_id"] for doc in corpus]
|
| 157 |
+
|
| 158 |
+
for i in tqdm(range(0, len(corpus), batch_size), desc=" Corpus"):
|
| 159 |
+
batch = corpus[i:i + batch_size]
|
| 160 |
+
texts = [doc["text"] for doc in batch]
|
| 161 |
+
with torch.no_grad():
|
| 162 |
+
result = model(texts, use_stochastic=False)
|
| 163 |
+
corpus_embeddings.append(result["embeddings"].cpu())
|
| 164 |
+
|
| 165 |
+
corpus_embeddings = torch.cat(corpus_embeddings, dim=0) # [N_corpus, d]
|
| 166 |
+
print(f" Corpus embeddings: {corpus_embeddings.shape}")
|
| 167 |
+
|
| 168 |
+
# Evaluate each query
|
| 169 |
+
all_metrics = defaultdict(list)
|
| 170 |
+
|
| 171 |
+
print(" Evaluating queries...")
|
| 172 |
+
for query in tqdm(queries, desc=" Queries"):
|
| 173 |
+
# Encode query
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
result = model([query["query_text"]], use_stochastic=False)
|
| 176 |
+
query_emb = result["embeddings"].cpu() # [1, d]
|
| 177 |
+
|
| 178 |
+
# Cosine similarity with all corpus
|
| 179 |
+
sim = F.cosine_similarity(
|
| 180 |
+
query_emb, corpus_embeddings
|
| 181 |
+
).numpy() # [N_corpus]
|
| 182 |
+
|
| 183 |
+
# Build relevance scores (1.0 if same law, 0.0 otherwise)
|
| 184 |
+
relevance = np.array([
|
| 185 |
+
1.0 if corpus_law_ids[i] == query["law_id"] else 0.0
|
| 186 |
+
for i in range(len(corpus))
|
| 187 |
+
])
|
| 188 |
+
|
| 189 |
+
# Rank by similarity and compute metrics
|
| 190 |
+
sorted_idx = sim.argsort()[::-1]
|
| 191 |
+
sorted_relevance = relevance[sorted_idx]
|
| 192 |
+
|
| 193 |
+
# Remove the query itself from results
|
| 194 |
+
query_idx_in_corpus = None
|
| 195 |
+
for i, cid in enumerate(corpus_ids):
|
| 196 |
+
if cid == query["query_id"]:
|
| 197 |
+
query_idx_in_corpus = i
|
| 198 |
+
break
|
| 199 |
+
|
| 200 |
+
if query_idx_in_corpus is not None:
|
| 201 |
+
# Remove self-match
|
| 202 |
+
mask = sorted_idx != query_idx_in_corpus
|
| 203 |
+
sorted_relevance = sorted_relevance[mask]
|
| 204 |
+
|
| 205 |
+
# Compute metrics
|
| 206 |
+
for k in [3, 5, 10]:
|
| 207 |
+
metrics = compute_metrics(sorted_relevance[:k], [k])
|
| 208 |
+
for metric_name, value in metrics.items():
|
| 209 |
+
all_metrics[metric_name].append(value)
|
| 210 |
+
|
| 211 |
+
# Average over queries
|
| 212 |
+
results = {name: np.mean(scores) for name, scores in all_metrics.items()}
|
| 213 |
+
return results
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 217 |
+
# Baselines
|
| 218 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 219 |
+
|
| 220 |
+
class BM25Baseline:
|
| 221 |
+
"""Simple BM25 implementation using TF-IDF as approximation."""
|
| 222 |
+
|
| 223 |
+
def __init__(self):
|
| 224 |
+
self.vectorizer = TfidfVectorizer(
|
| 225 |
+
max_features=10000,
|
| 226 |
+
ngram_range=(1, 2),
|
| 227 |
+
sublinear_tf=True,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
def fit(self, corpus: List[Dict]):
|
| 231 |
+
self.corpus = corpus
|
| 232 |
+
self.doc_texts = [doc["text"] for doc in corpus]
|
| 233 |
+
self.doc_ids = [doc["article_id"] for doc in corpus]
|
| 234 |
+
self.doc_law_ids = [doc["law_id"] for doc in corpus]
|
| 235 |
+
self.tfidf_matrix = self.vectorizer.fit_transform(self.doc_texts)
|
| 236 |
+
|
| 237 |
+
def search(self, query_text: str, k: int = 100) -> np.ndarray:
|
| 238 |
+
query_vec = self.vectorizer.transform([query_text])
|
| 239 |
+
scores = (self.tfidf_matrix @ query_vec.T).toarray().flatten()
|
| 240 |
+
sorted_idx = scores.argsort()[::-1]
|
| 241 |
+
return sorted_idx
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def evaluate_bm25(queries: List[Dict], corpus: List[Dict]) -> Dict[str, float]:
|
| 245 |
+
"""Evaluate BM25 baseline."""
|
| 246 |
+
print(" Building BM25 index...")
|
| 247 |
+
bm25 = BM25Baseline()
|
| 248 |
+
bm25.fit(corpus)
|
| 249 |
+
|
| 250 |
+
all_metrics = defaultdict(list)
|
| 251 |
+
|
| 252 |
+
print(" Evaluating queries...")
|
| 253 |
+
for query in tqdm(queries, desc=" Queries"):
|
| 254 |
+
sorted_idx = bm25.search(query["query_text"], k=100)
|
| 255 |
+
|
| 256 |
+
# Remove self
|
| 257 |
+
doc_ids = bm25.doc_ids
|
| 258 |
+
query_idx = None
|
| 259 |
+
for i, did in enumerate(doc_ids):
|
| 260 |
+
if did == query["query_id"]:
|
| 261 |
+
query_idx = i
|
| 262 |
+
break
|
| 263 |
+
|
| 264 |
+
relevance = np.array([
|
| 265 |
+
1.0 if bm25.doc_law_ids[i] == query["law_id"] else 0.0
|
| 266 |
+
for i in sorted_idx
|
| 267 |
+
])
|
| 268 |
+
|
| 269 |
+
if query_idx is not None:
|
| 270 |
+
pos = np.where(sorted_idx == query_idx)[0]
|
| 271 |
+
if len(pos) > 0:
|
| 272 |
+
relevance = np.delete(relevance, pos[0])
|
| 273 |
+
|
| 274 |
+
for k in [3, 5, 10]:
|
| 275 |
+
valid_rel = relevance[:k]
|
| 276 |
+
metrics = compute_metrics(valid_rel, [k])
|
| 277 |
+
for name, val in metrics.items():
|
| 278 |
+
all_metrics[name].append(val)
|
| 279 |
+
|
| 280 |
+
return {name: np.mean(scores) for name, scores in all_metrics.items()}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class FrozenPhoBERT:
|
| 284 |
+
"""Frozen PhoBERT with mean pooling baseline."""
|
| 285 |
+
|
| 286 |
+
def __init__(self, model_name: str = "vinai/phobert-base-v2"):
|
| 287 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 288 |
+
self.model = AutoModel.from_pretrained(model_name)
|
| 289 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 290 |
+
self.model = self.model.to(self.device)
|
| 291 |
+
self.model.eval()
|
| 292 |
+
|
| 293 |
+
def encode(self, texts: List[str], batch_size: int = 64) -> torch.Tensor:
|
| 294 |
+
embeddings = []
|
| 295 |
+
for i in range(0, len(texts), batch_size):
|
| 296 |
+
batch = texts[i:i + batch_size]
|
| 297 |
+
encoded = self.tokenizer(
|
| 298 |
+
batch, padding=True, truncation=True,
|
| 299 |
+
max_length=480, return_tensors="pt",
|
| 300 |
+
)
|
| 301 |
+
input_ids = encoded["input_ids"].to(self.device)
|
| 302 |
+
attention_mask = encoded["attention_mask"].to(self.device)
|
| 303 |
+
with torch.no_grad():
|
| 304 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
| 305 |
+
hidden = outputs.last_hidden_state
|
| 306 |
+
# Mean pooling
|
| 307 |
+
mask = attention_mask.unsqueeze(-1).float()
|
| 308 |
+
pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
|
| 309 |
+
pooled = F.normalize(pooled, p=2, dim=1)
|
| 310 |
+
embeddings.append(pooled.cpu())
|
| 311 |
+
return torch.cat(embeddings, dim=0)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def evaluate_frozen_phobert(
|
| 315 |
+
queries: List[Dict], corpus: List[Dict]
|
| 316 |
+
) -> Dict[str, float]:
|
| 317 |
+
"""Evaluate frozen PhoBERT baseline."""
|
| 318 |
+
print(" Loading frozen PhoBERT...")
|
| 319 |
+
encoder = FrozenPhoBERT()
|
| 320 |
+
|
| 321 |
+
print(" Encoding corpus...")
|
| 322 |
+
corpus_texts = [doc["text"] for doc in corpus]
|
| 323 |
+
corpus_embeddings = encoder.encode(corpus_texts)
|
| 324 |
+
corpus_ids = [doc["article_id"] for doc in corpus]
|
| 325 |
+
corpus_law_ids = [doc["law_id"] for doc in corpus]
|
| 326 |
+
|
| 327 |
+
all_metrics = defaultdict(list)
|
| 328 |
+
|
| 329 |
+
print(" Evaluating queries...")
|
| 330 |
+
query_texts = [q["query_text"] for q in queries]
|
| 331 |
+
query_embeddings = encoder.encode(query_texts)
|
| 332 |
+
|
| 333 |
+
for i, query in enumerate(tqdm(queries, desc=" Queries")):
|
| 334 |
+
query_emb = query_embeddings[i:i+1]
|
| 335 |
+
sim = F.cosine_similarity(query_emb, corpus_embeddings).numpy()
|
| 336 |
+
|
| 337 |
+
relevance = np.array([
|
| 338 |
+
1.0 if corpus_law_ids[j] == query["law_id"] else 0.0
|
| 339 |
+
for j in range(len(corpus))
|
| 340 |
+
])
|
| 341 |
+
|
| 342 |
+
sorted_idx = sim.argsort()[::-1]
|
| 343 |
+
sorted_relevance = relevance[sorted_idx]
|
| 344 |
+
|
| 345 |
+
# Remove self
|
| 346 |
+
for j, cid in enumerate(corpus_ids):
|
| 347 |
+
if cid == query["query_id"]:
|
| 348 |
+
mask = sorted_idx != j
|
| 349 |
+
sorted_relevance = sorted_relevance[mask]
|
| 350 |
+
break
|
| 351 |
+
|
| 352 |
+
for k in [3, 5, 10]:
|
| 353 |
+
metrics = compute_metrics(sorted_relevance[:k], [k])
|
| 354 |
+
for name, val in metrics.items():
|
| 355 |
+
all_metrics[name].append(val)
|
| 356 |
+
|
| 357 |
+
return {name: np.mean(scores) for name, scores in all_metrics.items()}
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 361 |
+
# Main evaluation entry point
|
| 362 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 363 |
+
|
| 364 |
+
def run_full_evaluation(
|
| 365 |
+
config: TELENConfig = None,
|
| 366 |
+
checkpoint_path: str = None,
|
| 367 |
+
):
|
| 368 |
+
"""Run complete evaluation with all baselines and TELEN."""
|
| 369 |
+
if config is None:
|
| 370 |
+
config = TELENConfig()
|
| 371 |
+
|
| 372 |
+
random.seed(config.seed)
|
| 373 |
+
np.random.seed(config.seed)
|
| 374 |
+
|
| 375 |
+
print("=" * 60)
|
| 376 |
+
print("TELEN Evaluation")
|
| 377 |
+
print("=" * 60)
|
| 378 |
+
|
| 379 |
+
# Prepare test data
|
| 380 |
+
test_df = prepare_test_data(config)
|
| 381 |
+
queries = build_test_queries(test_df, max_queries=300)
|
| 382 |
+
corpus = build_test_corpus(test_df)
|
| 383 |
+
|
| 384 |
+
k_values = [3, 5, 10]
|
| 385 |
+
results = {}
|
| 386 |
+
|
| 387 |
+
# --- Baseline 1: BM25 ---
|
| 388 |
+
print("\n" + "=" * 40)
|
| 389 |
+
print("[1/3] BM25 Baseline")
|
| 390 |
+
print("=" * 40)
|
| 391 |
+
results["BM25"] = evaluate_bm25(queries, corpus)
|
| 392 |
+
for m in k_values:
|
| 393 |
+
print(f" NDCG@{m}: {results['BM25'][f'ndcg@{m}']:.4f} | MRR@{m}: {results['BM25'][f'mrr@{m}']:.4f}")
|
| 394 |
+
|
| 395 |
+
# --- Baseline 2: Frozen PhoBERT ---
|
| 396 |
+
print("\n" + "=" * 40)
|
| 397 |
+
print("[2/3] Frozen PhoBERT Baseline")
|
| 398 |
+
print("=" * 40)
|
| 399 |
+
results["PhoBERT"] = evaluate_frozen_phobert(queries, corpus)
|
| 400 |
+
for m in k_values:
|
| 401 |
+
print(f" NDCG@{m}: {results['PhoBERT'][f'ndcg@{m}']:.4f} | MRR@{m}: {results['PhoBERT'][f'mrr@{m}']:.4f}")
|
| 402 |
+
|
| 403 |
+
# --- TELEN ---
|
| 404 |
+
print("\n" + "=" * 40)
|
| 405 |
+
print("[3/3] TELEN (Ours)")
|
| 406 |
+
print("=" * 40)
|
| 407 |
+
|
| 408 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 409 |
+
model = create_telen(config)
|
| 410 |
+
model = model.to(device)
|
| 411 |
+
|
| 412 |
+
# Load checkpoint if provided
|
| 413 |
+
if checkpoint_path and Path(checkpoint_path).exists():
|
| 414 |
+
print(f" Loading checkpoint: {checkpoint_path}")
|
| 415 |
+
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 416 |
+
model.hypernetwork.load_state_dict(ckpt["hypernetwork"])
|
| 417 |
+
model.state_encoder.load_state_dict(ckpt["state_encoder"])
|
| 418 |
+
model.base_projection.load_state_dict(ckpt["base_projection"])
|
| 419 |
+
model.attn_query.data.copy_(ckpt["attn_query"])
|
| 420 |
+
# Rebuild graph
|
| 421 |
+
model.build_graph(test_df[test_df["year"] <= config.meta.train_split_year])
|
| 422 |
+
|
| 423 |
+
results["TELEN"] = evaluate_telen(model, queries, corpus)
|
| 424 |
+
for m in k_values:
|
| 425 |
+
print(f" NDCG@{m}: {results['TELEN'][f'ndcg@{m}']:.4f} | MRR@{m}: {results['TELEN'][f'mrr@{m}']:.4f}")
|
| 426 |
+
|
| 427 |
+
# --- Summary ---
|
| 428 |
+
print("\n" + "=" * 60)
|
| 429 |
+
print("SUMMARY")
|
| 430 |
+
print("=" * 60)
|
| 431 |
+
header = f"{'Method':<20}"
|
| 432 |
+
for m in k_values:
|
| 433 |
+
header += f" {'NDCG@'+str(m):>12} {'MRR@'+str(m):>12}"
|
| 434 |
+
print(header)
|
| 435 |
+
print("-" * len(header))
|
| 436 |
+
|
| 437 |
+
for method in ["BM25", "PhoBERT", "TELEN"]:
|
| 438 |
+
row = f"{method:<20}"
|
| 439 |
+
for m in k_values:
|
| 440 |
+
row += f" {results[method][f'ndcg@{m}']:>12.4f} {results[method][f'mrr@{m}']:>12.4f}"
|
| 441 |
+
print(row)
|
| 442 |
+
|
| 443 |
+
# Relative improvement
|
| 444 |
+
print("\n--- Improvement over PhoBERT ---")
|
| 445 |
+
for m in k_values:
|
| 446 |
+
ndcg_imp = (results["TELEN"][f"ndcg@{m}"] / max(results["PhoBERT"][f"ndcg@{m}"], 1e-6) - 1) * 100
|
| 447 |
+
mrr_imp = (results["TELEN"][f"mrr@{m}"] / max(results["PhoBERT"][f"mrr@{m}"], 1e-6) - 1) * 100
|
| 448 |
+
print(f" NDCG@{m}: {ndcg_imp:+.1f}% | MRR@{m}: {mrr_imp:+.1f}%")
|
| 449 |
+
|
| 450 |
+
return results
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
if __name__ == "__main__":
|
| 454 |
+
run_full_evaluation()
|
src/telern/hypernetwork.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HyperNetwork for TELEN.
|
| 3 |
+
|
| 4 |
+
Core innovation: Instead of learning fixed projection weights, the HyperNetwork
|
| 5 |
+
GENERATES the projection function from the current legal corpus state.
|
| 6 |
+
|
| 7 |
+
When new laws arrive β state vector changes β HyperNetwork produces new weights
|
| 8 |
+
β embedding space adapts WITHOUT retraining.
|
| 9 |
+
|
| 10 |
+
Additionally outputs variance for stochastic embeddings (uncertainty-aware retrieval).
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class HyperNetwork(nn.Module):
|
| 19 |
+
"""
|
| 20 |
+
Generates embedding projection parameters from a legal state vector.
|
| 21 |
+
|
| 22 |
+
Given state vector s β R^d, produces:
|
| 23 |
+
- ΞW: low-rank projection shift (weighted sum of learned rank-1 bases)
|
| 24 |
+
- Ξb: bias shift (weighted sum of learned bias bases)
|
| 25 |
+
- log_ΟΒ²: per-dimension log-variance for stochastic embedding
|
| 26 |
+
|
| 27 |
+
Architecture: Instead of generating giant parameter matrices directly,
|
| 28 |
+
we store a compact set of learned basis vectors and use the HyperNetwork
|
| 29 |
+
to generate ONLY the combination weights. This is parameter-efficient
|
| 30 |
+
and forces generalization.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, config):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.config = config
|
| 36 |
+
hn = config.hypernetwork
|
| 37 |
+
d = config.hidden_dim
|
| 38 |
+
r = hn.adaptation_rank
|
| 39 |
+
hidden = hn.hn_hidden_dim
|
| 40 |
+
|
| 41 |
+
# Shared trunk: state β latent code
|
| 42 |
+
self.trunk = nn.Sequential(
|
| 43 |
+
nn.Linear(d, hidden),
|
| 44 |
+
nn.ReLU(),
|
| 45 |
+
nn.Dropout(hn.dropout),
|
| 46 |
+
nn.Linear(hidden, hidden),
|
| 47 |
+
nn.ReLU(),
|
| 48 |
+
nn.Dropout(hn.dropout),
|
| 49 |
+
nn.Linear(hidden, hidden),
|
| 50 |
+
nn.LayerNorm(hidden),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Modulator: latent β combination weights for all outputs
|
| 54 |
+
self.modulator = nn.Linear(hidden, 2 * r + r + 1) # A_weights + B_weights + bias_weights + var_context
|
| 55 |
+
|
| 56 |
+
# === Learned basis vectors (stored, not generated) ===
|
| 57 |
+
# For ΞW = Ξ£_i w^A_i * (u_i β v_i^T) where u_i, v_i β R^d
|
| 58 |
+
self.basis_u = nn.Parameter(torch.randn(r, d) * 0.01) # [r, d]
|
| 59 |
+
self.basis_v = nn.Parameter(torch.randn(r, d) * 0.01) # [r, d]
|
| 60 |
+
|
| 61 |
+
# For Ξb = Ξ£_i w^b_i * b_i where b_i β R^d
|
| 62 |
+
self.basis_b = nn.Parameter(torch.randn(r, d) * 0.01) # [r, d]
|
| 63 |
+
|
| 64 |
+
# Variance head
|
| 65 |
+
if hn.output_variance:
|
| 66 |
+
self.head_logvar = nn.Sequential(
|
| 67 |
+
nn.Linear(hidden, hidden),
|
| 68 |
+
nn.Tanh(),
|
| 69 |
+
nn.Linear(hidden, d),
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
self.head_logvar = None
|
| 73 |
+
|
| 74 |
+
def forward(self, state_vector: torch.Tensor) -> dict:
|
| 75 |
+
"""
|
| 76 |
+
Args:
|
| 77 |
+
state_vector: [d] or [B, d] summarizing current legal landscape
|
| 78 |
+
|
| 79 |
+
Returns dict with keys:
|
| 80 |
+
"shift_matrix": [d, d] or [B, d, d] rank-r projection shift
|
| 81 |
+
"bias": [d] or [B, d] bias shift
|
| 82 |
+
"log_variance": [d] or [B, d] log variance for stochastic embedding
|
| 83 |
+
"""
|
| 84 |
+
squeeze = state_vector.dim() == 1
|
| 85 |
+
if squeeze:
|
| 86 |
+
state_vector = state_vector.unsqueeze(0) # [1, d]
|
| 87 |
+
|
| 88 |
+
B, d = state_vector.shape
|
| 89 |
+
r = self.config.hypernetwork.adaptation_rank
|
| 90 |
+
|
| 91 |
+
# Shared representation
|
| 92 |
+
h = self.trunk(state_vector) # [B, hidden]
|
| 93 |
+
modulated = self.modulator(h) # [B, 2r + r + 1]
|
| 94 |
+
|
| 95 |
+
# Split modulation weights
|
| 96 |
+
w_A = modulated[:, :r] # [B, r]
|
| 97 |
+
w_B = modulated[:, r:2*r] # [B, r]
|
| 98 |
+
w_bias = modulated[:, 2*r:3*r] # [B, r]
|
| 99 |
+
|
| 100 |
+
# Build shift matrix: ΞW = Ξ£_i w^A_i * (u_i β v_i^T)
|
| 101 |
+
# Weighted combination of basis vectors
|
| 102 |
+
u_combined = w_A @ self.basis_u # [B, d]
|
| 103 |
+
v_combined = w_B @ self.basis_v # [B, d]
|
| 104 |
+
shift = torch.bmm(
|
| 105 |
+
u_combined.unsqueeze(2), # [B, d, 1]
|
| 106 |
+
v_combined.unsqueeze(1), # [B, 1, d]
|
| 107 |
+
) # [B, d, d]
|
| 108 |
+
# Low-rank: this is rank-1. For rank r, generate r outer products and sum.
|
| 109 |
+
# Simple yet effective: use weighted sum of r rank-1 components
|
| 110 |
+
shift = shift.squeeze(0) if B == 1 else shift # [d, d] or [B, d, d]
|
| 111 |
+
if B == 1:
|
| 112 |
+
shift = shift.unsqueeze(0)
|
| 113 |
+
|
| 114 |
+
# Actually let's do proper rank-r: sum over rank dimension
|
| 115 |
+
# w_A: [B, r], basis_u: [r, d]
|
| 116 |
+
# For each rank i: w_A[:, i:i+1] * (basis_u[i:i+1]^T @ basis_v[i:i+1])
|
| 117 |
+
# = Ξ£_i (w_A[:, i] * basis_u[i]) β (w_B[:, i] * basis_v[i])
|
| 118 |
+
u_weighted = (w_A.unsqueeze(2) * self.basis_u.unsqueeze(0)) # [B, r, d]
|
| 119 |
+
v_weighted = (w_B.unsqueeze(2) * self.basis_v.unsqueeze(0)) # [B, r, d]
|
| 120 |
+
shift_ranked = torch.einsum("brd,bre->brde", u_weighted, v_weighted) # [B, r, d, d]
|
| 121 |
+
shift = shift_ranked.sum(dim=1) # [B, d, d]
|
| 122 |
+
|
| 123 |
+
# Bias
|
| 124 |
+
bias = (w_bias.unsqueeze(2) * self.basis_b.unsqueeze(0)).sum(dim=1) # [B, d]
|
| 125 |
+
|
| 126 |
+
result = {"shift_matrix": shift, "bias": bias}
|
| 127 |
+
|
| 128 |
+
# Log variance
|
| 129 |
+
if self.head_logvar is not None:
|
| 130 |
+
logvar = self.head_logvar(h)
|
| 131 |
+
logvar = torch.clamp(logvar, min=-5.0, max=2.0)
|
| 132 |
+
result["log_variance"] = logvar
|
| 133 |
+
else:
|
| 134 |
+
result["log_variance"] = torch.full((B, d), -3.0, device=h.device)
|
| 135 |
+
|
| 136 |
+
if squeeze:
|
| 137 |
+
result = {k: v.squeeze(0) for k, v in result.items()}
|
| 138 |
+
|
| 139 |
+
return result
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class StateEncoder(nn.Module):
|
| 143 |
+
"""
|
| 144 |
+
Encodes the legal concept graph into a compact state vector.
|
| 145 |
+
|
| 146 |
+
This is separate from the HyperNetwork so the graph computation
|
| 147 |
+
can be cached and only updated when the graph changes.
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
def __init__(self, dim: int):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.state_proj = nn.Sequential(
|
| 153 |
+
nn.Linear(dim, dim * 2),
|
| 154 |
+
nn.ReLU(),
|
| 155 |
+
nn.Dropout(0.1),
|
| 156 |
+
nn.Linear(dim * 2, dim),
|
| 157 |
+
nn.LayerNorm(dim),
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
def forward(self, node_embeddings: torch.Tensor, node_weights: torch.Tensor = None) -> torch.Tensor:
|
| 161 |
+
"""
|
| 162 |
+
Args:
|
| 163 |
+
node_embeddings: [N, d] refined node embeddings from GNN
|
| 164 |
+
node_weights: [N] optional attention weights
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
state_vector: [d] summarizing the legal landscape
|
| 168 |
+
"""
|
| 169 |
+
if node_weights is None:
|
| 170 |
+
# Equal weight if none provided
|
| 171 |
+
node_weights = torch.ones(
|
| 172 |
+
node_embeddings.shape[0], device=node_embeddings.device
|
| 173 |
+
)
|
| 174 |
+
node_weights = F.softmax(node_weights, dim=0)
|
| 175 |
+
pooled = (node_embeddings * node_weights.unsqueeze(1)).sum(dim=0)
|
| 176 |
+
return self.state_proj(pooled)
|
src/telern/model.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TELEN: Temporal Evolving Legal Embedding Network.
|
| 3 |
+
|
| 4 |
+
Bi-encoder backbone + Legal Concept Graph + HyperNetwork projection.
|
| 5 |
+
Embedding space adapts dynamically to the legal corpus state.
|
| 6 |
+
"""
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from transformers import AutoModel, AutoTokenizer
|
| 11 |
+
from pyvi import ViTokenizer
|
| 12 |
+
|
| 13 |
+
from .config import TELENConfig
|
| 14 |
+
from .hypernetwork import StateEncoder, HyperNetwork
|
| 15 |
+
from .concept_graph import build_concept_graph
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def wseg(text):
|
| 19 |
+
return ViTokenizer.tokenize(text.replace("_", " "))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class BiEncoder(nn.Module):
|
| 23 |
+
"""Vietnamese bi-encoder backbone with attention pooling."""
|
| 24 |
+
|
| 25 |
+
def __init__(self, model_name="bkai-foundation-models/vietnamese-bi-encoder"):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.model = AutoModel.from_pretrained(model_name)
|
| 28 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 29 |
+
self.dim = self.model.config.hidden_size
|
| 30 |
+
self.attn_query = nn.Parameter(torch.randn(self.dim))
|
| 31 |
+
self.scale = self.dim ** 0.5
|
| 32 |
+
|
| 33 |
+
def forward(self, texts, max_len=480):
|
| 34 |
+
segmented = [wseg(t) for t in texts]
|
| 35 |
+
enc = self.tokenizer(segmented, padding=True, truncation=True,
|
| 36 |
+
max_length=max_len, return_tensors="pt")
|
| 37 |
+
input_ids = enc["input_ids"].to(self.attn_query.device)
|
| 38 |
+
mask = enc["attention_mask"].to(self.attn_query.device)
|
| 39 |
+
hidden = self.model(input_ids=input_ids, attention_mask=mask).last_hidden_state
|
| 40 |
+
scores = torch.einsum("bsd,d->bs", hidden, self.attn_query) / self.scale
|
| 41 |
+
scores = scores.masked_fill(mask == 0, float("-1e9"))
|
| 42 |
+
weights = F.softmax(scores, dim=1)
|
| 43 |
+
return torch.einsum("bsd,bs->bd", hidden, weights)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TELEN(nn.Module):
|
| 47 |
+
"""Temporal Evolving Legal Embedding Network."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, config: TELENConfig):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.config = config
|
| 52 |
+
d = config.hidden_dim
|
| 53 |
+
|
| 54 |
+
# Bi-encoder backbone (frozen)
|
| 55 |
+
self.encoder = BiEncoder()
|
| 56 |
+
for p in self.encoder.parameters():
|
| 57 |
+
p.requires_grad = False
|
| 58 |
+
|
| 59 |
+
# Projection
|
| 60 |
+
self.projection = nn.Sequential(nn.Linear(d, d), nn.Tanh())
|
| 61 |
+
self.proj_norm = nn.LayerNorm(d)
|
| 62 |
+
self.attn_query = nn.Parameter(torch.randn(d))
|
| 63 |
+
|
| 64 |
+
# Graph
|
| 65 |
+
self.concept_graph = None
|
| 66 |
+
self.law_id_to_idx = None
|
| 67 |
+
|
| 68 |
+
# HyperNetwork
|
| 69 |
+
self.state_encoder = StateEncoder(d)
|
| 70 |
+
self.hypernetwork = HyperNetwork(config)
|
| 71 |
+
|
| 72 |
+
def _pool(self, hidden, mask):
|
| 73 |
+
"""Attention-weighted pooling (for pre-tokenized inputs)."""
|
| 74 |
+
scores = torch.einsum("bsd,d->bs", hidden, self.attn_query) / (self.config.hidden_dim ** 0.5)
|
| 75 |
+
scores = scores.masked_fill(mask == 0, float("-1e9"))
|
| 76 |
+
weights = F.softmax(scores, dim=1)
|
| 77 |
+
return torch.einsum("bsd,bs->bd", hidden, weights)
|
| 78 |
+
|
| 79 |
+
def encode_text(self, texts):
|
| 80 |
+
return self.encoder(texts, max_len=self.config.max_seq_length)
|
| 81 |
+
|
| 82 |
+
def get_state_vector(self):
|
| 83 |
+
if self.concept_graph is None or self.concept_graph.num_nodes == 0:
|
| 84 |
+
return torch.zeros(self.config.hidden_dim, device=self.attn_query.device)
|
| 85 |
+
refined = self.concept_graph.forward()
|
| 86 |
+
return self.state_encoder(refined)
|
| 87 |
+
|
| 88 |
+
def adapt_embedding(self, raw, state_vec):
|
| 89 |
+
base = self.projection(raw)
|
| 90 |
+
hn = self.hypernetwork(state_vec)
|
| 91 |
+
shift = raw @ hn["shift_matrix"].T + hn["bias"]
|
| 92 |
+
mean = F.normalize(self.proj_norm(base + shift), p=2, dim=1)
|
| 93 |
+
result = {"mean": mean, "log_variance": hn.get("log_variance")}
|
| 94 |
+
if self.config.hypernetwork.output_variance:
|
| 95 |
+
noise = 0.1 * hn["log_variance"].exp().clamp(min=0.001, max=0.25).sqrt().clamp(max=0.5)
|
| 96 |
+
result["sample"] = F.normalize(mean + torch.randn_like(mean) * noise, p=2, dim=1)
|
| 97 |
+
else:
|
| 98 |
+
result["sample"] = mean
|
| 99 |
+
return result
|
| 100 |
+
|
| 101 |
+
def forward(self, texts, use_stochastic=False):
|
| 102 |
+
raw = self.encode_text(texts)
|
| 103 |
+
state = self.get_state_vector()
|
| 104 |
+
adapted = self.adapt_embedding(raw, state)
|
| 105 |
+
return {
|
| 106 |
+
"embeddings": adapted["sample"] if use_stochastic else adapted["mean"],
|
| 107 |
+
"mean": adapted["mean"],
|
| 108 |
+
"log_variance": adapted.get("log_variance"),
|
| 109 |
+
"state_vector": state,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
def build_graph(self, df):
|
| 113 |
+
self.concept_graph, self.law_id_to_idx = build_concept_graph(
|
| 114 |
+
df, lambda t: self.encode_text([t])[0].detach(), self.config,
|
| 115 |
+
)
|
| 116 |
+
self.concept_graph = self.concept_graph.to(self.attn_query.device)
|
| 117 |
+
|
| 118 |
+
def add_law(self, law_id, articles):
|
| 119 |
+
if self.concept_graph is None: return
|
| 120 |
+
if articles:
|
| 121 |
+
emb = self.encode_text(articles[:5]).mean(dim=0)
|
| 122 |
+
new_idx = self.concept_graph.num_nodes
|
| 123 |
+
self.concept_graph.add_nodes([law_id], emb.unsqueeze(0))
|
| 124 |
+
existing = self.concept_graph.node_embeddings[:-1]
|
| 125 |
+
if len(existing) > 0:
|
| 126 |
+
sim = F.cosine_similarity(emb.unsqueeze(0), existing)
|
| 127 |
+
_, top = sim.topk(k=min(10, len(existing)))
|
| 128 |
+
self.concept_graph.add_edges("semantic",
|
| 129 |
+
[(new_idx, i.item(), sim[i].item()) for i in top])
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def create_model(config: TELENConfig) -> TELEN:
|
| 133 |
+
return TELEN(config)
|
train.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TELEN: Temporal Evolving Legal Embedding Network β Training Script.
|
| 3 |
+
|
| 4 |
+
Stages:
|
| 5 |
+
1. Contrastive pretraining (5 epochs) β train projection head
|
| 6 |
+
2. Meta-training (50 epochs) β train HyperNetwork + State Encoder
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python train.py
|
| 10 |
+
"""
|
| 11 |
+
import sys, os, math, random
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from collections import defaultdict
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pandas as pd
|
| 16 |
+
import torch, torch.nn as nn, torch.nn.functional as F
|
| 17 |
+
from torch.utils.data import DataLoader, Dataset
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
from transformers import AutoTokenizer
|
| 20 |
+
from pyvi import ViTokenizer
|
| 21 |
+
|
| 22 |
+
sys.path.insert(0, ".")
|
| 23 |
+
from src.telern.config import TELENConfig, DATA_DIR
|
| 24 |
+
from src.telern.model import TELEN, create_model
|
| 25 |
+
from src.data import load_raw_data, extract_metadata, clean_data
|
| 26 |
+
|
| 27 |
+
SEED = 42
|
| 28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
+
|
| 30 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
# Data
|
| 32 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
def prepare_data(config):
|
| 34 |
+
df = load_raw_data(str(DATA_DIR / "train-00000-of-00001.parquet"))
|
| 35 |
+
df = extract_metadata(df); df = clean_data(df, min_text_len=10)
|
| 36 |
+
articles_by_law = defaultdict(list)
|
| 37 |
+
laws_by_year = defaultdict(list)
|
| 38 |
+
for _, row in df.iterrows():
|
| 39 |
+
articles_by_law[row["law_id"]].append({
|
| 40 |
+
"id": row["id"], "title": row["title"], "text": row["text"],
|
| 41 |
+
"law_type": row["law_type"], "year": row["year"],
|
| 42 |
+
})
|
| 43 |
+
for law_id in articles_by_law:
|
| 44 |
+
laws_by_year[articles_by_law[law_id][0]["year"]].append(law_id)
|
| 45 |
+
all_years = sorted(laws_by_year.keys())
|
| 46 |
+
train_years = [y for y in all_years if y <= config.meta.train_split_year]
|
| 47 |
+
val_years = [y for y in all_years if config.meta.train_split_year < y <= config.meta.val_split_year]
|
| 48 |
+
test_years = [y for y in all_years if y > config.meta.val_split_year]
|
| 49 |
+
return articles_by_law, laws_by_year, train_years, val_years, test_years, df
|
| 50 |
+
|
| 51 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 52 |
+
# Contrastive Dataset
|
| 53 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
+
class ContrastiveDataset(Dataset):
|
| 55 |
+
def __init__(self, df, tokenizer, max_len=480):
|
| 56 |
+
self.df = df.reset_index(drop=True)
|
| 57 |
+
self.tokenizer = tokenizer
|
| 58 |
+
self.max_len = max_len
|
| 59 |
+
self.law_groups = self.df.groupby("law_id")
|
| 60 |
+
self.law_ids = list(self.law_groups.groups.keys())
|
| 61 |
+
|
| 62 |
+
def __len__(self): return len(self.df)
|
| 63 |
+
|
| 64 |
+
def __getitem__(self, idx):
|
| 65 |
+
row = self.df.iloc[idx]; law_id = row["law_id"]
|
| 66 |
+
wseg = lambda t: ViTokenizer.tokenize(t.replace("_", " "))
|
| 67 |
+
anchor = wseg(f"{row['title']}: {row['text'][:400]}")
|
| 68 |
+
group_idx = self.law_groups.groups[law_id]
|
| 69 |
+
others = [i for i in group_idx if i != idx]
|
| 70 |
+
pos_row = self.df.iloc[random.choice(others)] if others else row
|
| 71 |
+
positive = wseg(f"{pos_row['title']}: {pos_row['text'][:400]}")
|
| 72 |
+
neg_law = random.choice([l for l in self.law_ids if l != law_id])
|
| 73 |
+
neg_row = self.df.iloc[random.choice(list(self.law_groups.groups[neg_law]))]
|
| 74 |
+
negative = wseg(f"{neg_row['title']}: {neg_row['text'][:400]}")
|
| 75 |
+
|
| 76 |
+
def tok(t): return self.tokenizer(t, truncation=True, max_length=self.max_len, padding="max_length", return_tensors="pt")
|
| 77 |
+
return {f"{k}_{s}": tok(t)[k].squeeze(0)
|
| 78 |
+
for t, s in [(anchor,"a"),(positive,"p"),(negative,"n")]
|
| 79 |
+
for k in ["input_ids","attention_mask"]}
|
| 80 |
+
|
| 81 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 82 |
+
# Stage 1: Contrastive Pretraining
|
| 83 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 84 |
+
def contrastive_pretrain(model, df, config, epochs=5, batch_size=24, lr=3e-5):
|
| 85 |
+
tokenizer = model.encoder.tokenizer
|
| 86 |
+
dataset = ContrastiveDataset(df, tokenizer, config.max_seq_length)
|
| 87 |
+
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
| 88 |
+
|
| 89 |
+
trainable = list(model.base_projection.parameters()) + [model.attn_query]
|
| 90 |
+
opt = torch.optim.AdamW(trainable, lr=lr, weight_decay=0.01)
|
| 91 |
+
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs * len(loader))
|
| 92 |
+
|
| 93 |
+
print(f" Contrastive pretraining: {epochs} epochs, {len(loader)} batches")
|
| 94 |
+
model.train(); model.encoder.model.eval()
|
| 95 |
+
|
| 96 |
+
for epoch in range(epochs):
|
| 97 |
+
total = 0.0
|
| 98 |
+
for batch in tqdm(loader, desc=f" Epoch {epoch+1}/{epochs}"):
|
| 99 |
+
a_ids=batch["input_ids_a"].to(device); a_mask=batch["attention_mask_a"].to(device)
|
| 100 |
+
p_ids=batch["input_ids_p"].to(device); p_mask=batch["attention_mask_p"].to(device)
|
| 101 |
+
n_ids=batch["input_ids_n"].to(device); n_mask=batch["attention_mask_n"].to(device)
|
| 102 |
+
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
ah=model._pool(model.encoder.model(input_ids=a_ids,attention_mask=a_mask).last_hidden_state,a_mask)
|
| 105 |
+
ph=model._pool(model.encoder.model(input_ids=p_ids,attention_mask=p_mask).last_hidden_state,p_mask)
|
| 106 |
+
nh=model._pool(model.encoder.model(input_ids=n_ids,attention_mask=n_mask).last_hidden_state,n_mask)
|
| 107 |
+
|
| 108 |
+
ae=F.normalize(model.base_projection(ah),p=2,dim=1)
|
| 109 |
+
pe=F.normalize(model.base_projection(ph),p=2,dim=1)
|
| 110 |
+
ne=F.normalize(model.base_projection(nh),p=2,dim=1)
|
| 111 |
+
|
| 112 |
+
trip=F.relu(0.3-(ae*pe).sum(1)+(ae*ne).sum(1)).mean()
|
| 113 |
+
sim=ae@torch.cat([ae,pe,ne],dim=0).T/0.05
|
| 114 |
+
infonce=F.cross_entropy(sim,torch.arange(len(a_ids),device=device)+len(a_ids))
|
| 115 |
+
loss=trip+0.5*infonce
|
| 116 |
+
|
| 117 |
+
opt.zero_grad(); loss.backward()
|
| 118 |
+
torch.nn.utils.clip_grad_norm_(trainable,1.0); opt.step(); sched.step()
|
| 119 |
+
total+=loss.item()
|
| 120 |
+
print(f" Epoch {epoch+1} avg loss: {total/len(loader):.4f}")
|
| 121 |
+
print(" Contrastive pretraining complete!")
|
| 122 |
+
return model
|
| 123 |
+
|
| 124 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 125 |
+
# Episode building
|
| 126 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 127 |
+
def build_episode(articles_by_law, laws_by_year, state_years, query_year, config):
|
| 128 |
+
mc = config.meta
|
| 129 |
+
q_laws = laws_by_year.get(query_year, [])
|
| 130 |
+
if len(q_laws) < 5: return None
|
| 131 |
+
sampled = random.sample(q_laws, min(mc.n_query // 4, len(q_laws)))
|
| 132 |
+
queries, positives, q_types = [], [], set()
|
| 133 |
+
for lid in sampled:
|
| 134 |
+
arts = articles_by_law[lid]
|
| 135 |
+
if len(arts) < 2: continue
|
| 136 |
+
qi, pi = random.sample(range(len(arts)), 2)
|
| 137 |
+
queries.append(arts[qi]); positives.append(arts[pi])
|
| 138 |
+
q_types.add(arts[qi]["law_type"])
|
| 139 |
+
if len(queries) < 4: return None
|
| 140 |
+
|
| 141 |
+
hard_neg, rand_neg = [], []
|
| 142 |
+
for lid in q_laws:
|
| 143 |
+
if lid in sampled: continue
|
| 144 |
+
for a in articles_by_law[lid]:
|
| 145 |
+
if a["law_type"] in q_types: hard_neg.append(a)
|
| 146 |
+
else: rand_neg.append(a)
|
| 147 |
+
|
| 148 |
+
nh = min(mc.n_negatives // 2, len(hard_neg))
|
| 149 |
+
nr = min(mc.n_negatives - nh, len(rand_neg))
|
| 150 |
+
negatives = (random.sample(hard_neg, nh) if nh > 0 else []) + (random.sample(rand_neg, nr) if nr > 0 else [])
|
| 151 |
+
if len(negatives) < 4: return None
|
| 152 |
+
return {"queries": queries, "positives": positives, "negatives": negatives}
|
| 153 |
+
|
| 154 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 155 |
+
# Stage 2: Meta-Training
|
| 156 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 157 |
+
def compute_loss(model, q_texts, p_texts, n_texts, state_vec, temp=0.05):
|
| 158 |
+
n_q, n_p = len(q_texts), len(p_texts)
|
| 159 |
+
if n_q == 0 or n_p == 0:
|
| 160 |
+
return torch.tensor(0.0, device=device, requires_grad=True)
|
| 161 |
+
all_t = q_texts + p_texts + n_texts
|
| 162 |
+
raw = model.encode_text(all_t)
|
| 163 |
+
adapted = model.adapt_embedding(raw, state_vec)
|
| 164 |
+
emb = adapted["mean"]
|
| 165 |
+
qe = emb[:n_q]; pe = emb[n_q:n_q+n_p]; ne = emb[n_q+n_p:]
|
| 166 |
+
|
| 167 |
+
if n_q == n_p:
|
| 168 |
+
sim = torch.cat([(qe*pe).sum(1).unsqueeze(1)/temp, qe@ne.T/temp], dim=1)
|
| 169 |
+
loss = F.cross_entropy(sim, torch.zeros(n_q, dtype=torch.long, device=device))
|
| 170 |
+
else:
|
| 171 |
+
loss = F.cross_entropy(qe @ torch.cat([pe, ne], dim=0).T / temp,
|
| 172 |
+
torch.arange(n_q, device=device).clamp(max=len(pe)-1))
|
| 173 |
+
|
| 174 |
+
if model.config.hypernetwork.output_variance:
|
| 175 |
+
lv = adapted.get("log_variance")
|
| 176 |
+
if lv is not None: loss = loss + (lv.exp() - lv - 1).mean() * model.config.meta.kl_weight
|
| 177 |
+
return loss
|
| 178 |
+
|
| 179 |
+
def validate(model, articles_by_law, laws_by_year, val_years, config):
|
| 180 |
+
model.eval(); losses = []
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
for _ in range(30):
|
| 183 |
+
qy = random.choice(val_years)
|
| 184 |
+
if qy not in laws_by_year: continue
|
| 185 |
+
sy = [y for y in sorted(laws_by_year.keys()) if y < qy]
|
| 186 |
+
if len(sy) < 3: sy = [y for y in sorted(laws_by_year.keys()) if y <= qy]
|
| 187 |
+
ep = build_episode(articles_by_law, laws_by_year, sy, qy, config)
|
| 188 |
+
if ep is None: continue
|
| 189 |
+
sv = model.get_state_vector()
|
| 190 |
+
losses.append(compute_loss(model,
|
| 191 |
+
[f"{q['title']}: {q['text'][:200]}" for q in ep["queries"]],
|
| 192 |
+
[f"{p['title']}: {p['text'][:200]}" for p in ep["positives"]],
|
| 193 |
+
[f"{n['title']}: {n['text'][:200]}" for n in ep["negatives"]],
|
| 194 |
+
sv, config.meta.temperature).item())
|
| 195 |
+
return sum(losses)/max(len(losses),1)
|
| 196 |
+
|
| 197 |
+
def meta_train(model, articles_by_law, laws_by_year, train_years, val_years, config, epochs=50, patience=10):
|
| 198 |
+
trainable = (list(model.hypernetwork.parameters()) + list(model.state_encoder.parameters()) +
|
| 199 |
+
list(model.base_projection.parameters()) + [model.attn_query])
|
| 200 |
+
opt = torch.optim.AdamW(trainable, lr=config.meta.meta_lr, weight_decay=1e-4)
|
| 201 |
+
sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=10, T_mult=2)
|
| 202 |
+
|
| 203 |
+
os.makedirs(config.output_dir, exist_ok=True)
|
| 204 |
+
best_val, patience_ctr = float("inf"), 0
|
| 205 |
+
|
| 206 |
+
for epoch in range(epochs):
|
| 207 |
+
model.train(); total_loss = 0.0
|
| 208 |
+
steps = config.meta.meta_batch_size * 100
|
| 209 |
+
progress = tqdm(range(steps), desc=f"Meta Epoch {epoch+1}/{epochs}")
|
| 210 |
+
for _ in progress:
|
| 211 |
+
if len(train_years) < 3: break
|
| 212 |
+
si = random.randint(2, len(train_years)-1)
|
| 213 |
+
sy, qy = train_years[:si], train_years[si]
|
| 214 |
+
if qy not in laws_by_year: continue
|
| 215 |
+
ep = build_episode(articles_by_law, laws_by_year, sy, qy, config)
|
| 216 |
+
if ep is None: continue
|
| 217 |
+
sv = model.get_state_vector()
|
| 218 |
+
loss = compute_loss(model,
|
| 219 |
+
[f"{q['title']}: {q['text'][:200]}" for q in ep["queries"]],
|
| 220 |
+
[f"{p['title']}: {p['text'][:200]}" for p in ep["positives"]],
|
| 221 |
+
[f"{n['title']}: {n['text'][:200]}" for n in ep["negatives"]],
|
| 222 |
+
sv, config.meta.temperature)
|
| 223 |
+
opt.zero_grad(); loss.backward()
|
| 224 |
+
torch.nn.utils.clip_grad_norm_(trainable, 1.0); opt.step()
|
| 225 |
+
total_loss += loss.item()
|
| 226 |
+
progress.set_postfix({"loss": f"{loss.item():.4f}"})
|
| 227 |
+
|
| 228 |
+
avg = total_loss / max(steps, 1)
|
| 229 |
+
print(f" Epoch {epoch+1} avg loss: {avg:.4f}")
|
| 230 |
+
sched.step()
|
| 231 |
+
|
| 232 |
+
vl = validate(model, articles_by_law, laws_by_year, val_years, config)
|
| 233 |
+
print(f" Val loss: {vl:.4f}")
|
| 234 |
+
|
| 235 |
+
if vl < best_val:
|
| 236 |
+
best_val, patience_ctr = vl, 0
|
| 237 |
+
torch.save({
|
| 238 |
+
"hypernetwork": model.hypernetwork.state_dict(),
|
| 239 |
+
"state_encoder": model.state_encoder.state_dict(),
|
| 240 |
+
"base_projection": model.base_projection.state_dict(),
|
| 241 |
+
"attn_query": model.attn_query,
|
| 242 |
+
"epoch": epoch, "val_loss": vl,
|
| 243 |
+
}, Path(config.output_dir) / "telen_best.pt")
|
| 244 |
+
print(f" Saved (val_loss={vl:.4f})")
|
| 245 |
+
else:
|
| 246 |
+
patience_ctr += 1
|
| 247 |
+
if patience_ctr >= patience:
|
| 248 |
+
print(f" Early stopping at epoch {epoch+1}"); break
|
| 249 |
+
print("Meta-training complete!")
|
| 250 |
+
return model
|
| 251 |
+
|
| 252 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 253 |
+
# Main
|
| 254 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 255 |
+
def main():
|
| 256 |
+
config = TELENConfig()
|
| 257 |
+
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
|
| 258 |
+
print(f"Device: {device}")
|
| 259 |
+
|
| 260 |
+
# Data
|
| 261 |
+
print("\nLoading data...")
|
| 262 |
+
articles_by_law, laws_by_year, train_years, val_years, test_years, df = prepare_data(config)
|
| 263 |
+
print(f" Train: {train_years[0]}-{train_years[-1]} ({len(train_years)}y)")
|
| 264 |
+
print(f" Val: {val_years[0]}-{val_years[-1]} ({len(val_years)}y)")
|
| 265 |
+
print(f" Test: {len(test_years)}y")
|
| 266 |
+
|
| 267 |
+
# Model
|
| 268 |
+
print("\nCreating TELEN...")
|
| 269 |
+
model = create_model(config).to(device)
|
| 270 |
+
print(f" HyperNetwork: {sum(p.numel() for p in model.hypernetwork.parameters()):,} params")
|
| 271 |
+
|
| 272 |
+
# Build graph
|
| 273 |
+
print("\nBuilding concept graph...")
|
| 274 |
+
train_df = df[df["year"].isin(train_years)]
|
| 275 |
+
model.build_graph(train_df)
|
| 276 |
+
print(f" Graph: {model.concept_graph.num_nodes} nodes")
|
| 277 |
+
|
| 278 |
+
# Stage 1
|
| 279 |
+
print("\n" + "=" * 60)
|
| 280 |
+
print("Stage 1: Contrastive Pretraining")
|
| 281 |
+
print("=" * 60)
|
| 282 |
+
model = contrastive_pretrain(model, train_df, config, epochs=5, batch_size=24, lr=3e-5)
|
| 283 |
+
|
| 284 |
+
# Stage 2
|
| 285 |
+
print("\n" + "=" * 60)
|
| 286 |
+
print("Stage 2: Meta-Training")
|
| 287 |
+
print("=" * 60)
|
| 288 |
+
model = meta_train(model, articles_by_law, laws_by_year, train_years, val_years, config, epochs=50, patience=10)
|
| 289 |
+
|
| 290 |
+
print(f"\nDone! Model saved to: {config.output_dir}/telen_best.pt")
|
| 291 |
+
|
| 292 |
+
if __name__ == "__main__":
|
| 293 |
+
main()
|
train_ce.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train the cross-encoder re-ranker for legal text.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python train_ce.py
|
| 6 |
+
|
| 7 |
+
Trains a PhoBERT-based cross-encoder on legal article pairs
|
| 8 |
+
with margin ranking loss for re-ranking TELEN retrieval results.
|
| 9 |
+
|
| 10 |
+
Output: data/checkpoints/telen/cross_encoder_best.pt
|
| 11 |
+
"""
|
| 12 |
+
import sys; sys.path.insert(0, ".")
|
| 13 |
+
sys.stdout.reconfigure(encoding='utf-8')
|
| 14 |
+
import warnings; warnings.filterwarnings("ignore")
|
| 15 |
+
import random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
from transformers import AutoModel, AutoTokenizer
|
| 19 |
+
|
| 20 |
+
from src.telern.config import DATA_DIR
|
| 21 |
+
from src.data import load_raw_data, extract_metadata, clean_data
|
| 22 |
+
|
| 23 |
+
SEED = 42; random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
|
| 24 |
+
device = torch.device("cuda")
|
| 25 |
+
|
| 26 |
+
# ββ Data ββ
|
| 27 |
+
print("Loading data...")
|
| 28 |
+
df = load_raw_data(str(DATA_DIR / "train-00000-of-00001.parquet"))
|
| 29 |
+
df = extract_metadata(df); df = clean_data(df, min_text_len=10)
|
| 30 |
+
train_df = df[df["year"] <= 2018]
|
| 31 |
+
print(f" {len(train_df)} articles, {train_df['law_id'].nunique()} laws")
|
| 32 |
+
|
| 33 |
+
# ββ Build pairs ββ
|
| 34 |
+
print("Building pairs...")
|
| 35 |
+
law_groups = train_df.groupby("law_id")
|
| 36 |
+
law_ids = list(law_groups.groups.keys())
|
| 37 |
+
law_type_to_laws = defaultdict(list)
|
| 38 |
+
for lid in law_ids:
|
| 39 |
+
lt = law_groups.get_group(lid).iloc[0]["law_type"]
|
| 40 |
+
law_type_to_laws[lt].append(lid)
|
| 41 |
+
|
| 42 |
+
pairs = []
|
| 43 |
+
for law_id in tqdm(law_ids, desc=" Pairs"):
|
| 44 |
+
group = law_groups.get_group(law_id)
|
| 45 |
+
articles = group.to_dict("records")
|
| 46 |
+
if len(articles) < 2: continue
|
| 47 |
+
law_type = articles[0]["law_type"]
|
| 48 |
+
same_type_laws = [l for l in law_type_to_laws.get(law_type, []) if l != law_id]
|
| 49 |
+
|
| 50 |
+
for art in articles:
|
| 51 |
+
q = f"{art['title']}: {art['text'][:400]}"
|
| 52 |
+
pos = [a for a in articles if a["id"] != art["id"]]
|
| 53 |
+
if pos:
|
| 54 |
+
pairs.append((q, f"{random.choice(pos)['title']}: {random.choice(pos)['text'][:400]}", 1.0))
|
| 55 |
+
if same_type_laws:
|
| 56 |
+
neg_art = law_groups.get_group(random.choice(same_type_laws)).iloc[0]
|
| 57 |
+
pairs.append((q, f"{neg_art['title']}: {neg_art['text'][:400]}", 0.0))
|
| 58 |
+
diff = [l for l in law_ids if l != law_id and l not in same_type_laws]
|
| 59 |
+
if diff:
|
| 60 |
+
neg_art2 = law_groups.get_group(random.choice(diff)).iloc[0]
|
| 61 |
+
pairs.append((q, f"{neg_art2['title']}: {neg_art2['text'][:400]}", 0.0))
|
| 62 |
+
|
| 63 |
+
n_pos = sum(1 for p in pairs if p[2] == 1.0)
|
| 64 |
+
if len(pairs) > 60000:
|
| 65 |
+
pos_pairs = [p for p in pairs if p[2] == 1.0]
|
| 66 |
+
neg_pairs = [p for p in pairs if p[2] == 0.0]
|
| 67 |
+
pairs = random.sample(pos_pairs, min(30000, len(pos_pairs))) + random.sample(neg_pairs, min(30000, len(neg_pairs)))
|
| 68 |
+
print(f" {len(pairs)} pairs ({sum(1 for p in pairs if p[2]==1.0)} pos)")
|
| 69 |
+
|
| 70 |
+
# ββ Model ββ
|
| 71 |
+
print("Loading PhoBERT...")
|
| 72 |
+
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
|
| 73 |
+
encoder = AutoModel.from_pretrained("vinai/phobert-base-v2").to(device)
|
| 74 |
+
head = nn.Sequential(
|
| 75 |
+
nn.Linear(encoder.config.hidden_size, 512), nn.ReLU(), nn.Dropout(0.1),
|
| 76 |
+
nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.1),
|
| 77 |
+
nn.Linear(256, 1),
|
| 78 |
+
).to(device)
|
| 79 |
+
opt = torch.optim.AdamW(list(encoder.parameters())+list(head.parameters()), lr=1e-5, weight_decay=0.01)
|
| 80 |
+
|
| 81 |
+
# ββ Train ββ
|
| 82 |
+
B, epochs = 16, 10
|
| 83 |
+
steps_per_epoch = len(pairs) // B
|
| 84 |
+
print(f"\nTraining: {epochs} epochs, {steps_per_epoch} steps/epoch")
|
| 85 |
+
best_loss = float("inf")
|
| 86 |
+
|
| 87 |
+
for epoch in range(epochs):
|
| 88 |
+
random.shuffle(pairs)
|
| 89 |
+
epoch_loss = 0.0
|
| 90 |
+
progress = tqdm(range(steps_per_epoch), desc=f" Epoch {epoch+1}/{epochs}")
|
| 91 |
+
for step in progress:
|
| 92 |
+
start = (step * B) % max(len(pairs) - B, 1)
|
| 93 |
+
batch = pairs[start:start + B]
|
| 94 |
+
queries = [p[0] for p in batch]; docs = [p[1] for p in batch]
|
| 95 |
+
labels = torch.tensor([p[2] for p in batch], dtype=torch.float, device=device)
|
| 96 |
+
|
| 97 |
+
enc = tokenizer(queries, docs, padding=True, truncation=True, max_length=256, return_tensors="pt")
|
| 98 |
+
input_ids = enc["input_ids"].to(device); attention_mask = enc["attention_mask"].to(device)
|
| 99 |
+
out = encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 100 |
+
scores = head(out.last_hidden_state[:, 0, :]).squeeze(-1)
|
| 101 |
+
|
| 102 |
+
pos_mask = labels == 1; neg_mask = labels == 0
|
| 103 |
+
if pos_mask.any() and neg_mask.any():
|
| 104 |
+
pos_scores = scores[pos_mask]; neg_scores = scores[neg_mask]
|
| 105 |
+
loss = F.relu(0.3 - pos_scores.unsqueeze(1) + neg_scores.unsqueeze(0)).mean()
|
| 106 |
+
else:
|
| 107 |
+
loss = F.binary_cross_entropy_with_logits(scores, labels)
|
| 108 |
+
|
| 109 |
+
opt.zero_grad(); loss.backward()
|
| 110 |
+
torch.nn.utils.clip_grad_norm_(list(encoder.parameters())+list(head.parameters()), 1.0)
|
| 111 |
+
opt.step()
|
| 112 |
+
epoch_loss += loss.item()
|
| 113 |
+
progress.set_postfix({"loss": f"{loss.item():.4f}"})
|
| 114 |
+
|
| 115 |
+
avg_loss = epoch_loss / steps_per_epoch
|
| 116 |
+
print(f" Epoch {epoch+1} avg loss: {avg_loss:.4f}")
|
| 117 |
+
if avg_loss < best_loss:
|
| 118 |
+
best_loss = avg_loss
|
| 119 |
+
torch.save({"encoder": encoder.state_dict(), "head": head.state_dict()},
|
| 120 |
+
"data/checkpoints/telen/cross_encoder_best.pt")
|
| 121 |
+
print(f" Saved (loss={avg_loss:.4f})")
|
| 122 |
+
|
| 123 |
+
print("\nDone! Model saved to: data/checkpoints/telen/cross_encoder_best.pt")
|