codelion commited on
Commit
1d952a0
1 Parent(s): 03eca98

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +16 -0
README.md CHANGED
@@ -54,6 +54,22 @@ class OptILMClassifier(nn.Module):
54
  logits = self.classifier(combined_input)
55
  return logits
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def preprocess_input(tokenizer, system_prompt, initial_query):
58
  combined_input = f"{system_prompt}\n\nUser: {initial_query}"
59
  encoding = tokenizer.encode_plus(
 
54
  logits = self.classifier(combined_input)
55
  return logits
56
 
57
+
58
+ def load_optillm_model():
59
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
60
+ # Load the base model
61
+ base_model = AutoModel.from_pretrained("google-bert/bert-large-uncased")
62
+ # Create the OptILMClassifier
63
+ model = OptILMClassifier(base_model, num_labels=len(APPROACHES))
64
+ model.to(device)
65
+ # Download the safetensors file
66
+ safetensors_path = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
67
+ # Load the state dict from the safetensors file
68
+ load_model(model, safetensors_path)
69
+
70
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
71
+ return model, tokenizer, device
72
+
73
  def preprocess_input(tokenizer, system_prompt, initial_query):
74
  combined_input = f"{system_prompt}\n\nUser: {initial_query}"
75
  encoding = tokenizer.encode_plus(