Spaces:
Sleeping
Sleeping
import os | |
import transformers | |
from transformers import pipeline | |
from transformers.pipelines.token_classification import TokenClassificationPipeline | |
import py_vncorenlp | |
os.system('pwd') | |
os.system('sudo update-alternatives --config java') | |
os.mkdir('/home/user/app/vncorenlp') | |
py_vncorenlp.download_model(save_dir='/home/user/app/vncorenlp') | |
rdrsegmenter = py_vncorenlp.VnCoreNLP(annotators=["wseg"], save_dir='/home/user/app/vncorenlp') | |
class MyPipeline(TokenClassificationPipeline): | |
def preprocess(self, sentence, offset_mapping=None): | |
truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False | |
model_inputs = self.tokenizer( | |
sentence, | |
return_tensors=self.framework, | |
truncation=truncation, | |
return_special_tokens_mask=True, | |
return_offsets_mapping=self.tokenizer.is_fast, | |
) | |
length = len(model_inputs['input_ids'][0]) - 2 | |
tokens = self.tokenizer.tokenize(sentence) | |
seek = 0 | |
offset_mapping_list = [[(0, 0)]] | |
for i in range(length): | |
if tokens[i][-2:] == '@@': | |
offset_mapping_list[0].append((seek, seek + len(tokens[i]) - 2)) | |
seek += len(tokens[i]) - 2 | |
else: | |
offset_mapping_list[0].append((seek, seek + len(tokens[i]))) | |
seek += len(tokens[i]) + 1 | |
offset_mapping_list[0].append((0, 0)) | |
# if offset_mapping: | |
# model_inputs["offset_mapping"] = offset_mapping | |
model_inputs['offset_mapping'] = offset_mapping_list | |
model_inputs["sentence"] = sentence | |
return model_inputs | |
model_checkpoint = "DD0101/disfluency-large" | |
my_classifier = pipeline( | |
"token-classification", model=model_checkpoint, aggregation_strategy="simple", pipeline_class=MyPipeline) | |
import gradio as gr | |
def ner(text): | |
text = " ".join(rdrsegmenter.word_segment(text)) | |
# Some words in lowercase like "đà nẵng" will get error (due to vncorenlp) | |
text = text.replace("đà ", " đà") | |
output = my_classifier(text) | |
for entity in output: | |
entity['entity'] = entity.pop('entity_group') | |
# Remove Disfluency-entities to return a sentence with "Fluency" version | |
list_str = list(text) | |
for entity in output[::-1]: # if we use default order of output list, we will shorten the length of the sentence, so the words later are not in the correct start and end index | |
start = max(0, entity['start'] - 1) | |
end = min(len(list_str), entity['end'] + 1) | |
list_str[start:end] = ' ' | |
fluency_sentence = "".join(list_str).strip() # use strip() in case we need to remove entity at the beginning or the end of sentence | |
# (without strip(): "Giá vé khứ hồi à nhầm giá vé một chiều ..." -> " giá vé một chiều ...") | |
fluency_sentence = fluency_sentence[0].upper() + fluency_sentence[1:] # since capitalize() just lowercase whole sentence first then uppercase the first letter | |
# Replace words like "Đà_Nẵng" to "Đà Nẵng" | |
text = text.replace("_", " ") | |
fluency_sentence = fluency_sentence.replace("_", " ") | |
return {'text': text, 'entities': output}, fluency_sentence | |
examples = ['Tôi cần thuê à tôi muốn bay một chuyến khứ hồi từ Đà Nẵng đến Đà Lạt', | |
'Giá vé một chiều à không khứ hồi từ Đà Nẵng đến Vinh dưới 2 triệu đồng giá vé khứ hồi từ Quy Nhơn đến Vinh dưới 3 triệu đồng giá vé khứ hồi từ Buôn Ma Thuột đến Quy Nhơn à đến Vinh dưới 4 triệu rưỡi', | |
'Cho tôi biết các chuyến bay đến Đà Nẵng vào ngày 12 mà không ngày 14 tháng sáu', | |
'Những chuyến bay nào khởi hành từ Thành phố Hồ Chí Minh bay đến Frankfurt mà nối chuyến ở Singapore và hạ cánh trước 10 giờ ý tôi là 9 giờ tối', | |
'Thành Phố nào có VNA ừm thôi cho tôi xem tất cả các chuyến bay từ Thanh Hóa hay Nghệ An nhỉ à Thanh Hóa đến Đà Lạt vào Thứ ba à thôi tôi cần vào Thứ hai' | |
] | |
demo = gr.Interface(ner, | |
gr.Textbox(label='Sentence', placeholder="Enter your sentence here..."), | |
outputs=[gr.HighlightedText(label='Highlighted Output'), gr.Textbox(label='"Fluency" version')], | |
examples=examples, | |
title="Disfluency Detection", | |
description="This is an easy-to-use built in Gradio for desmontrating a NER System that identifies disfluency-entities in \ | |
Vietnamese utterances", | |
theme=gr.themes.Soft()) | |
demo.launch() |