Joblib
English
llm
human-feedback
weak supervision
data filtering
Inference Endpoints
Christopher Glaze commited on
Commit
c5d744a
·
1 Parent(s): 6440c9b

Update device for simcse generator

Browse files
Files changed (1) hide show
  1. handler.py +7 -9
handler.py CHANGED
@@ -1,5 +1,5 @@
1
 
2
- from typing import Dict, List, Any, Union, Optional
3
  from pathlib import Path
4
  import json
5
  import joblib
@@ -12,11 +12,12 @@ from sklearn.base import TransformerMixin
12
 
13
  class SimcseGenerator(TransformerMixin):
14
  def __init__(
15
- self, device: str ='cpu', batch_size: int =16, model_name: str = "princeton-nlp/unsup-simcse-bert-base-uncased"
16
  ) -> None:
17
 
18
  self.model_name = model_name
19
- self.device = torch.device(device)
 
20
 
21
  tokenizer = AutoTokenizer.from_pretrained(model_name)
22
  model = AutoModel.from_pretrained(model_name).to(self.device)
@@ -53,13 +54,10 @@ class SimcseGenerator(TransformerMixin):
53
  return embeddings
54
 
55
  class EndpointHandler():
56
- def __init__(self, device: str = "cpu"):
57
- # Preload all the elements you are going to need at inference.
58
- # pseudo:
59
- # self.model= load_model(path)
60
 
61
  local_path = Path(__file__).parent
62
- self.device = device
63
  with open(local_path/'stop_words.json','r') as fp:
64
  self.stop_words = set(json.load(fp))
65
 
@@ -70,7 +68,7 @@ class EndpointHandler():
70
  self.instruction_pipeline = joblib.load(local_path/'instruction_classification_pipeline.joblib')
71
  self.response_pipeline = joblib.load(local_path/'response_quality_pipeline.joblib')
72
 
73
- self.simcse_generator = SimcseGenerator(device=self.device)
74
 
75
  def _get_stop_word_proportion(self, s):
76
  s = s.lower()
 
1
 
2
+ from typing import Dict, Union, Optional
3
  from pathlib import Path
4
  import json
5
  import joblib
 
12
 
13
  class SimcseGenerator(TransformerMixin):
14
  def __init__(
15
+ self, batch_size: int =16, model_name: str = "princeton-nlp/unsup-simcse-bert-base-uncased"
16
  ) -> None:
17
 
18
  self.model_name = model_name
19
+
20
+ self.device = torch.device('cpu')
21
 
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
  model = AutoModel.from_pretrained(model_name).to(self.device)
 
54
  return embeddings
55
 
56
  class EndpointHandler():
57
+ def __init__(self):
 
 
 
58
 
59
  local_path = Path(__file__).parent
60
+
61
  with open(local_path/'stop_words.json','r') as fp:
62
  self.stop_words = set(json.load(fp))
63
 
 
68
  self.instruction_pipeline = joblib.load(local_path/'instruction_classification_pipeline.joblib')
69
  self.response_pipeline = joblib.load(local_path/'response_quality_pipeline.joblib')
70
 
71
+ self.simcse_generator = SimcseGenerator()
72
 
73
  def _get_stop_word_proportion(self, s):
74
  s = s.lower()