Spaces:
Runtime error
Runtime error
harsh-manvar
commited on
Commit
•
6b68879
1
Parent(s):
721d85a
Update model.py
Browse files
model.py
CHANGED
@@ -20,29 +20,28 @@ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
|
|
20 |
|
21 |
def get_prompt(message: str, chat_history: list[tuple[str, str]],
|
22 |
system_prompt: str) -> str:
|
23 |
-
logger.info("get_prompt chat_history=%s",chat_history)
|
24 |
-
logger.info("get_prompt system_prompt=%s",system_prompt)
|
25 |
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
|
26 |
-
logger.info("texts=%s",texts)
|
27 |
-
# The first user input is _not_ stripped
|
28 |
do_strip = False
|
29 |
for user_input, response in chat_history:
|
30 |
user_input = user_input.strip() if do_strip else user_input
|
31 |
do_strip = True
|
32 |
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
|
33 |
message = message.strip() if do_strip else message
|
34 |
-
logger.info("get_prompt message=%s",message)
|
35 |
texts.append(f'{message} [/INST]')
|
36 |
-
logger.info("get_prompt final texts=%s",texts)
|
37 |
return ''.join(texts)
|
38 |
|
39 |
|
40 |
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
|
41 |
-
logger.info("get_input_token_length=%s",message)
|
42 |
prompt = get_prompt(message, chat_history, system_prompt)
|
43 |
-
logger.info("prompt=%s",prompt)
|
44 |
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
|
45 |
-
logger.info("input_ids=%s",input_ids)
|
46 |
return input_ids.shape[-1]
|
47 |
|
48 |
|
@@ -75,6 +74,5 @@ def run(message: str,
|
|
75 |
|
76 |
outputs = []
|
77 |
for text in streamer:
|
78 |
-
logger.info("outputs", outputs)
|
79 |
outputs.append(text)
|
80 |
yield "".join(outputs)
|
|
|
20 |
|
21 |
def get_prompt(message: str, chat_history: list[tuple[str, str]],
|
22 |
system_prompt: str) -> str:
|
23 |
+
#logger.info("get_prompt chat_history=%s",chat_history)
|
24 |
+
#logger.info("get_prompt system_prompt=%s",system_prompt)
|
25 |
texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
|
26 |
+
#logger.info("texts=%s",texts)
|
|
|
27 |
do_strip = False
|
28 |
for user_input, response in chat_history:
|
29 |
user_input = user_input.strip() if do_strip else user_input
|
30 |
do_strip = True
|
31 |
texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
|
32 |
message = message.strip() if do_strip else message
|
33 |
+
#logger.info("get_prompt message=%s",message)
|
34 |
texts.append(f'{message} [/INST]')
|
35 |
+
#logger.info("get_prompt final texts=%s",texts)
|
36 |
return ''.join(texts)
|
37 |
|
38 |
|
39 |
def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
|
40 |
+
#logger.info("get_input_token_length=%s",message)
|
41 |
prompt = get_prompt(message, chat_history, system_prompt)
|
42 |
+
#logger.info("prompt=%s",prompt)
|
43 |
input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
|
44 |
+
#logger.info("input_ids=%s",input_ids)
|
45 |
return input_ids.shape[-1]
|
46 |
|
47 |
|
|
|
74 |
|
75 |
outputs = []
|
76 |
for text in streamer:
|
|
|
77 |
outputs.append(text)
|
78 |
yield "".join(outputs)
|