Spaces:
Runtime error
Runtime error
han liu
commited on
Commit
•
ff78ef7
1
Parent(s):
ca2a245
init
Browse files- app.py +182 -0
- dataloaders/__init__.py +0 -0
- dataloaders/dataset_utils.py +57 -0
- dataloaders/item_decoder.py +320 -0
- dataloaders/item_encoder.py +534 -0
- models/__init__.py +2 -0
- models/extract_model.py +71 -0
- models/model.py +156 -0
app.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
## @Author: liuhan(liuhan@idea.edu.cn)
|
4 |
+
## @Created: 2022/12/28 11:24:43
|
5 |
+
# coding=utf-8
|
6 |
+
# Copyright 2021 The IDEA Authors. All rights reserved.
|
7 |
+
|
8 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
9 |
+
# you may not use this file except in compliance with the License.
|
10 |
+
# You may obtain a copy of the License at
|
11 |
+
|
12 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
13 |
+
|
14 |
+
# Unless required by applicable law or agreed to in writing, software
|
15 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
16 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
17 |
+
# See the License for the specific language governing permissions and
|
18 |
+
# limitations under the License.
|
19 |
+
from typing import List, Dict
|
20 |
+
from logging import basicConfig
|
21 |
+
import json
|
22 |
+
import os
|
23 |
+
import numpy as np
|
24 |
+
from transformers import AutoTokenizer
|
25 |
+
import argparse
|
26 |
+
import copy
|
27 |
+
import streamlit as st
|
28 |
+
import time
|
29 |
+
|
30 |
+
|
31 |
+
|
32 |
+
from models import BagualuIEModel, BagualuIEExtractModel
|
33 |
+
|
34 |
+
|
35 |
+
class BagualuIEPipelines:
|
36 |
+
def __init__(self, args: argparse.Namespace) -> None:
|
37 |
+
self.args = args
|
38 |
+
# load model
|
39 |
+
self.model = BagualuIEModel.from_pretrained(args.pretrained_model_root)
|
40 |
+
|
41 |
+
|
42 |
+
# get tokenizer
|
43 |
+
added_token = [f"[unused{i + 1}]" for i in range(99)]
|
44 |
+
self.tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_root,
|
45 |
+
additional_special_tokens=added_token)
|
46 |
+
|
47 |
+
def predict(self, test_data: List[dict], cuda: bool = True) -> List[dict]:
|
48 |
+
""" predict
|
49 |
+
|
50 |
+
Args:
|
51 |
+
test_data (List[dict]): test data
|
52 |
+
cuda (bool, optional): cuda. Defaults to True.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
List[dict]: result
|
56 |
+
"""
|
57 |
+
result = []
|
58 |
+
if cuda:
|
59 |
+
self.model = self.model.cuda()
|
60 |
+
self.model.eval()
|
61 |
+
|
62 |
+
batch_size = self.args.batch_size
|
63 |
+
extract_model = BagualuIEExtractModel(self.tokenizer, self.args)
|
64 |
+
|
65 |
+
for i in range(0, len(test_data), batch_size):
|
66 |
+
batch_data = test_data[i: i + batch_size]
|
67 |
+
batch_result = extract_model.extract(batch_data, self.model, cuda)
|
68 |
+
result.extend(batch_result)
|
69 |
+
return result
|
70 |
+
|
71 |
+
|
72 |
+
@st.experimental_memo()
|
73 |
+
def load_model(model_path):
|
74 |
+
parser = argparse.ArgumentParser()
|
75 |
+
|
76 |
+
# pipeline arguments
|
77 |
+
group_parser = parser.add_argument_group("piplines args")
|
78 |
+
group_parser.add_argument("--pretrained_model_root", default="", type=str)
|
79 |
+
group_parser.add_argument("--load_checkpoints_path", default="", type=str)
|
80 |
+
|
81 |
+
group_parser.add_argument("--threshold_ent", default=0.3, type=float)
|
82 |
+
group_parser.add_argument("--threshold_rel", default=0.3, type=float)
|
83 |
+
group_parser.add_argument("--entity_multi_label", action="store_true", default=True)
|
84 |
+
group_parser.add_argument("--relation_multi_label", action="store_true", default=True)
|
85 |
+
|
86 |
+
|
87 |
+
# data model arguments
|
88 |
+
group_parser = parser.add_argument_group("data_model")
|
89 |
+
group_parser.add_argument("--batch_size", default=4, type=int)
|
90 |
+
group_parser.add_argument("--max_length", default=512, type=int)
|
91 |
+
# pytorch_lightning.Trainer参数
|
92 |
+
args = parser.parse_args()
|
93 |
+
args.pretrained_model_root = model_path
|
94 |
+
|
95 |
+
model = BagualuIEPipelines(args)
|
96 |
+
return model
|
97 |
+
|
98 |
+
def main():
|
99 |
+
|
100 |
+
# model = load_model('/cognitive_comp/liuhan/pretrained/uniex_macbert_base_v7.1/')
|
101 |
+
model = load_model('IDEA-CCNL/Erlangshen-BERT-120M-IE-Chinese')
|
102 |
+
|
103 |
+
#
|
104 |
+
|
105 |
+
st.subheader("Erlangshen-BERT-120M-IE-Chinese Zero-shot 体验")
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
st.markdown("""
|
110 |
+
Erlangshen-BERT-120M-IE-Chinese是以110M参数的base模型为底座,基于大规模信息抽取数据进行预训练后的模型,
|
111 |
+
通过统一的抽取架构设计,可支持few-shot、zero-shot场景下的实体识别、关系三元组抽取任务。
|
112 |
+
更多信息见https://github.com/IDEA-CCNL/GTS-Engine/tree/main
|
113 |
+
模型效果见https://huggingface.co/IDEA-CCNL/Erlangshen-BERT-120M-IE-Chinese
|
114 |
+
""")
|
115 |
+
|
116 |
+
st.info("Please input the following information to experiencing Bagualu-IE「请输入以下信息开始体验 Bagualu-IE...」")
|
117 |
+
model_type = st.selectbox('Select task type「选择任务类型」',['Named Entity Recognition「命名实体识别」','Relation Extraction「关系抽取」'])
|
118 |
+
if '命名实体识别' in model_type:
|
119 |
+
example = st.selectbox('Example', ['Example: 人物信息', 'Example: 财经新闻'])
|
120 |
+
else:
|
121 |
+
example = st.selectbox('Example', ['Example: 雇佣关系', 'Example: 影视关系'])
|
122 |
+
form = st.form("参数设置")
|
123 |
+
if '命名实体识别' in model_type:
|
124 |
+
if '人物信息' in example:
|
125 |
+
sentences = form.text_area(
|
126 |
+
"Please input the context「请输入句子」",
|
127 |
+
"姚明,男,汉族,无党派人士,前中国职业篮球运动员。")
|
128 |
+
choice = form.text_input("Please input the choice「请输入抽取���体名称,用中文;分割」", "姓名;性别;民族;运动项目;政治面貌")
|
129 |
+
else:
|
130 |
+
sentences = form.text_area(
|
131 |
+
"Please input the context「请输入句子」",
|
132 |
+
"寒流吹响华尔街,摩根士丹利、高盛、瑞信三大银行裁员合计超过8千人")
|
133 |
+
choice = form.text_input("Please input the choice「请输入抽取实体名称,用中文;分割」", "裁员单位;裁员人数")
|
134 |
+
|
135 |
+
else:
|
136 |
+
if '雇佣关系' in example:
|
137 |
+
sentences = form.text_area(
|
138 |
+
"Please input the context「请输入句子」",
|
139 |
+
"东阳市企业家协会六届一次会员大会上,横店集团董事长、总裁徐永安当选为东阳市企业家协会会长。")
|
140 |
+
choice = form.text_input("Please input the choice「请输入抽取关系名称,用中文;分割(头实体类型|关系|尾实体类型)」", "企业|董事长|人物")
|
141 |
+
else:
|
142 |
+
sentences = form.text_area(
|
143 |
+
"Please input the context「请输入句子」",
|
144 |
+
"《傲骨贤妻第六季》是一套美国法律剧情电视连续剧,2014年9月29日在CBS上首播。")
|
145 |
+
choice = form.text_input("Please input the choice「请输入抽取关系名称,用中文;分割(头实体类型|关系|尾实体类型)」", "影视作品|上映时间|时间")
|
146 |
+
|
147 |
+
form.form_submit_button("Submit「点击一下,开始预测!」")
|
148 |
+
|
149 |
+
|
150 |
+
if '命名实体识别' in model_type:
|
151 |
+
data = [{"task": '实体识别',
|
152 |
+
"text": sentences,
|
153 |
+
"entity_list": [],
|
154 |
+
"choice": choice.split(';'),
|
155 |
+
}]
|
156 |
+
else:
|
157 |
+
choice = [one.split('|') for one in choice.split(';')]
|
158 |
+
data = [{"task": '关系抽取',
|
159 |
+
"text": sentences,
|
160 |
+
"entity_list": [],
|
161 |
+
"choice": choice,
|
162 |
+
}]
|
163 |
+
|
164 |
+
|
165 |
+
start = time.time()
|
166 |
+
# is_cuda= True if torch.cuda.is_available() else False
|
167 |
+
# result = model.predict(data, cuda=is_cuda)
|
168 |
+
|
169 |
+
# st.success(f"Prediction is successful, consumes {str(time.time()-start)} seconds")
|
170 |
+
# st.json(result[0])
|
171 |
+
|
172 |
+
rs = model.predict(data, False)
|
173 |
+
st.success(f"Prediction is successful, consumes {str(time.time() - start)} seconds")
|
174 |
+
st.json(rs[0])
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
if __name__ == "__main__":
|
181 |
+
main()
|
182 |
+
|
dataloaders/__init__.py
ADDED
File without changes
|
dataloaders/dataset_utils.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The IDEA Authors. All rights reserved.
|
3 |
+
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import collections
|
17 |
+
from typing import List, Dict, Tuple
|
18 |
+
|
19 |
+
|
20 |
+
def get_choice(spo_choice: list) -> tuple:
|
21 |
+
""" 把关系schema中的关系、实体获取出来
|
22 |
+
|
23 |
+
Args:
|
24 |
+
spo_choice (list): 关系schema
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
tuple:
|
28 |
+
choice_ent (list)
|
29 |
+
choice_rel (list)
|
30 |
+
choice_head (list)
|
31 |
+
choice_tail (list)
|
32 |
+
entity2rel (dict)
|
33 |
+
"""
|
34 |
+
choice_head = []
|
35 |
+
choice_tail = []
|
36 |
+
choice_ent = []
|
37 |
+
choice_rel = []
|
38 |
+
entity2rel = collections.defaultdict(list) # "subject|object" -> [relation]
|
39 |
+
|
40 |
+
for head, rel, tail in spo_choice:
|
41 |
+
|
42 |
+
if head not in choice_head:
|
43 |
+
choice_head.append(head)
|
44 |
+
if tail not in choice_tail:
|
45 |
+
choice_tail.append(tail)
|
46 |
+
|
47 |
+
if head not in choice_ent:
|
48 |
+
choice_ent.append(head)
|
49 |
+
if tail not in choice_ent:
|
50 |
+
choice_ent.append(tail)
|
51 |
+
|
52 |
+
if rel not in choice_rel:
|
53 |
+
choice_rel.append(rel)
|
54 |
+
|
55 |
+
entity2rel[head, tail].append(rel)
|
56 |
+
|
57 |
+
return choice_ent, choice_rel, choice_head, choice_tail, entity2rel
|
dataloaders/item_decoder.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The IDEA Authors. All rights reserved.
|
3 |
+
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# from collections import defaultdict
|
17 |
+
from typing import List, Tuple, Dict
|
18 |
+
import argparse
|
19 |
+
import numpy as np
|
20 |
+
from transformers import PreTrainedTokenizer
|
21 |
+
|
22 |
+
from .item_encoder import entity_based_tokenize, get_entity_indices
|
23 |
+
from .dataset_utils import get_choice
|
24 |
+
|
25 |
+
|
26 |
+
class ItemDecoder(object):
|
27 |
+
""" Decoder
|
28 |
+
|
29 |
+
Args:
|
30 |
+
tokenizer (PreTrainedTokenizer): tokenizer
|
31 |
+
args (TrainingArgumentsIEStd): arguments
|
32 |
+
"""
|
33 |
+
def __init__(self,
|
34 |
+
tokenizer: PreTrainedTokenizer,
|
35 |
+
args: argparse.Namespace) -> None:
|
36 |
+
self.tokenizer = tokenizer
|
37 |
+
self.max_length = args.max_length
|
38 |
+
self.threshold_entity = args.threshold_ent
|
39 |
+
self.threshold_rel = args.threshold_rel
|
40 |
+
self.entity_multi_label = args.entity_multi_label
|
41 |
+
self.relation_multi_label = args.relation_multi_label
|
42 |
+
|
43 |
+
def extract_entity_index(self,
|
44 |
+
entity_logits: np.ndarray,
|
45 |
+
) -> List[Tuple[int, int]]:
|
46 |
+
""" extract entity index
|
47 |
+
|
48 |
+
Args:
|
49 |
+
entity_logits (np.ndarray): entity_logits
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
List[Tuple[int, int]]: result
|
53 |
+
"""
|
54 |
+
|
55 |
+
l, _, d = entity_logits.shape
|
56 |
+
result = []
|
57 |
+
for i in range(l):
|
58 |
+
for j in range(i, l):
|
59 |
+
if self.entity_multi_label:
|
60 |
+
for k in range(d):
|
61 |
+
entity_score = float(entity_logits[i, j, k])
|
62 |
+
if entity_score > self.threshold_entity:
|
63 |
+
result.append((i, j, k, entity_score))
|
64 |
+
|
65 |
+
else:
|
66 |
+
k = np.argmax(entity_logits[i, j])
|
67 |
+
entity_score = float(entity_logits[i, j, k])
|
68 |
+
if entity_score > self.threshold_entity:
|
69 |
+
result.append((i, j, k, entity_score))
|
70 |
+
|
71 |
+
return result
|
72 |
+
|
73 |
+
@staticmethod
|
74 |
+
def extract_entity(text: str,
|
75 |
+
entity_idx: List[int],
|
76 |
+
entity_type: str,
|
77 |
+
entity_score: float,
|
78 |
+
text_start_id: int,
|
79 |
+
offset_mapping: List[List[int]]) -> dict:
|
80 |
+
""" extract entity
|
81 |
+
|
82 |
+
Args:
|
83 |
+
text (str): text
|
84 |
+
entity_idx (List[int]): entity indices
|
85 |
+
entity_type (str): entity type
|
86 |
+
entity_score (float): entity score
|
87 |
+
text_start_id (int): text_start_id
|
88 |
+
offset_mapping (List[List[int]]): offset mapping
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
dict: entity
|
92 |
+
"""
|
93 |
+
entity_start, entity_end = entity_idx[0] - text_start_id, entity_idx[1] - text_start_id
|
94 |
+
|
95 |
+
start_split = offset_mapping[entity_start] if 0 <= entity_start < len(offset_mapping) else []
|
96 |
+
end_split = offset_mapping[entity_end] if 0 <= entity_end < len(offset_mapping) else []
|
97 |
+
|
98 |
+
if not start_split or not end_split:
|
99 |
+
return None
|
100 |
+
|
101 |
+
start_idx, end_idx = start_split[0], end_split[-1]
|
102 |
+
entity_text = text[start_idx: end_idx]
|
103 |
+
|
104 |
+
if not entity_text:
|
105 |
+
return None
|
106 |
+
|
107 |
+
entity = {
|
108 |
+
"entity_text": entity_text,
|
109 |
+
"entity_type": entity_type,
|
110 |
+
"score": entity_score,
|
111 |
+
"entity_index": [start_idx, end_idx]
|
112 |
+
}
|
113 |
+
|
114 |
+
return entity
|
115 |
+
|
116 |
+
def decode_ner(self,
|
117 |
+
text: str,
|
118 |
+
choice: List[str],
|
119 |
+
sample_span_logits: np.ndarray,
|
120 |
+
offset_mapping: List[List[int]]
|
121 |
+
) -> List[dict]:
|
122 |
+
""" NER decode
|
123 |
+
|
124 |
+
Args:
|
125 |
+
text (str): text
|
126 |
+
choice (List[str]): choice
|
127 |
+
sample_span_logits (np.ndarray): sample span_logits
|
128 |
+
offset_mapping (List[List[int]]): offset mapping
|
129 |
+
|
130 |
+
|
131 |
+
Returns:
|
132 |
+
List[dict]: decoded entity list
|
133 |
+
"""
|
134 |
+
entity_list = []
|
135 |
+
|
136 |
+
entity_idx_list = self.extract_entity_index(sample_span_logits)
|
137 |
+
|
138 |
+
for entity_start, entity_end, entity_type_idx, entity_score in entity_idx_list:
|
139 |
+
|
140 |
+
entity = self.extract_entity(text,
|
141 |
+
[entity_start, entity_end],
|
142 |
+
choice[entity_type_idx],
|
143 |
+
entity_score,
|
144 |
+
text_start_id=1,
|
145 |
+
offset_mapping=offset_mapping)
|
146 |
+
|
147 |
+
if entity is None:
|
148 |
+
continue
|
149 |
+
|
150 |
+
if entity not in entity_list:
|
151 |
+
entity_list.append(entity)
|
152 |
+
|
153 |
+
return entity_list
|
154 |
+
|
155 |
+
def decode_spo(self,
|
156 |
+
text: str,
|
157 |
+
choice: List[List[str]],
|
158 |
+
sample_span_logits: np.ndarray,
|
159 |
+
offset_mapping: List[List[int]]) -> tuple:
|
160 |
+
""" SPO decode
|
161 |
+
|
162 |
+
Args:
|
163 |
+
text (str): text
|
164 |
+
choice (List[List[str]]): choice
|
165 |
+
sample_span_logits (np.ndarray): sample span_logits
|
166 |
+
offset_mapping (List[List[int]): offset mapping
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
List[dict]: decoded spo list
|
170 |
+
List[dict]: decoded entity list
|
171 |
+
"""
|
172 |
+
spo_list = []
|
173 |
+
entity_list = []
|
174 |
+
|
175 |
+
choice_ent, choice_rel, choice_head, choice_tail, entity2rel = get_choice(choice)
|
176 |
+
|
177 |
+
entity_logits = sample_span_logits[:, :, : len(choice_ent)] # (seq_len, seq_len, num_entity)
|
178 |
+
relation_logits = sample_span_logits[:, :, len(choice_ent): ] # (seq_len, seq_len, num_relation)
|
179 |
+
|
180 |
+
entity_idx_list = self.extract_entity_index(entity_logits)
|
181 |
+
|
182 |
+
head_list = []
|
183 |
+
tail_list = []
|
184 |
+
for entity_start, entity_end, entity_type_idx, entity_score in entity_idx_list:
|
185 |
+
|
186 |
+
entity_type = choice_ent[entity_type_idx]
|
187 |
+
|
188 |
+
entity = self.extract_entity(text,
|
189 |
+
[entity_start, entity_end],
|
190 |
+
entity_type,
|
191 |
+
entity_score,
|
192 |
+
text_start_id=1,
|
193 |
+
offset_mapping=offset_mapping)
|
194 |
+
|
195 |
+
if entity is None:
|
196 |
+
continue
|
197 |
+
|
198 |
+
if entity_type in choice_head:
|
199 |
+
head_list.append((entity_start, entity_end, entity_type, entity))
|
200 |
+
if entity_type in choice_tail:
|
201 |
+
tail_list.append((entity_start, entity_end, entity_type, entity))
|
202 |
+
|
203 |
+
for head_start, head_end, subject_type, subject_dict in head_list:
|
204 |
+
for tail_start, tail_end, object_type, object_dict in tail_list:
|
205 |
+
|
206 |
+
if subject_dict == object_dict:
|
207 |
+
continue
|
208 |
+
|
209 |
+
if (subject_type, object_type) not in entity2rel.keys():
|
210 |
+
continue
|
211 |
+
|
212 |
+
relation_candidates = entity2rel[subject_type, object_type]
|
213 |
+
rel_idx = [choice_rel.index(r) for r in relation_candidates]
|
214 |
+
|
215 |
+
so_rel_logits = relation_logits[:, :, rel_idx]
|
216 |
+
|
217 |
+
if self.relation_multi_label:
|
218 |
+
for idx, predicate in enumerate(relation_candidates):
|
219 |
+
rel_score = so_rel_logits[head_start, tail_start, idx] + \
|
220 |
+
so_rel_logits[head_end, tail_end, idx]
|
221 |
+
predicate_score = float(rel_score / 2)
|
222 |
+
|
223 |
+
if predicate_score <= self.threshold_rel:
|
224 |
+
continue
|
225 |
+
|
226 |
+
if subject_dict not in entity_list:
|
227 |
+
entity_list.append(subject_dict)
|
228 |
+
if object_dict not in entity_list:
|
229 |
+
entity_list.append(object_dict)
|
230 |
+
|
231 |
+
spo = {
|
232 |
+
"predicate": predicate,
|
233 |
+
"score": predicate_score,
|
234 |
+
"subject": subject_dict,
|
235 |
+
"object": object_dict,
|
236 |
+
}
|
237 |
+
|
238 |
+
if spo not in spo_list:
|
239 |
+
spo_list.append(spo)
|
240 |
+
|
241 |
+
else:
|
242 |
+
|
243 |
+
hh_idx = np.argmax(so_rel_logits[head_start, head_end])
|
244 |
+
tt_idx = np.argmax(so_rel_logits[tail_start, tail_end])
|
245 |
+
hh_score = so_rel_logits[head_start, tail_start, hh_idx] + so_rel_logits[head_end, tail_end, hh_idx]
|
246 |
+
tt_score = so_rel_logits[head_start, tail_start, tt_idx] + so_rel_logits[head_end, tail_end, tt_idx]
|
247 |
+
|
248 |
+
predicate = relation_candidates[hh_idx] if hh_score > tt_score else relation_candidates[tt_idx]
|
249 |
+
|
250 |
+
predicate_score = float(max(hh_score, tt_score) / 2)
|
251 |
+
|
252 |
+
if predicate_score <= self.threshold_rel:
|
253 |
+
continue
|
254 |
+
|
255 |
+
if subject_dict not in entity_list:
|
256 |
+
entity_list.append(subject_dict)
|
257 |
+
if object_dict not in entity_list:
|
258 |
+
entity_list.append(object_dict)
|
259 |
+
|
260 |
+
spo = {
|
261 |
+
"predicate": predicate,
|
262 |
+
"score": predicate_score,
|
263 |
+
"subject": subject_dict,
|
264 |
+
"object": object_dict,
|
265 |
+
}
|
266 |
+
|
267 |
+
if spo not in spo_list:
|
268 |
+
spo_list.append(spo)
|
269 |
+
|
270 |
+
return spo_list, entity_list
|
271 |
+
|
272 |
+
def decode(self,
|
273 |
+
item: Dict,
|
274 |
+
span_logits: np.ndarray,
|
275 |
+
label_mask: np.ndarray,
|
276 |
+
):
|
277 |
+
""" decode
|
278 |
+
|
279 |
+
Args:
|
280 |
+
task (str): task name
|
281 |
+
choice (list): choice
|
282 |
+
text (str): text
|
283 |
+
span_logits (np.ndarray): sample span_logits
|
284 |
+
label_mask (np.ndarray): label_mask
|
285 |
+
|
286 |
+
Raises:
|
287 |
+
NotImplementedError: raised if task name is not supported
|
288 |
+
|
289 |
+
Returns:
|
290 |
+
List[dict]: decoded entity list
|
291 |
+
List[dict]: decoded spo list
|
292 |
+
"""
|
293 |
+
task, choice, text = item["task"], item["choice"], item["text"]
|
294 |
+
entity_indices = get_entity_indices(item.get("entity_list", []), item.get("spo_list", []))
|
295 |
+
_, offset_mapping = entity_based_tokenize(text, self.tokenizer, entity_indices,
|
296 |
+
return_offsets_mapping=True)
|
297 |
+
|
298 |
+
assert span_logits.shape == label_mask.shape
|
299 |
+
|
300 |
+
span_logits = span_logits + (label_mask - 1) * 100000
|
301 |
+
|
302 |
+
spo_list = []
|
303 |
+
entity_list = []
|
304 |
+
|
305 |
+
if task in {"实体识别", "抽取任务"}:
|
306 |
+
entity_list = self.decode_ner(text,
|
307 |
+
choice,
|
308 |
+
span_logits,
|
309 |
+
offset_mapping)
|
310 |
+
|
311 |
+
elif task in {"关系抽取"}:
|
312 |
+
spo_list, entity_list = self.decode_spo(text,
|
313 |
+
choice,
|
314 |
+
span_logits,
|
315 |
+
offset_mapping)
|
316 |
+
|
317 |
+
else:
|
318 |
+
raise NotImplementedError
|
319 |
+
|
320 |
+
return entity_list, spo_list
|
dataloaders/item_encoder.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The IDEA Authors. All rights reserved.
|
3 |
+
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: disable=no-member
|
17 |
+
|
18 |
+
from typing import List, Tuple, Dict, Union
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from transformers import PreTrainedTokenizer
|
24 |
+
|
25 |
+
from .dataset_utils import get_choice
|
26 |
+
|
27 |
+
|
28 |
+
def get_entity_indices(entity_list: List[dict], spo_list: List[dict]) -> List[List[int]]:
|
29 |
+
""" 获取样本中包含的实体位置信息
|
30 |
+
|
31 |
+
Args:
|
32 |
+
entity_list (List[dict]): 实体列表
|
33 |
+
spo_list (List[dict]): 三元组列表
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
List[List[int]]: 实体位置信息
|
37 |
+
"""
|
38 |
+
entity_indices = []
|
39 |
+
|
40 |
+
# 实体中的实体位置
|
41 |
+
for entity in entity_list:
|
42 |
+
entity_index = entity["entity_index"]
|
43 |
+
entity_indices.append(entity_index)
|
44 |
+
|
45 |
+
# 三元组中的实体位置
|
46 |
+
for spo in spo_list:
|
47 |
+
sub_idx = spo["subject"]["entity_index"]
|
48 |
+
obj_idx = spo["object"]["entity_index"]
|
49 |
+
entity_indices.append(sub_idx)
|
50 |
+
entity_indices.append(obj_idx)
|
51 |
+
|
52 |
+
return entity_indices
|
53 |
+
|
54 |
+
|
55 |
+
def entity_based_tokenize(text: str,
|
56 |
+
tokenizer: PreTrainedTokenizer,
|
57 |
+
enitity_indices: List[Tuple[int, int]],
|
58 |
+
max_len: int = -1,
|
59 |
+
return_offsets_mapping: bool = False) \
|
60 |
+
-> Union[List[int], Tuple[List[int], List[Tuple[int, int]]]]:
|
61 |
+
""" 基于实体位置信息的编码,确保实体为连续1到多个token的合并,同时利用预训练模型词根信息
|
62 |
+
|
63 |
+
Args:
|
64 |
+
text (str): 文本
|
65 |
+
tokenizer (PreTrainedTokenizer): tokenizer
|
66 |
+
enitity_indices (List[Tuple[int, int]]): 实体位置信息
|
67 |
+
max_len (int, optional): 长度限制. Defaults to -1.
|
68 |
+
return_offsets_mapping (bool, optional): 是否返回offsets_mapping. Defaults to False.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
Union[List[int], Tuple[List[int], List[Tuple[int, int]]]]: 编码id
|
72 |
+
"""
|
73 |
+
# 根据实体位置遍历出需要对文本进行切割的点
|
74 |
+
split_points = sorted(list({i for idx in enitity_indices for i in idx} | {0, len(text)}))
|
75 |
+
# 对文本进行切割
|
76 |
+
text_parts = []
|
77 |
+
for i in range(0, len(split_points) - 1):
|
78 |
+
text_parts.append(text[split_points[i]: split_points[i + 1]])
|
79 |
+
|
80 |
+
# 对切割后的文本进行编码
|
81 |
+
bias = 0
|
82 |
+
text_ids = []
|
83 |
+
offset_mapping = []
|
84 |
+
for part in text_parts:
|
85 |
+
|
86 |
+
part_encoded = tokenizer(part, add_special_tokens=False, return_offsets_mapping=True)
|
87 |
+
part_ids, part_mapping = part_encoded["input_ids"], part_encoded["offset_mapping"]
|
88 |
+
|
89 |
+
text_ids.extend(part_ids)
|
90 |
+
for start, end in part_mapping:
|
91 |
+
offset_mapping.append((start + bias, end + bias))
|
92 |
+
|
93 |
+
bias += len(part)
|
94 |
+
|
95 |
+
if max_len > 0:
|
96 |
+
text_ids = text_ids[: max_len]
|
97 |
+
|
98 |
+
# 是否返回offsets_mapping
|
99 |
+
if return_offsets_mapping:
|
100 |
+
return text_ids, offset_mapping
|
101 |
+
return text_ids
|
102 |
+
|
103 |
+
|
104 |
+
class ItemEncoder(object):
|
105 |
+
""" Item Encoder
|
106 |
+
|
107 |
+
Args:
|
108 |
+
tokenizer (PreTrainedTokenizer): tokenizer
|
109 |
+
max_length (int): max length
|
110 |
+
"""
|
111 |
+
def __init__(self, tokenizer: PreTrainedTokenizer, max_length: int) -> None:
|
112 |
+
self.tokenizer = tokenizer
|
113 |
+
self.max_length = max_length
|
114 |
+
|
115 |
+
def search_index(self,
|
116 |
+
entity_idx: List[int],
|
117 |
+
offset_mapping: List[Tuple[int, int]],
|
118 |
+
bias: int = 0) -> Tuple[int, int]:
|
119 |
+
""" 查找实体在tokens中的索引
|
120 |
+
|
121 |
+
Args:
|
122 |
+
entity_idx (List[int]): entity index
|
123 |
+
offset_mapping (List[Tuple[int, int]]): text
|
124 |
+
bias (int): bias
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
Tuple[int]: (start_idx, end_idx)
|
128 |
+
"""
|
129 |
+
entity_start, entity_end = entity_idx
|
130 |
+
start_idx, end_idx = -1, -1
|
131 |
+
|
132 |
+
for token_idx, (start, end) in enumerate(offset_mapping):
|
133 |
+
if start == entity_start:
|
134 |
+
start_idx = token_idx
|
135 |
+
if end == entity_end:
|
136 |
+
end_idx = token_idx
|
137 |
+
assert start_idx >= 0 and end_idx >= 0
|
138 |
+
|
139 |
+
return start_idx + bias, end_idx + bias
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def get_position_ids(text_len: int,
|
143 |
+
ent_ranges: List,
|
144 |
+
rel_ranges: List) -> np.ndarray:
|
145 |
+
""" 获取position_ids
|
146 |
+
|
147 |
+
Args:
|
148 |
+
text_len (int): input length
|
149 |
+
ent_ranges (List[List[int, int]]): each entity ranges idx
|
150 |
+
rel_ranges (List[List[int, int]]): each relation ranges idx.
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
np.ndarray: position_ids
|
154 |
+
"""
|
155 |
+
# 一切从0开始算position,@liuhan
|
156 |
+
text_pos_ids = list(range(text_len))
|
157 |
+
|
158 |
+
ent_pos_ids, rel_pos_ids = [], []
|
159 |
+
for s, e in ent_ranges:
|
160 |
+
ent_pos_ids.extend(list(range(e - s)))
|
161 |
+
for s, e in rel_ranges:
|
162 |
+
rel_pos_ids.extend(list(range(e - s)))
|
163 |
+
position_ids = text_pos_ids + ent_pos_ids + rel_pos_ids
|
164 |
+
|
165 |
+
return position_ids
|
166 |
+
|
167 |
+
@staticmethod
|
168 |
+
def get_att_mask(input_len: int,
|
169 |
+
ent_ranges: List,
|
170 |
+
rel_ranges: List= None,
|
171 |
+
choice_ent: List[str] = None,
|
172 |
+
choice_rel: List[str] = None,
|
173 |
+
entity2rel: dict = None,
|
174 |
+
full_attent: bool = False) -> np.ndarray:
|
175 |
+
""" 获取att_mask,不同choice之间的attention_mask置零
|
176 |
+
|
177 |
+
Args:
|
178 |
+
input_len (int): input length
|
179 |
+
ent_ranges (List[List[int, int]]): each entity ranges idx
|
180 |
+
rel_ranges (List[List[int, int]]): each relation ranges idx. Defaults to None.
|
181 |
+
choice_ent (List[str], optional): choice entity. Defaults to None.
|
182 |
+
choice_rel (List[str], optional): choice relation. Defaults to None.
|
183 |
+
entity2rel (dict, optional): entity to relations. Defaults to None.
|
184 |
+
full_attent (bool, optional): is full attention or not. Defaults to None.
|
185 |
+
Returns:
|
186 |
+
np.ndarray: attention mask
|
187 |
+
"""
|
188 |
+
|
189 |
+
# attention_mask.shape = (input_len, input_len)
|
190 |
+
attention_mask = np.ones((input_len, input_len))
|
191 |
+
if full_attent and not rel_ranges: # full-attention且没有关系情况下,返回全1
|
192 |
+
return attention_mask
|
193 |
+
|
194 |
+
# input_ids: [CLS] text [SEP] [unused1] ent1 [unused2] rel1 [unused3] event1
|
195 |
+
text_len = ent_ranges[0][0] # text长度
|
196 |
+
# 将text-实体之间的attention置零,text看不到实体,不受传入的entity个数、顺序影响 @liuhan
|
197 |
+
attention_mask[:text_len, text_len:] = 0
|
198 |
+
|
199 |
+
# 将实体-实体、实体关系之间的attention_mask置零
|
200 |
+
attention_mask[text_len:, text_len: ] = 0
|
201 |
+
|
202 |
+
# 将每个实体与自己的attention_mask置一
|
203 |
+
for s, e in ent_ranges:
|
204 |
+
attention_mask[s: e, s: e] = 1
|
205 |
+
|
206 |
+
# 没有关系的话,直接返回
|
207 |
+
if not rel_ranges:
|
208 |
+
return attention_mask
|
209 |
+
|
210 |
+
# 处理有关系情况
|
211 |
+
|
212 |
+
# 关系自身attention_mask置1
|
213 |
+
for s, e in rel_ranges:
|
214 |
+
attention_mask[s: e, s: e] = 1
|
215 |
+
|
216 |
+
# 将有关联的实体-关系置一
|
217 |
+
for head_tail, relations in entity2rel.items():
|
218 |
+
for entity_type in head_tail:
|
219 |
+
ent_idx = choice_ent.index(entity_type)
|
220 |
+
ent_s, _ = ent_ranges[ent_idx] # ent_s, ent_e
|
221 |
+
for relation_type in relations:
|
222 |
+
rel_idx = choice_rel.index(relation_type)
|
223 |
+
rel_s, rel_e = rel_ranges[rel_idx]
|
224 |
+
attention_mask[rel_s: rel_e, ent_s] = 1 # 关系只看实体第一个的[unused1]
|
225 |
+
|
226 |
+
if full_attent: # full-attention且有关系情况下,让文本能看见关系
|
227 |
+
for s, e in rel_ranges:
|
228 |
+
attention_mask[: text_len, s: e] = 1
|
229 |
+
|
230 |
+
return attention_mask
|
231 |
+
|
232 |
+
def encode(self,
|
233 |
+
text: str,
|
234 |
+
task_name: str,
|
235 |
+
choice: List[str],
|
236 |
+
entity_list: List[dict],
|
237 |
+
spo_list: List[dict],
|
238 |
+
full_attent: bool = False,
|
239 |
+
with_label: bool = True) -> Dict[str, torch.Tensor]:
|
240 |
+
""" encode
|
241 |
+
|
242 |
+
Args:
|
243 |
+
text (str): text
|
244 |
+
task_name (str): task name
|
245 |
+
choice (List[str]): choice
|
246 |
+
entity_list (List[dict]): entity list
|
247 |
+
spo_list (List[dict]): spo list
|
248 |
+
full_attent (bool): full attention
|
249 |
+
with_label (bool): encoded with label. Defaults to True.
|
250 |
+
|
251 |
+
Returns:
|
252 |
+
Dict[str, torch.Tensor]: encoded
|
253 |
+
"""
|
254 |
+
choice_ent, choice_rel, entity2rel = choice, [], {}
|
255 |
+
if isinstance(choice, list):
|
256 |
+
if isinstance(choice[0], list): # 关系抽取 & 实体识别
|
257 |
+
choice_ent, choice_rel, _, _, entity2rel = get_choice(choice)
|
258 |
+
elif isinstance(choice, dict):
|
259 |
+
# 事件类型
|
260 |
+
raise ValueError('event extract not supported now!')
|
261 |
+
else:
|
262 |
+
raise NotImplementedError
|
263 |
+
|
264 |
+
input_ids = []
|
265 |
+
text_ids = [] # text部分id
|
266 |
+
ent_ids = [] # entity部分id
|
267 |
+
rel_ids = [] # relation部分id
|
268 |
+
entity_labels_idx = []
|
269 |
+
relation_labels_idx = []
|
270 |
+
|
271 |
+
sep_ids = self.tokenizer.encode("[SEP]", add_special_tokens=False) # [SEP]的编码
|
272 |
+
cls_ids = self.tokenizer.encode("[CLS]", add_special_tokens=False) # [CLS]的编码
|
273 |
+
entity_op_ids = self.tokenizer.encode("[unused1]", add_special_tokens=False) # [unused1]的编码
|
274 |
+
relation_op_ids = self.tokenizer.encode("[unused2]", add_special_tokens=False) # [unused2]的编码
|
275 |
+
|
276 |
+
# 任务名称的编码
|
277 |
+
task_ids = self.tokenizer.encode(task_name, add_special_tokens=False)
|
278 |
+
|
279 |
+
# 实体标签的编码
|
280 |
+
for c in choice_ent:
|
281 |
+
c_ids = self.tokenizer.encode(c, add_special_tokens=False)[: self.max_length]
|
282 |
+
ent_ids += entity_op_ids + c_ids
|
283 |
+
|
284 |
+
# 关系标签的编码
|
285 |
+
for c in choice_rel:
|
286 |
+
c_ids = self.tokenizer.encode(c, add_special_tokens=False)[: self.max_length]
|
287 |
+
rel_ids += relation_op_ids + c_ids
|
288 |
+
|
289 |
+
# text的编码
|
290 |
+
entity_indices = get_entity_indices(entity_list, spo_list)
|
291 |
+
text_max_len = self.max_length - len(task_ids) - 3
|
292 |
+
text_ids, offset_mapping = entity_based_tokenize(text, self.tokenizer, entity_indices,
|
293 |
+
max_len=text_max_len,
|
294 |
+
return_offsets_mapping=True)
|
295 |
+
text_ids = cls_ids + text_ids + sep_ids
|
296 |
+
|
297 |
+
input_ids = text_ids + task_ids + sep_ids + ent_ids + rel_ids
|
298 |
+
|
299 |
+
token_type_ids = [0] * len(text_ids) + [0] * (len(task_ids) + 1) + \
|
300 |
+
[1] * len(ent_ids) + [1] * len(rel_ids)
|
301 |
+
|
302 |
+
entity_labels_idx = [i for i, id_ in enumerate(input_ids) if id_ == entity_op_ids[0]]
|
303 |
+
relation_labels_idx = [i for i, id_ in enumerate(input_ids) if id_ == relation_op_ids[0]]
|
304 |
+
|
305 |
+
ent_ranges = [] # 每个实体的起始范围
|
306 |
+
for i in range(len(entity_labels_idx) - 1):
|
307 |
+
ent_ranges.append([entity_labels_idx[i], entity_labels_idx[i + 1]])
|
308 |
+
if not relation_labels_idx:
|
309 |
+
ent_ranges.append([entity_labels_idx[-1], len(input_ids)])
|
310 |
+
else:
|
311 |
+
ent_ranges.append([entity_labels_idx[-1], relation_labels_idx[0]])
|
312 |
+
assert len(ent_ranges) == len(choice_ent)
|
313 |
+
|
314 |
+
rel_ranges = [] # 每个关系的起始范围
|
315 |
+
for i in range(len(relation_labels_idx) - 1):
|
316 |
+
rel_ranges.append([relation_labels_idx[i], relation_labels_idx[i + 1]])
|
317 |
+
if relation_labels_idx:
|
318 |
+
rel_ranges.append([relation_labels_idx[-1], len(input_ids)])
|
319 |
+
assert len(rel_ranges) == len(choice_rel)
|
320 |
+
|
321 |
+
# 所有unused的位置
|
322 |
+
label_token_idx = entity_labels_idx + relation_labels_idx
|
323 |
+
task_num_labels = len(label_token_idx)
|
324 |
+
input_len = len(input_ids)
|
325 |
+
text_len = len(text_ids)
|
326 |
+
|
327 |
+
# 计算mask
|
328 |
+
attention_mask = self.get_att_mask(input_len,
|
329 |
+
ent_ranges,
|
330 |
+
rel_ranges,
|
331 |
+
choice_ent,
|
332 |
+
choice_rel,
|
333 |
+
entity2rel,
|
334 |
+
full_attent)
|
335 |
+
# 计算label-mask
|
336 |
+
label_mask = np.ones((text_len, text_len, task_num_labels))
|
337 |
+
for i in range(text_len):
|
338 |
+
for j in range(text_len):
|
339 |
+
if j < i:
|
340 |
+
for l in range(len(entity_labels_idx)):
|
341 |
+
# entity部分的下三角可mask
|
342 |
+
label_mask[i, j, l] = 0
|
343 |
+
|
344 |
+
# 计算position_ids
|
345 |
+
position_ids = self.get_position_ids(len(text_ids) + len(task_ids) + 1,
|
346 |
+
ent_ranges,
|
347 |
+
rel_ranges)
|
348 |
+
|
349 |
+
assert len(input_ids) == len(position_ids) == len(token_type_ids)
|
350 |
+
|
351 |
+
if not with_label:
|
352 |
+
return {
|
353 |
+
"input_ids": torch.tensor(input_ids).long(),
|
354 |
+
"attention_mask": torch.tensor(attention_mask).float(),
|
355 |
+
"position_ids": torch.tensor(position_ids).long(),
|
356 |
+
"token_type_ids": torch.tensor(token_type_ids).long(),
|
357 |
+
"label_token_idx": torch.tensor(label_token_idx).long(),
|
358 |
+
"label_mask": torch.tensor(label_mask).float(),
|
359 |
+
"text_len": torch.tensor(text_len).long(),
|
360 |
+
"ent_ranges": ent_ranges,
|
361 |
+
"rel_ranges": rel_ranges,
|
362 |
+
}
|
363 |
+
|
364 |
+
# 输入的span_labels,只保留text部分
|
365 |
+
span_labels = np.zeros((text_len, text_len, task_num_labels))
|
366 |
+
|
367 |
+
# 将实体转成span
|
368 |
+
for entity in entity_list:
|
369 |
+
|
370 |
+
entity_type = entity["entity_type"]
|
371 |
+
entity_index = entity["entity_index"]
|
372 |
+
|
373 |
+
start_idx, end_idx = self.search_index(entity_index, offset_mapping, 1)
|
374 |
+
|
375 |
+
if start_idx < text_len and end_idx < text_len:
|
376 |
+
ent_label = choice_ent.index(entity_type)
|
377 |
+
span_labels[start_idx, end_idx, ent_label] = 1
|
378 |
+
|
379 |
+
# 将三元组转成span
|
380 |
+
for spo in spo_list:
|
381 |
+
|
382 |
+
sub_idx = spo["subject"]["entity_index"]
|
383 |
+
obj_idx = spo["object"]["entity_index"]
|
384 |
+
|
385 |
+
# 获取头实体、尾实体的开始、结束index
|
386 |
+
sub_start_idx, sub_end_idx = self.search_index(sub_idx, offset_mapping, 1)
|
387 |
+
obj_start_idx, obj_end_idx = self.search_index(obj_idx, offset_mapping, 1)
|
388 |
+
# 实体label置1
|
389 |
+
if sub_start_idx < text_len and sub_end_idx < text_len:
|
390 |
+
sub_label = choice_ent.index(spo["subject"]["entity_type"])
|
391 |
+
span_labels[sub_start_idx, sub_end_idx, sub_label] = 1
|
392 |
+
|
393 |
+
if obj_start_idx < text_len and obj_end_idx < text_len:
|
394 |
+
obj_label = choice_ent.index(spo["object"]["entity_type"])
|
395 |
+
span_labels[obj_start_idx, obj_end_idx, obj_label] = 1
|
396 |
+
|
397 |
+
# 有关系的sub/obj实体的start/end在realtion对应的label置1
|
398 |
+
if spo["predicate"] in choice_rel:
|
399 |
+
pre_label = choice_rel.index(spo["predicate"]) + len(choice_ent)
|
400 |
+
if sub_start_idx < text_len and obj_start_idx < text_len:
|
401 |
+
span_labels[sub_start_idx, obj_start_idx, pre_label] = 1
|
402 |
+
if sub_end_idx < text_len and obj_end_idx < text_len:
|
403 |
+
span_labels[sub_end_idx, obj_end_idx, pre_label] = 1
|
404 |
+
|
405 |
+
return {
|
406 |
+
"input_ids": torch.tensor(input_ids).long(),
|
407 |
+
"attention_mask": torch.tensor(attention_mask).float(),
|
408 |
+
"position_ids": torch.tensor(position_ids).long(),
|
409 |
+
"token_type_ids": torch.tensor(token_type_ids).long(),
|
410 |
+
"label_token_idx": torch.tensor(label_token_idx).long(),
|
411 |
+
"span_labels": torch.tensor(span_labels).float(),
|
412 |
+
"label_mask": torch.tensor(label_mask).float(),
|
413 |
+
"text_len": torch.tensor(text_len).long(),
|
414 |
+
"ent_ranges": ent_ranges,
|
415 |
+
"rel_ranges": rel_ranges,
|
416 |
+
}
|
417 |
+
|
418 |
+
def encode_item(self, item: dict, with_label: bool = True) -> Dict[str, torch.Tensor]: # pylint: disable=unused-argument
|
419 |
+
""" encode
|
420 |
+
|
421 |
+
Args:
|
422 |
+
item (dict): item
|
423 |
+
with_label (bool): encoded with label. Defaults to True.
|
424 |
+
|
425 |
+
Returns:
|
426 |
+
Dict[str, torch.Tensor]: encoded
|
427 |
+
"""
|
428 |
+
return self.encode(text=item["text"],
|
429 |
+
task_name=item["task"],
|
430 |
+
choice=item["choice"],
|
431 |
+
entity_list=item.get("entity_list", []),
|
432 |
+
spo_list=item.get("spo_list", []),
|
433 |
+
full_attent=item.get('full_attent', False),
|
434 |
+
with_label=with_label)
|
435 |
+
|
436 |
+
@staticmethod
|
437 |
+
def collate(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
438 |
+
"""
|
439 |
+
Aggregate a batch data.
|
440 |
+
batch = [ins1_dict, ins2_dict, ..., insN_dict]
|
441 |
+
batch_data = {"sentence":[ins1_sentence, ins2_sentence...],
|
442 |
+
"input_ids":[ins1_input_ids, ins2_input_ids...], ...}
|
443 |
+
"""
|
444 |
+
input_ids = nn.utils.rnn.pad_sequence(
|
445 |
+
sequences=[encoded["input_ids"] for encoded in batch],
|
446 |
+
batch_first=True,
|
447 |
+
padding_value=0)
|
448 |
+
|
449 |
+
label_token_idx = nn.utils.rnn.pad_sequence(
|
450 |
+
sequences=[encoded["label_token_idx"] for encoded in batch],
|
451 |
+
batch_first=True,
|
452 |
+
padding_value=0)
|
453 |
+
|
454 |
+
token_type_ids = nn.utils.rnn.pad_sequence(
|
455 |
+
sequences=[encoded["token_type_ids"] for encoded in batch],
|
456 |
+
batch_first=True,
|
457 |
+
padding_value=0)
|
458 |
+
|
459 |
+
position_ids = nn.utils.rnn.pad_sequence(
|
460 |
+
sequences=[encoded["position_ids"] for encoded in batch],
|
461 |
+
batch_first=True,
|
462 |
+
padding_value=0)
|
463 |
+
|
464 |
+
text_len = torch.tensor([encoded["text_len"] for encoded in batch]).long()
|
465 |
+
max_text_len = text_len.max()
|
466 |
+
|
467 |
+
batch_size, batch_max_length = input_ids.shape
|
468 |
+
_, batch_max_labels = label_token_idx.shape
|
469 |
+
|
470 |
+
attention_mask = torch.zeros((batch_size, batch_max_length, batch_max_length))
|
471 |
+
label_mask = torch.zeros((batch_size,
|
472 |
+
batch_max_length,
|
473 |
+
batch_max_length,
|
474 |
+
batch_max_labels))
|
475 |
+
for i, encoded in enumerate(batch):
|
476 |
+
input_len = encoded["attention_mask"].shape[0]
|
477 |
+
attention_mask[i, :input_len, :input_len] = encoded["attention_mask"]
|
478 |
+
_, cur_text_len, label_len = encoded['label_mask'].shape
|
479 |
+
label_mask[i, :cur_text_len, :cur_text_len, :label_len] = encoded['label_mask']
|
480 |
+
label_mask = label_mask[:, :max_text_len, :max_text_len, :]
|
481 |
+
|
482 |
+
batch_data = {
|
483 |
+
"input_ids": input_ids,
|
484 |
+
"attention_mask": attention_mask,
|
485 |
+
"position_ids": position_ids,
|
486 |
+
"token_type_ids": token_type_ids,
|
487 |
+
"label_token_idx": label_token_idx,
|
488 |
+
"label_mask": label_mask,
|
489 |
+
'text_len': text_len
|
490 |
+
}
|
491 |
+
|
492 |
+
if "span_labels" in batch[0].keys():
|
493 |
+
span_labels = torch.zeros((batch_size,
|
494 |
+
batch_max_length,
|
495 |
+
batch_max_length,
|
496 |
+
batch_max_labels))
|
497 |
+
for i, encoded in enumerate(batch):
|
498 |
+
input_len, _, sample_num_labels = encoded["span_labels"].shape
|
499 |
+
span_labels[i, :input_len, :input_len, :sample_num_labels] = encoded["span_labels"]
|
500 |
+
batch_data["span_labels"] = span_labels[:, :max_text_len, :max_text_len, :]
|
501 |
+
|
502 |
+
return batch_data
|
503 |
+
|
504 |
+
@staticmethod
|
505 |
+
def collate_expand(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
506 |
+
"""
|
507 |
+
Aggregate a batch data and expand to full attention
|
508 |
+
batch = [ins1_dict, ins2_dict, ..., insN_dict]
|
509 |
+
batch_data = {"sentence":[ins1_sentence, ins2_sentence...],
|
510 |
+
"input_ids":[ins1_input_ids, ins2_input_ids...], ...}
|
511 |
+
"""
|
512 |
+
mask_atten_batch = ItemEncoder.collate(batch)
|
513 |
+
full_atten_batch = ItemEncoder.collate(batch)
|
514 |
+
# 对full_atten_batch进行改造
|
515 |
+
atten_mask = full_atten_batch['attention_mask']
|
516 |
+
b, _, _ = atten_mask.size()
|
517 |
+
for i in range(b):
|
518 |
+
ent_ranges, rel_ranges = batch[i]['ent_ranges'], batch[i]['rel_ranges']
|
519 |
+
text_len = ent_ranges[0][0] # text长度
|
520 |
+
|
521 |
+
if not rel_ranges:
|
522 |
+
assert len(ent_ranges) == 1, 'ent_ranges:%s' % ent_ranges
|
523 |
+
s, e = ent_ranges[0]
|
524 |
+
atten_mask[i, : text_len, s: e] = 1
|
525 |
+
else:
|
526 |
+
assert len(rel_ranges) == 1 and len(ent_ranges) <= 2, \
|
527 |
+
'ent_ranges:%s, rel_ranges:%s' % (ent_ranges, rel_ranges)
|
528 |
+
s, e = rel_ranges[0]
|
529 |
+
atten_mask[i, : text_len, s: e] = 1
|
530 |
+
full_atten_batch['attention_mask'] = atten_mask
|
531 |
+
collate_batch = {}
|
532 |
+
for key, value in mask_atten_batch.items():
|
533 |
+
collate_batch[key] = torch.cat((value, full_atten_batch[key]), 0)
|
534 |
+
return collate_batch
|
models/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .model import BagualuIEModel
|
2 |
+
from .extract_model import BagualuIEExtractModel
|
models/extract_model.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The IDEA Authors. All rights reserved.
|
3 |
+
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from typing import List
|
17 |
+
import copy
|
18 |
+
|
19 |
+
from transformers import PreTrainedTokenizer
|
20 |
+
import argparse
|
21 |
+
from dataloaders.item_encoder import ItemEncoder
|
22 |
+
from dataloaders.item_decoder import ItemDecoder
|
23 |
+
from .model import BagualuIEModel
|
24 |
+
|
25 |
+
|
26 |
+
class BagualuIEExtractModel(object):
|
27 |
+
""" BagualuIEExtractModel
|
28 |
+
|
29 |
+
Args:
|
30 |
+
tokenizer (PreTrainedTokenizer): tokenizer
|
31 |
+
args (TrainingArgumentsIEStd): arguments
|
32 |
+
"""
|
33 |
+
def __init__(self,
|
34 |
+
tokenizer: PreTrainedTokenizer,
|
35 |
+
args: argparse.Namespace) -> None:
|
36 |
+
self.encoder = ItemEncoder(tokenizer, args.max_length)
|
37 |
+
self.decoder = ItemDecoder(tokenizer, args)
|
38 |
+
|
39 |
+
def extract(self, batch_data: List[dict], model: BagualuIEModel, use_cuda: bool) -> List[dict]:
|
40 |
+
""" extract
|
41 |
+
|
42 |
+
Args:
|
43 |
+
batch_data (List[dict]): batch of data
|
44 |
+
model (BagualuIEModel): model
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
List[dict]: batch of data
|
48 |
+
"""
|
49 |
+
if use_cuda:
|
50 |
+
model = model.cuda()
|
51 |
+
model.eval()
|
52 |
+
|
53 |
+
batch_data = copy.deepcopy(batch_data)
|
54 |
+
batch = [self.encoder.encode_item(item, with_label=False) for item in batch_data]
|
55 |
+
batch = self.encoder.collate(batch)
|
56 |
+
if use_cuda:
|
57 |
+
batch = {k: v.cuda() for k, v in batch.items()}
|
58 |
+
|
59 |
+
span_logits = model(**batch).cpu().detach().numpy()
|
60 |
+
label_mask = batch["label_mask"].cpu().detach().numpy()
|
61 |
+
|
62 |
+
for i, item in enumerate(batch_data):
|
63 |
+
|
64 |
+
entity_list, spo_list = self.decoder.decode(item,
|
65 |
+
span_logits[i],
|
66 |
+
label_mask[i])
|
67 |
+
|
68 |
+
item["spo_list"] = spo_list
|
69 |
+
item["entity_list"] = entity_list
|
70 |
+
|
71 |
+
return batch_data
|
models/model.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 The IDEA Authors. All rights reserved.
|
3 |
+
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# pylint: disable=no-member
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn, Tensor
|
20 |
+
from transformers import BertPreTrainedModel, BertModel, BertConfig
|
21 |
+
|
22 |
+
|
23 |
+
class Triaffine(nn.Module):
|
24 |
+
""" Triaffine module
|
25 |
+
|
26 |
+
Args:
|
27 |
+
triaffine_hidden_size (int): Triaffine module hidden size
|
28 |
+
"""
|
29 |
+
def __init__(self, triaffine_hidden_size: int) -> None:
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.triaffine_hidden_size = triaffine_hidden_size
|
33 |
+
|
34 |
+
self.weight_start_end = nn.Parameter(
|
35 |
+
torch.zeros(triaffine_hidden_size,
|
36 |
+
triaffine_hidden_size,
|
37 |
+
triaffine_hidden_size))
|
38 |
+
|
39 |
+
nn.init.normal_(self.weight_start_end, mean=0, std=0.1)
|
40 |
+
|
41 |
+
def forward(self,
|
42 |
+
start_logits: Tensor,
|
43 |
+
end_logits: Tensor,
|
44 |
+
cls_logits: Tensor) -> Tensor:
|
45 |
+
"""forward
|
46 |
+
|
47 |
+
Args:
|
48 |
+
start_logits (Tensor): start logits
|
49 |
+
end_logits (Tensor): end logits
|
50 |
+
cls_logits (Tensor): cls logits
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
Tensor: span_logits
|
54 |
+
"""
|
55 |
+
start_end_logits = torch.einsum("bxi,ioj,byj->bxyo",
|
56 |
+
start_logits,
|
57 |
+
self.weight_start_end,
|
58 |
+
end_logits)
|
59 |
+
|
60 |
+
span_logits = torch.einsum("bxyo,bzo->bxyz",
|
61 |
+
start_end_logits,
|
62 |
+
cls_logits)
|
63 |
+
|
64 |
+
return span_logits
|
65 |
+
|
66 |
+
|
67 |
+
class MLPLayer(nn.Module):
|
68 |
+
"""MLP layer
|
69 |
+
|
70 |
+
Args:
|
71 |
+
input_size (int): input size
|
72 |
+
output_size (int): output size
|
73 |
+
"""
|
74 |
+
def __init__(self, input_size: int, output_size: int) -> None:
|
75 |
+
super().__init__()
|
76 |
+
self.linear = nn.Linear(in_features=input_size, out_features=output_size)
|
77 |
+
self.act = nn.GELU()
|
78 |
+
|
79 |
+
def forward(self, x: Tensor) -> Tensor: # pylint: disable=invalid-name
|
80 |
+
""" forward
|
81 |
+
|
82 |
+
Args:
|
83 |
+
x (Tensor): input
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
Tensor: output
|
87 |
+
"""
|
88 |
+
x = self.linear(x)
|
89 |
+
x = self.act(x)
|
90 |
+
return x
|
91 |
+
|
92 |
+
|
93 |
+
class BagualuIEModel(BertPreTrainedModel):
|
94 |
+
""" BagualuIEModel
|
95 |
+
|
96 |
+
Args:
|
97 |
+
config (BertConfig): config
|
98 |
+
"""
|
99 |
+
def __init__(self, config: BertConfig) -> None:
|
100 |
+
super().__init__(config)
|
101 |
+
self.bert = BertModel(config)
|
102 |
+
self.config = config
|
103 |
+
|
104 |
+
self.triaffine_hidden_size = 128
|
105 |
+
|
106 |
+
self.mlp_start = MLPLayer(self.config.hidden_size,
|
107 |
+
self.triaffine_hidden_size)
|
108 |
+
self.mlp_end = MLPLayer(self.config.hidden_size,
|
109 |
+
self.triaffine_hidden_size)
|
110 |
+
self.mlp_cls = MLPLayer(self.config.hidden_size,
|
111 |
+
self.triaffine_hidden_size)
|
112 |
+
|
113 |
+
self.triaffine = Triaffine(self.triaffine_hidden_size)
|
114 |
+
|
115 |
+
def forward(self, # pylint: disable=unused-argument
|
116 |
+
input_ids: Tensor,
|
117 |
+
attention_mask: Tensor,
|
118 |
+
position_ids: Tensor,
|
119 |
+
token_type_ids: Tensor,
|
120 |
+
text_len: Tensor,
|
121 |
+
label_token_idx: Tensor,
|
122 |
+
**kwargs) -> Tensor:
|
123 |
+
""" forward
|
124 |
+
|
125 |
+
Args:
|
126 |
+
input_ids (Tensor): input_ids
|
127 |
+
attention_mask (Tensor): attention_mask
|
128 |
+
position_ids (Tensor): position_ids
|
129 |
+
token_type_ids (Tensor): token_type_ids
|
130 |
+
text_len (Tensor): query length
|
131 |
+
label_token_idx (Tensor, optional): label_token_idx
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
Tensor: span logits
|
135 |
+
"""
|
136 |
+
|
137 |
+
# bert forward
|
138 |
+
hidden_states = self.bert(input_ids=input_ids,
|
139 |
+
attention_mask=attention_mask,
|
140 |
+
position_ids=position_ids,
|
141 |
+
token_type_ids=token_type_ids,
|
142 |
+
output_hidden_states=True)[0] # (bsz, seq, dim)
|
143 |
+
|
144 |
+
max_text_len = text_len.max()
|
145 |
+
|
146 |
+
# 获取start、end、cls的hidden_states
|
147 |
+
hidden_start_end = hidden_states[:, :max_text_len, :] # text部分表示
|
148 |
+
hidden_cls = hidden_states.gather(1, label_token_idx.unsqueeze(-1)\
|
149 |
+
.repeat(1, 1, self.config.hidden_size)) # (bsz, task, dim)
|
150 |
+
|
151 |
+
# Triaffine
|
152 |
+
span_logits = self.triaffine(self.mlp_start(hidden_start_end),
|
153 |
+
self.mlp_end(hidden_start_end),
|
154 |
+
self.mlp_cls(hidden_cls)).sigmoid()
|
155 |
+
|
156 |
+
return span_logits
|