|
import argparse |
|
import json |
|
from pathlib import Path |
|
import re |
|
from typing import Dict, Optional, Union |
|
import torch |
|
import torch.nn.functional as F |
|
from .modules.layers import LstmSeq2SeqEncoder |
|
from .modules.base import InstructBase |
|
from .modules.evaluator import Evaluator, greedy_search |
|
from .modules.span_rep import SpanRepLayer |
|
from .modules.token_rep import TokenRepLayer |
|
from torch import nn |
|
from torch.nn.utils.rnn import pad_sequence |
|
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download |
|
from huggingface_hub.utils import HfHubHTTPError |
|
|
|
|
|
|
|
class GLiNER(InstructBase, PyTorchModelHubMixin): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.config = config |
|
|
|
|
|
self.entity_token = "<<ENT>>" |
|
self.sep_token = "<<SEP>>" |
|
|
|
|
|
self.token_rep_layer = TokenRepLayer(model_name=config.model_name, fine_tune=config.fine_tune, |
|
subtoken_pooling=config.subtoken_pooling, hidden_size=config.hidden_size, |
|
add_tokens=[self.entity_token, self.sep_token]) |
|
|
|
|
|
self.rnn = LstmSeq2SeqEncoder( |
|
input_size=config.hidden_size, |
|
hidden_size=config.hidden_size // 2, |
|
num_layers=1, |
|
bidirectional=True, |
|
) |
|
|
|
|
|
self.span_rep_layer = SpanRepLayer( |
|
span_mode=config.span_mode, |
|
hidden_size=config.hidden_size, |
|
max_width=config.max_width, |
|
dropout=config.dropout, |
|
) |
|
|
|
|
|
self.prompt_rep_layer = nn.Sequential( |
|
nn.Linear(config.hidden_size, config.hidden_size * 4), |
|
nn.Dropout(config.dropout), |
|
nn.ReLU(), |
|
nn.Linear(config.hidden_size * 4, config.hidden_size) |
|
) |
|
|
|
def compute_score_train(self, x): |
|
span_idx = x['span_idx'] * x['span_mask'].unsqueeze(-1) |
|
|
|
new_length = x['seq_length'].clone() |
|
new_tokens = [] |
|
all_len_prompt = [] |
|
num_classes_all = [] |
|
|
|
|
|
for i in range(len(x['tokens'])): |
|
all_types_i = list(x['classes_to_id'][i].keys()) |
|
|
|
entity_prompt = [] |
|
num_classes_all.append(len(all_types_i)) |
|
|
|
for entity_type in all_types_i: |
|
entity_prompt.append(self.entity_token) |
|
entity_prompt.append(entity_type) |
|
entity_prompt.append(self.sep_token) |
|
|
|
|
|
|
|
|
|
|
|
tokens_p = entity_prompt + x['tokens'][i] |
|
|
|
|
|
|
|
|
|
|
|
new_length[i] = new_length[i] + len(entity_prompt) |
|
|
|
new_tokens.append(tokens_p) |
|
|
|
all_len_prompt.append(len(entity_prompt)) |
|
|
|
|
|
max_num_classes = max(num_classes_all) |
|
entity_type_mask = torch.arange(max_num_classes).unsqueeze(0).expand(len(num_classes_all), -1).to( |
|
x['span_mask'].device) |
|
entity_type_mask = entity_type_mask < torch.tensor(num_classes_all).unsqueeze(-1).to( |
|
x['span_mask'].device) |
|
|
|
|
|
bert_output = self.token_rep_layer(new_tokens, new_length) |
|
word_rep_w_prompt = bert_output["embeddings"] |
|
mask_w_prompt = bert_output["mask"] |
|
|
|
|
|
word_rep = [] |
|
mask = [] |
|
entity_type_rep = [] |
|
for i in range(len(x['tokens'])): |
|
prompt_entity_length = all_len_prompt[i] |
|
|
|
word_rep.append(word_rep_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]]) |
|
|
|
mask.append(mask_w_prompt[i, prompt_entity_length:prompt_entity_length + x['seq_length'][i]]) |
|
|
|
|
|
entity_rep = word_rep_w_prompt[i, :prompt_entity_length - 1] |
|
entity_rep = entity_rep[0::2] |
|
entity_type_rep.append(entity_rep) |
|
|
|
|
|
word_rep = pad_sequence(word_rep, batch_first=True) |
|
mask = pad_sequence(mask, batch_first=True) |
|
entity_type_rep = pad_sequence(entity_type_rep, batch_first=True) |
|
|
|
|
|
word_rep = self.rnn(word_rep, mask) |
|
span_rep = self.span_rep_layer(word_rep, span_idx) |
|
|
|
|
|
entity_type_rep = self.prompt_rep_layer(entity_type_rep) |
|
num_classes = entity_type_rep.shape[1] |
|
|
|
|
|
scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep) |
|
|
|
return scores, num_classes, entity_type_mask |
|
|
|
def forward(self, x): |
|
|
|
scores, num_classes, entity_type_mask = self.compute_score_train(x) |
|
batch_size = scores.shape[0] |
|
|
|
|
|
logits_label = scores.view(-1, num_classes) |
|
labels = x["span_label"].view(-1) |
|
mask_label = labels != -1 |
|
labels.masked_fill_(~mask_label, 0) |
|
|
|
|
|
labels_one_hot = torch.zeros(labels.size(0), num_classes + 1, dtype=torch.float32).to(scores.device) |
|
labels_one_hot.scatter_(1, labels.unsqueeze(1), 1) |
|
labels_one_hot = labels_one_hot[:, 1:] |
|
|
|
|
|
|
|
all_losses = F.binary_cross_entropy_with_logits(logits_label, labels_one_hot, |
|
reduction='none') |
|
|
|
masked_loss = all_losses.view(batch_size, -1, num_classes) * entity_type_mask.unsqueeze(1) |
|
all_losses = masked_loss.view(-1, num_classes) |
|
|
|
mask_label = mask_label.unsqueeze(-1).expand_as(all_losses) |
|
|
|
weight_c = labels_one_hot + 1 |
|
|
|
all_losses = all_losses * mask_label.float() * weight_c |
|
return all_losses.sum() |
|
|
|
def compute_score_eval(self, x, device): |
|
|
|
assert isinstance(x['classes_to_id'], dict), "classes_to_id must be a dict" |
|
|
|
span_idx = (x['span_idx'] * x['span_mask'].unsqueeze(-1)).to(device) |
|
|
|
all_types = list(x['classes_to_id'].keys()) |
|
|
|
entity_prompt = [] |
|
|
|
|
|
for entity_type in all_types: |
|
entity_prompt.append(self.entity_token) |
|
entity_prompt.append(entity_type) |
|
|
|
entity_prompt.append(self.sep_token) |
|
|
|
prompt_entity_length = len(entity_prompt) |
|
|
|
|
|
tokens_p = [entity_prompt + tokens for tokens in x['tokens']] |
|
seq_length_p = x['seq_length'] + prompt_entity_length |
|
|
|
out = self.token_rep_layer(tokens_p, seq_length_p) |
|
|
|
word_rep_w_prompt = out["embeddings"] |
|
mask_w_prompt = out["mask"] |
|
|
|
|
|
word_rep = word_rep_w_prompt[:, prompt_entity_length:, :] |
|
mask = mask_w_prompt[:, prompt_entity_length:] |
|
|
|
|
|
entity_type_rep = word_rep_w_prompt[:, :prompt_entity_length - 1, :] |
|
|
|
entity_type_rep = entity_type_rep[:, 0::2, :] |
|
|
|
entity_type_rep = self.prompt_rep_layer(entity_type_rep) |
|
|
|
word_rep = self.rnn(word_rep, mask) |
|
|
|
span_rep = self.span_rep_layer(word_rep, span_idx) |
|
|
|
local_scores = torch.einsum('BLKD,BCD->BLKC', span_rep, entity_type_rep) |
|
|
|
return local_scores |
|
|
|
@torch.no_grad() |
|
def predict(self, x, flat_ner=False, threshold=0.5): |
|
self.eval() |
|
local_scores = self.compute_score_eval(x, device=next(self.parameters()).device) |
|
spans = [] |
|
for i, _ in enumerate(x["tokens"]): |
|
local_i = local_scores[i] |
|
wh_i = [i.tolist() for i in torch.where(torch.sigmoid(local_i) > threshold)] |
|
span_i = [] |
|
for s, k, c in zip(*wh_i): |
|
if s + k < len(x["tokens"][i]): |
|
span_i.append((s, s + k, x["id_to_classes"][c + 1], local_i[s, k, c])) |
|
span_i = greedy_search(span_i, flat_ner) |
|
spans.append(span_i) |
|
return spans |
|
|
|
def predict_entities(self, text, labels, flat_ner=True, threshold=0.5): |
|
tokens = [] |
|
start_token_idx_to_text_idx = [] |
|
end_token_idx_to_text_idx = [] |
|
for match in re.finditer(r'\w+(?:[-_]\w+)*|\S', text): |
|
tokens.append(match.group()) |
|
start_token_idx_to_text_idx.append(match.start()) |
|
end_token_idx_to_text_idx.append(match.end()) |
|
|
|
input_x = {"tokenized_text": tokens, "ner": None} |
|
x = self.collate_fn([input_x], labels) |
|
output = self.predict(x, flat_ner=flat_ner, threshold=threshold) |
|
|
|
entities = [] |
|
for start_token_idx, end_token_idx, ent_type in output[0]: |
|
start_text_idx = start_token_idx_to_text_idx[start_token_idx] |
|
end_text_idx = end_token_idx_to_text_idx[end_token_idx] |
|
entities.append({ |
|
"start": start_token_idx_to_text_idx[start_token_idx], |
|
"end": end_token_idx_to_text_idx[end_token_idx], |
|
"text": text[start_text_idx:end_text_idx], |
|
"label": ent_type, |
|
}) |
|
return entities |
|
|
|
def evaluate(self, test_data, flat_ner=False, threshold=0.5, batch_size=12, entity_types=None): |
|
self.eval() |
|
data_loader = self.create_dataloader(test_data, batch_size=batch_size, entity_types=entity_types, shuffle=False) |
|
device = next(self.parameters()).device |
|
all_preds = [] |
|
all_trues = [] |
|
for x in data_loader: |
|
for k, v in x.items(): |
|
if isinstance(v, torch.Tensor): |
|
x[k] = v.to(device) |
|
batch_predictions = self.predict(x, flat_ner, threshold) |
|
all_preds.extend(batch_predictions) |
|
all_trues.extend(x["entities"]) |
|
evaluator = Evaluator(all_trues, all_preds) |
|
out, f1 = evaluator.evaluate() |
|
return out, f1 |
|
|
|
@classmethod |
|
def _from_pretrained( |
|
cls, |
|
*, |
|
model_id: str, |
|
revision: Optional[str], |
|
cache_dir: Optional[Union[str, Path]], |
|
force_download: bool, |
|
proxies: Optional[Dict], |
|
resume_download: bool, |
|
local_files_only: bool, |
|
token: Union[str, bool, None], |
|
map_location: str = "cpu", |
|
strict: bool = False, |
|
**model_kwargs, |
|
): |
|
|
|
filenames = ["gliner_base.pt", "gliner_multi.pt"] |
|
for filename in filenames: |
|
model_file = Path(model_id) / filename |
|
if not model_file.exists(): |
|
try: |
|
model_file = hf_hub_download( |
|
repo_id=model_id, |
|
filename=filename, |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
token=token, |
|
local_files_only=local_files_only, |
|
) |
|
except HfHubHTTPError: |
|
continue |
|
dict_load = torch.load(model_file, map_location=torch.device(map_location)) |
|
config = dict_load["config"] |
|
state_dict = dict_load["model_weights"] |
|
config.model_name = "microsoft/deberta-v3-base" if filename == "gliner_base.pt" else "microsoft/mdeberta-v3-base" |
|
model = cls(config) |
|
model.load_state_dict(state_dict, strict=strict, assign=True) |
|
|
|
model.to(map_location) |
|
return model |
|
|
|
|
|
from .train import load_config_as_namespace |
|
|
|
model_file = Path(model_id) / "pytorch_model.bin" |
|
if not model_file.exists(): |
|
model_file = hf_hub_download( |
|
repo_id=model_id, |
|
filename="pytorch_model.bin", |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
token=token, |
|
local_files_only=local_files_only, |
|
) |
|
config_file = Path(model_id) / "gliner_config.json" |
|
if not config_file.exists(): |
|
config_file = hf_hub_download( |
|
repo_id=model_id, |
|
filename="gliner_config.json", |
|
revision=revision, |
|
cache_dir=cache_dir, |
|
force_download=force_download, |
|
proxies=proxies, |
|
resume_download=resume_download, |
|
token=token, |
|
local_files_only=local_files_only, |
|
) |
|
config = load_config_as_namespace(config_file) |
|
model = cls(config) |
|
state_dict = torch.load(model_file, map_location=torch.device(map_location)) |
|
model.load_state_dict(state_dict, strict=strict, assign=True) |
|
model.to(map_location) |
|
return model |
|
|
|
def save_pretrained( |
|
self, |
|
save_directory: Union[str, Path], |
|
*, |
|
config: Optional[Union[dict, "DataclassInstance"]] = None, |
|
repo_id: Optional[str] = None, |
|
push_to_hub: bool = False, |
|
**push_to_hub_kwargs, |
|
) -> Optional[str]: |
|
""" |
|
Save weights in local directory. |
|
|
|
Args: |
|
save_directory (`str` or `Path`): |
|
Path to directory in which the model weights and configuration will be saved. |
|
config (`dict` or `DataclassInstance`, *optional*): |
|
Model configuration specified as a key/value dictionary or a dataclass instance. |
|
push_to_hub (`bool`, *optional*, defaults to `False`): |
|
Whether or not to push your model to the Huggingface Hub after saving it. |
|
repo_id (`str`, *optional*): |
|
ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to the folder name if |
|
not provided. |
|
kwargs: |
|
Additional key word arguments passed along to the [`~ModelHubMixin.push_to_hub`] method. |
|
""" |
|
save_directory = Path(save_directory) |
|
save_directory.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
torch.save(self.state_dict(), save_directory / "pytorch_model.bin") |
|
|
|
|
|
if config is None: |
|
config = self.config |
|
if config is not None: |
|
if isinstance(config, argparse.Namespace): |
|
config = vars(config) |
|
(save_directory / "gliner_config.json").write_text(json.dumps(config, indent=2)) |
|
|
|
|
|
if push_to_hub: |
|
kwargs = push_to_hub_kwargs.copy() |
|
if config is not None: |
|
kwargs["config"] = config |
|
if repo_id is None: |
|
repo_id = save_directory.name |
|
return self.push_to_hub(repo_id=repo_id, **kwargs) |
|
return None |
|
|
|
def to(self, device): |
|
super().to(device) |
|
import flair |
|
|
|
flair.device = device |
|
return self |
|
|