Spaces:
Runtime error
Runtime error
from args import args, config | |
from items_dataset import items_dataset | |
from torch.utils.data import DataLoader | |
from models import Model_Crf, Model_Softmax | |
from transformers import AutoTokenizer | |
from tqdm import tqdm | |
import prediction | |
import torch | |
import math | |
directory = args.SAVE_MODEL_PATH | |
model_name = "roberta_CRF.pt" | |
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu") | |
model_crf = Model_Crf(config).to(device) | |
model_crf.load_state_dict( | |
state_dict=torch.load(directory + model_name, map_location=device) | |
) | |
model_name = "roberta_softmax.pt" | |
device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu") | |
model_roberta = Model_Softmax(config).to(device) | |
model_roberta.load_state_dict( | |
state_dict=torch.load(directory + model_name, map_location=device) | |
) | |
def prepare_span_data(dataset): | |
for sample in dataset: | |
spans = items_dataset.cal_agreement_span( | |
None, | |
agreement_table=sample["predict_sentence_table"], | |
min_agree=1, | |
max_agree=2, | |
) | |
sample["span_labels"] = spans | |
sample["original_text"] = sample["text_a"] | |
del sample["text_a"] | |
def rank_spans(test_loader, device, model, reverse=True): | |
"""Calculate each span probability by e**(word average log likelihood)""" | |
model.eval() | |
result = [] | |
for i, test_batch in enumerate(tqdm(test_loader)): | |
batch_text = test_batch["batch_text"] | |
input_ids = test_batch["input_ids"].to(device) | |
token_type_ids = test_batch["token_type_ids"].to(device) | |
attention_mask = test_batch["attention_mask"].to(device) | |
labels = test_batch["labels"] | |
crf_mask = test_batch["crf_mask"].to(device) | |
sample_mapping = test_batch["overflow_to_sample_mapping"] | |
output = model( | |
input_ids=input_ids, | |
token_type_ids=token_type_ids, | |
attention_mask=attention_mask, | |
labels=None, | |
crf_mask=crf_mask, | |
) | |
output = torch.nn.functional.softmax(output[0], dim=-1) | |
# make result of every sample | |
sample_id = 0 | |
sample_result = { | |
"original_text": test_batch["batch_text"][sample_id], | |
"span_ranked": [], | |
} | |
for batch_id in range(len(sample_mapping)): | |
change_sample = False | |
# make sure status | |
if sample_id != sample_mapping[batch_id]: | |
change_sample = True | |
if change_sample: | |
sample_id = sample_mapping[batch_id] | |
result.append(sample_result) | |
sample_result = { | |
"original_text": test_batch["batch_text"][sample_id], | |
"span_ranked": [], | |
} | |
encoded_spans = items_dataset.cal_agreement_span( | |
None, agreement_table=labels[batch_id], min_agree=1, max_agree=2 | |
) | |
# print(encoded_spans) | |
for encoded_span in encoded_spans: | |
# calculate span loss | |
span_lenght = encoded_span[1] - encoded_span[0] | |
# print(span_lenght) | |
span_prob_table = torch.log( | |
output[batch_id][encoded_span[0] : encoded_span[1]] | |
) | |
if ( | |
not change_sample and encoded_span[0] == 0 and batch_id != 0 | |
): # span cross two tensors | |
span_loss += span_prob_table[0][1] # Begin | |
else: | |
span_loss = span_prob_table[0][1] # Begin | |
for token_id in range(1, span_prob_table.shape[0]): | |
span_loss += span_prob_table[token_id][2] # Inside | |
span_loss /= span_lenght | |
# span decode | |
decode_start = test_batch[batch_id].token_to_chars(encoded_span[0] + 1)[ | |
0 | |
] | |
decode_end = test_batch[batch_id].token_to_chars(encoded_span[1])[0] + 1 | |
# print((decode_start, decode_end)) | |
span_text = test_batch["batch_text"][sample_mapping[batch_id]][ | |
decode_start:decode_end | |
] | |
if ( | |
not change_sample and encoded_span[0] == 0 and batch_id != 0 | |
): # span cross two tensors | |
presample = sample_result["span_ranked"].pop(-1) | |
sample_result["span_ranked"].append( | |
[presample[0] + span_text, math.e ** float(span_loss)] | |
) | |
else: | |
sample_result["span_ranked"].append( | |
[span_text, math.e ** float(span_loss)] | |
) | |
result.append(sample_result) | |
# sorted spans by probability | |
# for sample in result: | |
# sample["span_ranked"] = sorted( | |
# sample["span_ranked"], key=lambda x: x[1], reverse=reverse | |
# ) | |
return result | |
def predict_single(text): | |
input_dict = [{"span_labels": []}] | |
input_dict[0]["original_text"] = text | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.pre_model_name, add_prefix_space=True | |
) | |
prediction_dataset = items_dataset(tokenizer, input_dict, args.label_dict) | |
prediction_loader = DataLoader( | |
prediction_dataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
collate_fn=prediction_dataset.collate_fn, | |
) | |
predict_data = prediction.test_predict(prediction_loader, device, model_crf) | |
prediction.add_sentence_table(predict_data) | |
prepare_span_data(predict_data) | |
tokenizer = AutoTokenizer.from_pretrained( | |
args.pre_model_name, add_prefix_space=True | |
) | |
prediction_dataset = items_dataset(tokenizer, predict_data, args.label_dict) | |
prediction_loader = DataLoader( | |
prediction_dataset, | |
batch_size=args.batch_size, | |
shuffle=False, | |
collate_fn=prediction_dataset.collate_fn, | |
) | |
span_ranked = rank_spans(prediction_loader, device, model_roberta) | |
# for sample in span_ranked: | |
# print(sample["original_text"]) | |
# print(sample["span_ranked"]) | |
result = [] | |
sample = span_ranked[0] | |
orig = sample["original_text"] | |
cur = 0 | |
for s, score in sample["span_ranked"]: | |
# print() | |
# print('ORIG', repr(orig)) | |
# print('CCUR', repr(orig[cur:])) | |
# print('SSSS', repr(s)) | |
# print() | |
end = orig.index(s, cur) | |
if cur != end: | |
result.append([orig[cur:end], 0]) | |
result.append([s, score]) | |
cur = end + len(s) | |
if cur < len(orig): | |
result.append([orig[cur:], 0]) | |
return result | |
if __name__ == "__main__": | |
s = """貓咪犯錯後,以下5種懲罰方法很有效,飼主可以試試樂享網 2021-03-06 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 貓咪雖然高冷,但也是會犯錯的,那貓咪犯錯後,怎麼懲罰它才最有效呢?今天就來說一些懲罰貓咪最有效的5個方法!1、把痛感形成條件反射 這裡說的是「痛感」,而不是「暴打」。在貓咪犯錯後,寵主不需要打它,可以彈鼻頭或者是輕拍它的頭頂,給它造成痛感,這樣讓貓咪有一些畏懼心理,知道你在懲罰它。這樣時間長了,貓咪就會形成條件反射,以後就會少犯錯了。 2、大聲呵斥比起打貓,大聲呵斥貓咪會更加有效。因為貓咪對聲音很敏感,它能從主人的語氣中判斷主人的情緒,當大聲呵斥它的時候,它往往會楞一下,這時你繼續大聲呵斥它,那它就會明白你不允許它做這件事,這樣犯錯地方幾率就會減少了。 3、限制自由限制自由說白了,就是把貓咪關進籠子裡。因為貓咪都是很愛外出玩耍,當它犯錯咯,主人可以把它關進籠子裡,不搭理它,讓它自己反思。但要注意,這個方法不能經常用,而且不能把貓咪關進籠子太久。 4、利用水都知道貓咪很怕水的,所以當貓咪犯錯後,寵主也可以利用水來懲罰貓咪,這也是很效果的方法。寵主可以給貓咪臉上或是頭頂噴一些水,從而讓貓知道這個行為是錯誤的,以後就不會再犯了。 5、冷淡雖然貓咪不是很粘主人,但它還是很愛主人的,所以在貓咪犯錯後,寵主也可以採取冷淡的方法來懲罰貓。對貓咪採取不理睬、不靠近、不擁抱等策略,這樣貓咪就會知道自己錯了。當然懲罰的時間不要太長,不然貓咪就會以為你不愛它了。""" | |
print(predict_single(s)) | |