Update app.py
Browse files
app.py
CHANGED
@@ -21,7 +21,7 @@ conversations = {}
|
|
21 |
device = "cuda" # the device to load the model onto
|
22 |
|
23 |
def mistral_model():
|
24 |
-
|
25 |
Loads the GPT-3.5 model and tokenizer.
|
26 |
|
27 |
Returns:
|
@@ -40,21 +40,6 @@ def mistral_model():
|
|
40 |
|
41 |
model, tokenizer = mistral_model()
|
42 |
|
43 |
-
def mistral_generated_response(msg_prompt, persona_desc_prompt):
|
44 |
-
user_prompt = f'{msg_prompt} [/INST]'
|
45 |
-
persona_prompt = f'{persona_desc_prompt} [/INST]'
|
46 |
-
prompt_template = f'''### [INST] Instruction:{persona_prompt} [INST] {user_prompt}'''
|
47 |
-
|
48 |
-
encodeds = tokenizer.apply_chat_template(prompt_template, return_tensors="pt")
|
49 |
-
|
50 |
-
model_inputs = encodeds.to(device)
|
51 |
-
model.to(device)
|
52 |
-
generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
|
53 |
-
decoded = tokenizer.batch_decode(generated_ids)
|
54 |
-
response = (decoded[0])
|
55 |
-
return response
|
56 |
-
|
57 |
-
|
58 |
|
59 |
def load_model_norm():
|
60 |
"""
|
@@ -75,6 +60,19 @@ def load_model_norm():
|
|
75 |
|
76 |
#model, tokenizer = load_model_norm()
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
def generate_response(msg_prompt: str) -> dict:
|
80 |
"""
|
|
|
21 |
device = "cuda" # the device to load the model onto
|
22 |
|
23 |
def mistral_model():
|
24 |
+
"""
|
25 |
Loads the GPT-3.5 model and tokenizer.
|
26 |
|
27 |
Returns:
|
|
|
40 |
|
41 |
model, tokenizer = mistral_model()
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def load_model_norm():
|
45 |
"""
|
|
|
60 |
|
61 |
#model, tokenizer = load_model_norm()
|
62 |
|
63 |
+
def mistral_generated_response(msg_prompt, persona_desc_prompt):
|
64 |
+
user_prompt = f'{msg_prompt} [/INST]'
|
65 |
+
persona_prompt = f'{persona_desc_prompt} [/INST]'
|
66 |
+
prompt_template = f'''### [INST] Instruction:{persona_prompt} [INST] {user_prompt}'''
|
67 |
+
|
68 |
+
encodeds = tokenizer.apply_chat_template(prompt_template, return_tensors="pt")
|
69 |
+
|
70 |
+
model_inputs = encodeds.to(device)
|
71 |
+
model.to(device)
|
72 |
+
generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
|
73 |
+
decoded = tokenizer.batch_decode(generated_ids)
|
74 |
+
response = (decoded[0])
|
75 |
+
return response
|
76 |
|
77 |
def generate_response(msg_prompt: str) -> dict:
|
78 |
"""
|