raheemuddin commited on
Commit
283e40c
·
1 Parent(s): 4cf8bc3

test with dbert-base

Browse files
Files changed (3) hide show
  1. handler.py +10 -3
  2. model.onnx +0 -3
  3. requirements.txt +2 -1
handler.py CHANGED
@@ -1,22 +1,29 @@
1
  from typing import Dict, List, Any
 
 
 
 
2
  # from optimum.onnxruntime import ORTModelForSequenceClassification
3
  from transformers import AutoModel
4
  # from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
 
6
  from transformers import pipeline, AutoTokenizer
7
 
8
- checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
 
9
 
10
  class EndpointHandler():
11
 
12
  def __init__(self, path=""):
13
  # load the optimized model
14
  # model = ORTModelForSequenceClassification.from_pretrained(path)
15
- model = AutoModel.from_pretrained(checkpoint)
16
  # model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
17
 
18
  # tokenizer = AutoTokenizer.from_pretrained(path)
19
- tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=checkpoint)
 
 
20
 
21
  # create inference pipeline
22
  self.pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
 
1
  from typing import Dict, List, Any
2
+ # import torch
3
+ from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
4
+
5
+
6
  # from optimum.onnxruntime import ORTModelForSequenceClassification
7
  from transformers import AutoModel
8
  # from transformers import AutoModelForSequenceClassification, AutoTokenizer
9
 
10
  from transformers import pipeline, AutoTokenizer
11
 
12
+ # checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
13
+ checkpoint = "distilbert-base-uncased"
14
 
15
  class EndpointHandler():
16
 
17
  def __init__(self, path=""):
18
  # load the optimized model
19
  # model = ORTModelForSequenceClassification.from_pretrained(path)
20
+ # model = AutoModel.from_pretrained(checkpoint)
21
  # model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
22
 
23
  # tokenizer = AutoTokenizer.from_pretrained(path)
24
+ # tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=checkpoint)
25
+ model = DistilBertForSequenceClassification.from_pretrained(checkpoint)
26
+ tokenizer = DistilBertTokenizer.from_pretrained(checkpoint)
27
 
28
  # create inference pipeline
29
  self.pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
model.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4ff21bc94e07b6a5b20db9d5ffc692adb220989684a59aee0a11b78607dc4560
3
- size 172529172
 
 
 
 
requirements.txt CHANGED
@@ -5,4 +5,5 @@ mkl
5
  # spacy
6
  transformers
7
  datasets
8
- evaluate
 
 
5
  # spacy
6
  transformers
7
  datasets
8
+ evaluate
9
+ # torch