Spaces:
Runtime error
Runtime error
Upload 7 files
Browse files- code/app.py +26 -8
- code/args.py +21 -0
- code/do_predict.py +187 -0
- code/items_dataset.py +153 -0
- code/models.py +53 -0
- code/prediction.py +92 -0
- code/rank.ipynb +201 -0
code/app.py
CHANGED
@@ -1,14 +1,32 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
app = flask.Flask(__name__, template_folder="static")
|
8 |
|
9 |
@app.route("/")
|
10 |
def index():
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
if __name__ == "__main__":
|
14 |
-
app.run(host="0.0.0.0", port=
|
|
|
1 |
+
from flask import Flask, request, jsonify, make_response, render_template
|
2 |
+
from do_predict import predict_single
|
3 |
|
4 |
+
app = Flask(__name__)
|
5 |
+
app.config["JSON_AS_ASCII"] = False
|
|
|
|
|
6 |
|
7 |
@app.route("/")
|
8 |
def index():
|
9 |
+
return render_template("index.html")
|
10 |
+
|
11 |
+
@app.before_request
|
12 |
+
def before():
|
13 |
+
# handle preflight
|
14 |
+
if request.method == "OPTIONS":
|
15 |
+
resp = make_response()
|
16 |
+
resp.headers["Access-Control-Allow-Origin"] = "*"
|
17 |
+
resp.headers["Access-Control-Allow-Methods"] = "GET, POST"
|
18 |
+
resp.headers["Access-Control-Allow-Headers"] = "Content-Type"
|
19 |
+
return resp
|
20 |
+
|
21 |
+
|
22 |
+
@app.post("/api/predict_single")
|
23 |
+
def api_predict_single():
|
24 |
+
text = request.json["text"]
|
25 |
+
result = predict_single(text)
|
26 |
+
resp = jsonify(result)
|
27 |
+
resp.headers["Access-Control-Allow-Origin"] = "*"
|
28 |
+
return resp
|
29 |
+
|
30 |
|
31 |
if __name__ == "__main__":
|
32 |
+
app.run(host="0.0.0.0", port=22222)
|
code/args.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class args():
|
2 |
+
DATA_PATH = "../Dataset/"
|
3 |
+
SAVE_MODEL_PATH = "model/"
|
4 |
+
|
5 |
+
#pre_model_name = "bert-base-chinese"
|
6 |
+
#pre_model_name = "hfl/chinese-macbert-base"
|
7 |
+
pre_model_name = "hfl/chinese-roberta-wwm-ext"
|
8 |
+
save_model_name = "roberta_crf"
|
9 |
+
|
10 |
+
LOG_DIR = "../log/long_term/"+save_model_name+"/"
|
11 |
+
|
12 |
+
use_crf = False
|
13 |
+
label_dict = {"O":0, "B":1, "I":2}
|
14 |
+
epoch_num = 10
|
15 |
+
batch_size = 2
|
16 |
+
label_size = 3
|
17 |
+
max_length = 512
|
18 |
+
|
19 |
+
class config():
|
20 |
+
hidden_dropout_prob = 0.1
|
21 |
+
hidden_size = 768
|
code/do_predict.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from args import args, config
|
2 |
+
from items_dataset import items_dataset
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from models import Model_Crf, Model_Softmax
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
from tqdm import tqdm
|
7 |
+
import prediction
|
8 |
+
import torch
|
9 |
+
import math
|
10 |
+
|
11 |
+
directory = args.SAVE_MODEL_PATH
|
12 |
+
model_name = "roberta_CRF.pt"
|
13 |
+
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
|
14 |
+
model_crf = Model_Crf(config).to(device)
|
15 |
+
model_crf.load_state_dict(
|
16 |
+
state_dict=torch.load(directory + model_name, map_location=device)
|
17 |
+
)
|
18 |
+
|
19 |
+
model_name = "roberta_softmax.pt"
|
20 |
+
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
|
21 |
+
model_roberta = Model_Softmax(config).to(device)
|
22 |
+
model_roberta.load_state_dict(
|
23 |
+
state_dict=torch.load(directory + model_name, map_location=device)
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def prepare_span_data(dataset):
|
28 |
+
for sample in dataset:
|
29 |
+
spans = items_dataset.cal_agreement_span(
|
30 |
+
None,
|
31 |
+
agreement_table=sample["predict_sentence_table"],
|
32 |
+
min_agree=1,
|
33 |
+
max_agree=2,
|
34 |
+
)
|
35 |
+
sample["span_labels"] = spans
|
36 |
+
sample["original_text"] = sample["text_a"]
|
37 |
+
del sample["text_a"]
|
38 |
+
|
39 |
+
|
40 |
+
def rank_spans(test_loader, device, model, reverse=True):
|
41 |
+
"""Calculate each span probability by e**(word average log likelihood)"""
|
42 |
+
model.eval()
|
43 |
+
result = []
|
44 |
+
|
45 |
+
for i, test_batch in enumerate(tqdm(test_loader)):
|
46 |
+
batch_text = test_batch["batch_text"]
|
47 |
+
input_ids = test_batch["input_ids"].to(device)
|
48 |
+
token_type_ids = test_batch["token_type_ids"].to(device)
|
49 |
+
attention_mask = test_batch["attention_mask"].to(device)
|
50 |
+
labels = test_batch["labels"]
|
51 |
+
crf_mask = test_batch["crf_mask"].to(device)
|
52 |
+
sample_mapping = test_batch["overflow_to_sample_mapping"]
|
53 |
+
output = model(
|
54 |
+
input_ids=input_ids,
|
55 |
+
token_type_ids=token_type_ids,
|
56 |
+
attention_mask=attention_mask,
|
57 |
+
labels=None,
|
58 |
+
crf_mask=crf_mask,
|
59 |
+
)
|
60 |
+
output = torch.nn.functional.softmax(output[0], dim=-1)
|
61 |
+
|
62 |
+
# make result of every sample
|
63 |
+
sample_id = 0
|
64 |
+
sample_result = {
|
65 |
+
"original_text": test_batch["batch_text"][sample_id],
|
66 |
+
"span_ranked": [],
|
67 |
+
}
|
68 |
+
for batch_id in range(len(sample_mapping)):
|
69 |
+
change_sample = False
|
70 |
+
|
71 |
+
# make sure status
|
72 |
+
if sample_id != sample_mapping[batch_id]:
|
73 |
+
change_sample = True
|
74 |
+
if change_sample:
|
75 |
+
sample_id = sample_mapping[batch_id]
|
76 |
+
result.append(sample_result)
|
77 |
+
sample_result = {
|
78 |
+
"original_text": test_batch["batch_text"][sample_id],
|
79 |
+
"span_ranked": [],
|
80 |
+
}
|
81 |
+
|
82 |
+
encoded_spans = items_dataset.cal_agreement_span(
|
83 |
+
None, agreement_table=labels[batch_id], min_agree=1, max_agree=2
|
84 |
+
)
|
85 |
+
# print(encoded_spans)
|
86 |
+
for encoded_span in encoded_spans:
|
87 |
+
# calculate span loss
|
88 |
+
span_lenght = encoded_span[1] - encoded_span[0]
|
89 |
+
# print(span_lenght)
|
90 |
+
span_prob_table = torch.log(
|
91 |
+
output[batch_id][encoded_span[0] : encoded_span[1]]
|
92 |
+
)
|
93 |
+
if (
|
94 |
+
not change_sample and encoded_span[0] == 0 and batch_id != 0
|
95 |
+
): # span cross two tensors
|
96 |
+
span_loss += span_prob_table[0][1] # Begin
|
97 |
+
else:
|
98 |
+
span_loss = span_prob_table[0][1] # Begin
|
99 |
+
for token_id in range(1, span_prob_table.shape[0]):
|
100 |
+
span_loss += span_prob_table[token_id][2] # Inside
|
101 |
+
span_loss /= span_lenght
|
102 |
+
|
103 |
+
# span decode
|
104 |
+
decode_start = test_batch[batch_id].token_to_chars(encoded_span[0] + 1)[
|
105 |
+
0
|
106 |
+
]
|
107 |
+
decode_end = test_batch[batch_id].token_to_chars(encoded_span[1])[0] + 1
|
108 |
+
# print((decode_start, decode_end))
|
109 |
+
span_text = test_batch["batch_text"][sample_mapping[batch_id]][
|
110 |
+
decode_start:decode_end
|
111 |
+
]
|
112 |
+
if (
|
113 |
+
not change_sample and encoded_span[0] == 0 and batch_id != 0
|
114 |
+
): # span cross two tensors
|
115 |
+
presample = sample_result["span_ranked"].pop(-1)
|
116 |
+
sample_result["span_ranked"].append(
|
117 |
+
[presample[0] + span_text, math.e ** float(span_loss)]
|
118 |
+
)
|
119 |
+
else:
|
120 |
+
sample_result["span_ranked"].append(
|
121 |
+
[span_text, math.e ** float(span_loss)]
|
122 |
+
)
|
123 |
+
result.append(sample_result)
|
124 |
+
|
125 |
+
# sorted spans by probability
|
126 |
+
# for sample in result:
|
127 |
+
# sample["span_ranked"] = sorted(
|
128 |
+
# sample["span_ranked"], key=lambda x: x[1], reverse=reverse
|
129 |
+
# )
|
130 |
+
return result
|
131 |
+
|
132 |
+
|
133 |
+
def predict_single(text):
|
134 |
+
input_dict = [{"span_labels": []}]
|
135 |
+
input_dict[0]["original_text"] = text
|
136 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
137 |
+
args.pre_model_name, add_prefix_space=True
|
138 |
+
)
|
139 |
+
prediction_dataset = items_dataset(tokenizer, input_dict, args.label_dict)
|
140 |
+
prediction_loader = DataLoader(
|
141 |
+
prediction_dataset,
|
142 |
+
batch_size=args.batch_size,
|
143 |
+
shuffle=True,
|
144 |
+
collate_fn=prediction_dataset.collate_fn,
|
145 |
+
)
|
146 |
+
predict_data = prediction.test_predict(prediction_loader, device, model_crf)
|
147 |
+
prediction.add_sentence_table(predict_data)
|
148 |
+
|
149 |
+
prepare_span_data(predict_data)
|
150 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
151 |
+
args.pre_model_name, add_prefix_space=True
|
152 |
+
)
|
153 |
+
prediction_dataset = items_dataset(tokenizer, predict_data, args.label_dict)
|
154 |
+
prediction_loader = DataLoader(
|
155 |
+
prediction_dataset,
|
156 |
+
batch_size=args.batch_size,
|
157 |
+
shuffle=False,
|
158 |
+
collate_fn=prediction_dataset.collate_fn,
|
159 |
+
)
|
160 |
+
span_ranked = rank_spans(prediction_loader, device, model_roberta)
|
161 |
+
# for sample in span_ranked:
|
162 |
+
# print(sample["original_text"])
|
163 |
+
# print(sample["span_ranked"])
|
164 |
+
|
165 |
+
result = []
|
166 |
+
sample = span_ranked[0]
|
167 |
+
orig = sample["original_text"]
|
168 |
+
cur = 0
|
169 |
+
for s, score in sample["span_ranked"]:
|
170 |
+
# print()
|
171 |
+
# print('ORIG', repr(orig))
|
172 |
+
# print('CCUR', repr(orig[cur:]))
|
173 |
+
# print('SSSS', repr(s))
|
174 |
+
# print()
|
175 |
+
end = orig.index(s, cur)
|
176 |
+
if cur != end:
|
177 |
+
result.append([orig[cur:end], 0])
|
178 |
+
result.append([s, score])
|
179 |
+
cur = end + len(s)
|
180 |
+
if cur < len(orig):
|
181 |
+
result.append([orig[cur:], 0])
|
182 |
+
return result
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == "__main__":
|
186 |
+
s = """貓咪犯錯後,以下5種懲罰方法很有效,飼主可以試試樂享網 2021-03-06 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 貓咪雖然高冷,但也是會犯錯的,那貓咪犯錯後,怎麼懲罰它才最有效呢?今天就來說一些懲罰貓咪最有效的5個方法!1、把痛感形成條件反射 這裡說的是「痛感」,而不是「暴打」。在貓咪犯錯後,寵主不需要打它,可以彈鼻頭或者是輕拍它的頭頂,給它造成痛感,這樣讓貓咪有一些畏懼心理,知道你在懲罰它。這樣時間長了,貓咪就會形成條件反射,以後就會少犯錯了。 2、大聲呵斥比起打貓,大聲呵斥貓咪會更加有效。因為貓咪對聲音很敏感,它能從主人的語氣中判斷主人的情緒,當大聲呵斥它的時候,它往往會楞一下,這時你繼續大聲呵斥它,那它就會明白你不允許它做這件事,這樣犯錯地方幾率就會減少了。 3、限制自由限制自由說白了,就是把貓咪關進籠子裡。因為貓咪都是很愛外出玩耍,當它犯錯咯,主人可以把它關進籠子裡,不搭理它,讓它自己反思。但要注意,這個方法不能經常用,而且不能把貓咪關進籠子太久。 4、利用水都知道貓咪很怕水的,所以當貓咪犯錯後,寵主也可以利用水來懲罰貓咪,這也是很效果的方法。寵主可以給貓咪臉上或是頭頂噴一些水,從而讓貓知道這個行為是錯誤的,以後就不會再犯了。 5、冷淡雖然貓咪不是很粘主人,但它還是很愛主人的,所以在貓咪犯錯後,寵主也可以採取冷淡的方法來懲罰貓。對貓咪採取不理睬、不靠近、不擁抱等策略,這樣貓咪就會知道自己錯了。當然懲罰的時間不要太長,不然貓咪就會以為你不愛它了。"""
|
187 |
+
print(predict_single(s))
|
code/items_dataset.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
from args import args
|
4 |
+
class items_dataset(Dataset):
|
5 |
+
def __init__(self, tokenizer, data_set, label_dict, stride=0, max_length=args.max_length):
|
6 |
+
self.data_set = data_set
|
7 |
+
self.tokenizer = tokenizer
|
8 |
+
self.label_dict = label_dict
|
9 |
+
self.max_length = max_length
|
10 |
+
self.encode_max_length = max_length-2 #[CLS] [SEP]
|
11 |
+
self.batch_max_lenght = max_length
|
12 |
+
self.stride = stride
|
13 |
+
|
14 |
+
def __getitem__(self, index):
|
15 |
+
result = self.data_set[index]
|
16 |
+
return result
|
17 |
+
|
18 |
+
def __len__(self):
|
19 |
+
return len(self.data_set)
|
20 |
+
|
21 |
+
def create_label_list(self, span_label, max_len):
|
22 |
+
#ans = []
|
23 |
+
table = torch.zeros(max_len)
|
24 |
+
for start, end in span_label:
|
25 |
+
table[start:end] = 2 #"I"
|
26 |
+
table[start] = 1 #"B"
|
27 |
+
"""
|
28 |
+
for label in table.tolist():
|
29 |
+
if label == 0:
|
30 |
+
ans.append("O")
|
31 |
+
elif label == 1:
|
32 |
+
ans.append("B")
|
33 |
+
elif label == 2:
|
34 |
+
ans.append("I")
|
35 |
+
else:
|
36 |
+
print("error")
|
37 |
+
"""
|
38 |
+
return table
|
39 |
+
def encode_lable(self, encoded, batch_table):
|
40 |
+
batch_encode_seq_lens = []
|
41 |
+
sample_mapping = encoded["overflow_to_sample_mapping"]
|
42 |
+
offset_mapping = encoded["offset_mapping"]
|
43 |
+
encoded_label = torch.zeros(len(sample_mapping) ,self.encode_max_length, dtype=torch.long)
|
44 |
+
for id_in_batch in range(len(sample_mapping)):
|
45 |
+
encode_len=0
|
46 |
+
table = batch_table[sample_mapping[id_in_batch]]
|
47 |
+
for i in range(self.max_length):
|
48 |
+
char_start, char_end = offset_mapping[id_in_batch][i]
|
49 |
+
# ignore [CLS], [SEP] token
|
50 |
+
if char_start!=0 or char_end!=0:
|
51 |
+
encode_len+=1
|
52 |
+
#print(encoded_label.shape, table.shape)
|
53 |
+
encoded_label[id_in_batch][i-1] = table[char_start].long()
|
54 |
+
batch_encode_seq_lens.append(encode_len)
|
55 |
+
return encoded_label, batch_encode_seq_lens
|
56 |
+
|
57 |
+
|
58 |
+
def create_crf_mask(self, batch_encode_seq_lens):
|
59 |
+
mask = torch.zeros(len(batch_encode_seq_lens), self.encode_max_length, dtype=torch.bool)
|
60 |
+
#print(len(batch_table), len(batch_lens), seq_lens, batch_lens)
|
61 |
+
for i, batch_len in enumerate(batch_encode_seq_lens):
|
62 |
+
mask[i][:batch_len]=True
|
63 |
+
return mask
|
64 |
+
|
65 |
+
def boundary_encoded(self, encodings, batch_boundary):
|
66 |
+
batch_boundary_encoded = []
|
67 |
+
for batch_id, span_labels in enumerate(batch_boundary):
|
68 |
+
boundary_encoded = []
|
69 |
+
end = 0
|
70 |
+
for boundary in span_labels:
|
71 |
+
end += boundary
|
72 |
+
|
73 |
+
encoded_end = encodings[batch_id].char_to_token(end-1)
|
74 |
+
|
75 |
+
#
|
76 |
+
tmp_end = end
|
77 |
+
while encoded_end==None and tmp_end>0:
|
78 |
+
tmp_end-=1
|
79 |
+
encoded_end = encodings[batch_id].char_to_token(tmp_end-1)
|
80 |
+
if end!=None: encoded_end+=1
|
81 |
+
|
82 |
+
if encoded_end>self.encode_max_length:
|
83 |
+
boundary_encoded.append(self.encode_max_length)
|
84 |
+
break
|
85 |
+
else:
|
86 |
+
boundary_encoded.append(encoded_end)
|
87 |
+
for i in range(len(boundary_encoded)-1, 0, -1):
|
88 |
+
boundary_encoded[i]=boundary_encoded[i]-boundary_encoded[i-1]
|
89 |
+
|
90 |
+
batch_boundary_encoded.append(boundary_encoded)
|
91 |
+
return batch_boundary_encoded
|
92 |
+
def cal_agreement_span(self, agreement_table, min_agree=2, max_agree=3):
|
93 |
+
"""
|
94 |
+
find the spans from agreement table
|
95 |
+
"""
|
96 |
+
ans_span=[]
|
97 |
+
start, end =(0, 0)
|
98 |
+
pre_p = agreement_table[0]
|
99 |
+
for i, word_agreement in enumerate(agreement_table):
|
100 |
+
curr_p = word_agreement
|
101 |
+
if curr_p != pre_p:
|
102 |
+
if start != end: ans_span.append([start, end])
|
103 |
+
start=i
|
104 |
+
end=i
|
105 |
+
pre_p = curr_p
|
106 |
+
if word_agreement<min_agree:
|
107 |
+
start+=1
|
108 |
+
if word_agreement<=max_agree:
|
109 |
+
end+=1
|
110 |
+
#print([start, end])
|
111 |
+
pre_p = curr_p
|
112 |
+
if start != end: ans_span.append([start, end])
|
113 |
+
#print(ans_span)
|
114 |
+
if len(ans_span)<=1 or min_agree == max_agree:
|
115 |
+
return ans_span
|
116 |
+
#span 合併
|
117 |
+
span_concate = []
|
118 |
+
start, end = [ans_span[0][0], ans_span[0][1]]
|
119 |
+
for span_id in range(1, len(ans_span)):
|
120 |
+
if ans_span[span_id-1][1]==ans_span[span_id][0]:
|
121 |
+
ans_span[span_id]=[ans_span[span_id-1][0], ans_span[span_id][1]]
|
122 |
+
if span_id==len(ans_span)-1: span_concate.append(ans_span[span_id])
|
123 |
+
#span_concate.append()
|
124 |
+
elif span_id==len(ans_span)-1:
|
125 |
+
span_concate.extend([ans_span[span_id-1], ans_span[span_id]])
|
126 |
+
else:
|
127 |
+
span_concate.append(ans_span[span_id-1])
|
128 |
+
return span_concate
|
129 |
+
|
130 |
+
def collate_fn(self, batch_sample):
|
131 |
+
batch_text = []
|
132 |
+
batch_table = []
|
133 |
+
batch_span_label= []
|
134 |
+
seq_lens = []
|
135 |
+
for sample in batch_sample:
|
136 |
+
batch_text.append(sample['original_text'])
|
137 |
+
batch_table.append(self.create_label_list(sample["span_labels"], len(sample['original_text'])))
|
138 |
+
#batch_boundary = [sample['data_len_c'] for sample in batch_sample]
|
139 |
+
batch_span_label.append(sample["span_labels"])
|
140 |
+
seq_lens.append(len(sample['original_text']))
|
141 |
+
self.batch_max_lenght = max(seq_lens)
|
142 |
+
if self.batch_max_lenght > self.encode_max_length : self.batch_max_lenght = self.encode_max_length
|
143 |
+
|
144 |
+
encoded = self.tokenizer(batch_text, truncation=True, max_length=512, padding='max_length', stride=self.stride, return_overflowing_tokens=True, return_tensors="pt", return_offsets_mapping=True)
|
145 |
+
#encoded = self.tokenizer(batch_text, truncation=True, padding=True, return_tensors="pt", max_length=self.max_length)
|
146 |
+
|
147 |
+
encoded['labels'], batch_encode_seq_lens = self.encode_lable(encoded, batch_table)
|
148 |
+
encoded["crf_mask"] = self.create_crf_mask(batch_encode_seq_lens)
|
149 |
+
#encoded["boundary"] = batch_boundary
|
150 |
+
#encoded["boundary_encode"] = self.boundary_encoded(encoded, batch_boundary)
|
151 |
+
encoded["span_labels"] = batch_span_label
|
152 |
+
encoded["batch_text"] = batch_text
|
153 |
+
return encoded
|
code/models.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional, CrossEntropyLoss, Softmax
|
3 |
+
from torchcrf import CRF
|
4 |
+
from transformers import RobertaModel, BertModel
|
5 |
+
|
6 |
+
from args import args, config
|
7 |
+
class Model_Crf(torch.nn.Module):
|
8 |
+
def __init__(self, config):
|
9 |
+
super(Model_Crf, self).__init__()
|
10 |
+
self.bert = BertModel.from_pretrained(args.pre_model_name)
|
11 |
+
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
|
12 |
+
self.classifier = torch.nn.Linear(config.hidden_size, args.label_size)
|
13 |
+
self.crf = CRF(num_tags=args.label_size, batch_first=True)
|
14 |
+
|
15 |
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, context_mask=None, labels=None, span_labels=None, start_positions=None, end_positions=None, testing=False, crf_mask=None):
|
16 |
+
outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
|
17 |
+
sequence_output = outputs[0]
|
18 |
+
sequence_output = self.dropout(sequence_output)
|
19 |
+
sequence_output = sequence_output[:,1:-1,:] #remove [CLS], [SEP]
|
20 |
+
logits = self.classifier(sequence_output)#[batch, max_len, label_size]
|
21 |
+
outputs = (logits,)
|
22 |
+
if labels is not None:
|
23 |
+
#print('logits = ', logits.size())
|
24 |
+
#print('labels = ', labels.size())
|
25 |
+
#print('crf_mask = ', crf_mask.size())
|
26 |
+
loss = self.crf(emissions = logits, tags=labels, mask = crf_mask, reduction="mean")
|
27 |
+
outputs =(-1*loss,)+outputs
|
28 |
+
return outputs
|
29 |
+
|
30 |
+
class Model_Softmax(torch.nn.Module):
|
31 |
+
def __init__(self, config):
|
32 |
+
super(Model_Softmax, self).__init__()
|
33 |
+
self.bert = BertModel.from_pretrained(args.pre_model_name)
|
34 |
+
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
|
35 |
+
self.classifier = torch.nn.Linear(config.hidden_size, args.label_size)
|
36 |
+
self.loss_calculater = CrossEntropyLoss()
|
37 |
+
self.softmax = Softmax(dim=-1)
|
38 |
+
|
39 |
+
def forward(self, input_ids, token_type_ids=None, attention_mask=None, context_mask=None, labels=None, span_labels=None, start_positions=None, end_positions=None, testing=False, crf_mask=None):
|
40 |
+
outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
|
41 |
+
sequence_output = outputs[0]
|
42 |
+
sequence_output = self.dropout(sequence_output)
|
43 |
+
sequence_output = sequence_output[:,1:-1,:] #remove [CLS], [SEP]
|
44 |
+
logits = self.classifier(sequence_output)#[batch, max_len, label_size]
|
45 |
+
logits = self.softmax(logits)
|
46 |
+
outputs = (logits,)
|
47 |
+
if labels is not None:
|
48 |
+
#print('logits = ', logits.size())
|
49 |
+
#print('labels = ', labels.size())
|
50 |
+
labels = functional.one_hot(labels, num_classes=args.label_size).float()
|
51 |
+
loss = self.loss_calculater(logits, labels)
|
52 |
+
outputs =(loss,)+outputs
|
53 |
+
return outputs
|
code/prediction.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from args import args, config
|
3 |
+
from tqdm import tqdm
|
4 |
+
from items_dataset import items_dataset
|
5 |
+
|
6 |
+
def test_predict(test_loader, device, model, min_label=1, max_label=3):
|
7 |
+
model.eval()
|
8 |
+
result = []
|
9 |
+
|
10 |
+
for i, test_batch in enumerate(tqdm(test_loader)):
|
11 |
+
batch_text = test_batch['batch_text']
|
12 |
+
input_ids = test_batch['input_ids'].to(device)
|
13 |
+
token_type_ids = test_batch['token_type_ids'].to(device)
|
14 |
+
attention_mask = test_batch['attention_mask'].to(device)
|
15 |
+
#labels = test_batch['labels'].to(device)
|
16 |
+
crf_mask = test_batch["crf_mask"].to(device)
|
17 |
+
sample_mapping = test_batch["overflow_to_sample_mapping"]
|
18 |
+
output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=None, crf_mask=crf_mask)
|
19 |
+
if args.use_crf:
|
20 |
+
prediction = model.crf.decode(output[0], crf_mask)
|
21 |
+
else:
|
22 |
+
prediction = torch.max(output[0], -1).indices
|
23 |
+
|
24 |
+
#make result of every sample
|
25 |
+
sample_id = -1
|
26 |
+
sample_result= {"text_a" : test_batch['batch_text'][0]}
|
27 |
+
for batch_id in range(len(sample_mapping)):
|
28 |
+
change_sample = False
|
29 |
+
if sample_id != sample_mapping[batch_id]: change_sample = True
|
30 |
+
#print(i, id)
|
31 |
+
if change_sample:
|
32 |
+
sample_id = sample_mapping[batch_id]
|
33 |
+
sample_result= {"text_a" : test_batch['batch_text'][sample_id]}
|
34 |
+
decode_span_table = torch.zeros(len(test_batch['batch_text'][sample_id]))
|
35 |
+
|
36 |
+
spans = items_dataset.cal_agreement_span(None, agreement_table=prediction[batch_id], min_agree=min_label, max_agree=max_label)
|
37 |
+
#decode spans
|
38 |
+
for span in spans:
|
39 |
+
#print(span)
|
40 |
+
if span[0]==0: span[0]+=1
|
41 |
+
if span[1]==1: span[1]+=1
|
42 |
+
|
43 |
+
while(True):
|
44 |
+
start = test_batch[batch_id].token_to_chars(span[0])
|
45 |
+
if start != None or span[0]>=span[1]:
|
46 |
+
break
|
47 |
+
span[0]+=1
|
48 |
+
|
49 |
+
while(True):
|
50 |
+
end = test_batch[batch_id].token_to_chars(span[1])
|
51 |
+
if end != None or span[0]>=span[1]:
|
52 |
+
break
|
53 |
+
span[1]-=1
|
54 |
+
|
55 |
+
if span[0]<span[1]:
|
56 |
+
de_start = test_batch[batch_id].token_to_chars(span[0])[0]
|
57 |
+
de_end = test_batch[batch_id].token_to_chars(span[1]-1)[0]
|
58 |
+
#print(de_start, de_end)
|
59 |
+
#if(de_start>512): print(de_start, de_end)
|
60 |
+
decode_span_table[de_start:de_end]=2 #insite
|
61 |
+
decode_span_table[de_start]=1 #begin
|
62 |
+
if change_sample:
|
63 |
+
sample_result["predict_span_table"] = decode_span_table
|
64 |
+
#sample_result["boundary"] = test_batch["boundary"][id]
|
65 |
+
result.append(sample_result)
|
66 |
+
model.train()
|
67 |
+
return result
|
68 |
+
|
69 |
+
def add_sentence_table(result):
|
70 |
+
|
71 |
+
pattern =":;。,?!~!: "
|
72 |
+
for sample in result:
|
73 |
+
boundary_list = []
|
74 |
+
for i, char in enumerate(sample['text_a']):
|
75 |
+
if char in pattern:
|
76 |
+
boundary_list.append(i)
|
77 |
+
boundary_list.append(len(sample['text_a'])+1)
|
78 |
+
start=0
|
79 |
+
end =0
|
80 |
+
pre_states =False
|
81 |
+
sample["predict_sentence_table"] = torch.zeros(len(sample["predict_span_table"]))
|
82 |
+
for boundary in boundary_list:
|
83 |
+
end = boundary
|
84 |
+
if(sum(sample["predict_span_table"][start:end])>0):
|
85 |
+
if pre_states:
|
86 |
+
sample["predict_sentence_table"][start-1:end] = 2
|
87 |
+
else:
|
88 |
+
sample["predict_sentence_table"][start:end] = 2
|
89 |
+
sample["predict_sentence_table"][start] = 1
|
90 |
+
pre_states=True
|
91 |
+
else: pre_states =False
|
92 |
+
start = end+1
|
code/rank.ipynb
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from args import args, config\n",
|
10 |
+
"from items_dataset import items_dataset\n",
|
11 |
+
"from torch.utils.data import DataLoader\n",
|
12 |
+
"from models import Model_Crf, Model_Softmax\n",
|
13 |
+
"from transformers import AutoTokenizer\n",
|
14 |
+
"from tqdm import tqdm\n",
|
15 |
+
"import prediction\n",
|
16 |
+
"import torch\n",
|
17 |
+
"import math"
|
18 |
+
]
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"cell_type": "code",
|
22 |
+
"execution_count": null,
|
23 |
+
"metadata": {},
|
24 |
+
"outputs": [],
|
25 |
+
"source": [
|
26 |
+
"directory = \"../model/\"\n",
|
27 |
+
"model_name = \"roberta_CRF.pt\"\n",
|
28 |
+
"device = torch.device('cuda', 0) if torch.cuda.is_available() else torch.device('cpu')\n",
|
29 |
+
"model = Model_Crf(config).to(device)\n",
|
30 |
+
"model.load_state_dict(state_dict=torch.load(directory + model_name, map_location=device))"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": null,
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"input_dict = [{\"span_labels\":[]}]\n",
|
40 |
+
"input_dict[0][\"original_text\"] = \"\"\"貓咪犯錯後,以下5種懲罰方法很有效,飼主可以試試樂享網 2021-03-06 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 貓咪雖然高冷,但也是會犯錯的,那貓咪犯錯後,怎麼懲罰它才最有效呢?今天就來說一些懲罰貓咪最有效的5個方法!1、把痛感形成條件反射 這裡說的是「痛感」,而不是「暴打」。在貓咪犯錯後,寵主不需要打它,可以彈鼻頭或者是輕拍它的頭頂,給它造成痛感,這樣讓貓咪有一些畏懼心理,知道你在懲罰它。這樣時間長了,貓咪就會形成條件反射,以後就會少犯錯了。 2、大聲呵斥比起打貓,大聲呵斥貓咪會更加有效。因為貓咪對聲音很敏感,它能從主人的語氣中判斷主人的情緒,當大聲呵斥它的時候,它往往會楞一下,這時你繼續大聲呵斥它,那它就會明白你不允許它做這件事,這樣犯錯地方幾率就會減少了。 3、限制自由限制自由說白了,就是把貓咪關進籠子裡。因為貓咪都是很愛外出玩耍,當它犯錯咯,主人可以把它關進籠子裡,不搭理它,讓它自己反思。但要注意,這個方法不能經常用,而且不能把貓咪關進籠子太久。 4、利用水都知道貓咪很怕水的,所以當貓咪犯錯後,寵主也可以利用水來懲罰貓咪,這也是很效果的方法。寵主可以給貓咪臉上或是頭頂噴一些水,從而讓貓知道這個行為是錯誤的,以後就不會再犯了。 5、冷淡雖然貓咪不是很粘主人,但它還是很愛主人的,所以在貓咪犯錯後,寵主也可以採取冷淡的方法來懲罰貓。對貓咪採取不理睬、不靠近、不擁抱等策略,這樣貓咪就會知道自己錯了。當然懲罰的時間不要太長,不然貓咪就會以為你不愛它了。\"\"\"\n",
|
41 |
+
"tokenizer = AutoTokenizer.from_pretrained(args.pre_model_name, add_prefix_space=True)\n",
|
42 |
+
"prediction_dataset = items_dataset(tokenizer, input_dict, args.label_dict)\n",
|
43 |
+
"prediction_loader = DataLoader(prediction_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=prediction_dataset.collate_fn)\n",
|
44 |
+
"predict_data = prediction.test_predict(prediction_loader, device, model)"
|
45 |
+
]
|
46 |
+
},
|
47 |
+
{
|
48 |
+
"cell_type": "code",
|
49 |
+
"execution_count": null,
|
50 |
+
"metadata": {},
|
51 |
+
"outputs": [],
|
52 |
+
"source": [
|
53 |
+
"prediction.add_sentence_table(predict_data)\n",
|
54 |
+
"print(predict_data[0])"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": null,
|
60 |
+
"metadata": {},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"def prepare_span_data(dataset):\n",
|
64 |
+
" \"\"\"prepare spans labels for each sample\"\"\"\n",
|
65 |
+
" for sample in dataset:\n",
|
66 |
+
" spans = items_dataset.cal_agreement_span(None, agreement_table=sample[\"predict_sentence_table\"], min_agree=1, max_agree=2)\n",
|
67 |
+
" sample[\"span_labels\"] = spans\n",
|
68 |
+
" sample[\"original_text\"] = sample[\"text_a\"]\n",
|
69 |
+
" del sample[\"text_a\"]\n",
|
70 |
+
"prepare_span_data(predict_data)\n",
|
71 |
+
"tokenizer = AutoTokenizer.from_pretrained(args.pre_model_name, add_prefix_space=True)\n",
|
72 |
+
"prediction_dataset = items_dataset(tokenizer, predict_data, args.label_dict)\n",
|
73 |
+
"prediction_loader = DataLoader(prediction_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=prediction_dataset.collate_fn)\n",
|
74 |
+
"\n",
|
75 |
+
"index=0\n",
|
76 |
+
"print(predict_data[index][\"original_text\"])\n",
|
77 |
+
"print(predict_data[index][\"span_labels\"])"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": null,
|
83 |
+
"metadata": {},
|
84 |
+
"outputs": [],
|
85 |
+
"source": [
|
86 |
+
"directory = \"../model/\"\n",
|
87 |
+
"model_name = \"roberta_softmax.pt\"\n",
|
88 |
+
"device = torch.device('cuda', 0) if torch.cuda.is_available() else torch.device('cpu')\n",
|
89 |
+
"model = Model_Softmax(config).to(device)\n",
|
90 |
+
"model.load_state_dict(state_dict=torch.load(directory + model_name, map_location=device))"
|
91 |
+
]
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"execution_count": null,
|
96 |
+
"metadata": {},
|
97 |
+
"outputs": [],
|
98 |
+
"source": [
|
99 |
+
"def rank_spans(test_loader, device, model, reverse=True):\n",
|
100 |
+
" \"\"\"Calculate each span probability by e**(word average log likelihood)\"\"\"\n",
|
101 |
+
" model.eval()\n",
|
102 |
+
" result = []\n",
|
103 |
+
" \n",
|
104 |
+
" for i, test_batch in enumerate(tqdm(test_loader)):\n",
|
105 |
+
" batch_text = test_batch['batch_text']\n",
|
106 |
+
" input_ids = test_batch['input_ids'].to(device)\n",
|
107 |
+
" token_type_ids = test_batch['token_type_ids'].to(device)\n",
|
108 |
+
" attention_mask = test_batch['attention_mask'].to(device)\n",
|
109 |
+
" labels = test_batch['labels']\n",
|
110 |
+
" crf_mask = test_batch[\"crf_mask\"].to(device)\n",
|
111 |
+
" sample_mapping = test_batch[\"overflow_to_sample_mapping\"]\n",
|
112 |
+
" output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=None, crf_mask=crf_mask)\n",
|
113 |
+
" output = torch.nn.functional.softmax(output[0], dim=-1)\n",
|
114 |
+
" \n",
|
115 |
+
" #make result of every sample\n",
|
116 |
+
" sample_id = 0\n",
|
117 |
+
" sample_result= {\"original_text\" : test_batch['batch_text'][sample_id], \"span_ranked\" : []}\n",
|
118 |
+
" for batch_id in range(len(sample_mapping)):\n",
|
119 |
+
" change_sample = False\n",
|
120 |
+
" \n",
|
121 |
+
" #make sure status\n",
|
122 |
+
" if sample_id != sample_mapping[batch_id]: change_sample = True\n",
|
123 |
+
" if change_sample:\n",
|
124 |
+
" sample_id = sample_mapping[batch_id]\n",
|
125 |
+
" result.append(sample_result)\n",
|
126 |
+
" sample_result= {\"original_text\" : test_batch['batch_text'][sample_id], \"span_ranked\" : []}\n",
|
127 |
+
" \n",
|
128 |
+
" encoded_spans = items_dataset.cal_agreement_span(None, agreement_table=labels[batch_id], min_agree=1, max_agree=2)\n",
|
129 |
+
" #print(encoded_spans)\n",
|
130 |
+
" for encoded_span in encoded_spans:\n",
|
131 |
+
" #calculate span loss\n",
|
132 |
+
" span_lenght = encoded_span[1]-encoded_span[0]\n",
|
133 |
+
" #print(span_lenght)\n",
|
134 |
+
" span_prob_table = torch.log(output[batch_id][encoded_span[0]:encoded_span[1]])\n",
|
135 |
+
" if not change_sample and encoded_span[0]==0 and batch_id!=0: #span cross two tensors\n",
|
136 |
+
" span_loss += span_prob_table[0][1] #Begin\n",
|
137 |
+
" else:\n",
|
138 |
+
" span_loss = span_prob_table[0][1] #Begin\n",
|
139 |
+
" for token_id in range(1, span_prob_table.shape[0]):\n",
|
140 |
+
" span_loss+=span_prob_table[token_id][2] #Inside\n",
|
141 |
+
" span_loss /= span_lenght\n",
|
142 |
+
" \n",
|
143 |
+
" #span decode\n",
|
144 |
+
" decode_start = test_batch[batch_id].token_to_chars(encoded_span[0]+1)[0]\n",
|
145 |
+
" decode_end = test_batch[batch_id].token_to_chars(encoded_span[1])[0]+1\n",
|
146 |
+
" #print((decode_start, decode_end))\n",
|
147 |
+
" span_text = test_batch['batch_text'][sample_mapping[batch_id]][decode_start:decode_end]\n",
|
148 |
+
" if not change_sample and encoded_span[0]==0 and batch_id!=0: #span cross two tensors\n",
|
149 |
+
" presample = sample_result[\"span_ranked\"].pop(-1)\n",
|
150 |
+
" sample_result[\"span_ranked\"].append([presample[0]+span_text, math.e**float(span_loss)])\n",
|
151 |
+
" else:\n",
|
152 |
+
" sample_result[\"span_ranked\"].append([span_text, math.e**float(span_loss)])\n",
|
153 |
+
" result.append(sample_result)\n",
|
154 |
+
" \n",
|
155 |
+
" #sorted spans by probability\n",
|
156 |
+
" for sample in result:\n",
|
157 |
+
" sample[\"span_ranked\"] = sorted(sample[\"span_ranked\"], key=lambda x:x[1], reverse=reverse)\n",
|
158 |
+
" return result"
|
159 |
+
]
|
160 |
+
},
|
161 |
+
{
|
162 |
+
"cell_type": "code",
|
163 |
+
"execution_count": null,
|
164 |
+
"metadata": {},
|
165 |
+
"outputs": [],
|
166 |
+
"source": [
|
167 |
+
"span_ranked = rank_spans(prediction_loader, device, model)\n",
|
168 |
+
"for sample in span_ranked:\n",
|
169 |
+
" print(sample[\"original_text\"])\n",
|
170 |
+
" print(sample[\"span_ranked\"])"
|
171 |
+
]
|
172 |
+
}
|
173 |
+
],
|
174 |
+
"metadata": {
|
175 |
+
"kernelspec": {
|
176 |
+
"display_name": "for_project",
|
177 |
+
"language": "python",
|
178 |
+
"name": "python3"
|
179 |
+
},
|
180 |
+
"language_info": {
|
181 |
+
"codemirror_mode": {
|
182 |
+
"name": "ipython",
|
183 |
+
"version": 3
|
184 |
+
},
|
185 |
+
"file_extension": ".py",
|
186 |
+
"mimetype": "text/x-python",
|
187 |
+
"name": "python",
|
188 |
+
"nbconvert_exporter": "python",
|
189 |
+
"pygments_lexer": "ipython3",
|
190 |
+
"version": "3.9.13"
|
191 |
+
},
|
192 |
+
"orig_nbformat": 4,
|
193 |
+
"vscode": {
|
194 |
+
"interpreter": {
|
195 |
+
"hash": "7d6017e34087523a14d1e41a3fef2927de5697dc5dbb9b7906df99909cc5c8a1"
|
196 |
+
}
|
197 |
+
}
|
198 |
+
},
|
199 |
+
"nbformat": 4,
|
200 |
+
"nbformat_minor": 2
|
201 |
+
}
|