pszemraj commited on
Commit
57a7aa0
1 Parent(s): 84b8f8b

add bettertransformer for CoLA model

Browse files
Files changed (1) hide show
  1. app.py +21 -4
app.py CHANGED
@@ -1,21 +1,38 @@
1
  import re
2
  import os
 
3
 
4
  from cleantext import clean
5
  import gradio as gr
6
  from tqdm.auto import tqdm
7
  from transformers import pipeline
 
8
 
9
 
10
  checker_model_name = "textattack/roberta-base-CoLA"
11
  corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
12
 
13
  # pipelines
14
- checker = pipeline(
15
- "text-classification",
16
- checker_model_name,
17
- )
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  if os.environ.get("HF_DEMO_NO_USE_ONNX") is None:
20
  # load onnx runtime unless HF_DEMO_NO_USE_ONNX is set
21
  from optimum.pipelines import pipeline
 
1
  import re
2
  import os
3
+ import gc
4
 
5
  from cleantext import clean
6
  import gradio as gr
7
  from tqdm.auto import tqdm
8
  from transformers import pipeline
9
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
10
 
11
 
12
  checker_model_name = "textattack/roberta-base-CoLA"
13
  corrector_model_name = "pszemraj/flan-t5-large-grammar-synthesis"
14
 
15
  # pipelines
 
 
 
 
16
 
17
+
18
+ if os.environ.get("HF_DEMO_NO_USE_ONNX") is None:
19
+ from optimum.bettertransformer import BetterTransformer
20
+
21
+ model_hf = AutoModelForSequenceClassification.from_pretrained(checker_model_name)
22
+ tokenizer = AutoTokenizer.from_pretrained(checker_model_name)
23
+ model = BetterTransformer.transform(model_hf, keep_original_model=False)
24
+
25
+ checker = pipeline(
26
+ "text-classification",
27
+ model=model,
28
+ tokenizer=tokenizer,
29
+ )
30
+ else:
31
+ checker = pipeline(
32
+ "text-classification",
33
+ checker_model_name,
34
+ )
35
+ gc.collect()
36
  if os.environ.get("HF_DEMO_NO_USE_ONNX") is None:
37
  # load onnx runtime unless HF_DEMO_NO_USE_ONNX is set
38
  from optimum.pipelines import pipeline