telen / train_ce.py
haidang2405's picture
first commit
e7cfc32 verified
"""
Train the cross-encoder re-ranker for legal text.
Usage:
python train_ce.py
Trains a PhoBERT-based cross-encoder on legal article pairs
with margin ranking loss for re-ranking TELEN retrieval results.
Output: data/checkpoints/telen/cross_encoder_best.pt
"""
import sys; sys.path.insert(0, ".")
sys.stdout.reconfigure(encoding='utf-8')
import warnings; warnings.filterwarnings("ignore")
import random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F
from tqdm import tqdm
from collections import defaultdict
from transformers import AutoModel, AutoTokenizer
from src.telern.config import DATA_DIR
from src.data import load_raw_data, extract_metadata, clean_data
SEED = 42; random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = torch.device("cuda")
# ── Data ──
print("Loading data...")
df = load_raw_data(str(DATA_DIR / "train-00000-of-00001.parquet"))
df = extract_metadata(df); df = clean_data(df, min_text_len=10)
train_df = df[df["year"] <= 2018]
print(f" {len(train_df)} articles, {train_df['law_id'].nunique()} laws")
# ── Build pairs ──
print("Building pairs...")
law_groups = train_df.groupby("law_id")
law_ids = list(law_groups.groups.keys())
law_type_to_laws = defaultdict(list)
for lid in law_ids:
lt = law_groups.get_group(lid).iloc[0]["law_type"]
law_type_to_laws[lt].append(lid)
pairs = []
for law_id in tqdm(law_ids, desc=" Pairs"):
group = law_groups.get_group(law_id)
articles = group.to_dict("records")
if len(articles) < 2: continue
law_type = articles[0]["law_type"]
same_type_laws = [l for l in law_type_to_laws.get(law_type, []) if l != law_id]
for art in articles:
q = f"{art['title']}: {art['text'][:400]}"
pos = [a for a in articles if a["id"] != art["id"]]
if pos:
pairs.append((q, f"{random.choice(pos)['title']}: {random.choice(pos)['text'][:400]}", 1.0))
if same_type_laws:
neg_art = law_groups.get_group(random.choice(same_type_laws)).iloc[0]
pairs.append((q, f"{neg_art['title']}: {neg_art['text'][:400]}", 0.0))
diff = [l for l in law_ids if l != law_id and l not in same_type_laws]
if diff:
neg_art2 = law_groups.get_group(random.choice(diff)).iloc[0]
pairs.append((q, f"{neg_art2['title']}: {neg_art2['text'][:400]}", 0.0))
n_pos = sum(1 for p in pairs if p[2] == 1.0)
if len(pairs) > 60000:
pos_pairs = [p for p in pairs if p[2] == 1.0]
neg_pairs = [p for p in pairs if p[2] == 0.0]
pairs = random.sample(pos_pairs, min(30000, len(pos_pairs))) + random.sample(neg_pairs, min(30000, len(neg_pairs)))
print(f" {len(pairs)} pairs ({sum(1 for p in pairs if p[2]==1.0)} pos)")
# ── Model ──
print("Loading PhoBERT...")
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2")
encoder = AutoModel.from_pretrained("vinai/phobert-base-v2").to(device)
head = nn.Sequential(
nn.Linear(encoder.config.hidden_size, 512), nn.ReLU(), nn.Dropout(0.1),
nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.1),
nn.Linear(256, 1),
).to(device)
opt = torch.optim.AdamW(list(encoder.parameters())+list(head.parameters()), lr=1e-5, weight_decay=0.01)
# ── Train ──
B, epochs = 16, 10
steps_per_epoch = len(pairs) // B
print(f"\nTraining: {epochs} epochs, {steps_per_epoch} steps/epoch")
best_loss = float("inf")
for epoch in range(epochs):
random.shuffle(pairs)
epoch_loss = 0.0
progress = tqdm(range(steps_per_epoch), desc=f" Epoch {epoch+1}/{epochs}")
for step in progress:
start = (step * B) % max(len(pairs) - B, 1)
batch = pairs[start:start + B]
queries = [p[0] for p in batch]; docs = [p[1] for p in batch]
labels = torch.tensor([p[2] for p in batch], dtype=torch.float, device=device)
enc = tokenizer(queries, docs, padding=True, truncation=True, max_length=256, return_tensors="pt")
input_ids = enc["input_ids"].to(device); attention_mask = enc["attention_mask"].to(device)
out = encoder(input_ids=input_ids, attention_mask=attention_mask)
scores = head(out.last_hidden_state[:, 0, :]).squeeze(-1)
pos_mask = labels == 1; neg_mask = labels == 0
if pos_mask.any() and neg_mask.any():
pos_scores = scores[pos_mask]; neg_scores = scores[neg_mask]
loss = F.relu(0.3 - pos_scores.unsqueeze(1) + neg_scores.unsqueeze(0)).mean()
else:
loss = F.binary_cross_entropy_with_logits(scores, labels)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(list(encoder.parameters())+list(head.parameters()), 1.0)
opt.step()
epoch_loss += loss.item()
progress.set_postfix({"loss": f"{loss.item():.4f}"})
avg_loss = epoch_loss / steps_per_epoch
print(f" Epoch {epoch+1} avg loss: {avg_loss:.4f}")
if avg_loss < best_loss:
best_loss = avg_loss
torch.save({"encoder": encoder.state_dict(), "head": head.state_dict()},
"data/checkpoints/telen/cross_encoder_best.pt")
print(f" Saved (loss={avg_loss:.4f})")
print("\nDone! Model saved to: data/checkpoints/telen/cross_encoder_best.pt")