haidang2405 commited on
Commit
e7cfc32
Β·
verified Β·
1 Parent(s): 45040b3

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /data
2
+ **/__pycache__
3
+ /dataset
README.md CHANGED
@@ -1,3 +1,215 @@
 
 
 
 
 
 
 
 
1
  ---
2
- license: mit
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TELEN: Temporal Evolving Legal Embedding Network
2
+
3
+ > **Vietnamese legal text embedding with meta-learning for continuous adaptation to new laws.**
4
+
5
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/)
6
+ [![PyTorch](https://img.shields.io/badge/PyTorch-2.0+-red.svg)](https://pytorch.org/)
7
+ [![License](https://img.shields.io/badge/license-MIT-green.svg)](LICENSE)
8
+
9
  ---
10
+
11
+ ## Overview
12
+
13
+ TELEN introduces a **novel embedding architecture** designed specifically for Vietnamese legal text retrieval in RAG (Retrieval-Augmented Generation) systems. Unlike conventional static embedding models, TELEN generates embeddings that **adapt dynamically** to the current state of the legal corpus β€” enabling seamless integration of new laws without retraining.
14
+
15
+ ### Key Innovations
16
+
17
+ 1. **HyperNetwork-Driven Projection** β€” Instead of fixed projection weights, a HyperNetwork generates the embedding projection function from the current legal corpus state. When new laws are published, the embedding space adapts automatically.
18
+
19
+ 2. **Legal Concept Graph (LCG)** β€” An evolving knowledge graph where nodes represent legal entities (laws, key terms) and edges encode cross-references, agency hierarchy, temporal sequences, and semantic similarity.
20
+
21
+ 3. **State-Adaptive Embeddings** β€” Embeddings are not static vectors but are modulated by a learned "legal state vector" that summarizes the entire legal landscape at any point in time.
22
+
23
  ---
24
+
25
+ ## Architecture
26
+
27
+ ```
28
+ Legal Text
29
+ ↓
30
+ Bi-Encoder (bkai-foundation-models/vietnamese-bi-encoder)
31
+ ↓
32
+ Raw Representation [768-dim]
33
+ ↓
34
+ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
35
+ β”‚ HyperNetwork(state_vector) β†’ Ξ”W, Ξ”b β”‚ ← Generated, not learned!
36
+ β”‚ Adapted Projection = Base + Ξ”WΒ·x + Ξ”b β”‚
37
+ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
38
+ ↓
39
+ Legal Concept Graph (GNN)
40
+ ↓ state_vector
41
+ State Encoder ← current legal corpus
42
+ ↓
43
+ L2-Normalized Embedding [768-dim]
44
+ ```
45
+
46
+ ## Benchmark Results
47
+
48
+ **Test set**: 1,406 Vietnamese legal articles from 2021 (held-out, unseen during training)
49
+
50
+ | Model | NDCG@3 | NDCG@5 | NDCG@10 | MRR@3 | MRR@5 | MRR@10 |
51
+ |---|---|---|---|---|---|---|
52
+ | **BM25** (bm25) | 0.5164 | 0.5628 | 0.5718 | 0.5016 | 0.5290 | 0.5354 |
53
+ | **PhoBERT-base-v2** (dense) | 0.4803 | 0.5305 | 0.5738 | 0.4503 | 0.4792 | 0.4961 |
54
+ | **DEk21** (dense) | 0.6651 | 0.6907 | 0.7286 | 0.6394 | 0.6553 | 0.6734 |
55
+ | **TELEN** (dense) | **0.8878** | **0.9097** | **0.9132** | **0.8686** | **0.8782** | **0.8782** |
56
+
57
+ ### Relative Improvement
58
+
59
+ | Baseline | NDCG@3 | NDCG@10 | MRR@10 |
60
+ |---|---|---|---|
61
+ | vs PhoBERT (dense) | **+84.9%** | **+59.2%** | **+77.1%** |
62
+ | vs DEk21 (dense) | **+33.5%** | **+25.3%** | **+30.4%** |
63
+
64
+ ---
65
+
66
+ ## Quick Start
67
+
68
+ ### Installation
69
+
70
+ ```bash
71
+ pip install -r requirements.txt
72
+ ```
73
+
74
+ ### Inference
75
+
76
+ ```python
77
+ from inference import TELENInference
78
+
79
+ # Load model
80
+ model = TELENInference()
81
+
82
+ # Encode legal texts
83
+ texts = [
84
+ "Điều 1: ThΓ΄ng tΖ° nΓ y quy Δ‘α»‹nh về quαΊ£n lΓ½ thuαΊΏ giΓ‘ trα»‹ gia tΔƒng...",
85
+ "Điều 2: Đối tượng Γ‘p dα»₯ng lΓ  cΓ‘c tα»• chα»©c, cΓ‘ nhΓ’n kinh doanh...",
86
+ ]
87
+ embeddings = model.encode(texts) # β†’ [2, 768] normalized vectors
88
+
89
+ # Compute similarity
90
+ similarity = model.similarity(texts[0], texts[1])
91
+ print(f"Cosine similarity: {similarity:.4f}")
92
+
93
+ # Retrieve similar documents
94
+ results = model.retrieve(texts[0], corpus, top_k=10)
95
+ ```
96
+
97
+ ### Training
98
+
99
+ ```bash
100
+ # Train TELEN from scratch
101
+ python train.py
102
+
103
+ # Train cross-encoder re-ranker (optional, for extra +2-3% gain)
104
+ python train_ce.py
105
+ ```
106
+
107
+ ### Evaluation
108
+
109
+ ```bash
110
+ python eval.py
111
+ ```
112
+
113
+ ---
114
+
115
+ ## Training Details
116
+
117
+ ### Dataset
118
+ - **Source**: [another-symato/VMTEB-Zalo-legel-retrieval-wseg](https://huggingface.co/datasets/another-symato/VMTEB-Zalo-legel-retrieval-wseg) on HuggingFace
119
+ - **Content**: 61,425 Vietnamese legal articles (ThΓ΄ng tΖ°, Nghα»‹ Δ‘α»‹nh, LuαΊ­t, PhΓ‘p lệnh)
120
+ - **Period**: 1999–2021
121
+ - **Format**: Word-segmented Vietnamese text (underscore-separated compound words)
122
+
123
+ ### Training Pipeline
124
+
125
+ | Stage | Description | Epochs | Trainable Params |
126
+ |---|---|---|---|
127
+ | 1. Contrastive Pretraining | Triplet + InfoNCE loss on same-law article pairs | 5 | ~1M (projection head) |
128
+ | 2. Meta-Training | HyperNetwork learns to adapt embedding space for future laws | 50 (early stop) | ~4M (HyperNetwork + State Encoder) |
129
+
130
+ ### Hyperparameters
131
+
132
+ | Parameter | Value |
133
+ |---|---|
134
+ | Backbone | `bkai-foundation-models/vietnamese-bi-encoder` |
135
+ | Embedding dimension | 768 |
136
+ | Adaptation rank | 64 |
137
+ | GNN layers | 3 |
138
+ | Meta N-way, K-shot | 16-way, 5-shot |
139
+ | Negatives per query | 256 (50% hard + 50% random) |
140
+ | Temperature | 0.05 |
141
+ | Optimizer | AdamW + CosineAnnealingWarmRestarts |
142
+
143
+ ### Hardware
144
+ - GPU: NVIDIA RTX 5070 Ti (16GB VRAM)
145
+ - Training time: ~8 hours (5 contrastive + 50 meta epochs)
146
+
147
+ ---
148
+
149
+ ## Continuous Adaptation
150
+
151
+ When a new law is published, TELEN adapts without retraining:
152
+
153
+ ```python
154
+ # New law arrives
155
+ new_articles = [
156
+ "Điều 1: LuαΊ­t mα»›i về trΓ­ tuệ nhΓ’n tαΊ‘o...",
157
+ "Điều 2: CΓ‘c nguyΓͺn tαΊ―c Γ‘p dα»₯ng AI trong xΓ©t xα»­...",
158
+ ]
159
+
160
+ # Update concept graph (milliseconds)
161
+ model.add_new_law("123/2025/l-ai", new_articles)
162
+
163
+ # Embedding space automatically adapts via HyperNetwork
164
+ # All subsequent query embeddings reflect the new legal landscape
165
+ embeddings = model.encode(["Điều 1: ..."])
166
+ ```
167
+
168
+ ---
169
+
170
+ ## Project Structure
171
+
172
+ ```
173
+ law-embedding/
174
+ β”œβ”€β”€ dataset/
175
+ β”‚ └── train-00000-of-00001.parquet # Training data (61K legal articles)
176
+ β”œβ”€β”€ src/
177
+ β”‚ β”œβ”€β”€ data.py # Data loading utilities
178
+ β”‚ └── telern/
179
+ β”‚ β”œβ”€β”€ config.py # Configuration
180
+ β”‚ β”œβ”€β”€ model.py # TELEN architecture
181
+ β”‚ β”œβ”€β”€ concept_graph.py # Legal Concept Graph + GNN
182
+ β”‚ β”œβ”€β”€ hypernetwork.py # HyperNetwork + StateEncoder
183
+ β”‚ └── evaluate.py # Evaluation metrics & baselines
184
+ β”œβ”€β”€ data/checkpoints/telen/
185
+ β”‚ └── telen_best.pt # Pretrained model weights
186
+ β”œβ”€β”€ train.py # Training script
187
+ β”œβ”€β”€ train_ce.py # Cross-encoder training (optional)
188
+ β”œβ”€β”€ eval.py # Evaluation script
189
+ β”œβ”€β”€ inference.py # Inference API
190
+ β”œβ”€β”€ requirements.txt
191
+ └── README.md
192
+ ```
193
+
194
+ ---
195
+
196
+ ## Citation
197
+
198
+ ```bibtex
199
+ @misc{telen2025,
200
+ title={TELEN: Temporal Evolving Legal Embedding Network for Vietnamese Law},
201
+ author={dangdinh},
202
+ year={2026},
203
+ publisher={Huggingface},
204
+ }
205
+ ```
206
+
207
+ ## License
208
+
209
+ MIT License β€” see [LICENSE](LICENSE) file for details.
210
+
211
+ ## Acknowledgments
212
+
213
+ - `bkai-foundation-models/vietnamese-bi-encoder` β€” backbone bi-encoder
214
+ - `huyydangg/DEk21_hcmute_embedding` β€” baseline comparison (previous SOTA)
215
+ - `vinai/phobert-base-v2` β€” used in cross-encoder re-ranker
eval.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate TELEN with full benchmarks.
3
+
4
+ Metrics: NDCG@3, NDCG@5, NDCG@10, MRR@3, MRR@5, MRR@10
5
+
6
+ Baselines:
7
+ - BM25 (lexical retrieval)
8
+ - Frozen PhoBERT (vinai/phobert-base-v2)
9
+ - DEk21 (huyydangg/DEk21_hcmute_embedding)
10
+ - TELEN (ours)
11
+
12
+ Usage:
13
+ python eval.py
14
+ """
15
+ import sys; sys.path.insert(0, ".")
16
+ sys.stdout.reconfigure(encoding='utf-8')
17
+ import warnings; warnings.filterwarnings("ignore")
18
+ import random, numpy as np, torch, torch.nn.functional as F
19
+ from tqdm import tqdm
20
+ from collections import defaultdict
21
+ from sentence_transformers import SentenceTransformer
22
+ from pyvi import ViTokenizer
23
+
24
+ from src.telern.config import TELENConfig
25
+ from src.telern.model import create_model
26
+ from src.telern.evaluate import (
27
+ BM25Baseline, FrozenPhoBERT, prepare_test_data,
28
+ build_test_queries, build_test_corpus, compute_metrics, evaluate_bm25,
29
+ )
30
+
31
+ SEED = 42; random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ config = TELENConfig()
34
+
35
+ def wseg(text):
36
+ return ViTokenizer.tokenize(text.replace("_", " "))
37
+
38
+ def evaluate_model(name, encode_fn, queries, corpus, corpus_ids, corpus_law_ids):
39
+ """Generic evaluation for any embedding model."""
40
+ print(f"\n [{name}] Encoding corpus ({len(corpus)} docs)...")
41
+ c_embs = []
42
+ for i in range(0, len(corpus), 64):
43
+ batch = [d["text"] for d in corpus[i:i+64]]
44
+ embs = encode_fn(batch)
45
+ if isinstance(embs, np.ndarray): embs = torch.tensor(embs)
46
+ c_embs.append(embs.cpu())
47
+ c_embs = torch.cat(c_embs, dim=0)
48
+
49
+ print(f" [{name}] Evaluating {len(queries)} queries...")
50
+ all_m = defaultdict(list)
51
+ for q in tqdm(queries, desc=f" {name}"):
52
+ q_emb = encode_fn([q["query_text"]])
53
+ if isinstance(q_emb, np.ndarray): q_emb = torch.tensor(q_emb)
54
+ sim = F.cosine_similarity(q_emb.cpu(), c_embs).numpy()
55
+ rel = np.array([1.0 if corpus_law_ids[j]==q["law_id"] else 0.0 for j in range(len(corpus))])
56
+ si = sim.argsort()[::-1]; sr = rel[si]
57
+ for j,cid in enumerate(corpus_ids):
58
+ if cid==q["query_id"]:
59
+ p=np.where(si==j)[0]; sr=np.delete(sr,p[0]) if len(p)>0 else None; break
60
+ for k in [3,5,10]:
61
+ for mn,mv in compute_metrics(sr[:k],[k]).items(): all_m[mn].append(mv)
62
+ return {n: np.mean(v) for n,v in all_m.items()}
63
+
64
+ # ── Data ──
65
+ test_df = prepare_test_data(config)
66
+ queries = build_test_queries(test_df, max_queries=300)
67
+ corpus = build_test_corpus(test_df)
68
+ corpus_ids = [d["article_id"] for d in corpus]
69
+ corpus_law_ids = [d["law_id"] for d in corpus]
70
+ train_df = test_df[test_df["year"] <= config.meta.train_split_year]
71
+ print(f"Test: {len(queries)} queries, {len(corpus)} docs, {test_df['law_id'].nunique()} laws")
72
+
73
+ results = {}
74
+
75
+ # ── BM25 ──
76
+ print("\n[1/4] BM25")
77
+ results["BM25"] = evaluate_bm25(queries, corpus)
78
+
79
+ # ── PhoBERT ──
80
+ print("\n[2/4] Frozen PhoBERT")
81
+ phobert = FrozenPhoBERT()
82
+ results["PhoBERT"] = evaluate_model("PhoBERT", lambda texts: phobert.encode(texts, batch_size=64), queries, corpus, corpus_ids, corpus_law_ids)
83
+
84
+ # ── DEk21 ──
85
+ print("\n[3/4] DEk21 (SOTA)")
86
+ dek21 = SentenceTransformer("huyydangg/DEk21_hcmute_embedding", device=device)
87
+ results["DEk21"] = evaluate_model("DEk21", lambda texts: dek21.encode([wseg(t) for t in texts], batch_size=64, show_progress_bar=False, normalize_embeddings=True, convert_to_tensor=True), queries, corpus, corpus_ids, corpus_law_ids)
88
+
89
+ # ── TELEN ──
90
+ print("\n[4/4] TELEN (Ours)")
91
+ telen = create_model(config).to(device)
92
+ ckpt = torch.load(config.output_dir + "/telen_best.pt", map_location=device, weights_only=False)
93
+ telen.hypernetwork.load_state_dict(ckpt["hypernetwork"])
94
+ telen.state_encoder.load_state_dict(ckpt["state_encoder"])
95
+ telen.base_projection.load_state_dict(ckpt["base_projection"])
96
+ telen.attn_query.data.copy_(ckpt["attn_query"])
97
+ if len(train_df) > 0: telen.build_graph(train_df)
98
+
99
+ def telen_encode(texts):
100
+ with torch.no_grad():
101
+ return telen(texts, use_stochastic=False)["embeddings"].cpu()
102
+
103
+ results["TELEN"] = evaluate_model("TELEN", telen_encode, queries, corpus, corpus_ids, corpus_law_ids)
104
+
105
+ # ── Summary ──
106
+ print("\n" + "=" * 75)
107
+ print("BENCHMARK RESULTS")
108
+ print("=" * 75)
109
+ h = f"{'Method':<15}"
110
+ for m in [3,5,10]: h += f" {'NDCG@'+str(m):>10} {'MRR@'+str(m):>10}"
111
+ print(h); print("-"*len(h))
112
+ for name in ["BM25", "PhoBERT", "DEk21", "TELEN"]:
113
+ r = f"{name:<15}"
114
+ for m in [3,5,10]: r += f" {results[name][f'ndcg@{m}']:>10.4f} {results[name][f'mrr@{m}']:>10.4f}"
115
+ print(r)
116
+
117
+ print("\n--- Relative Improvement over Baselines ---")
118
+ for baseline in ["PhoBERT", "DEk21"]:
119
+ print(f" TELEN vs {baseline}:")
120
+ for m in [3,5,10]:
121
+ ni = (results["TELEN"][f"ndcg@{m}"] / max(results[baseline][f"ndcg@{m}"], 1e-6) - 1) * 100
122
+ mi = (results["TELEN"][f"mrr@{m}"] / max(results[baseline][f"mrr@{m}"], 1e-6) - 1) * 100
123
+ print(f" NDCG@{m}: {ni:+.1f}% MRR@{m}: {mi:+.1f}%")
124
+ print("Done!")
inference.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TELEN Inference β€” encode legal texts to 768-dim embeddings.
3
+
4
+ Usage:
5
+ from inference import TELENInference
6
+ model = TELENInference()
7
+ embeddings = model.encode(["Điều 1: ThΓ΄ng tΖ° nΓ y quy Δ‘α»‹nh về..."])
8
+ similarity = model.similarity(text1, text2)
9
+ """
10
+ import sys; sys.path.insert(0, ".")
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from pyvi import ViTokenizer
14
+
15
+ from src.telern.config import TELENConfig
16
+ from src.telern.model import create_model
17
+
18
+
19
+ class TELENInference:
20
+ def __init__(self, checkpoint_path: str = None):
21
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ self.config = TELENConfig()
23
+ self.model = create_model(self.config).to(self.device)
24
+
25
+ if checkpoint_path is None:
26
+ checkpoint_path = self.config.output_dir + "/telen_best.pt"
27
+
28
+ ckpt = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
29
+ self.model.hypernetwork.load_state_dict(ckpt["hypernetwork"])
30
+ self.model.state_encoder.load_state_dict(ckpt["state_encoder"])
31
+ self.model.base_projection.load_state_dict(ckpt["base_projection"])
32
+ self.model.attn_query.data.copy_(ckpt["attn_query"])
33
+ self.model.eval()
34
+
35
+ print(f"TELEN loaded on {self.device}")
36
+ print(f" HyperNetwork: {sum(p.numel() for p in self.model.hypernetwork.parameters()):,} params")
37
+ print(f" Ready for inference.")
38
+
39
+ def build_graph(self, df):
40
+ """Build concept graph from a DataFrame with [id, title, text, law_id, law_type, year] columns."""
41
+ self.model.build_graph(df)
42
+
43
+ def encode(self, texts: list, batch_size: int = 64) -> torch.Tensor:
44
+ """Encode a list of legal texts to 768-dim normalized embeddings."""
45
+ embeddings = []
46
+ for i in range(0, len(texts), batch_size):
47
+ batch = texts[i:i + batch_size]
48
+ with torch.no_grad():
49
+ result = self.model(batch, use_stochastic=False)
50
+ embeddings.append(result["embeddings"].cpu())
51
+ return torch.cat(embeddings, dim=0)
52
+
53
+ def similarity(self, text1: str, text2: str) -> float:
54
+ """Compute cosine similarity between two texts."""
55
+ emb = self.encode([text1, text2])
56
+ return F.cosine_similarity(emb[0:1], emb[1:2]).item()
57
+
58
+ def retrieve(self, query: str, corpus: list, top_k: int = 10) -> list:
59
+ """Retrieve top-k most similar documents from a corpus."""
60
+ query_emb = self.encode([query])
61
+ corpus_embs = self.encode(corpus)
62
+ sim = F.cosine_similarity(query_emb, corpus_embs).numpy()
63
+ top_indices = sim.argsort()[::-1][:top_k]
64
+ return [(int(i), float(sim[i])) for i in top_indices]
65
+
66
+
67
+ # ── Demo ──
68
+ if __name__ == "__main__":
69
+ model = TELENInference()
70
+
71
+ # Example queries
72
+ q1 = "Điều 1: ThΓ΄ng tΖ° nΓ y quy Δ‘α»‹nh về quαΊ£n lΓ½ thuαΊΏ giΓ‘ trα»‹ gia tΔƒng Δ‘α»‘i vα»›i hΓ ng hΓ³a nhαΊ­p khαΊ©u"
73
+ q2 = "Điều 2: Đối tượng Γ‘p dα»₯ng lΓ  cΓ‘c tα»• chα»©c, cΓ‘ nhΓ’n kinh doanh hΓ ng hΓ³a nhαΊ­p khαΊ©u"
74
+ q3 = "Điều 1: Nghα»‹ Δ‘α»‹nh nΓ y quy Δ‘α»‹nh về xα»­ phαΊ‘t vi phαΊ‘m hΓ nh chΓ­nh trong lΔ©nh vα»±c giao thΓ΄ng"
75
+
76
+ print(f"\nSimilarity test:")
77
+ print(f" q1 vs q2 (same law): {model.similarity(q1, q2):.4f}")
78
+ print(f" q1 vs q3 (diff law): {model.similarity(q1, q3):.4f}")
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.40.0
3
+ sentence-transformers>=3.0.0
4
+ peft>=0.10.0
5
+ pandas>=2.0.0
6
+ pyarrow>=14.0.0
7
+ scikit-learn>=1.3.0
8
+ tqdm>=4.65.0
9
+ numpy>=1.24.0
10
+ pyvi>=0.1.0
11
+ accelerate>=0.24.0
src/__init__.py ADDED
File without changes
src/data.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared data utilities used by TELEN modules."""
2
+ import unicodedata
3
+ import pandas as pd
4
+
5
+
6
+ def load_raw_data(parquet_path: str) -> pd.DataFrame:
7
+ """Load the raw parquet file."""
8
+ return pd.read_parquet(parquet_path)
9
+
10
+
11
+ def extract_metadata(df: pd.DataFrame) -> pd.DataFrame:
12
+ """Extract law_id, article_num, law_type, year from id column."""
13
+ df = df.copy()
14
+
15
+ def parse_id(id_str):
16
+ if "#" in id_str:
17
+ parts = id_str.split("#")
18
+ law_id = parts[0]
19
+ article_part = parts[1]
20
+ article_num = int(article_part.split("-")[0])
21
+ else:
22
+ law_id = id_str
23
+ article_num = 0
24
+ return law_id, article_num
25
+
26
+ parsed = df["id"].apply(parse_id)
27
+ df["law_id"] = parsed.apply(lambda x: x[0])
28
+ df["article_num"] = parsed.apply(lambda x: x[1])
29
+
30
+ def extract_law_type(law_id):
31
+ parts = law_id.split("/")
32
+ if len(parts) >= 3:
33
+ return parts[2].split("-")[-1] if "-" in parts[2] else parts[2]
34
+ return "unknown"
35
+
36
+ df["law_type"] = df["law_id"].apply(extract_law_type)
37
+
38
+ def extract_year(law_id):
39
+ parts = law_id.split("/")
40
+ if len(parts) >= 2:
41
+ year_str = parts[1]
42
+ try:
43
+ year = int(year_str)
44
+ return year if year >= 100 else year + 1900
45
+ except ValueError:
46
+ pass
47
+ return 1999
48
+
49
+ df["year"] = df["law_id"].apply(extract_year)
50
+ return df
51
+
52
+
53
+ def clean_data(df: pd.DataFrame, min_text_len: int = 10) -> pd.DataFrame:
54
+ """Remove short/empty texts and duplicates."""
55
+ df = df.copy()
56
+ df = df[df["text"].str.len() >= min_text_len].reset_index(drop=True)
57
+ df["title"] = df["title"].apply(lambda x: unicodedata.normalize("NFC", str(x)))
58
+ df["text"] = df["text"].apply(lambda x: unicodedata.normalize("NFC", str(x)))
59
+ df = df.drop_duplicates(subset=["text"], keep="first").reset_index(drop=True)
60
+ return df
src/telern/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """TELEN: Temporal Evolving Legal Embedding Network."""
src/telern/concept_graph.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Legal Concept Graph β€” evolving knowledge backbone of TELEN.
3
+
4
+ Nodes: law entities + key terms extracted via TF-IDF
5
+ Edges: agency, temporal, semantic, cross-reference, term-document
6
+ GNN: Multi-layer sparse graph convolution
7
+ """
8
+ import re
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from sklearn.feature_extraction.text import TfidfVectorizer
13
+
14
+
15
+ # ═══════════════════════════════════════════════
16
+ # GNN Layers
17
+ # ═══════════════════════════════════════════════
18
+ class GCNLayer(nn.Module):
19
+ def __init__(self, in_dim, out_dim, dropout=0.1):
20
+ super().__init__()
21
+ self.linear = nn.Linear(in_dim, out_dim)
22
+ self.dropout = nn.Dropout(dropout)
23
+ self.norm = nn.LayerNorm(out_dim)
24
+
25
+ def forward(self, x, adj):
26
+ deg = adj.sum(dim=1).clamp(min=1)
27
+ deg_inv_sqrt = deg.pow(-0.5)
28
+ norm_adj = deg_inv_sqrt.unsqueeze(1) * adj * deg_inv_sqrt.unsqueeze(0)
29
+ x = norm_adj @ x
30
+ x = self.linear(x)
31
+ x = F.relu(x)
32
+ x = self.dropout(x)
33
+ x = self.norm(x)
34
+ return x
35
+
36
+
37
+ class GNNEncoder(nn.Module):
38
+ def __init__(self, dim, n_layers=3, dropout=0.1):
39
+ super().__init__()
40
+ self.layers = nn.ModuleList([GCNLayer(dim, dim, dropout) for _ in range(n_layers)])
41
+
42
+ def forward(self, x, adj):
43
+ for layer in self.layers:
44
+ x = layer(x, adj) + x # residual
45
+ return x
46
+
47
+
48
+ # ═══════════════════════════════════════════════
49
+ # Legal Concept Graph
50
+ # ═══════════════════════════════════════════════
51
+ class LegalConceptGraph(nn.Module):
52
+ def __init__(self, config):
53
+ super().__init__()
54
+ self.config = config
55
+ self.hidden_dim = config.graph.hidden_dim
56
+
57
+ self.node_ids = []
58
+ self.node_embeddings = None
59
+ self.edges = {"cross_ref": [], "agency": [], "temporal": [], "semantic": []}
60
+ self._adj_cached = None
61
+ self._adj_dirty = True
62
+
63
+ self.gnn = GNNEncoder(config.graph.hidden_dim, config.graph.gnn_layers, config.graph.gnn_dropout)
64
+
65
+ @property
66
+ def num_nodes(self):
67
+ return len(self.node_ids)
68
+
69
+ @property
70
+ def device(self):
71
+ return self.gnn.layers[0].linear.weight.device
72
+
73
+ def add_nodes(self, node_ids, embeddings):
74
+ if self.node_embeddings is None:
75
+ self.node_embeddings = embeddings
76
+ else:
77
+ self.node_embeddings = torch.cat([self.node_embeddings, embeddings], dim=0)
78
+ self.node_ids.extend(node_ids)
79
+ self._adj_dirty = True
80
+
81
+ def add_edges(self, edge_type, edges):
82
+ self.edges[edge_type].extend(edges)
83
+ self._adj_dirty = True
84
+
85
+ def build_adjacency(self):
86
+ if not self._adj_dirty and self._adj_cached is not None:
87
+ return self._adj_cached
88
+ N = self.num_nodes
89
+ adj = torch.zeros(N, N, device=self.device)
90
+
91
+ for edge_type, use in [("cross_ref", self.config.graph.use_cross_ref_edges),
92
+ ("agency", self.config.graph.use_agency_edges),
93
+ ("temporal", self.config.graph.use_temporal_edges),
94
+ ("semantic", self.config.graph.use_semantic_edges)]:
95
+ if not use or not self.edges[edge_type]:
96
+ continue
97
+ valid = [(s, d, w) for s, d, w in self.edges[edge_type] if s < N and d < N]
98
+ if not valid:
99
+ continue
100
+ src = torch.tensor([e[0] for e in valid], device=self.device, dtype=torch.long)
101
+ dst = torch.tensor([e[1] for e in valid], device=self.device, dtype=torch.long)
102
+ wgt = torch.tensor([e[2] for e in valid], device=self.device, dtype=torch.float)
103
+ adj.index_put_((src, dst), wgt, accumulate=True)
104
+ adj.index_put_((dst, src), wgt, accumulate=True)
105
+
106
+ adj = adj + torch.eye(N, device=self.device)
107
+ self._adj_cached = adj
108
+ self._adj_dirty = False
109
+ return adj
110
+
111
+ def forward(self):
112
+ dev = self.device
113
+ if self.node_embeddings.device != dev:
114
+ self.node_embeddings = self.node_embeddings.to(dev)
115
+ adj = self.build_adjacency()
116
+ return self.gnn(self.node_embeddings, adj)
117
+
118
+ def to(self, device):
119
+ super().to(device)
120
+ if self.node_embeddings is not None:
121
+ self.node_embeddings = self.node_embeddings.to(device)
122
+ return self
123
+
124
+
125
+ # ═══════════════════════════════════════════════
126
+ # Cross-reference extraction
127
+ # ════════════════════════════��══════════════════
128
+ CROSS_REF_PATTERNS = [
129
+ (re.compile(r"(?:theo|theo quy Δ‘α»‹nh tαΊ‘i|cΔƒn cα»© vΓ o|cΔƒn cα»©)\s+Điều\s+(\d+)\s+(?:cα»§a\s+)?(LuαΊ­t|Bα»™ luαΊ­t|Nghα»‹ Δ‘α»‹nh|ThΓ΄ng tΖ°|PhΓ‘p lệnh)\s+([^,.;]+)"), "citation"),
130
+ (re.compile(r"(LuαΊ­t|Bα»™ luαΊ­t|Nghα»‹ Δ‘α»‹nh|ThΓ΄ng tΖ°|PhΓ‘p lệnh|QuyαΊΏt Δ‘α»‹nh)\s+(?:sα»‘\s+)?([\d]+/[\d]+/[\w-]+)"), "reference"),
131
+ (re.compile(r"sα»­a Δ‘α»•i[,,]\s*bα»• sung\s+(?:mα»™t sα»‘ Δ‘iều cα»§a\s+)?(LuαΊ­t|Nghα»‹ Δ‘α»‹nh|ThΓ΄ng tΖ°)\s+([^,.;]+)"), "amendment"),
132
+ (re.compile(r"(?:thay thαΊΏ|bΓ£i bỏ)\s+(?:Điều\s+(\d+)\s+(?:cα»§a\s+)?)?(LuαΊ­t|Nghα»‹ Δ‘α»‹nh|ThΓ΄ng tΖ°)\s+([^,.;]+)"), "replacement"),
133
+ ]
134
+
135
+
136
+ def extract_key_terms(df, max_terms=200):
137
+ texts = [(f"{row['title']} {row['text'][:500]}").replace("_", " ")
138
+ for _, row in df.iterrows()]
139
+ vectorizer = TfidfVectorizer(max_features=max_terms, ngram_range=(1, 2),
140
+ min_df=3, max_df=0.8, token_pattern=r'(?u)\b\w+\b')
141
+ tfidf = vectorizer.fit_transform(texts)
142
+ scores = tfidf.max(axis=0).toarray().flatten()
143
+ return list(vectorizer.get_feature_names_out()[scores.argsort()[::-1][:max_terms]])
144
+
145
+
146
+ def _law_matches_ref(law_id, ref_text):
147
+ law_lower = law_id.lower().replace("_", " ").replace("-", " ")
148
+ ref_lower = ref_text.lower().replace("_", " ").replace("-", " ")
149
+ parts = law_id.split("/")
150
+ if len(parts) >= 3:
151
+ if parts[2].replace("_", " ") in ref_lower: return True
152
+ if len(parts) >= 2 and parts[1] in ref_lower: return True
153
+ return False
154
+
155
+
156
+ def build_concept_graph(df, encode_fn, config):
157
+ """Build enhanced concept graph from training data."""
158
+ graph = LegalConceptGraph(config)
159
+ law_groups = df.groupby("law_id")
160
+ law_ids = sorted(law_groups.groups.keys())
161
+ N_laws = len(law_ids)
162
+ print(f" Building graph: {N_laws} law nodes...")
163
+
164
+ # Law embeddings
165
+ embs = []
166
+ for lid in law_ids:
167
+ group = law_groups.get_group(lid)
168
+ texts = [f"{t}: {txt[:300]}" for t, txt in zip(group["title"], group["text"])]
169
+ embs.append(torch.stack([encode_fn(t) for t in texts[:5]]).mean(dim=0))
170
+ law_embs = torch.stack(embs)
171
+ graph.add_nodes(law_ids, law_embs)
172
+ law_id_to_idx = {lid: i for i, lid in enumerate(law_ids)}
173
+
174
+ # Key term nodes
175
+ print(" Extracting key terms...")
176
+ key_terms = extract_key_terms(df, max_terms=200)
177
+ term_embs = torch.stack([encode_fn(t) for t in key_terms])
178
+ graph.add_nodes([f"TERM:{t}" for t in key_terms], term_embs)
179
+ print(f" {len(key_terms)} key terms")
180
+
181
+ # Agency edges
182
+ agency_edges = []
183
+ for _, group in df.groupby("law_type"):
184
+ same = group["law_id"].unique()
185
+ for i in range(len(same)):
186
+ for j in range(i + 1, len(same)):
187
+ if same[i] in law_id_to_idx and same[j] in law_id_to_idx:
188
+ agency_edges.append((law_id_to_idx[same[i]], law_id_to_idx[same[j]], 0.3))
189
+ graph.add_edges("agency", agency_edges)
190
+ print(f" Agency edges: {len(agency_edges)}")
191
+
192
+ # Temporal edges
193
+ temporal_edges = []
194
+ for _, group in df.groupby("law_type"):
195
+ yl = group.groupby("year")["law_id"].unique()
196
+ for y1, y2 in zip(sorted(yl.keys()), sorted(yl.keys())[1:]):
197
+ for l1 in yl[y1]:
198
+ for l2 in yl[y2]:
199
+ if l1 in law_id_to_idx and l2 in law_id_to_idx:
200
+ temporal_edges.append((law_id_to_idx[l1], law_id_to_idx[l2], 0.2))
201
+ graph.add_edges("temporal", temporal_edges)
202
+ print(f" Temporal edges: {len(temporal_edges)}")
203
+
204
+ # Semantic edges (chunked k-NN)
205
+ semantic_k = min(config.graph.semantic_knn, N_laws - 1)
206
+ semantic_edges = []
207
+ if N_laws > 1:
208
+ chunk = 64
209
+ for i in range(0, N_laws, chunk):
210
+ end = min(i + chunk, N_laws)
211
+ sim = F.cosine_similarity(law_embs[i:end].unsqueeze(1), law_embs.unsqueeze(0), dim=2)
212
+ for j in range(sim.shape[0]):
213
+ sim[j, i + j] = float("-inf")
214
+ vals, idx = sim.topk(k=semantic_k, dim=1)
215
+ for j in range(sim.shape[0]):
216
+ for kk in range(semantic_k):
217
+ semantic_edges.append((i + j, idx[j, kk].item(), vals[j, kk].item()))
218
+ graph.add_edges("semantic", semantic_edges)
219
+ print(f" Semantic edges: {len(semantic_edges)}")
220
+
221
+ # Cross-reference edges
222
+ cross_ref_edges = []
223
+ for _, row in df.iterrows():
224
+ src = row["law_id"]
225
+ if src not in law_id_to_idx: continue
226
+ for pattern, etype in CROSS_REF_PATTERNS:
227
+ for match in pattern.findall(row["text"]):
228
+ match_str = " ".join(match).lower() if isinstance(match, tuple) else str(match).lower()
229
+ for tgt in law_ids:
230
+ if tgt != src and _law_matches_ref(tgt, match_str):
231
+ cross_ref_edges.append((law_id_to_idx[src], law_id_to_idx[tgt], 0.5))
232
+ break
233
+ graph.add_edges("cross_ref", cross_ref_edges)
234
+ print(f" Cross-ref edges: {len(cross_ref_edges)}")
235
+
236
+ # Term-document edges
237
+ term_doc_edges = []
238
+ law_texts = [(f"{row['title']} {row['text'][:300]}").replace("_", " ")
239
+ for _, row in df.iterrows()]
240
+ vec = TfidfVectorizer(vocabulary=key_terms if key_terms else None)
241
+ try:
242
+ tfidf = vec.fit_transform(law_texts)
243
+ for ti, term in enumerate(key_terms):
244
+ if ti < tfidf.shape[1]:
245
+ col = tfidf[:, ti].toarray().flatten()
246
+ for lp in col.argsort()[::-1][:10]:
247
+ if col[lp] > 0.1 and lp < N_laws:
248
+ term_doc_edges.append((N_laws + ti, lp, float(col[lp])))
249
+ except ValueError:
250
+ pass
251
+ graph.add_edges("semantic", term_doc_edges)
252
+ print(f" Term-doc edges: {len(term_doc_edges)}")
253
+ print(f" Total: {graph.num_nodes} nodes ({N_laws} laws + {len(key_terms)} terms)")
254
+
255
+ return graph, law_id_to_idx
src/telern/config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TELEN configuration."""
2
+ from dataclasses import dataclass, field
3
+ from pathlib import Path
4
+ from typing import List
5
+
6
+ ROOT = Path("E:/law-embedding")
7
+ DATA_DIR = ROOT / "dataset"
8
+ CHECKPOINT_DIR = ROOT / "data" / "checkpoints" / "telen"
9
+
10
+
11
+ @dataclass
12
+ class GraphConfig:
13
+ """Legal Concept Graph configuration."""
14
+ hidden_dim: int = 768
15
+ gnn_layers: int = 3
16
+ gnn_dropout: float = 0.1
17
+ # Edge types
18
+ use_cross_ref_edges: bool = True
19
+ use_agency_edges: bool = True
20
+ use_temporal_edges: bool = True
21
+ use_semantic_edges: bool = True
22
+ semantic_knn: int = 10
23
+ # Concept extraction
24
+ max_concepts_per_article: int = 8
25
+ min_tfidf_score: float = 0.05
26
+
27
+
28
+ @dataclass
29
+ class HyperNetworkConfig:
30
+ """HyperNetwork that generates projection weights from legal state."""
31
+ adaptation_rank: int = 64 # Low-rank adaptation
32
+ hn_hidden_dim: int = 512
33
+ hn_layers: int = 3
34
+ dropout: float = 0.1
35
+ # What the HyperNetwork outputs
36
+ output_shift: bool = True # Ξ”W for projection
37
+ output_bias: bool = True # Ξ”b for projection
38
+ output_variance: bool = True # log σ² for stochastic embedding
39
+ min_variance: float = 0.01 # minimum variance
40
+
41
+
42
+ @dataclass
43
+ class MetaTrainingConfig:
44
+ """Meta-learning training configuration."""
45
+ meta_lr: float = 3e-4
46
+ inner_lr: float = 5e-3
47
+ meta_batch_size: int = 4 # episodes per meta-update
48
+ n_query: int = 32 # query articles per episode
49
+ n_negatives: int = 256 # negative articles per query
50
+ meta_epochs: int = 50
51
+ temperature: float = 0.05
52
+ # Temporal splits for meta-training
53
+ train_split_year: int = 2018
54
+ val_split_year: int = 2020
55
+ # State construction
56
+ max_state_articles: int = 500 # max articles to include in state
57
+ # Stochastic embedding
58
+ kl_weight: float = 0.001 # weight for KL regularization
59
+ n_mc_samples: int = 1 # Monte Carlo samples during training
60
+
61
+
62
+ @dataclass
63
+ class TELENConfig:
64
+ """Full TELEN configuration."""
65
+ backbone: str = "vinai/phobert-base-v2"
66
+ hidden_dim: int = 768
67
+ max_seq_length: int = 480
68
+ graph: GraphConfig = field(default_factory=GraphConfig)
69
+ hypernetwork: HyperNetworkConfig = field(default_factory=HyperNetworkConfig)
70
+ meta: MetaTrainingConfig = field(default_factory=MetaTrainingConfig)
71
+ output_dir: str = str(CHECKPOINT_DIR)
72
+ seed: int = 42
src/telern/evaluate.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation for TELEN: NDCG@k and MRR@k.
3
+
4
+ Metrics:
5
+ - NDCG@3, NDCG@5, NDCG@10
6
+ - MRR@3, MRR@5, MRR@10
7
+
8
+ Baselines:
9
+ - BM25 (lexical)
10
+ - Frozen PhoBERT + mean pooling
11
+ - TELEN (ours)
12
+
13
+ Evaluation setup:
14
+ - Query = article title + first 100 chars of text
15
+ - Relevant = other articles from the SAME law
16
+ - Corpus = all articles from test years (held-out)
17
+ """
18
+
19
+ import math
20
+ import random
21
+ from collections import defaultdict
22
+ from pathlib import Path
23
+ from typing import List, Dict, Tuple
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from tqdm import tqdm
30
+ from sklearn.feature_extraction.text import TfidfVectorizer
31
+ from transformers import AutoModel, AutoTokenizer
32
+
33
+ from .config import TELENConfig, DATA_DIR
34
+ from .model import TELEN, create_telen
35
+ from ..data import load_raw_data, extract_metadata, clean_data
36
+
37
+
38
+ # ═══════════════════════════════════════════════════════════
39
+ # Metrics
40
+ # ═══════════════════════════════════════════════════════════
41
+
42
+ def dcg_at_k(scores: np.ndarray, k: int) -> float:
43
+ """Discounted Cumulative Gain at k."""
44
+ scores = np.asarray(scores)[:k]
45
+ if len(scores) == 0:
46
+ return 0.0
47
+ discounts = np.log2(np.arange(2, len(scores) + 2))
48
+ return np.sum((2.0**scores - 1) / discounts)
49
+
50
+
51
+ def ndcg_at_k(scores: np.ndarray, k: int) -> float:
52
+ """Normalized DCG at k."""
53
+ ideal = np.sort(scores)[::-1]
54
+ dcg_val = dcg_at_k(scores, k)
55
+ idcg_val = dcg_at_k(ideal, k)
56
+ return dcg_val / idcg_val if idcg_val > 0 else 0.0
57
+
58
+
59
+ def mrr_at_k(scores: np.ndarray, k: int) -> float:
60
+ """Mean Reciprocal Rank at k."""
61
+ scores = np.asarray(scores)[:k]
62
+ for rank, s in enumerate(scores, start=1):
63
+ if s > 0:
64
+ return 1.0 / rank
65
+ return 0.0
66
+
67
+
68
+ def compute_metrics(
69
+ relevance_scores: np.ndarray, k_values: List[int] = [3, 5, 10]
70
+ ) -> Dict[str, float]:
71
+ """Compute NDCG@k and MRR@k from relevance scores."""
72
+ metrics = {}
73
+ for k in k_values:
74
+ metrics[f"ndcg@{k}"] = ndcg_at_k(relevance_scores, k)
75
+ metrics[f"mrr@{k}"] = mrr_at_k(relevance_scores, k)
76
+ return metrics
77
+
78
+
79
+ # ═══════════════════════════════════════════════════════════
80
+ # Evaluation
81
+ # ═══════════════════════════════════════════════════════════
82
+
83
+ def prepare_test_data(config: TELENConfig):
84
+ """Prepare test data from held-out years."""
85
+ print("Loading data...")
86
+ df = load_raw_data(str(DATA_DIR / "train-00000-of-00001.parquet"))
87
+ df = extract_metadata(df)
88
+ df = clean_data(df, min_text_len=10)
89
+
90
+ # Test split: articles from test years
91
+ test_years = range(config.meta.val_split_year + 1, 2025)
92
+ test_df = df[df["year"].isin(test_years)].reset_index(drop=True)
93
+
94
+ print(f" Test set: {len(test_df)} articles from {test_df['law_id'].nunique()} laws")
95
+ return test_df
96
+
97
+
98
+ def build_test_queries(test_df: pd.DataFrame, max_queries: int = 500) -> List[Dict]:
99
+ """Build query set from test articles."""
100
+ # Group by law_id
101
+ law_groups = test_df.groupby("law_id")
102
+
103
+ queries = []
104
+ for law_id, group in law_groups:
105
+ articles = group.to_dict("records")
106
+ if len(articles) < 3: # Need at least 1 query + 2 relevant
107
+ continue
108
+ # Use each article as a potential query
109
+ for article in articles[:2]: # Max 2 queries per law
110
+ queries.append({
111
+ "query_id": article["id"],
112
+ "query_text": f"{article['title']}: {article['text'][:500]}",
113
+ "query_full": article["text"],
114
+ "law_id": law_id,
115
+ })
116
+
117
+ if len(queries) > max_queries:
118
+ queries = random.sample(queries, max_queries)
119
+
120
+ print(f" Queries: {len(queries)}")
121
+ return queries
122
+
123
+
124
+ def build_test_corpus(test_df: pd.DataFrame) -> List[Dict]:
125
+ """Build corpus of all test articles for retrieval."""
126
+ corpus = []
127
+ for _, row in test_df.iterrows():
128
+ corpus.append({
129
+ "article_id": row["id"],
130
+ "text": f"{row['title']}: {row['text'][:500]}",
131
+ "law_id": row["law_id"],
132
+ })
133
+ print(f" Corpus: {len(corpus)} documents")
134
+ return corpus
135
+
136
+
137
+ def evaluate_telen(
138
+ model: TELEN,
139
+ queries: List[Dict],
140
+ corpus: List[Dict],
141
+ batch_size: int = 64,
142
+ ) -> Dict[str, float]:
143
+ """
144
+ Evaluate TELEN on retrieval metrics.
145
+
146
+ For each query, rank all corpus documents by cosine similarity.
147
+ Relevance = article is from the same law.
148
+ """
149
+ device = next(model.parameters()).device
150
+ model.eval()
151
+
152
+ # Encode corpus
153
+ print(" Encoding corpus...")
154
+ corpus_embeddings = []
155
+ corpus_ids = [doc["article_id"] for doc in corpus]
156
+ corpus_law_ids = [doc["law_id"] for doc in corpus]
157
+
158
+ for i in tqdm(range(0, len(corpus), batch_size), desc=" Corpus"):
159
+ batch = corpus[i:i + batch_size]
160
+ texts = [doc["text"] for doc in batch]
161
+ with torch.no_grad():
162
+ result = model(texts, use_stochastic=False)
163
+ corpus_embeddings.append(result["embeddings"].cpu())
164
+
165
+ corpus_embeddings = torch.cat(corpus_embeddings, dim=0) # [N_corpus, d]
166
+ print(f" Corpus embeddings: {corpus_embeddings.shape}")
167
+
168
+ # Evaluate each query
169
+ all_metrics = defaultdict(list)
170
+
171
+ print(" Evaluating queries...")
172
+ for query in tqdm(queries, desc=" Queries"):
173
+ # Encode query
174
+ with torch.no_grad():
175
+ result = model([query["query_text"]], use_stochastic=False)
176
+ query_emb = result["embeddings"].cpu() # [1, d]
177
+
178
+ # Cosine similarity with all corpus
179
+ sim = F.cosine_similarity(
180
+ query_emb, corpus_embeddings
181
+ ).numpy() # [N_corpus]
182
+
183
+ # Build relevance scores (1.0 if same law, 0.0 otherwise)
184
+ relevance = np.array([
185
+ 1.0 if corpus_law_ids[i] == query["law_id"] else 0.0
186
+ for i in range(len(corpus))
187
+ ])
188
+
189
+ # Rank by similarity and compute metrics
190
+ sorted_idx = sim.argsort()[::-1]
191
+ sorted_relevance = relevance[sorted_idx]
192
+
193
+ # Remove the query itself from results
194
+ query_idx_in_corpus = None
195
+ for i, cid in enumerate(corpus_ids):
196
+ if cid == query["query_id"]:
197
+ query_idx_in_corpus = i
198
+ break
199
+
200
+ if query_idx_in_corpus is not None:
201
+ # Remove self-match
202
+ mask = sorted_idx != query_idx_in_corpus
203
+ sorted_relevance = sorted_relevance[mask]
204
+
205
+ # Compute metrics
206
+ for k in [3, 5, 10]:
207
+ metrics = compute_metrics(sorted_relevance[:k], [k])
208
+ for metric_name, value in metrics.items():
209
+ all_metrics[metric_name].append(value)
210
+
211
+ # Average over queries
212
+ results = {name: np.mean(scores) for name, scores in all_metrics.items()}
213
+ return results
214
+
215
+
216
+ # ═══════════════════════════════════════════════════════════
217
+ # Baselines
218
+ # ═══════════════════════════════════════════════════════════
219
+
220
+ class BM25Baseline:
221
+ """Simple BM25 implementation using TF-IDF as approximation."""
222
+
223
+ def __init__(self):
224
+ self.vectorizer = TfidfVectorizer(
225
+ max_features=10000,
226
+ ngram_range=(1, 2),
227
+ sublinear_tf=True,
228
+ )
229
+
230
+ def fit(self, corpus: List[Dict]):
231
+ self.corpus = corpus
232
+ self.doc_texts = [doc["text"] for doc in corpus]
233
+ self.doc_ids = [doc["article_id"] for doc in corpus]
234
+ self.doc_law_ids = [doc["law_id"] for doc in corpus]
235
+ self.tfidf_matrix = self.vectorizer.fit_transform(self.doc_texts)
236
+
237
+ def search(self, query_text: str, k: int = 100) -> np.ndarray:
238
+ query_vec = self.vectorizer.transform([query_text])
239
+ scores = (self.tfidf_matrix @ query_vec.T).toarray().flatten()
240
+ sorted_idx = scores.argsort()[::-1]
241
+ return sorted_idx
242
+
243
+
244
+ def evaluate_bm25(queries: List[Dict], corpus: List[Dict]) -> Dict[str, float]:
245
+ """Evaluate BM25 baseline."""
246
+ print(" Building BM25 index...")
247
+ bm25 = BM25Baseline()
248
+ bm25.fit(corpus)
249
+
250
+ all_metrics = defaultdict(list)
251
+
252
+ print(" Evaluating queries...")
253
+ for query in tqdm(queries, desc=" Queries"):
254
+ sorted_idx = bm25.search(query["query_text"], k=100)
255
+
256
+ # Remove self
257
+ doc_ids = bm25.doc_ids
258
+ query_idx = None
259
+ for i, did in enumerate(doc_ids):
260
+ if did == query["query_id"]:
261
+ query_idx = i
262
+ break
263
+
264
+ relevance = np.array([
265
+ 1.0 if bm25.doc_law_ids[i] == query["law_id"] else 0.0
266
+ for i in sorted_idx
267
+ ])
268
+
269
+ if query_idx is not None:
270
+ pos = np.where(sorted_idx == query_idx)[0]
271
+ if len(pos) > 0:
272
+ relevance = np.delete(relevance, pos[0])
273
+
274
+ for k in [3, 5, 10]:
275
+ valid_rel = relevance[:k]
276
+ metrics = compute_metrics(valid_rel, [k])
277
+ for name, val in metrics.items():
278
+ all_metrics[name].append(val)
279
+
280
+ return {name: np.mean(scores) for name, scores in all_metrics.items()}
281
+
282
+
283
+ class FrozenPhoBERT:
284
+ """Frozen PhoBERT with mean pooling baseline."""
285
+
286
+ def __init__(self, model_name: str = "vinai/phobert-base-v2"):
287
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
288
+ self.model = AutoModel.from_pretrained(model_name)
289
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
290
+ self.model = self.model.to(self.device)
291
+ self.model.eval()
292
+
293
+ def encode(self, texts: List[str], batch_size: int = 64) -> torch.Tensor:
294
+ embeddings = []
295
+ for i in range(0, len(texts), batch_size):
296
+ batch = texts[i:i + batch_size]
297
+ encoded = self.tokenizer(
298
+ batch, padding=True, truncation=True,
299
+ max_length=480, return_tensors="pt",
300
+ )
301
+ input_ids = encoded["input_ids"].to(self.device)
302
+ attention_mask = encoded["attention_mask"].to(self.device)
303
+ with torch.no_grad():
304
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
305
+ hidden = outputs.last_hidden_state
306
+ # Mean pooling
307
+ mask = attention_mask.unsqueeze(-1).float()
308
+ pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
309
+ pooled = F.normalize(pooled, p=2, dim=1)
310
+ embeddings.append(pooled.cpu())
311
+ return torch.cat(embeddings, dim=0)
312
+
313
+
314
+ def evaluate_frozen_phobert(
315
+ queries: List[Dict], corpus: List[Dict]
316
+ ) -> Dict[str, float]:
317
+ """Evaluate frozen PhoBERT baseline."""
318
+ print(" Loading frozen PhoBERT...")
319
+ encoder = FrozenPhoBERT()
320
+
321
+ print(" Encoding corpus...")
322
+ corpus_texts = [doc["text"] for doc in corpus]
323
+ corpus_embeddings = encoder.encode(corpus_texts)
324
+ corpus_ids = [doc["article_id"] for doc in corpus]
325
+ corpus_law_ids = [doc["law_id"] for doc in corpus]
326
+
327
+ all_metrics = defaultdict(list)
328
+
329
+ print(" Evaluating queries...")
330
+ query_texts = [q["query_text"] for q in queries]
331
+ query_embeddings = encoder.encode(query_texts)
332
+
333
+ for i, query in enumerate(tqdm(queries, desc=" Queries")):
334
+ query_emb = query_embeddings[i:i+1]
335
+ sim = F.cosine_similarity(query_emb, corpus_embeddings).numpy()
336
+
337
+ relevance = np.array([
338
+ 1.0 if corpus_law_ids[j] == query["law_id"] else 0.0
339
+ for j in range(len(corpus))
340
+ ])
341
+
342
+ sorted_idx = sim.argsort()[::-1]
343
+ sorted_relevance = relevance[sorted_idx]
344
+
345
+ # Remove self
346
+ for j, cid in enumerate(corpus_ids):
347
+ if cid == query["query_id"]:
348
+ mask = sorted_idx != j
349
+ sorted_relevance = sorted_relevance[mask]
350
+ break
351
+
352
+ for k in [3, 5, 10]:
353
+ metrics = compute_metrics(sorted_relevance[:k], [k])
354
+ for name, val in metrics.items():
355
+ all_metrics[name].append(val)
356
+
357
+ return {name: np.mean(scores) for name, scores in all_metrics.items()}
358
+
359
+
360
+ # ═══════════════════════════════════════════════════════════
361
+ # Main evaluation entry point
362
+ # ═══════════════════════════════════════════════════════════
363
+
364
+ def run_full_evaluation(
365
+ config: TELENConfig = None,
366
+ checkpoint_path: str = None,
367
+ ):
368
+ """Run complete evaluation with all baselines and TELEN."""
369
+ if config is None:
370
+ config = TELENConfig()
371
+
372
+ random.seed(config.seed)
373
+ np.random.seed(config.seed)
374
+
375
+ print("=" * 60)
376
+ print("TELEN Evaluation")
377
+ print("=" * 60)
378
+
379
+ # Prepare test data
380
+ test_df = prepare_test_data(config)
381
+ queries = build_test_queries(test_df, max_queries=300)
382
+ corpus = build_test_corpus(test_df)
383
+
384
+ k_values = [3, 5, 10]
385
+ results = {}
386
+
387
+ # --- Baseline 1: BM25 ---
388
+ print("\n" + "=" * 40)
389
+ print("[1/3] BM25 Baseline")
390
+ print("=" * 40)
391
+ results["BM25"] = evaluate_bm25(queries, corpus)
392
+ for m in k_values:
393
+ print(f" NDCG@{m}: {results['BM25'][f'ndcg@{m}']:.4f} | MRR@{m}: {results['BM25'][f'mrr@{m}']:.4f}")
394
+
395
+ # --- Baseline 2: Frozen PhoBERT ---
396
+ print("\n" + "=" * 40)
397
+ print("[2/3] Frozen PhoBERT Baseline")
398
+ print("=" * 40)
399
+ results["PhoBERT"] = evaluate_frozen_phobert(queries, corpus)
400
+ for m in k_values:
401
+ print(f" NDCG@{m}: {results['PhoBERT'][f'ndcg@{m}']:.4f} | MRR@{m}: {results['PhoBERT'][f'mrr@{m}']:.4f}")
402
+
403
+ # --- TELEN ---
404
+ print("\n" + "=" * 40)
405
+ print("[3/3] TELEN (Ours)")
406
+ print("=" * 40)
407
+
408
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
409
+ model = create_telen(config)
410
+ model = model.to(device)
411
+
412
+ # Load checkpoint if provided
413
+ if checkpoint_path and Path(checkpoint_path).exists():
414
+ print(f" Loading checkpoint: {checkpoint_path}")
415
+ ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
416
+ model.hypernetwork.load_state_dict(ckpt["hypernetwork"])
417
+ model.state_encoder.load_state_dict(ckpt["state_encoder"])
418
+ model.base_projection.load_state_dict(ckpt["base_projection"])
419
+ model.attn_query.data.copy_(ckpt["attn_query"])
420
+ # Rebuild graph
421
+ model.build_graph(test_df[test_df["year"] <= config.meta.train_split_year])
422
+
423
+ results["TELEN"] = evaluate_telen(model, queries, corpus)
424
+ for m in k_values:
425
+ print(f" NDCG@{m}: {results['TELEN'][f'ndcg@{m}']:.4f} | MRR@{m}: {results['TELEN'][f'mrr@{m}']:.4f}")
426
+
427
+ # --- Summary ---
428
+ print("\n" + "=" * 60)
429
+ print("SUMMARY")
430
+ print("=" * 60)
431
+ header = f"{'Method':<20}"
432
+ for m in k_values:
433
+ header += f" {'NDCG@'+str(m):>12} {'MRR@'+str(m):>12}"
434
+ print(header)
435
+ print("-" * len(header))
436
+
437
+ for method in ["BM25", "PhoBERT", "TELEN"]:
438
+ row = f"{method:<20}"
439
+ for m in k_values:
440
+ row += f" {results[method][f'ndcg@{m}']:>12.4f} {results[method][f'mrr@{m}']:>12.4f}"
441
+ print(row)
442
+
443
+ # Relative improvement
444
+ print("\n--- Improvement over PhoBERT ---")
445
+ for m in k_values:
446
+ ndcg_imp = (results["TELEN"][f"ndcg@{m}"] / max(results["PhoBERT"][f"ndcg@{m}"], 1e-6) - 1) * 100
447
+ mrr_imp = (results["TELEN"][f"mrr@{m}"] / max(results["PhoBERT"][f"mrr@{m}"], 1e-6) - 1) * 100
448
+ print(f" NDCG@{m}: {ndcg_imp:+.1f}% | MRR@{m}: {mrr_imp:+.1f}%")
449
+
450
+ return results
451
+
452
+
453
+ if __name__ == "__main__":
454
+ run_full_evaluation()
src/telern/hypernetwork.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HyperNetwork for TELEN.
3
+
4
+ Core innovation: Instead of learning fixed projection weights, the HyperNetwork
5
+ GENERATES the projection function from the current legal corpus state.
6
+
7
+ When new laws arrive β†’ state vector changes β†’ HyperNetwork produces new weights
8
+ β†’ embedding space adapts WITHOUT retraining.
9
+
10
+ Additionally outputs variance for stochastic embeddings (uncertainty-aware retrieval).
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+
18
+ class HyperNetwork(nn.Module):
19
+ """
20
+ Generates embedding projection parameters from a legal state vector.
21
+
22
+ Given state vector s ∈ R^d, produces:
23
+ - Ξ”W: low-rank projection shift (weighted sum of learned rank-1 bases)
24
+ - Ξ”b: bias shift (weighted sum of learned bias bases)
25
+ - log_σ²: per-dimension log-variance for stochastic embedding
26
+
27
+ Architecture: Instead of generating giant parameter matrices directly,
28
+ we store a compact set of learned basis vectors and use the HyperNetwork
29
+ to generate ONLY the combination weights. This is parameter-efficient
30
+ and forces generalization.
31
+ """
32
+
33
+ def __init__(self, config):
34
+ super().__init__()
35
+ self.config = config
36
+ hn = config.hypernetwork
37
+ d = config.hidden_dim
38
+ r = hn.adaptation_rank
39
+ hidden = hn.hn_hidden_dim
40
+
41
+ # Shared trunk: state β†’ latent code
42
+ self.trunk = nn.Sequential(
43
+ nn.Linear(d, hidden),
44
+ nn.ReLU(),
45
+ nn.Dropout(hn.dropout),
46
+ nn.Linear(hidden, hidden),
47
+ nn.ReLU(),
48
+ nn.Dropout(hn.dropout),
49
+ nn.Linear(hidden, hidden),
50
+ nn.LayerNorm(hidden),
51
+ )
52
+
53
+ # Modulator: latent β†’ combination weights for all outputs
54
+ self.modulator = nn.Linear(hidden, 2 * r + r + 1) # A_weights + B_weights + bias_weights + var_context
55
+
56
+ # === Learned basis vectors (stored, not generated) ===
57
+ # For Ξ”W = Ξ£_i w^A_i * (u_i βŠ— v_i^T) where u_i, v_i ∈ R^d
58
+ self.basis_u = nn.Parameter(torch.randn(r, d) * 0.01) # [r, d]
59
+ self.basis_v = nn.Parameter(torch.randn(r, d) * 0.01) # [r, d]
60
+
61
+ # For Ξ”b = Ξ£_i w^b_i * b_i where b_i ∈ R^d
62
+ self.basis_b = nn.Parameter(torch.randn(r, d) * 0.01) # [r, d]
63
+
64
+ # Variance head
65
+ if hn.output_variance:
66
+ self.head_logvar = nn.Sequential(
67
+ nn.Linear(hidden, hidden),
68
+ nn.Tanh(),
69
+ nn.Linear(hidden, d),
70
+ )
71
+ else:
72
+ self.head_logvar = None
73
+
74
+ def forward(self, state_vector: torch.Tensor) -> dict:
75
+ """
76
+ Args:
77
+ state_vector: [d] or [B, d] summarizing current legal landscape
78
+
79
+ Returns dict with keys:
80
+ "shift_matrix": [d, d] or [B, d, d] rank-r projection shift
81
+ "bias": [d] or [B, d] bias shift
82
+ "log_variance": [d] or [B, d] log variance for stochastic embedding
83
+ """
84
+ squeeze = state_vector.dim() == 1
85
+ if squeeze:
86
+ state_vector = state_vector.unsqueeze(0) # [1, d]
87
+
88
+ B, d = state_vector.shape
89
+ r = self.config.hypernetwork.adaptation_rank
90
+
91
+ # Shared representation
92
+ h = self.trunk(state_vector) # [B, hidden]
93
+ modulated = self.modulator(h) # [B, 2r + r + 1]
94
+
95
+ # Split modulation weights
96
+ w_A = modulated[:, :r] # [B, r]
97
+ w_B = modulated[:, r:2*r] # [B, r]
98
+ w_bias = modulated[:, 2*r:3*r] # [B, r]
99
+
100
+ # Build shift matrix: Ξ”W = Ξ£_i w^A_i * (u_i βŠ— v_i^T)
101
+ # Weighted combination of basis vectors
102
+ u_combined = w_A @ self.basis_u # [B, d]
103
+ v_combined = w_B @ self.basis_v # [B, d]
104
+ shift = torch.bmm(
105
+ u_combined.unsqueeze(2), # [B, d, 1]
106
+ v_combined.unsqueeze(1), # [B, 1, d]
107
+ ) # [B, d, d]
108
+ # Low-rank: this is rank-1. For rank r, generate r outer products and sum.
109
+ # Simple yet effective: use weighted sum of r rank-1 components
110
+ shift = shift.squeeze(0) if B == 1 else shift # [d, d] or [B, d, d]
111
+ if B == 1:
112
+ shift = shift.unsqueeze(0)
113
+
114
+ # Actually let's do proper rank-r: sum over rank dimension
115
+ # w_A: [B, r], basis_u: [r, d]
116
+ # For each rank i: w_A[:, i:i+1] * (basis_u[i:i+1]^T @ basis_v[i:i+1])
117
+ # = Ξ£_i (w_A[:, i] * basis_u[i]) βŠ— (w_B[:, i] * basis_v[i])
118
+ u_weighted = (w_A.unsqueeze(2) * self.basis_u.unsqueeze(0)) # [B, r, d]
119
+ v_weighted = (w_B.unsqueeze(2) * self.basis_v.unsqueeze(0)) # [B, r, d]
120
+ shift_ranked = torch.einsum("brd,bre->brde", u_weighted, v_weighted) # [B, r, d, d]
121
+ shift = shift_ranked.sum(dim=1) # [B, d, d]
122
+
123
+ # Bias
124
+ bias = (w_bias.unsqueeze(2) * self.basis_b.unsqueeze(0)).sum(dim=1) # [B, d]
125
+
126
+ result = {"shift_matrix": shift, "bias": bias}
127
+
128
+ # Log variance
129
+ if self.head_logvar is not None:
130
+ logvar = self.head_logvar(h)
131
+ logvar = torch.clamp(logvar, min=-5.0, max=2.0)
132
+ result["log_variance"] = logvar
133
+ else:
134
+ result["log_variance"] = torch.full((B, d), -3.0, device=h.device)
135
+
136
+ if squeeze:
137
+ result = {k: v.squeeze(0) for k, v in result.items()}
138
+
139
+ return result
140
+
141
+
142
+ class StateEncoder(nn.Module):
143
+ """
144
+ Encodes the legal concept graph into a compact state vector.
145
+
146
+ This is separate from the HyperNetwork so the graph computation
147
+ can be cached and only updated when the graph changes.
148
+ """
149
+
150
+ def __init__(self, dim: int):
151
+ super().__init__()
152
+ self.state_proj = nn.Sequential(
153
+ nn.Linear(dim, dim * 2),
154
+ nn.ReLU(),
155
+ nn.Dropout(0.1),
156
+ nn.Linear(dim * 2, dim),
157
+ nn.LayerNorm(dim),
158
+ )
159
+
160
+ def forward(self, node_embeddings: torch.Tensor, node_weights: torch.Tensor = None) -> torch.Tensor:
161
+ """
162
+ Args:
163
+ node_embeddings: [N, d] refined node embeddings from GNN
164
+ node_weights: [N] optional attention weights
165
+
166
+ Returns:
167
+ state_vector: [d] summarizing the legal landscape
168
+ """
169
+ if node_weights is None:
170
+ # Equal weight if none provided
171
+ node_weights = torch.ones(
172
+ node_embeddings.shape[0], device=node_embeddings.device
173
+ )
174
+ node_weights = F.softmax(node_weights, dim=0)
175
+ pooled = (node_embeddings * node_weights.unsqueeze(1)).sum(dim=0)
176
+ return self.state_proj(pooled)
src/telern/model.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TELEN: Temporal Evolving Legal Embedding Network.
3
+
4
+ Bi-encoder backbone + Legal Concept Graph + HyperNetwork projection.
5
+ Embedding space adapts dynamically to the legal corpus state.
6
+ """
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from transformers import AutoModel, AutoTokenizer
11
+ from pyvi import ViTokenizer
12
+
13
+ from .config import TELENConfig
14
+ from .hypernetwork import StateEncoder, HyperNetwork
15
+ from .concept_graph import build_concept_graph
16
+
17
+
18
+ def wseg(text):
19
+ return ViTokenizer.tokenize(text.replace("_", " "))
20
+
21
+
22
+ class BiEncoder(nn.Module):
23
+ """Vietnamese bi-encoder backbone with attention pooling."""
24
+
25
+ def __init__(self, model_name="bkai-foundation-models/vietnamese-bi-encoder"):
26
+ super().__init__()
27
+ self.model = AutoModel.from_pretrained(model_name)
28
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ self.dim = self.model.config.hidden_size
30
+ self.attn_query = nn.Parameter(torch.randn(self.dim))
31
+ self.scale = self.dim ** 0.5
32
+
33
+ def forward(self, texts, max_len=480):
34
+ segmented = [wseg(t) for t in texts]
35
+ enc = self.tokenizer(segmented, padding=True, truncation=True,
36
+ max_length=max_len, return_tensors="pt")
37
+ input_ids = enc["input_ids"].to(self.attn_query.device)
38
+ mask = enc["attention_mask"].to(self.attn_query.device)
39
+ hidden = self.model(input_ids=input_ids, attention_mask=mask).last_hidden_state
40
+ scores = torch.einsum("bsd,d->bs", hidden, self.attn_query) / self.scale
41
+ scores = scores.masked_fill(mask == 0, float("-1e9"))
42
+ weights = F.softmax(scores, dim=1)
43
+ return torch.einsum("bsd,bs->bd", hidden, weights)
44
+
45
+
46
+ class TELEN(nn.Module):
47
+ """Temporal Evolving Legal Embedding Network."""
48
+
49
+ def __init__(self, config: TELENConfig):
50
+ super().__init__()
51
+ self.config = config
52
+ d = config.hidden_dim
53
+
54
+ # Bi-encoder backbone (frozen)
55
+ self.encoder = BiEncoder()
56
+ for p in self.encoder.parameters():
57
+ p.requires_grad = False
58
+
59
+ # Projection
60
+ self.projection = nn.Sequential(nn.Linear(d, d), nn.Tanh())
61
+ self.proj_norm = nn.LayerNorm(d)
62
+ self.attn_query = nn.Parameter(torch.randn(d))
63
+
64
+ # Graph
65
+ self.concept_graph = None
66
+ self.law_id_to_idx = None
67
+
68
+ # HyperNetwork
69
+ self.state_encoder = StateEncoder(d)
70
+ self.hypernetwork = HyperNetwork(config)
71
+
72
+ def _pool(self, hidden, mask):
73
+ """Attention-weighted pooling (for pre-tokenized inputs)."""
74
+ scores = torch.einsum("bsd,d->bs", hidden, self.attn_query) / (self.config.hidden_dim ** 0.5)
75
+ scores = scores.masked_fill(mask == 0, float("-1e9"))
76
+ weights = F.softmax(scores, dim=1)
77
+ return torch.einsum("bsd,bs->bd", hidden, weights)
78
+
79
+ def encode_text(self, texts):
80
+ return self.encoder(texts, max_len=self.config.max_seq_length)
81
+
82
+ def get_state_vector(self):
83
+ if self.concept_graph is None or self.concept_graph.num_nodes == 0:
84
+ return torch.zeros(self.config.hidden_dim, device=self.attn_query.device)
85
+ refined = self.concept_graph.forward()
86
+ return self.state_encoder(refined)
87
+
88
+ def adapt_embedding(self, raw, state_vec):
89
+ base = self.projection(raw)
90
+ hn = self.hypernetwork(state_vec)
91
+ shift = raw @ hn["shift_matrix"].T + hn["bias"]
92
+ mean = F.normalize(self.proj_norm(base + shift), p=2, dim=1)
93
+ result = {"mean": mean, "log_variance": hn.get("log_variance")}
94
+ if self.config.hypernetwork.output_variance:
95
+ noise = 0.1 * hn["log_variance"].exp().clamp(min=0.001, max=0.25).sqrt().clamp(max=0.5)
96
+ result["sample"] = F.normalize(mean + torch.randn_like(mean) * noise, p=2, dim=1)
97
+ else:
98
+ result["sample"] = mean
99
+ return result
100
+
101
+ def forward(self, texts, use_stochastic=False):
102
+ raw = self.encode_text(texts)
103
+ state = self.get_state_vector()
104
+ adapted = self.adapt_embedding(raw, state)
105
+ return {
106
+ "embeddings": adapted["sample"] if use_stochastic else adapted["mean"],
107
+ "mean": adapted["mean"],
108
+ "log_variance": adapted.get("log_variance"),
109
+ "state_vector": state,
110
+ }
111
+
112
+ def build_graph(self, df):
113
+ self.concept_graph, self.law_id_to_idx = build_concept_graph(
114
+ df, lambda t: self.encode_text([t])[0].detach(), self.config,
115
+ )
116
+ self.concept_graph = self.concept_graph.to(self.attn_query.device)
117
+
118
+ def add_law(self, law_id, articles):
119
+ if self.concept_graph is None: return
120
+ if articles:
121
+ emb = self.encode_text(articles[:5]).mean(dim=0)
122
+ new_idx = self.concept_graph.num_nodes
123
+ self.concept_graph.add_nodes([law_id], emb.unsqueeze(0))
124
+ existing = self.concept_graph.node_embeddings[:-1]
125
+ if len(existing) > 0:
126
+ sim = F.cosine_similarity(emb.unsqueeze(0), existing)
127
+ _, top = sim.topk(k=min(10, len(existing)))
128
+ self.concept_graph.add_edges("semantic",
129
+ [(new_idx, i.item(), sim[i].item()) for i in top])
130
+
131
+
132
+ def create_model(config: TELENConfig) -> TELEN:
133
+ return TELEN(config)
train.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TELEN: Temporal Evolving Legal Embedding Network β€” Training Script.
3
+
4
+ Stages:
5
+ 1. Contrastive pretraining (5 epochs) β€” train projection head
6
+ 2. Meta-training (50 epochs) β€” train HyperNetwork + State Encoder
7
+
8
+ Usage:
9
+ python train.py
10
+ """
11
+ import sys, os, math, random
12
+ from pathlib import Path
13
+ from collections import defaultdict
14
+ import numpy as np
15
+ import pandas as pd
16
+ import torch, torch.nn as nn, torch.nn.functional as F
17
+ from torch.utils.data import DataLoader, Dataset
18
+ from tqdm import tqdm
19
+ from transformers import AutoTokenizer
20
+ from pyvi import ViTokenizer
21
+
22
+ sys.path.insert(0, ".")
23
+ from src.telern.config import TELENConfig, DATA_DIR
24
+ from src.telern.model import TELEN, create_model
25
+ from src.data import load_raw_data, extract_metadata, clean_data
26
+
27
+ SEED = 42
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ # ═══════════════════════════════════════════════════════════
31
+ # Data
32
+ # ═══════════════════════════════════════════════════════════
33
+ def prepare_data(config):
34
+ df = load_raw_data(str(DATA_DIR / "train-00000-of-00001.parquet"))
35
+ df = extract_metadata(df); df = clean_data(df, min_text_len=10)
36
+ articles_by_law = defaultdict(list)
37
+ laws_by_year = defaultdict(list)
38
+ for _, row in df.iterrows():
39
+ articles_by_law[row["law_id"]].append({
40
+ "id": row["id"], "title": row["title"], "text": row["text"],
41
+ "law_type": row["law_type"], "year": row["year"],
42
+ })
43
+ for law_id in articles_by_law:
44
+ laws_by_year[articles_by_law[law_id][0]["year"]].append(law_id)
45
+ all_years = sorted(laws_by_year.keys())
46
+ train_years = [y for y in all_years if y <= config.meta.train_split_year]
47
+ val_years = [y for y in all_years if config.meta.train_split_year < y <= config.meta.val_split_year]
48
+ test_years = [y for y in all_years if y > config.meta.val_split_year]
49
+ return articles_by_law, laws_by_year, train_years, val_years, test_years, df
50
+
51
+ # ═══════════════════════════════════════════════════════════
52
+ # Contrastive Dataset
53
+ # ═══════════════════════════════════════════════════════════
54
+ class ContrastiveDataset(Dataset):
55
+ def __init__(self, df, tokenizer, max_len=480):
56
+ self.df = df.reset_index(drop=True)
57
+ self.tokenizer = tokenizer
58
+ self.max_len = max_len
59
+ self.law_groups = self.df.groupby("law_id")
60
+ self.law_ids = list(self.law_groups.groups.keys())
61
+
62
+ def __len__(self): return len(self.df)
63
+
64
+ def __getitem__(self, idx):
65
+ row = self.df.iloc[idx]; law_id = row["law_id"]
66
+ wseg = lambda t: ViTokenizer.tokenize(t.replace("_", " "))
67
+ anchor = wseg(f"{row['title']}: {row['text'][:400]}")
68
+ group_idx = self.law_groups.groups[law_id]
69
+ others = [i for i in group_idx if i != idx]
70
+ pos_row = self.df.iloc[random.choice(others)] if others else row
71
+ positive = wseg(f"{pos_row['title']}: {pos_row['text'][:400]}")
72
+ neg_law = random.choice([l for l in self.law_ids if l != law_id])
73
+ neg_row = self.df.iloc[random.choice(list(self.law_groups.groups[neg_law]))]
74
+ negative = wseg(f"{neg_row['title']}: {neg_row['text'][:400]}")
75
+
76
+ def tok(t): return self.tokenizer(t, truncation=True, max_length=self.max_len, padding="max_length", return_tensors="pt")
77
+ return {f"{k}_{s}": tok(t)[k].squeeze(0)
78
+ for t, s in [(anchor,"a"),(positive,"p"),(negative,"n")]
79
+ for k in ["input_ids","attention_mask"]}
80
+
81
+ # ═══════════════════════════════════════════════════════════
82
+ # Stage 1: Contrastive Pretraining
83
+ # ═══════════════════════════════════════════════════════════
84
+ def contrastive_pretrain(model, df, config, epochs=5, batch_size=24, lr=3e-5):
85
+ tokenizer = model.encoder.tokenizer
86
+ dataset = ContrastiveDataset(df, tokenizer, config.max_seq_length)
87
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
88
+
89
+ trainable = list(model.base_projection.parameters()) + [model.attn_query]
90
+ opt = torch.optim.AdamW(trainable, lr=lr, weight_decay=0.01)
91
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs * len(loader))
92
+
93
+ print(f" Contrastive pretraining: {epochs} epochs, {len(loader)} batches")
94
+ model.train(); model.encoder.model.eval()
95
+
96
+ for epoch in range(epochs):
97
+ total = 0.0
98
+ for batch in tqdm(loader, desc=f" Epoch {epoch+1}/{epochs}"):
99
+ a_ids=batch["input_ids_a"].to(device); a_mask=batch["attention_mask_a"].to(device)
100
+ p_ids=batch["input_ids_p"].to(device); p_mask=batch["attention_mask_p"].to(device)
101
+ n_ids=batch["input_ids_n"].to(device); n_mask=batch["attention_mask_n"].to(device)
102
+
103
+ with torch.no_grad():
104
+ ah=model._pool(model.encoder.model(input_ids=a_ids,attention_mask=a_mask).last_hidden_state,a_mask)
105
+ ph=model._pool(model.encoder.model(input_ids=p_ids,attention_mask=p_mask).last_hidden_state,p_mask)
106
+ nh=model._pool(model.encoder.model(input_ids=n_ids,attention_mask=n_mask).last_hidden_state,n_mask)
107
+
108
+ ae=F.normalize(model.base_projection(ah),p=2,dim=1)
109
+ pe=F.normalize(model.base_projection(ph),p=2,dim=1)
110
+ ne=F.normalize(model.base_projection(nh),p=2,dim=1)
111
+
112
+ trip=F.relu(0.3-(ae*pe).sum(1)+(ae*ne).sum(1)).mean()
113
+ sim=ae@torch.cat([ae,pe,ne],dim=0).T/0.05
114
+ infonce=F.cross_entropy(sim,torch.arange(len(a_ids),device=device)+len(a_ids))
115
+ loss=trip+0.5*infonce
116
+
117
+ opt.zero_grad(); loss.backward()
118
+ torch.nn.utils.clip_grad_norm_(trainable,1.0); opt.step(); sched.step()
119
+ total+=loss.item()
120
+ print(f" Epoch {epoch+1} avg loss: {total/len(loader):.4f}")
121
+ print(" Contrastive pretraining complete!")
122
+ return model
123
+
124
+ # ═══════════════════════════════════════════════════════════
125
+ # Episode building
126
+ # ═══════════════════════════════════════════════════════════
127
+ def build_episode(articles_by_law, laws_by_year, state_years, query_year, config):
128
+ mc = config.meta
129
+ q_laws = laws_by_year.get(query_year, [])
130
+ if len(q_laws) < 5: return None
131
+ sampled = random.sample(q_laws, min(mc.n_query // 4, len(q_laws)))
132
+ queries, positives, q_types = [], [], set()
133
+ for lid in sampled:
134
+ arts = articles_by_law[lid]
135
+ if len(arts) < 2: continue
136
+ qi, pi = random.sample(range(len(arts)), 2)
137
+ queries.append(arts[qi]); positives.append(arts[pi])
138
+ q_types.add(arts[qi]["law_type"])
139
+ if len(queries) < 4: return None
140
+
141
+ hard_neg, rand_neg = [], []
142
+ for lid in q_laws:
143
+ if lid in sampled: continue
144
+ for a in articles_by_law[lid]:
145
+ if a["law_type"] in q_types: hard_neg.append(a)
146
+ else: rand_neg.append(a)
147
+
148
+ nh = min(mc.n_negatives // 2, len(hard_neg))
149
+ nr = min(mc.n_negatives - nh, len(rand_neg))
150
+ negatives = (random.sample(hard_neg, nh) if nh > 0 else []) + (random.sample(rand_neg, nr) if nr > 0 else [])
151
+ if len(negatives) < 4: return None
152
+ return {"queries": queries, "positives": positives, "negatives": negatives}
153
+
154
+ # ═══════════════════════════════════════════════════════════
155
+ # Stage 2: Meta-Training
156
+ # ═══════════════════════════════════════════════════════════
157
+ def compute_loss(model, q_texts, p_texts, n_texts, state_vec, temp=0.05):
158
+ n_q, n_p = len(q_texts), len(p_texts)
159
+ if n_q == 0 or n_p == 0:
160
+ return torch.tensor(0.0, device=device, requires_grad=True)
161
+ all_t = q_texts + p_texts + n_texts
162
+ raw = model.encode_text(all_t)
163
+ adapted = model.adapt_embedding(raw, state_vec)
164
+ emb = adapted["mean"]
165
+ qe = emb[:n_q]; pe = emb[n_q:n_q+n_p]; ne = emb[n_q+n_p:]
166
+
167
+ if n_q == n_p:
168
+ sim = torch.cat([(qe*pe).sum(1).unsqueeze(1)/temp, qe@ne.T/temp], dim=1)
169
+ loss = F.cross_entropy(sim, torch.zeros(n_q, dtype=torch.long, device=device))
170
+ else:
171
+ loss = F.cross_entropy(qe @ torch.cat([pe, ne], dim=0).T / temp,
172
+ torch.arange(n_q, device=device).clamp(max=len(pe)-1))
173
+
174
+ if model.config.hypernetwork.output_variance:
175
+ lv = adapted.get("log_variance")
176
+ if lv is not None: loss = loss + (lv.exp() - lv - 1).mean() * model.config.meta.kl_weight
177
+ return loss
178
+
179
+ def validate(model, articles_by_law, laws_by_year, val_years, config):
180
+ model.eval(); losses = []
181
+ with torch.no_grad():
182
+ for _ in range(30):
183
+ qy = random.choice(val_years)
184
+ if qy not in laws_by_year: continue
185
+ sy = [y for y in sorted(laws_by_year.keys()) if y < qy]
186
+ if len(sy) < 3: sy = [y for y in sorted(laws_by_year.keys()) if y <= qy]
187
+ ep = build_episode(articles_by_law, laws_by_year, sy, qy, config)
188
+ if ep is None: continue
189
+ sv = model.get_state_vector()
190
+ losses.append(compute_loss(model,
191
+ [f"{q['title']}: {q['text'][:200]}" for q in ep["queries"]],
192
+ [f"{p['title']}: {p['text'][:200]}" for p in ep["positives"]],
193
+ [f"{n['title']}: {n['text'][:200]}" for n in ep["negatives"]],
194
+ sv, config.meta.temperature).item())
195
+ return sum(losses)/max(len(losses),1)
196
+
197
+ def meta_train(model, articles_by_law, laws_by_year, train_years, val_years, config, epochs=50, patience=10):
198
+ trainable = (list(model.hypernetwork.parameters()) + list(model.state_encoder.parameters()) +
199
+ list(model.base_projection.parameters()) + [model.attn_query])
200
+ opt = torch.optim.AdamW(trainable, lr=config.meta.meta_lr, weight_decay=1e-4)
201
+ sched = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=10, T_mult=2)
202
+
203
+ os.makedirs(config.output_dir, exist_ok=True)
204
+ best_val, patience_ctr = float("inf"), 0
205
+
206
+ for epoch in range(epochs):
207
+ model.train(); total_loss = 0.0
208
+ steps = config.meta.meta_batch_size * 100
209
+ progress = tqdm(range(steps), desc=f"Meta Epoch {epoch+1}/{epochs}")
210
+ for _ in progress:
211
+ if len(train_years) < 3: break
212
+ si = random.randint(2, len(train_years)-1)
213
+ sy, qy = train_years[:si], train_years[si]
214
+ if qy not in laws_by_year: continue
215
+ ep = build_episode(articles_by_law, laws_by_year, sy, qy, config)
216
+ if ep is None: continue
217
+ sv = model.get_state_vector()
218
+ loss = compute_loss(model,
219
+ [f"{q['title']}: {q['text'][:200]}" for q in ep["queries"]],
220
+ [f"{p['title']}: {p['text'][:200]}" for p in ep["positives"]],
221
+ [f"{n['title']}: {n['text'][:200]}" for n in ep["negatives"]],
222
+ sv, config.meta.temperature)
223
+ opt.zero_grad(); loss.backward()
224
+ torch.nn.utils.clip_grad_norm_(trainable, 1.0); opt.step()
225
+ total_loss += loss.item()
226
+ progress.set_postfix({"loss": f"{loss.item():.4f}"})
227
+
228
+ avg = total_loss / max(steps, 1)
229
+ print(f" Epoch {epoch+1} avg loss: {avg:.4f}")
230
+ sched.step()
231
+
232
+ vl = validate(model, articles_by_law, laws_by_year, val_years, config)
233
+ print(f" Val loss: {vl:.4f}")
234
+
235
+ if vl < best_val:
236
+ best_val, patience_ctr = vl, 0
237
+ torch.save({
238
+ "hypernetwork": model.hypernetwork.state_dict(),
239
+ "state_encoder": model.state_encoder.state_dict(),
240
+ "base_projection": model.base_projection.state_dict(),
241
+ "attn_query": model.attn_query,
242
+ "epoch": epoch, "val_loss": vl,
243
+ }, Path(config.output_dir) / "telen_best.pt")
244
+ print(f" Saved (val_loss={vl:.4f})")
245
+ else:
246
+ patience_ctr += 1
247
+ if patience_ctr >= patience:
248
+ print(f" Early stopping at epoch {epoch+1}"); break
249
+ print("Meta-training complete!")
250
+ return model
251
+
252
+ # ═══════════════════════════════════════════════════════════
253
+ # Main
254
+ # ═══════════════════════════════════════════════════════════
255
+ def main():
256
+ config = TELENConfig()
257
+ random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
258
+ print(f"Device: {device}")
259
+
260
+ # Data
261
+ print("\nLoading data...")
262
+ articles_by_law, laws_by_year, train_years, val_years, test_years, df = prepare_data(config)
263
+ print(f" Train: {train_years[0]}-{train_years[-1]} ({len(train_years)}y)")
264
+ print(f" Val: {val_years[0]}-{val_years[-1]} ({len(val_years)}y)")
265
+ print(f" Test: {len(test_years)}y")
266
+
267
+ # Model
268
+ print("\nCreating TELEN...")
269
+ model = create_model(config).to(device)
270
+ print(f" HyperNetwork: {sum(p.numel() for p in model.hypernetwork.parameters()):,} params")
271
+
272
+ # Build graph
273
+ print("\nBuilding concept graph...")
274
+ train_df = df[df["year"].isin(train_years)]
275
+ model.build_graph(train_df)
276
+ print(f" Graph: {model.concept_graph.num_nodes} nodes")
277
+
278
+ # Stage 1
279
+ print("\n" + "=" * 60)
280
+ print("Stage 1: Contrastive Pretraining")
281
+ print("=" * 60)
282
+ model = contrastive_pretrain(model, train_df, config, epochs=5, batch_size=24, lr=3e-5)
283
+
284
+ # Stage 2
285
+ print("\n" + "=" * 60)
286
+ print("Stage 2: Meta-Training")
287
+ print("=" * 60)
288
+ model = meta_train(model, articles_by_law, laws_by_year, train_years, val_years, config, epochs=50, patience=10)
289
+
290
+ print(f"\nDone! Model saved to: {config.output_dir}/telen_best.pt")
291
+
292
+ if __name__ == "__main__":
293
+ main()
train_ce.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train the cross-encoder re-ranker for legal text.
3
+
4
+ Usage:
5
+ python train_ce.py
6
+
7
+ Trains a PhoBERT-based cross-encoder on legal article pairs
8
+ with margin ranking loss for re-ranking TELEN retrieval results.
9
+
10
+ Output: data/checkpoints/telen/cross_encoder_best.pt
11
+ """
12
+ import sys; sys.path.insert(0, ".")
13
+ sys.stdout.reconfigure(encoding='utf-8')
14
+ import warnings; warnings.filterwarnings("ignore")
15
+ import random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
16
+ from tqdm import tqdm
17
+ from collections import defaultdict
18
+ from transformers import AutoModel, AutoTokenizer
19
+
20
+ from src.telern.config import DATA_DIR
21
+ from src.data import load_raw_data, extract_metadata, clean_data
22
+
23
+ SEED = 42; random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
24
+ device = torch.device("cuda")
25
+
26
+ # ── Data ──
27
+ print("Loading data...")
28
+ df = load_raw_data(str(DATA_DIR / "train-00000-of-00001.parquet"))
29
+ df = extract_metadata(df); df = clean_data(df, min_text_len=10)
30
+ train_df = df[df["year"] <= 2018]
31
+ print(f" {len(train_df)} articles, {train_df['law_id'].nunique()} laws")
32
+
33
+ # ── Build pairs ──
34
+ print("Building pairs...")
35
+ law_groups = train_df.groupby("law_id")
36
+ law_ids = list(law_groups.groups.keys())
37
+ law_type_to_laws = defaultdict(list)
38
+ for lid in law_ids:
39
+ lt = law_groups.get_group(lid).iloc[0]["law_type"]
40
+ law_type_to_laws[lt].append(lid)
41
+
42
+ pairs = []
43
+ for law_id in tqdm(law_ids, desc=" Pairs"):
44
+ group = law_groups.get_group(law_id)
45
+ articles = group.to_dict("records")
46
+ if len(articles) < 2: continue
47
+ law_type = articles[0]["law_type"]
48
+ same_type_laws = [l for l in law_type_to_laws.get(law_type, []) if l != law_id]
49
+
50
+ for art in articles:
51
+ q = f"{art['title']}: {art['text'][:400]}"
52
+ pos = [a for a in articles if a["id"] != art["id"]]
53
+ if pos:
54
+ pairs.append((q, f"{random.choice(pos)['title']}: {random.choice(pos)['text'][:400]}", 1.0))
55
+ if same_type_laws:
56
+ neg_art = law_groups.get_group(random.choice(same_type_laws)).iloc[0]
57
+ pairs.append((q, f"{neg_art['title']}: {neg_art['text'][:400]}", 0.0))
58
+ diff = [l for l in law_ids if l != law_id and l not in same_type_laws]
59
+ if diff:
60
+ neg_art2 = law_groups.get_group(random.choice(diff)).iloc[0]
61
+ pairs.append((q, f"{neg_art2['title']}: {neg_art2['text'][:400]}", 0.0))
62
+
63
+ n_pos = sum(1 for p in pairs if p[2] == 1.0)
64
+ if len(pairs) > 60000:
65
+ pos_pairs = [p for p in pairs if p[2] == 1.0]
66
+ neg_pairs = [p for p in pairs if p[2] == 0.0]
67
+ pairs = random.sample(pos_pairs, min(30000, len(pos_pairs))) + random.sample(neg_pairs, min(30000, len(neg_pairs)))
68
+ print(f" {len(pairs)} pairs ({sum(1 for p in pairs if p[2]==1.0)} pos)")
69
+
70
+ # ── Model ──
71
+ print("Loading PhoBERT...")
72
+ tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
73
+ encoder = AutoModel.from_pretrained("vinai/phobert-base-v2").to(device)
74
+ head = nn.Sequential(
75
+ nn.Linear(encoder.config.hidden_size, 512), nn.ReLU(), nn.Dropout(0.1),
76
+ nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.1),
77
+ nn.Linear(256, 1),
78
+ ).to(device)
79
+ opt = torch.optim.AdamW(list(encoder.parameters())+list(head.parameters()), lr=1e-5, weight_decay=0.01)
80
+
81
+ # ── Train ──
82
+ B, epochs = 16, 10
83
+ steps_per_epoch = len(pairs) // B
84
+ print(f"\nTraining: {epochs} epochs, {steps_per_epoch} steps/epoch")
85
+ best_loss = float("inf")
86
+
87
+ for epoch in range(epochs):
88
+ random.shuffle(pairs)
89
+ epoch_loss = 0.0
90
+ progress = tqdm(range(steps_per_epoch), desc=f" Epoch {epoch+1}/{epochs}")
91
+ for step in progress:
92
+ start = (step * B) % max(len(pairs) - B, 1)
93
+ batch = pairs[start:start + B]
94
+ queries = [p[0] for p in batch]; docs = [p[1] for p in batch]
95
+ labels = torch.tensor([p[2] for p in batch], dtype=torch.float, device=device)
96
+
97
+ enc = tokenizer(queries, docs, padding=True, truncation=True, max_length=256, return_tensors="pt")
98
+ input_ids = enc["input_ids"].to(device); attention_mask = enc["attention_mask"].to(device)
99
+ out = encoder(input_ids=input_ids, attention_mask=attention_mask)
100
+ scores = head(out.last_hidden_state[:, 0, :]).squeeze(-1)
101
+
102
+ pos_mask = labels == 1; neg_mask = labels == 0
103
+ if pos_mask.any() and neg_mask.any():
104
+ pos_scores = scores[pos_mask]; neg_scores = scores[neg_mask]
105
+ loss = F.relu(0.3 - pos_scores.unsqueeze(1) + neg_scores.unsqueeze(0)).mean()
106
+ else:
107
+ loss = F.binary_cross_entropy_with_logits(scores, labels)
108
+
109
+ opt.zero_grad(); loss.backward()
110
+ torch.nn.utils.clip_grad_norm_(list(encoder.parameters())+list(head.parameters()), 1.0)
111
+ opt.step()
112
+ epoch_loss += loss.item()
113
+ progress.set_postfix({"loss": f"{loss.item():.4f}"})
114
+
115
+ avg_loss = epoch_loss / steps_per_epoch
116
+ print(f" Epoch {epoch+1} avg loss: {avg_loss:.4f}")
117
+ if avg_loss < best_loss:
118
+ best_loss = avg_loss
119
+ torch.save({"encoder": encoder.state_dict(), "head": head.state_dict()},
120
+ "data/checkpoints/telen/cross_encoder_best.pt")
121
+ print(f" Saved (loss={avg_loss:.4f})")
122
+
123
+ print("\nDone! Model saved to: data/checkpoints/telen/cross_encoder_best.pt")