bagualu-ie / models /model.py
han liu
init
ff78ef7
# coding=utf-8
# Copyright 2021 The IDEA Authors. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=no-member
import torch
from torch import nn, Tensor
from transformers import BertPreTrainedModel, BertModel, BertConfig
class Triaffine(nn.Module):
""" Triaffine module
Args:
triaffine_hidden_size (int): Triaffine module hidden size
"""
def __init__(self, triaffine_hidden_size: int) -> None:
super().__init__()
self.triaffine_hidden_size = triaffine_hidden_size
self.weight_start_end = nn.Parameter(
torch.zeros(triaffine_hidden_size,
triaffine_hidden_size,
triaffine_hidden_size))
nn.init.normal_(self.weight_start_end, mean=0, std=0.1)
def forward(self,
start_logits: Tensor,
end_logits: Tensor,
cls_logits: Tensor) -> Tensor:
"""forward
Args:
start_logits (Tensor): start logits
end_logits (Tensor): end logits
cls_logits (Tensor): cls logits
Returns:
Tensor: span_logits
"""
start_end_logits = torch.einsum("bxi,ioj,byj->bxyo",
start_logits,
self.weight_start_end,
end_logits)
span_logits = torch.einsum("bxyo,bzo->bxyz",
start_end_logits,
cls_logits)
return span_logits
class MLPLayer(nn.Module):
"""MLP layer
Args:
input_size (int): input size
output_size (int): output size
"""
def __init__(self, input_size: int, output_size: int) -> None:
super().__init__()
self.linear = nn.Linear(in_features=input_size, out_features=output_size)
self.act = nn.GELU()
def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name
""" forward
Args:
x (Tensor): input
Returns:
Tensor: output
"""
x = self.linear(x)
x = self.act(x)
return x
class BagualuIEModel(BertPreTrainedModel):
""" BagualuIEModel
Args:
config (BertConfig): config
"""
def __init__(self, config: BertConfig) -> None:
super().__init__(config)
self.bert = BertModel(config)
self.config = config
self.triaffine_hidden_size = 128
self.mlp_start = MLPLayer(self.config.hidden_size,
self.triaffine_hidden_size)
self.mlp_end = MLPLayer(self.config.hidden_size,
self.triaffine_hidden_size)
self.mlp_cls = MLPLayer(self.config.hidden_size,
self.triaffine_hidden_size)
self.triaffine = Triaffine(self.triaffine_hidden_size)
def forward(self, # pylint: disable=unused-argument
input_ids: Tensor,
attention_mask: Tensor,
position_ids: Tensor,
token_type_ids: Tensor,
text_len: Tensor,
label_token_idx: Tensor,
**kwargs) -> Tensor:
""" forward
Args:
input_ids (Tensor): input_ids
attention_mask (Tensor): attention_mask
position_ids (Tensor): position_ids
token_type_ids (Tensor): token_type_ids
text_len (Tensor): query length
label_token_idx (Tensor, optional): label_token_idx
Returns:
Tensor: span logits
"""
# bert forward
hidden_states = self.bert(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
token_type_ids=token_type_ids,
output_hidden_states=True)[0] # (bsz, seq, dim)
max_text_len = text_len.max()
# 获取start、end、cls的hidden_states
hidden_start_end = hidden_states[:, :max_text_len, :] # text部分表示
hidden_cls = hidden_states.gather(1, label_token_idx.unsqueeze(-1)\
.repeat(1, 1, self.config.hidden_size)) # (bsz, task, dim)
# Triaffine
span_logits = self.triaffine(self.mlp_start(hidden_start_end),
self.mlp_end(hidden_start_end),
self.mlp_cls(hidden_cls)).sigmoid()
return span_logits