File size: 2,240 Bytes
1e3b2d2
 
9eafe2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
pip install transformers

import transformers

from transformers.pipelines.token_classification import TokenClassificationPipeline

class MyPipeline(TokenClassificationPipeline):
  def preprocess(self, sentence, offset_mapping=None):
      truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
      model_inputs = self.tokenizer(
          sentence,
          return_tensors=self.framework,
          truncation=truncation,
          return_special_tokens_mask=True,
          return_offsets_mapping=self.tokenizer.is_fast,
      )


      length = len(model_inputs['input_ids'][0]) - 2
      tokens = tokenizer.tokenize(sentence)
      seek = 0
      offset_mapping_list = [[(0, 0)]]
      for i in range(length):
        if tokens[i][-2:] == '@@':
          offset_mapping_list[0].append((seek, seek + len(tokens[i]) - 2))
          seek += len(tokens[i]) - 2            
        else:
          offset_mapping_list[0].append((seek, seek + len(tokens[i])))
          seek += len(tokens[i]) + 1
      offset_mapping_list[0].append((0, 0))

      # if offset_mapping:
      #     model_inputs["offset_mapping"] = offset_mapping

      model_inputs['offset_mapping'] = offset_mapping_list
      model_inputs["sentence"] = sentence

      return model_inputs

model_checkpoint = "DD0101/disfluency-base"

my_classifier = pipeline(
  "token-classification", model=model_checkpoint, aggregation_strategy="simple", pipeline_class=MyPipeline)


import gradio as gr

def ner(text):
  output = my_classifier(text)
  for entity in output:
    entity['entity'] = entity.pop('entity_group')
  return {'text': text, 'entities': output}

examples = dataset['test'][:10]['text']

demo = gr.Interface(ner, 
                    gr.Textbox(label='Text', placeholder="Enter sentence here..."),
                    gr.HighlightedText(label='Highlighted Output'),
                    examples=examples,
                    title="Disfluency Detection",
                    description="This is an easy-to-use built in Gradio for desmontrating a NER System that identifies disfluency-entities in \
                    Vietnamese utterances",
                    theme=gr.themes.Soft())

demo.launch()