Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2021 The IDEA Authors. All rights reserved. | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# pylint: disable=no-member | |
import torch | |
from torch import nn, Tensor | |
from transformers import BertPreTrainedModel, BertModel, BertConfig | |
class Triaffine(nn.Module): | |
""" Triaffine module | |
Args: | |
triaffine_hidden_size (int): Triaffine module hidden size | |
""" | |
def __init__(self, triaffine_hidden_size: int) -> None: | |
super().__init__() | |
self.triaffine_hidden_size = triaffine_hidden_size | |
self.weight_start_end = nn.Parameter( | |
torch.zeros(triaffine_hidden_size, | |
triaffine_hidden_size, | |
triaffine_hidden_size)) | |
nn.init.normal_(self.weight_start_end, mean=0, std=0.1) | |
def forward(self, | |
start_logits: Tensor, | |
end_logits: Tensor, | |
cls_logits: Tensor) -> Tensor: | |
"""forward | |
Args: | |
start_logits (Tensor): start logits | |
end_logits (Tensor): end logits | |
cls_logits (Tensor): cls logits | |
Returns: | |
Tensor: span_logits | |
""" | |
start_end_logits = torch.einsum("bxi,ioj,byj->bxyo", | |
start_logits, | |
self.weight_start_end, | |
end_logits) | |
span_logits = torch.einsum("bxyo,bzo->bxyz", | |
start_end_logits, | |
cls_logits) | |
return span_logits | |
class MLPLayer(nn.Module): | |
"""MLP layer | |
Args: | |
input_size (int): input size | |
output_size (int): output size | |
""" | |
def __init__(self, input_size: int, output_size: int) -> None: | |
super().__init__() | |
self.linear = nn.Linear(in_features=input_size, out_features=output_size) | |
self.act = nn.GELU() | |
def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name | |
""" forward | |
Args: | |
x (Tensor): input | |
Returns: | |
Tensor: output | |
""" | |
x = self.linear(x) | |
x = self.act(x) | |
return x | |
class BagualuIEModel(BertPreTrainedModel): | |
""" BagualuIEModel | |
Args: | |
config (BertConfig): config | |
""" | |
def __init__(self, config: BertConfig) -> None: | |
super().__init__(config) | |
self.bert = BertModel(config) | |
self.config = config | |
self.triaffine_hidden_size = 128 | |
self.mlp_start = MLPLayer(self.config.hidden_size, | |
self.triaffine_hidden_size) | |
self.mlp_end = MLPLayer(self.config.hidden_size, | |
self.triaffine_hidden_size) | |
self.mlp_cls = MLPLayer(self.config.hidden_size, | |
self.triaffine_hidden_size) | |
self.triaffine = Triaffine(self.triaffine_hidden_size) | |
def forward(self, # pylint: disable=unused-argument | |
input_ids: Tensor, | |
attention_mask: Tensor, | |
position_ids: Tensor, | |
token_type_ids: Tensor, | |
text_len: Tensor, | |
label_token_idx: Tensor, | |
**kwargs) -> Tensor: | |
""" forward | |
Args: | |
input_ids (Tensor): input_ids | |
attention_mask (Tensor): attention_mask | |
position_ids (Tensor): position_ids | |
token_type_ids (Tensor): token_type_ids | |
text_len (Tensor): query length | |
label_token_idx (Tensor, optional): label_token_idx | |
Returns: | |
Tensor: span logits | |
""" | |
# bert forward | |
hidden_states = self.bert(input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
token_type_ids=token_type_ids, | |
output_hidden_states=True)[0] # (bsz, seq, dim) | |
max_text_len = text_len.max() | |
# 获取start、end、cls的hidden_states | |
hidden_start_end = hidden_states[:, :max_text_len, :] # text部分表示 | |
hidden_cls = hidden_states.gather(1, label_token_idx.unsqueeze(-1)\ | |
.repeat(1, 1, self.config.hidden_size)) # (bsz, task, dim) | |
# Triaffine | |
span_logits = self.triaffine(self.mlp_start(hidden_start_end), | |
self.mlp_end(hidden_start_end), | |
self.mlp_cls(hidden_cls)).sigmoid() | |
return span_logits | |