| """ |
| Train the cross-encoder re-ranker for legal text. |
| |
| Usage: |
| python train_ce.py |
| |
| Trains a PhoBERT-based cross-encoder on legal article pairs |
| with margin ranking loss for re-ranking TELEN retrieval results. |
| |
| Output: data/checkpoints/telen/cross_encoder_best.pt |
| """ |
| import sys; sys.path.insert(0, ".") |
| sys.stdout.reconfigure(encoding='utf-8') |
| import warnings; warnings.filterwarnings("ignore") |
| import random, numpy as np, torch, torch.nn as nn, torch.nn.functional as F |
| from tqdm import tqdm |
| from collections import defaultdict |
| from transformers import AutoModel, AutoTokenizer |
|
|
| from src.telern.config import DATA_DIR |
| from src.data import load_raw_data, extract_metadata, clean_data |
|
|
| SEED = 42; random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED) |
| device = torch.device("cuda") |
|
|
| |
| print("Loading data...") |
| df = load_raw_data(str(DATA_DIR / "train-00000-of-00001.parquet")) |
| df = extract_metadata(df); df = clean_data(df, min_text_len=10) |
| train_df = df[df["year"] <= 2018] |
| print(f" {len(train_df)} articles, {train_df['law_id'].nunique()} laws") |
|
|
| |
| print("Building pairs...") |
| law_groups = train_df.groupby("law_id") |
| law_ids = list(law_groups.groups.keys()) |
| law_type_to_laws = defaultdict(list) |
| for lid in law_ids: |
| lt = law_groups.get_group(lid).iloc[0]["law_type"] |
| law_type_to_laws[lt].append(lid) |
|
|
| pairs = [] |
| for law_id in tqdm(law_ids, desc=" Pairs"): |
| group = law_groups.get_group(law_id) |
| articles = group.to_dict("records") |
| if len(articles) < 2: continue |
| law_type = articles[0]["law_type"] |
| same_type_laws = [l for l in law_type_to_laws.get(law_type, []) if l != law_id] |
|
|
| for art in articles: |
| q = f"{art['title']}: {art['text'][:400]}" |
| pos = [a for a in articles if a["id"] != art["id"]] |
| if pos: |
| pairs.append((q, f"{random.choice(pos)['title']}: {random.choice(pos)['text'][:400]}", 1.0)) |
| if same_type_laws: |
| neg_art = law_groups.get_group(random.choice(same_type_laws)).iloc[0] |
| pairs.append((q, f"{neg_art['title']}: {neg_art['text'][:400]}", 0.0)) |
| diff = [l for l in law_ids if l != law_id and l not in same_type_laws] |
| if diff: |
| neg_art2 = law_groups.get_group(random.choice(diff)).iloc[0] |
| pairs.append((q, f"{neg_art2['title']}: {neg_art2['text'][:400]}", 0.0)) |
|
|
| n_pos = sum(1 for p in pairs if p[2] == 1.0) |
| if len(pairs) > 60000: |
| pos_pairs = [p for p in pairs if p[2] == 1.0] |
| neg_pairs = [p for p in pairs if p[2] == 0.0] |
| pairs = random.sample(pos_pairs, min(30000, len(pos_pairs))) + random.sample(neg_pairs, min(30000, len(neg_pairs))) |
| print(f" {len(pairs)} pairs ({sum(1 for p in pairs if p[2]==1.0)} pos)") |
|
|
| |
| print("Loading PhoBERT...") |
| tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base-v2") |
| encoder = AutoModel.from_pretrained("vinai/phobert-base-v2").to(device) |
| head = nn.Sequential( |
| nn.Linear(encoder.config.hidden_size, 512), nn.ReLU(), nn.Dropout(0.1), |
| nn.Linear(512, 256), nn.ReLU(), nn.Dropout(0.1), |
| nn.Linear(256, 1), |
| ).to(device) |
| opt = torch.optim.AdamW(list(encoder.parameters())+list(head.parameters()), lr=1e-5, weight_decay=0.01) |
|
|
| |
| B, epochs = 16, 10 |
| steps_per_epoch = len(pairs) // B |
| print(f"\nTraining: {epochs} epochs, {steps_per_epoch} steps/epoch") |
| best_loss = float("inf") |
|
|
| for epoch in range(epochs): |
| random.shuffle(pairs) |
| epoch_loss = 0.0 |
| progress = tqdm(range(steps_per_epoch), desc=f" Epoch {epoch+1}/{epochs}") |
| for step in progress: |
| start = (step * B) % max(len(pairs) - B, 1) |
| batch = pairs[start:start + B] |
| queries = [p[0] for p in batch]; docs = [p[1] for p in batch] |
| labels = torch.tensor([p[2] for p in batch], dtype=torch.float, device=device) |
|
|
| enc = tokenizer(queries, docs, padding=True, truncation=True, max_length=256, return_tensors="pt") |
| input_ids = enc["input_ids"].to(device); attention_mask = enc["attention_mask"].to(device) |
| out = encoder(input_ids=input_ids, attention_mask=attention_mask) |
| scores = head(out.last_hidden_state[:, 0, :]).squeeze(-1) |
|
|
| pos_mask = labels == 1; neg_mask = labels == 0 |
| if pos_mask.any() and neg_mask.any(): |
| pos_scores = scores[pos_mask]; neg_scores = scores[neg_mask] |
| loss = F.relu(0.3 - pos_scores.unsqueeze(1) + neg_scores.unsqueeze(0)).mean() |
| else: |
| loss = F.binary_cross_entropy_with_logits(scores, labels) |
|
|
| opt.zero_grad(); loss.backward() |
| torch.nn.utils.clip_grad_norm_(list(encoder.parameters())+list(head.parameters()), 1.0) |
| opt.step() |
| epoch_loss += loss.item() |
| progress.set_postfix({"loss": f"{loss.item():.4f}"}) |
|
|
| avg_loss = epoch_loss / steps_per_epoch |
| print(f" Epoch {epoch+1} avg loss: {avg_loss:.4f}") |
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| torch.save({"encoder": encoder.state_dict(), "head": head.state_dict()}, |
| "data/checkpoints/telen/cross_encoder_best.pt") |
| print(f" Saved (loss={avg_loss:.4f})") |
|
|
| print("\nDone! Model saved to: data/checkpoints/telen/cross_encoder_best.pt") |
|
|