Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import pad_sequence | |
from transformers import AutoModel | |
from pathlib import Path | |
class LinearTokenSelector(nn.Module): | |
def __init__(self, encoder, embedding_size=768): | |
super(LinearTokenSelector, self).__init__() | |
self.encoder = encoder | |
self.classifier = nn.Linear(embedding_size, 2, bias=False) | |
def forward(self, x): | |
output = self.encoder(x, output_hidden_states=True) | |
x = output["hidden_states"][-1] # B * S * H | |
x = self.classifier(x) | |
x = F.log_softmax(x, dim=2) | |
return x | |
def save(self, classifier_path, encoder_path): | |
state = self.state_dict() | |
state = dict((k, v) for k, v in state.items() if k.startswith("classifier")) | |
torch.save(state, classifier_path) | |
self.encoder.save_pretrained(encoder_path) | |
def predict(self, texts, tokenizer, device): | |
input_ids = tokenizer(texts)["input_ids"] | |
input_ids = pad_sequence( | |
[torch.tensor(ids) for ids in input_ids], batch_first=True | |
).to(device) | |
logits = self.forward(input_ids) | |
argmax_labels = torch.argmax(logits, dim=2) | |
return labels_to_summary(input_ids, argmax_labels, tokenizer) | |
def load_model(model_dir, device="cuda", prefix="best"): | |
if isinstance(model_dir, str): | |
model_dir = Path(model_dir) | |
for p in (model_dir / "checkpoints").iterdir(): | |
if p.name.startswith(f"{prefix}"): | |
checkpoint_dir = p | |
return load_checkpoint(checkpoint_dir, device=device) | |
def load_checkpoint(checkpoint_dir, device="cuda"): | |
if isinstance(checkpoint_dir, str): | |
checkpoint_dir = Path(checkpoint_dir) | |
encoder_path = checkpoint_dir / "encoder.bin" | |
classifier_path = checkpoint_dir / "classifier.bin" | |
encoder = AutoModel.from_pretrained(encoder_path).to(device) | |
embedding_size = encoder.state_dict()["embeddings.word_embeddings.weight"].shape[1] | |
classifier = LinearTokenSelector(None, embedding_size).to(device) | |
classifier_state = torch.load(classifier_path, map_location=device) | |
classifier_state = dict( | |
(k, v) for k, v in classifier_state.items() | |
if k.startswith("classifier") | |
) | |
classifier.load_state_dict(classifier_state) | |
classifier.encoder = encoder | |
return classifier.to(device) | |
def labels_to_summary(input_batch, label_batch, tokenizer): | |
summaries = [] | |
for input_ids, labels in zip(input_batch, label_batch): | |
selected = [int(input_ids[i]) for i in range(len(input_ids)) | |
if labels[i] == 1] | |
summary = tokenizer.decode(selected, skip_special_tokens=True) | |
summaries.append(summary) | |
return summaries | |