Spaces:
Runtime error
Runtime error
File size: 8,654 Bytes
9fdc3cc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
from args import args, config
from items_dataset import items_dataset
from torch.utils.data import DataLoader
from models import Model_Crf, Model_Softmax
from transformers import AutoTokenizer
from tqdm import tqdm
import prediction
import torch
import math
directory = args.SAVE_MODEL_PATH
model_name = "roberta_CRF.pt"
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
model_crf = Model_Crf(config).to(device)
model_crf.load_state_dict(
state_dict=torch.load(directory + model_name, map_location=device)
)
model_name = "roberta_softmax.pt"
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
model_roberta = Model_Softmax(config).to(device)
model_roberta.load_state_dict(
state_dict=torch.load(directory + model_name, map_location=device)
)
def prepare_span_data(dataset):
for sample in dataset:
spans = items_dataset.cal_agreement_span(
None,
agreement_table=sample["predict_sentence_table"],
min_agree=1,
max_agree=2,
)
sample["span_labels"] = spans
sample["original_text"] = sample["text_a"]
del sample["text_a"]
def rank_spans(test_loader, device, model, reverse=True):
"""Calculate each span probability by e**(word average log likelihood)"""
model.eval()
result = []
for i, test_batch in enumerate(tqdm(test_loader)):
batch_text = test_batch["batch_text"]
input_ids = test_batch["input_ids"].to(device)
token_type_ids = test_batch["token_type_ids"].to(device)
attention_mask = test_batch["attention_mask"].to(device)
labels = test_batch["labels"]
crf_mask = test_batch["crf_mask"].to(device)
sample_mapping = test_batch["overflow_to_sample_mapping"]
output = model(
input_ids=input_ids,
token_type_ids=token_type_ids,
attention_mask=attention_mask,
labels=None,
crf_mask=crf_mask,
)
output = torch.nn.functional.softmax(output[0], dim=-1)
# make result of every sample
sample_id = 0
sample_result = {
"original_text": test_batch["batch_text"][sample_id],
"span_ranked": [],
}
for batch_id in range(len(sample_mapping)):
change_sample = False
# make sure status
if sample_id != sample_mapping[batch_id]:
change_sample = True
if change_sample:
sample_id = sample_mapping[batch_id]
result.append(sample_result)
sample_result = {
"original_text": test_batch["batch_text"][sample_id],
"span_ranked": [],
}
encoded_spans = items_dataset.cal_agreement_span(
None, agreement_table=labels[batch_id], min_agree=1, max_agree=2
)
# print(encoded_spans)
for encoded_span in encoded_spans:
# calculate span loss
span_lenght = encoded_span[1] - encoded_span[0]
# print(span_lenght)
span_prob_table = torch.log(
output[batch_id][encoded_span[0] : encoded_span[1]]
)
if (
not change_sample and encoded_span[0] == 0 and batch_id != 0
): # span cross two tensors
span_loss += span_prob_table[0][1] # Begin
else:
span_loss = span_prob_table[0][1] # Begin
for token_id in range(1, span_prob_table.shape[0]):
span_loss += span_prob_table[token_id][2] # Inside
span_loss /= span_lenght
# span decode
decode_start = test_batch[batch_id].token_to_chars(encoded_span[0] + 1)[
0
]
decode_end = test_batch[batch_id].token_to_chars(encoded_span[1])[0] + 1
# print((decode_start, decode_end))
span_text = test_batch["batch_text"][sample_mapping[batch_id]][
decode_start:decode_end
]
if (
not change_sample and encoded_span[0] == 0 and batch_id != 0
): # span cross two tensors
presample = sample_result["span_ranked"].pop(-1)
sample_result["span_ranked"].append(
[presample[0] + span_text, math.e ** float(span_loss)]
)
else:
sample_result["span_ranked"].append(
[span_text, math.e ** float(span_loss)]
)
result.append(sample_result)
# sorted spans by probability
# for sample in result:
# sample["span_ranked"] = sorted(
# sample["span_ranked"], key=lambda x: x[1], reverse=reverse
# )
return result
def predict_single(text):
input_dict = [{"span_labels": []}]
input_dict[0]["original_text"] = text
tokenizer = AutoTokenizer.from_pretrained(
args.pre_model_name, add_prefix_space=True
)
prediction_dataset = items_dataset(tokenizer, input_dict, args.label_dict)
prediction_loader = DataLoader(
prediction_dataset,
batch_size=args.batch_size,
shuffle=True,
collate_fn=prediction_dataset.collate_fn,
)
predict_data = prediction.test_predict(prediction_loader, device, model_crf)
prediction.add_sentence_table(predict_data)
prepare_span_data(predict_data)
tokenizer = AutoTokenizer.from_pretrained(
args.pre_model_name, add_prefix_space=True
)
prediction_dataset = items_dataset(tokenizer, predict_data, args.label_dict)
prediction_loader = DataLoader(
prediction_dataset,
batch_size=args.batch_size,
shuffle=False,
collate_fn=prediction_dataset.collate_fn,
)
span_ranked = rank_spans(prediction_loader, device, model_roberta)
# for sample in span_ranked:
# print(sample["original_text"])
# print(sample["span_ranked"])
result = []
sample = span_ranked[0]
orig = sample["original_text"]
cur = 0
for s, score in sample["span_ranked"]:
# print()
# print('ORIG', repr(orig))
# print('CCUR', repr(orig[cur:]))
# print('SSSS', repr(s))
# print()
end = orig.index(s, cur)
if cur != end:
result.append([orig[cur:end], 0])
result.append([s, score])
cur = end + len(s)
if cur < len(orig):
result.append([orig[cur:], 0])
return result
if __name__ == "__main__":
s = """貓咪犯錯後,以下5種懲罰方法很有效,飼主可以試試樂享網 2021-03-06 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 貓咪雖然高冷,但也是會犯錯的,那貓咪犯錯後,怎麼懲罰它才最有效呢?今天就來說一些懲罰貓咪最有效的5個方法!1、把痛感形成條件反射 這裡說的是「痛感」,而不是「暴打」。在貓咪犯錯後,寵主不需要打它,可以彈鼻頭或者是輕拍它的頭頂,給它造成痛感,這樣讓貓咪有一些畏懼心理,知道你在懲罰它。這樣時間長了,貓咪就會形成條件反射,以後就會少犯錯了。 2、大聲呵斥比起打貓,大聲呵斥貓咪會更加有效。因為貓咪對聲音很敏感,它能從主人的語氣中判斷主人的情緒,當大聲呵斥它的時候,它往往會楞一下,這時你繼續大聲呵斥它,那它就會明白你不允許它做這件事,這樣犯錯地方幾率就會減少了。 3、限制自由限制自由說白了,就是把貓咪關進籠子裡。因為貓咪都是很愛外出玩耍,當它犯錯咯,主人可以把它關進籠子裡,不搭理它,讓它自己反思。但要注意,這個方法不能經常用,而且不能把貓咪關進籠子太久。 4、利用水都知道貓咪很怕水的,所以當貓咪犯錯後,寵主也可以利用水來懲罰貓咪,這也是很效果的方法。寵主可以給貓咪臉上或是頭頂噴一些水,從而讓貓知道這個行為是錯誤的,以後就不會再犯了。 5、冷淡雖然貓咪不是很粘主人,但它還是很愛主人的,所以在貓咪犯錯後,寵主也可以採取冷淡的方法來懲罰貓。對貓咪採取不理睬、不靠近、不擁抱等策略,這樣貓咪就會知道自己錯了。當然懲罰的時間不要太長,不然貓咪就會以為你不愛它了。"""
print(predict_single(s))
|