Vitrous commited on
Commit
4aadb45
·
verified ·
1 Parent(s): 01f12ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -16
app.py CHANGED
@@ -21,7 +21,7 @@ conversations = {}
21
  device = "cuda" # the device to load the model onto
22
 
23
  def mistral_model():
24
- """
25
  Loads the GPT-3.5 model and tokenizer.
26
 
27
  Returns:
@@ -40,21 +40,6 @@ def mistral_model():
40
 
41
  model, tokenizer = mistral_model()
42
 
43
- def mistral_generated_response(msg_prompt, persona_desc_prompt):
44
- user_prompt = f'{msg_prompt} [/INST]'
45
- persona_prompt = f'{persona_desc_prompt} [/INST]'
46
- prompt_template = f'''### [INST] Instruction:{persona_prompt} [INST] {user_prompt}'''
47
-
48
- encodeds = tokenizer.apply_chat_template(prompt_template, return_tensors="pt")
49
-
50
- model_inputs = encodeds.to(device)
51
- model.to(device)
52
- generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
53
- decoded = tokenizer.batch_decode(generated_ids)
54
- response = (decoded[0])
55
- return response
56
-
57
-
58
 
59
  def load_model_norm():
60
  """
@@ -75,6 +60,19 @@ def load_model_norm():
75
 
76
  #model, tokenizer = load_model_norm()
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  def generate_response(msg_prompt: str) -> dict:
80
  """
 
21
  device = "cuda" # the device to load the model onto
22
 
23
  def mistral_model():
24
+ """
25
  Loads the GPT-3.5 model and tokenizer.
26
 
27
  Returns:
 
40
 
41
  model, tokenizer = mistral_model()
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def load_model_norm():
45
  """
 
60
 
61
  #model, tokenizer = load_model_norm()
62
 
63
+ def mistral_generated_response(msg_prompt, persona_desc_prompt):
64
+ user_prompt = f'{msg_prompt} [/INST]'
65
+ persona_prompt = f'{persona_desc_prompt} [/INST]'
66
+ prompt_template = f'''### [INST] Instruction:{persona_prompt} [INST] {user_prompt}'''
67
+
68
+ encodeds = tokenizer.apply_chat_template(prompt_template, return_tensors="pt")
69
+
70
+ model_inputs = encodeds.to(device)
71
+ model.to(device)
72
+ generated_ids = model.generate(model_inputs, max_new_tokens=1000, do_sample=True)
73
+ decoded = tokenizer.batch_decode(generated_ids)
74
+ response = (decoded[0])
75
+ return response
76
 
77
  def generate_response(msg_prompt: str) -> dict:
78
  """