haidang2405 commited on
Commit
2691c77
Β·
verified Β·
1 Parent(s): 73847bc

Update new

Browse files
Files changed (2) hide show
  1. README.md +12 -22
  2. eval.py +55 -10
README.md CHANGED
@@ -1,19 +1,3 @@
1
- ---
2
- license: mit
3
- datasets:
4
- - another-symato/VMTEB-Zalo-legel-retrieval-wseg
5
- language:
6
- - vi
7
- base_model:
8
- - bkai-foundation-models/vietnamese-bi-encoder
9
- pipeline_tag: feature-extraction
10
- tags:
11
- - feature-extraction
12
- - embedding
13
- - legal
14
- - law
15
- - vietnamese
16
- ---
17
  # TELEN: Temporal Evolving Legal Embedding Network
18
 
19
  > **Vietnamese legal text embedding with meta-learning for continuous adaptation to new laws.**
@@ -66,15 +50,21 @@ L2-Normalized Embedding [768-dim]
66
  | Model | NDCG@3 | NDCG@5 | NDCG@10 | MRR@3 | MRR@5 | MRR@10 |
67
  |---|---|---|---|---|---|---|
68
  | **BM25** (lexical) | 0.6753 | 0.7173 | 0.7250 | 0.6683 | 0.6928 | 0.6990 |
69
- | **PhoBERT-base-v2** (dense) | 0.5866 | 0.6360 | 0.6505 | 0.5657 | 0.5970 | 0.6059 |
70
- | **DEk21** (dense) | 0.7900 | 0.8127 | 0.8344 | 0.7660 | 0.7785 | 0.7865 |
71
- | **TELEN** (dense) | 0.9036 | 0.9138 | 0.9132 | 0.8830 | 0.8878 | 0.8878 |
72
- | **TELEN + CE re-rank** (dense) | **0.9346** | **0.9339** | **0.9238** | **0.9199** | **0.9223** | **0.9223** |
 
 
 
 
73
 
74
  ### Relative Improvement
75
 
76
  | Baseline | NDCG@3 | NDCG@5 | NDCG@10 | MRR@10 |
77
  |---|---|---|---|---|
 
 
78
  | vs PhoBERT | **+59.3%** | **+46.8%** | **+42.0%** | **+52.2%** |
79
  | vs DEk21 | **+18.3%** | **+14.9%** | **+10.7%** | **+17.3%** |
80
 
@@ -124,7 +114,7 @@ python train_ce.py
124
  ### Evaluation
125
 
126
  ```bash
127
- # Full benchmark (TELEN vs BM25/PhoBERT/DEk21)
128
  python eval.py
129
 
130
  # TELEN + Cross-encoder re-ranking (MRR-optimized)
@@ -232,4 +222,4 @@ MIT License β€” see [LICENSE](LICENSE) file for details.
232
  ## Acknowledgments
233
 
234
  - `bkai-foundation-models/vietnamese-bi-encoder` β€” backbone bi-encoder
235
- - `huyydangg/DEk21_hcmute_embedding` β€” baseline comparison - `vinai/phobert-base-v2` β€” used in cross-encoder re-ranker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # TELEN: Temporal Evolving Legal Embedding Network
2
 
3
  > **Vietnamese legal text embedding with meta-learning for continuous adaptation to new laws.**
 
50
  | Model | NDCG@3 | NDCG@5 | NDCG@10 | MRR@3 | MRR@5 | MRR@10 |
51
  |---|---|---|---|---|---|---|
52
  | **BM25** (lexical) | 0.6753 | 0.7173 | 0.7250 | 0.6683 | 0.6928 | 0.6990 |
53
+ | **PhoBERT-base-v2** (monolingual dense) | 0.5866 | 0.6360 | 0.6505 | 0.5657 | 0.5970 | 0.6059 |
54
+ | **multilingual-E5-base** (multilingual dense) | 0.4675 | 0.4888 | 0.5157 | 0.4327 | 0.4452 | 0.4573 |
55
+ | **BAAI/bge-m3** (multilingual dense, 1024d) | 0.4668 | 0.5129 | 0.5452 | 0.4407 | 0.4657 | 0.4802 |
56
+ | **DEk21** (legal dense) | 0.7900 | 0.8127 | 0.8344 | 0.7660 | 0.7785 | 0.7865 |
57
+ | **TELEN** (adaptive dense) | 0.9036 | 0.9138 | 0.9132 | 0.8830 | 0.8878 | 0.8878 |
58
+ | **TELEN + CE re-rank** (adaptive dense) | **0.9346** | **0.9339** | **0.9238** | **0.9199** | **0.9223** | **0.9223** |
59
+
60
+ > **Key insight:** Multilingual SOTA models (multilingual-E5, BGE-M3) score **below even BM25** on Vietnamese legal text, confirming that domain and language specialization trumps generic multilingual pre-training for legal retrieval.
61
 
62
  ### Relative Improvement
63
 
64
  | Baseline | NDCG@3 | NDCG@5 | NDCG@10 | MRR@10 |
65
  |---|---|---|---|---|
66
+ | vs multilingual-E5 | **+93.3%** | **+86.9%** | **+77.1%** | **+94.1%** |
67
+ | vs BGE-M3 | **+93.6%** | **+78.2%** | **+67.5%** | **+84.9%** |
68
  | vs PhoBERT | **+59.3%** | **+46.8%** | **+42.0%** | **+52.2%** |
69
  | vs DEk21 | **+18.3%** | **+14.9%** | **+10.7%** | **+17.3%** |
70
 
 
114
  ### Evaluation
115
 
116
  ```bash
117
+ # Full benchmark (TELEN vs BM25/PhoBERT/mE5/BGE-M3/DEk21)
118
  python eval.py
119
 
120
  # TELEN + Cross-encoder re-ranking (MRR-optimized)
 
222
  ## Acknowledgments
223
 
224
  - `bkai-foundation-models/vietnamese-bi-encoder` β€” backbone bi-encoder
225
+ - `huyydangg/DEk21_hcmute_embedding` β€” baseline comparison - `vinai/phobert-base-v2` β€” used in cross-encoder re-ranker
eval.py CHANGED
@@ -6,6 +6,8 @@ Metrics: NDCG@3, NDCG@5, NDCG@10, MRR@3, MRR@5, MRR@10
6
  Baselines:
7
  - BM25 (lexical retrieval)
8
  - Frozen PhoBERT (vinai/phobert-base-v2)
 
 
9
  - DEk21 (huyydangg/DEk21_hcmute_embedding)
10
  - TELEN (ours)
11
 
@@ -19,6 +21,7 @@ import random, numpy as np, torch, torch.nn.functional as F
19
  from tqdm import tqdm
20
  from collections import defaultdict
21
  from sentence_transformers import SentenceTransformer
 
22
  from pyvi import ViTokenizer
23
 
24
  from src.telern.config import TELENConfig
@@ -35,13 +38,15 @@ config = TELENConfig()
35
  def wseg(text):
36
  return ViTokenizer.tokenize(text.replace("_", " "))
37
 
38
- def evaluate_model(name, encode_fn, queries, corpus, corpus_ids, corpus_law_ids):
39
  """Generic evaluation for any embedding model."""
 
 
40
  print(f"\n [{name}] Encoding corpus ({len(corpus)} docs)...")
41
  c_embs = []
42
  for i in range(0, len(corpus), 64):
43
  batch = [d["text"] for d in corpus[i:i+64]]
44
- embs = encode_fn(batch)
45
  if isinstance(embs, np.ndarray): embs = torch.tensor(embs)
46
  c_embs.append(embs.cpu())
47
  c_embs = torch.cat(c_embs, dim=0)
@@ -73,21 +78,59 @@ print(f"Test: {len(queries)} queries, {len(corpus)} docs, {test_df['law_id'].nun
73
  results = {}
74
 
75
  # ── BM25 ──
76
- print("\n[1/4] BM25")
77
  results["BM25"] = evaluate_bm25(queries, corpus)
78
 
79
  # ── PhoBERT ──
80
- print("\n[2/4] Frozen PhoBERT")
81
  phobert = FrozenPhoBERT()
82
  results["PhoBERT"] = evaluate_model("PhoBERT", lambda texts: phobert.encode(texts, batch_size=64), queries, corpus, corpus_ids, corpus_law_ids)
83
 
84
  # ── DEk21 ──
85
- print("\n[3/4] DEk21 (SOTA)")
86
  dek21 = SentenceTransformer("huyydangg/DEk21_hcmute_embedding", device=device)
87
  results["DEk21"] = evaluate_model("DEk21", lambda texts: dek21.encode([wseg(t) for t in texts], batch_size=64, show_progress_bar=False, normalize_embeddings=True, convert_to_tensor=True), queries, corpus, corpus_ids, corpus_law_ids)
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  # ── TELEN ──
90
- print("\n[4/4] TELEN (Ours)")
91
  telen = create_model(config).to(device)
92
  ckpt = torch.load(config.output_dir + "/telen_best.pt", map_location=device, weights_only=False)
93
  telen.hypernetwork.load_state_dict(ckpt["hypernetwork"])
@@ -109,14 +152,16 @@ print("=" * 75)
109
  h = f"{'Method':<15}"
110
  for m in [3,5,10]: h += f" {'NDCG@'+str(m):>10} {'MRR@'+str(m):>10}"
111
  print(h); print("-"*len(h))
112
- for name in ["BM25", "PhoBERT", "DEk21", "TELEN"]:
113
- r = f"{name:<15}"
 
114
  for m in [3,5,10]: r += f" {results[name][f'ndcg@{m}']:>10.4f} {results[name][f'mrr@{m}']:>10.4f}"
115
  print(r)
116
 
117
  print("\n--- Relative Improvement over Baselines ---")
118
- for baseline in ["PhoBERT", "DEk21"]:
119
- print(f" TELEN vs {baseline}:")
 
120
  for m in [3,5,10]:
121
  ni = (results["TELEN"][f"ndcg@{m}"] / max(results[baseline][f"ndcg@{m}"], 1e-6) - 1) * 100
122
  mi = (results["TELEN"][f"mrr@{m}"] / max(results[baseline][f"mrr@{m}"], 1e-6) - 1) * 100
 
6
  Baselines:
7
  - BM25 (lexical retrieval)
8
  - Frozen PhoBERT (vinai/phobert-base-v2)
9
+ - multilingual-E5-base (intfloat/multilingual-e5-base)
10
+ - BGE-M3 (BAAI/bge-m3)
11
  - DEk21 (huyydangg/DEk21_hcmute_embedding)
12
  - TELEN (ours)
13
 
 
21
  from tqdm import tqdm
22
  from collections import defaultdict
23
  from sentence_transformers import SentenceTransformer
24
+ from transformers import AutoModel, AutoTokenizer
25
  from pyvi import ViTokenizer
26
 
27
  from src.telern.config import TELENConfig
 
38
  def wseg(text):
39
  return ViTokenizer.tokenize(text.replace("_", " "))
40
 
41
+ def evaluate_model(name, encode_fn, queries, corpus, corpus_ids, corpus_law_ids, corpus_encode_fn=None):
42
  """Generic evaluation for any embedding model."""
43
+ if corpus_encode_fn is None:
44
+ corpus_encode_fn = encode_fn
45
  print(f"\n [{name}] Encoding corpus ({len(corpus)} docs)...")
46
  c_embs = []
47
  for i in range(0, len(corpus), 64):
48
  batch = [d["text"] for d in corpus[i:i+64]]
49
+ embs = corpus_encode_fn(batch)
50
  if isinstance(embs, np.ndarray): embs = torch.tensor(embs)
51
  c_embs.append(embs.cpu())
52
  c_embs = torch.cat(c_embs, dim=0)
 
78
  results = {}
79
 
80
  # ── BM25 ──
81
+ print("\n[1/6] BM25")
82
  results["BM25"] = evaluate_bm25(queries, corpus)
83
 
84
  # ── PhoBERT ──
85
+ print("\n[2/6] Frozen PhoBERT")
86
  phobert = FrozenPhoBERT()
87
  results["PhoBERT"] = evaluate_model("PhoBERT", lambda texts: phobert.encode(texts, batch_size=64), queries, corpus, corpus_ids, corpus_law_ids)
88
 
89
  # ── DEk21 ──
90
+ print("\n[3/6] DEk21 (legal SOTA)")
91
  dek21 = SentenceTransformer("huyydangg/DEk21_hcmute_embedding", device=device)
92
  results["DEk21"] = evaluate_model("DEk21", lambda texts: dek21.encode([wseg(t) for t in texts], batch_size=64, show_progress_bar=False, normalize_embeddings=True, convert_to_tensor=True), queries, corpus, corpus_ids, corpus_law_ids)
93
 
94
+ # ── multilingual-E5-base ──
95
+ print("\n[4/6] multilingual-E5-base")
96
+ e5_tokenizer = AutoTokenizer.from_pretrained("intfloat/multilingual-e5-base")
97
+ e5_model = AutoModel.from_pretrained("intfloat/multilingual-e5-base").to(device)
98
+ e5_model.eval()
99
+ def e5_encode(texts, prefix="query: "):
100
+ prefixed = [prefix + t for t in texts]
101
+ enc = e5_tokenizer(prefixed, padding=True, truncation=True, max_length=512, return_tensors="pt")
102
+ with torch.no_grad():
103
+ hidden = e5_model(input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device)).last_hidden_state
104
+ mask = enc["attention_mask"].unsqueeze(-1).float().to(device)
105
+ pooled = (hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1e-9)
106
+ return F.normalize(pooled, p=2, dim=1)
107
+ results["multilingual-e5"] = evaluate_model("mE5",
108
+ lambda texts: e5_encode(texts), # queries: "query: " prefix
109
+ queries, corpus, corpus_ids, corpus_law_ids,
110
+ corpus_encode_fn=lambda texts: e5_encode(texts, prefix="passage: "))
111
+ del e5_model, e5_tokenizer; torch.cuda.empty_cache()
112
+
113
+ # ── BGE-M3 ──
114
+ print("\n[5/6] BAAI/bge-m3")
115
+ bge_tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-m3")
116
+ bge_model = AutoModel.from_pretrained("BAAI/bge-m3").to(device)
117
+ bge_model.eval()
118
+ def bge_encode(texts, add_prefix=True):
119
+ if add_prefix:
120
+ texts = ["Represent this sentence for searching relevant passages: " + t for t in texts]
121
+ enc = bge_tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
122
+ with torch.no_grad():
123
+ hidden = bge_model(input_ids=enc["input_ids"].to(device), attention_mask=enc["attention_mask"].to(device)).last_hidden_state
124
+ cls_emb = hidden[:, 0, :]
125
+ return F.normalize(cls_emb, p=2, dim=1)
126
+ results["bge-m3"] = evaluate_model("BGE-M3",
127
+ lambda texts: bge_encode(texts, add_prefix=True), # queries: with instruction
128
+ queries, corpus, corpus_ids, corpus_law_ids,
129
+ corpus_encode_fn=lambda texts: bge_encode(texts, add_prefix=False)) # passages: no prefix
130
+ del bge_model, bge_tokenizer; torch.cuda.empty_cache()
131
+
132
  # ── TELEN ──
133
+ print("\n[6/6] TELEN (Ours)")
134
  telen = create_model(config).to(device)
135
  ckpt = torch.load(config.output_dir + "/telen_best.pt", map_location=device, weights_only=False)
136
  telen.hypernetwork.load_state_dict(ckpt["hypernetwork"])
 
152
  h = f"{'Method':<15}"
153
  for m in [3,5,10]: h += f" {'NDCG@'+str(m):>10} {'MRR@'+str(m):>10}"
154
  print(h); print("-"*len(h))
155
+ for name in ["BM25", "PhoBERT", "multilingual-e5", "bge-m3", "DEk21", "TELEN"]:
156
+ display = {"multilingual-e5": "mE5-base", "bge-m3": "BGE-M3"}.get(name, name)
157
+ r = f"{display:<15}"
158
  for m in [3,5,10]: r += f" {results[name][f'ndcg@{m}']:>10.4f} {results[name][f'mrr@{m}']:>10.4f}"
159
  print(r)
160
 
161
  print("\n--- Relative Improvement over Baselines ---")
162
+ for baseline in ["PhoBERT", "multilingual-e5", "bge-m3", "DEk21"]:
163
+ display = {"multilingual-e5": "mE5-base", "bge-m3": "BGE-M3"}.get(baseline, baseline)
164
+ print(f" TELEN vs {display}:")
165
  for m in [3,5,10]:
166
  ni = (results["TELEN"][f"ndcg@{m}"] / max(results[baseline][f"ndcg@{m}"], 1e-6) - 1) * 100
167
  mi = (results["TELEN"][f"mrr@{m}"] / max(results[baseline][f"mrr@{m}"], 1e-6) - 1) * 100