import os os.system("pip install transformers") os.system("pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117") import transformers from transformers import pipeline from transformers.pipelines.token_classification import TokenClassificationPipeline class MyPipeline(TokenClassificationPipeline): def preprocess(self, sentence, offset_mapping=None): truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False model_inputs = self.tokenizer( sentence, return_tensors=self.framework, truncation=truncation, return_special_tokens_mask=True, return_offsets_mapping=self.tokenizer.is_fast, ) length = len(model_inputs['input_ids'][0]) - 2 tokens = self.tokenizer.tokenize(sentence) seek = 0 offset_mapping_list = [[(0, 0)]] for i in range(length): if tokens[i][-2:] == '@@': offset_mapping_list[0].append((seek, seek + len(tokens[i]) - 2)) seek += len(tokens[i]) - 2 else: offset_mapping_list[0].append((seek, seek + len(tokens[i]))) seek += len(tokens[i]) + 1 offset_mapping_list[0].append((0, 0)) # if offset_mapping: # model_inputs["offset_mapping"] = offset_mapping model_inputs['offset_mapping'] = offset_mapping_list model_inputs["sentence"] = sentence return model_inputs model_checkpoint = "DD0101/disfluency-base" my_classifier = pipeline( "token-classification", model=model_checkpoint, aggregation_strategy="simple", pipeline_class=MyPipeline) import gradio as gr def ner(text): output = my_classifier(text) for entity in output: entity['entity'] = entity.pop('entity_group') return {'text': text, 'entities': output} examples = ['tôi cần thuê à tôi muốn bay một chuyến khứ_hồi từ đà_nẵng đến đà_lạt', 'giá vé một_chiều à không khứ_hồi từ đà_nẵng đến vinh dưới 2 triệu đồng giá vé khứ_hồi từ quy nhơn đến vinh dưới 3 triệu đồng giá vé khứ_hồi từ buôn_ma_thuột đến quy nhơn à đến vinh dưới 4 triệu rưỡi', 'cho tôi biết các chuyến bay đến đà_nẵng vào ngày 12 mà không ngày 14 tháng sáu', 'những chuyến bay nào khởi_hành từ thành_phố hồ_chí_minh bay đến frankfurt mà nối chuyến ở singapore và hạ_cánh trước 10 giờ ý tôi là 9 giờ tối' ] demo = gr.Interface(ner, gr.Textbox(label='Text', placeholder="Enter sentence here..."), gr.HighlightedText(label='Highlighted Output'), examples=examples, title="Disfluency Detection", description="This is an easy-to-use built in Gradio for desmontrating a NER System that identifies disfluency-entities in \ Vietnamese utterances", theme=gr.themes.Soft()) demo.launch()