Vitrous commited on
Commit
01f12ae
·
verified ·
1 Parent(s): 2a47b1a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -1
app.py CHANGED
@@ -16,7 +16,45 @@ app = FastAPI(root_path="/api/v1")
16
 
17
  # Load the model and tokenizer
18
  model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
 
19
  conversations = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def load_model_norm():
22
  """
@@ -35,7 +73,7 @@ def load_model_norm():
35
 
36
  return model, tokenizer
37
 
38
- model, tokenizer = load_model_norm()
39
 
40
 
41
  def generate_response(msg_prompt: str) -> dict:
@@ -192,3 +230,16 @@ async def get_response(thread_id: int):
192
  response = thread['responses'][-1]
193
 
194
  return {'response': response}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Load the model and tokenizer
18
  model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
19
+ mistral_model="mistralai/Mistral-7B-Instruct-v0.2"
20
  conversations = {}
21
+ device = "cuda" # the device to load the model onto
22
+
23
+ def mistral_model():
24
+ """
25
+ Loads the GPT-3.5 model and tokenizer.
26
+
27
+ Returns:
28
+ tuple: A tuple containing the loaded model and tokenizer.
29
+ """
30
+ if torch.cuda.is_available():
31
+ print("CUDA is available. GPU will be used.")
32
+ else:
33
+ print("CUDA is not available. CPU will be used.")
34
+
35
+ model = AutoModelForCausalLM.from_pretrained(mistral_chat)
36
+
37
+ tokenizer = AutoTokenizer.from_pretrained(mistral_chat)
38
+
39
+ return model,tokenizer
40
+
41
+ model, tokenizer = mistral_model()
42
+
43
+ def mistral_generated_response(msg_prompt, persona_desc_prompt):
44
+ user_prompt = f'{msg_prompt} [/INST]'
45
+ persona_prompt = f'{persona_desc_prompt} [/INST]'
46
+ prompt_template = f'''### [INST] Instruction:{persona_prompt} [INST] {user_prompt}'''
47
+
48
+ encodeds = tokenizer.apply_chat_template(prompt_template, return_tensors="pt")
49
+
50
+ model_inputs = encodeds.to(device)
51
+ model.to(device)
52
+ generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
53
+ decoded = tokenizer.batch_decode(generated_ids)
54
+ response = (decoded[0])
55
+ return response
56
+
57
+
58
 
59
  def load_model_norm():
60
  """
 
73
 
74
  return model, tokenizer
75
 
76
+ #model, tokenizer = load_model_norm()
77
 
78
 
79
  def generate_response(msg_prompt: str) -> dict:
 
230
  response = thread['responses'][-1]
231
 
232
  return {'response': response}
233
+
234
+ @app.post("/mistral_chat")
235
+ async def mistral_chat(prompt: dict):
236
+ try:
237
+ msg_prompt = prompt.get("msg_prompt")
238
+ persona_desc_prompt = prompt.get("persona_desc_prompt")
239
+ if not msg_prompt or not persona_desc_prompt:
240
+ return {"error": "msg_prompt and persona_desc_prompt are required fields."}
241
+
242
+ response = mistral_generated_response(msg_prompt, persona_desc_prompt)
243
+ return {"response": response, "prompt": {"msg_prompt": msg_prompt, "persona_desc_prompt": persona_desc_prompt}}
244
+ except Exception as e:
245
+ return {"error": str(e)}