import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils.rnn import pad_sequence from transformers import AutoModel from pathlib import Path class LinearTokenSelector(nn.Module): def __init__(self, encoder, embedding_size=768): super(LinearTokenSelector, self).__init__() self.encoder = encoder self.classifier = nn.Linear(embedding_size, 2, bias=False) def forward(self, x): output = self.encoder(x, output_hidden_states=True) x = output["hidden_states"][-1] # B * S * H x = self.classifier(x) x = F.log_softmax(x, dim=2) return x def save(self, classifier_path, encoder_path): state = self.state_dict() state = dict((k, v) for k, v in state.items() if k.startswith("classifier")) torch.save(state, classifier_path) self.encoder.save_pretrained(encoder_path) def predict(self, texts, tokenizer, device): input_ids = tokenizer(texts)["input_ids"] input_ids = pad_sequence( [torch.tensor(ids) for ids in input_ids], batch_first=True ).to(device) logits = self.forward(input_ids) argmax_labels = torch.argmax(logits, dim=2) return labels_to_summary(input_ids, argmax_labels, tokenizer) def load_model(model_dir, device="cuda", prefix="best"): if isinstance(model_dir, str): model_dir = Path(model_dir) for p in (model_dir / "checkpoints").iterdir(): if p.name.startswith(f"{prefix}"): checkpoint_dir = p return load_checkpoint(checkpoint_dir, device=device) def load_checkpoint(checkpoint_dir, device="cuda"): if isinstance(checkpoint_dir, str): checkpoint_dir = Path(checkpoint_dir) encoder_path = checkpoint_dir / "encoder.bin" classifier_path = checkpoint_dir / "classifier.bin" encoder = AutoModel.from_pretrained(encoder_path).to(device) embedding_size = encoder.state_dict()["embeddings.word_embeddings.weight"].shape[1] classifier = LinearTokenSelector(None, embedding_size).to(device) classifier_state = torch.load(classifier_path, map_location=device) classifier_state = dict( (k, v) for k, v in classifier_state.items() if k.startswith("classifier") ) classifier.load_state_dict(classifier_state) classifier.encoder = encoder return classifier.to(device) def labels_to_summary(input_batch, label_batch, tokenizer): summaries = [] for input_ids, labels in zip(input_batch, label_batch): selected = [int(input_ids[i]) for i in range(len(input_ids)) if labels[i] == 1] summary = tokenizer.decode(selected, skip_special_tokens=True) summaries.append(summary) return summaries