harsh-manvar commited on
Commit
6b68879
1 Parent(s): 721d85a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +8 -10
model.py CHANGED
@@ -20,29 +20,28 @@ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
20
 
21
  def get_prompt(message: str, chat_history: list[tuple[str, str]],
22
  system_prompt: str) -> str:
23
- logger.info("get_prompt chat_history=%s",chat_history)
24
- logger.info("get_prompt system_prompt=%s",system_prompt)
25
  texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
26
- logger.info("texts=%s",texts)
27
- # The first user input is _not_ stripped
28
  do_strip = False
29
  for user_input, response in chat_history:
30
  user_input = user_input.strip() if do_strip else user_input
31
  do_strip = True
32
  texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
33
  message = message.strip() if do_strip else message
34
- logger.info("get_prompt message=%s",message)
35
  texts.append(f'{message} [/INST]')
36
- logger.info("get_prompt final texts=%s",texts)
37
  return ''.join(texts)
38
 
39
 
40
  def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
41
- logger.info("get_input_token_length=%s",message)
42
  prompt = get_prompt(message, chat_history, system_prompt)
43
- logger.info("prompt=%s",prompt)
44
  input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
45
- logger.info("input_ids=%s",input_ids)
46
  return input_ids.shape[-1]
47
 
48
 
@@ -75,6 +74,5 @@ def run(message: str,
75
 
76
  outputs = []
77
  for text in streamer:
78
- logger.info("outputs", outputs)
79
  outputs.append(text)
80
  yield "".join(outputs)
 
20
 
21
  def get_prompt(message: str, chat_history: list[tuple[str, str]],
22
  system_prompt: str) -> str:
23
+ #logger.info("get_prompt chat_history=%s",chat_history)
24
+ #logger.info("get_prompt system_prompt=%s",system_prompt)
25
  texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
26
+ #logger.info("texts=%s",texts)
 
27
  do_strip = False
28
  for user_input, response in chat_history:
29
  user_input = user_input.strip() if do_strip else user_input
30
  do_strip = True
31
  texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
32
  message = message.strip() if do_strip else message
33
+ #logger.info("get_prompt message=%s",message)
34
  texts.append(f'{message} [/INST]')
35
+ #logger.info("get_prompt final texts=%s",texts)
36
  return ''.join(texts)
37
 
38
 
39
  def get_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> int:
40
+ #logger.info("get_input_token_length=%s",message)
41
  prompt = get_prompt(message, chat_history, system_prompt)
42
+ #logger.info("prompt=%s",prompt)
43
  input_ids = tokenizer([prompt], return_tensors='np', add_special_tokens=False)['input_ids']
44
+ #logger.info("input_ids=%s",input_ids)
45
  return input_ids.shape[-1]
46
 
47
 
 
74
 
75
  outputs = []
76
  for text in streamer:
 
77
  outputs.append(text)
78
  yield "".join(outputs)