bloom-daliy-dialogue-english / reconstructor.py
svjack's picture
Upload 3 files
7a71994
raw
history blame
1.2 kB
from predict import *
from transformers import (
T5ForConditionalGeneration,
T5TokenizerFast as T5Tokenizer,
)
import jieba.posseg as posseg
model_path = "svjack/T5-dialogue-collect-v5"
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
rec_obj = Obj(model, tokenizer)
def process_one_sent(input_):
assert type(input_) == type("")
input_ = " ".join(map(lambda y: y.word.strip() ,filter(lambda x: x.flag != "x" ,
posseg.lcut(input_))))
return input_
def predict_split(sp_list, cut_tokens = True):
assert type(sp_list) == type([])
if cut_tokens:
src_text = '''
根据下面的上下文进行分段:
上下文:{}
答案:
'''.format(" ".join(
map(process_one_sent ,sp_list)
))
else:
src_text = '''
根据下面的上下文进行分段:
上下文:{}
答案:
'''.format("".join(sp_list))
print(src_text)
pred = rec_obj.predict(src_text)[0]
pred = list(filter(lambda y: y ,map(lambda x: x.strip() ,pred.split("分段:"))))
return pred