Update app.py
Browse files
app.py
CHANGED
@@ -16,7 +16,45 @@ app = FastAPI(root_path="/api/v1")
|
|
16 |
|
17 |
# Load the model and tokenizer
|
18 |
model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
|
|
|
19 |
conversations = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def load_model_norm():
|
22 |
"""
|
@@ -35,7 +73,7 @@ def load_model_norm():
|
|
35 |
|
36 |
return model, tokenizer
|
37 |
|
38 |
-
model, tokenizer = load_model_norm()
|
39 |
|
40 |
|
41 |
def generate_response(msg_prompt: str) -> dict:
|
@@ -192,3 +230,16 @@ async def get_response(thread_id: int):
|
|
192 |
response = thread['responses'][-1]
|
193 |
|
194 |
return {'response': response}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# Load the model and tokenizer
|
18 |
model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-GPTQ"
|
19 |
+
mistral_model="mistralai/Mistral-7B-Instruct-v0.2"
|
20 |
conversations = {}
|
21 |
+
device = "cuda" # the device to load the model onto
|
22 |
+
|
23 |
+
def mistral_model():
|
24 |
+
"""
|
25 |
+
Loads the GPT-3.5 model and tokenizer.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
tuple: A tuple containing the loaded model and tokenizer.
|
29 |
+
"""
|
30 |
+
if torch.cuda.is_available():
|
31 |
+
print("CUDA is available. GPU will be used.")
|
32 |
+
else:
|
33 |
+
print("CUDA is not available. CPU will be used.")
|
34 |
+
|
35 |
+
model = AutoModelForCausalLM.from_pretrained(mistral_chat)
|
36 |
+
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained(mistral_chat)
|
38 |
+
|
39 |
+
return model,tokenizer
|
40 |
+
|
41 |
+
model, tokenizer = mistral_model()
|
42 |
+
|
43 |
+
def mistral_generated_response(msg_prompt, persona_desc_prompt):
|
44 |
+
user_prompt = f'{msg_prompt} [/INST]'
|
45 |
+
persona_prompt = f'{persona_desc_prompt} [/INST]'
|
46 |
+
prompt_template = f'''### [INST] Instruction:{persona_prompt} [INST] {user_prompt}'''
|
47 |
+
|
48 |
+
encodeds = tokenizer.apply_chat_template(prompt_template, return_tensors="pt")
|
49 |
+
|
50 |
+
model_inputs = encodeds.to(device)
|
51 |
+
model.to(device)
|
52 |
+
generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
|
53 |
+
decoded = tokenizer.batch_decode(generated_ids)
|
54 |
+
response = (decoded[0])
|
55 |
+
return response
|
56 |
+
|
57 |
+
|
58 |
|
59 |
def load_model_norm():
|
60 |
"""
|
|
|
73 |
|
74 |
return model, tokenizer
|
75 |
|
76 |
+
#model, tokenizer = load_model_norm()
|
77 |
|
78 |
|
79 |
def generate_response(msg_prompt: str) -> dict:
|
|
|
230 |
response = thread['responses'][-1]
|
231 |
|
232 |
return {'response': response}
|
233 |
+
|
234 |
+
@app.post("/mistral_chat")
|
235 |
+
async def mistral_chat(prompt: dict):
|
236 |
+
try:
|
237 |
+
msg_prompt = prompt.get("msg_prompt")
|
238 |
+
persona_desc_prompt = prompt.get("persona_desc_prompt")
|
239 |
+
if not msg_prompt or not persona_desc_prompt:
|
240 |
+
return {"error": "msg_prompt and persona_desc_prompt are required fields."}
|
241 |
+
|
242 |
+
response = mistral_generated_response(msg_prompt, persona_desc_prompt)
|
243 |
+
return {"response": response, "prompt": {"msg_prompt": msg_prompt, "persona_desc_prompt": persona_desc_prompt}}
|
244 |
+
except Exception as e:
|
245 |
+
return {"error": str(e)}
|