DD0101 commited on
Commit
9eafe2a
1 Parent(s): a3a3ecf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+
3
+ from transformers.pipelines.token_classification import TokenClassificationPipeline
4
+
5
+ class MyPipeline(TokenClassificationPipeline):
6
+ def preprocess(self, sentence, offset_mapping=None):
7
+ truncation = True if self.tokenizer.model_max_length and self.tokenizer.model_max_length > 0 else False
8
+ model_inputs = self.tokenizer(
9
+ sentence,
10
+ return_tensors=self.framework,
11
+ truncation=truncation,
12
+ return_special_tokens_mask=True,
13
+ return_offsets_mapping=self.tokenizer.is_fast,
14
+ )
15
+
16
+
17
+ length = len(model_inputs['input_ids'][0]) - 2
18
+ tokens = tokenizer.tokenize(sentence)
19
+ seek = 0
20
+ offset_mapping_list = [[(0, 0)]]
21
+ for i in range(length):
22
+ if tokens[i][-2:] == '@@':
23
+ offset_mapping_list[0].append((seek, seek + len(tokens[i]) - 2))
24
+ seek += len(tokens[i]) - 2
25
+ else:
26
+ offset_mapping_list[0].append((seek, seek + len(tokens[i])))
27
+ seek += len(tokens[i]) + 1
28
+ offset_mapping_list[0].append((0, 0))
29
+
30
+ # if offset_mapping:
31
+ # model_inputs["offset_mapping"] = offset_mapping
32
+
33
+ model_inputs['offset_mapping'] = offset_mapping_list
34
+ model_inputs["sentence"] = sentence
35
+
36
+ return model_inputs
37
+
38
+ model_checkpoint = "DD0101/disfluency-base"
39
+
40
+ my_classifier = pipeline(
41
+ "token-classification", model=model_checkpoint, aggregation_strategy="simple", pipeline_class=MyPipeline)
42
+
43
+
44
+ import gradio as gr
45
+
46
+ def ner(text):
47
+ output = my_classifier(text)
48
+ for entity in output:
49
+ entity['entity'] = entity.pop('entity_group')
50
+ return {'text': text, 'entities': output}
51
+
52
+ examples = dataset['test'][:10]['text']
53
+
54
+ demo = gr.Interface(ner,
55
+ gr.Textbox(label='Text', placeholder="Enter sentence here..."),
56
+ gr.HighlightedText(label='Highlighted Output'),
57
+ examples=examples,
58
+ title="Disfluency Detection",
59
+ description="This is an easy-to-use built in Gradio for desmontrating a NER System that identifies disfluency-entities in \
60
+ Vietnamese utterances",
61
+ theme=gr.themes.Soft())
62
+
63
+ demo.launch()
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+