mlabonne commited on
Commit
209aee6
·
1 Parent(s): 044264a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -21,7 +21,7 @@ model = AutoModelForCausalLM.from_pretrained(
21
  # Defining a custom stopping criteria class for the model's text generation.
22
  class StopOnTokens(StoppingCriteria):
23
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
24
- stop_ids = [2] # IDs of tokens where the generation should stop.
25
  for stop_id in stop_ids:
26
  if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
27
  return True
@@ -36,6 +36,7 @@ def predict(message, history):
36
  # Formatting the input for the model.
37
  system_prompt = "<|im_start|>system\nYou are Phixtral, a helpful AI assistant.<|im_end|>"
38
  messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
 
39
  input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
40
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
41
  generate_kwargs = dict(
 
21
  # Defining a custom stopping criteria class for the model's text generation.
22
  class StopOnTokens(StoppingCriteria):
23
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
24
+ stop_ids = [50256, 50295] # IDs of tokens where the generation should stop.
25
  for stop_id in stop_ids:
26
  if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
27
  return True
 
36
  # Formatting the input for the model.
37
  system_prompt = "<|im_start|>system\nYou are Phixtral, a helpful AI assistant.<|im_end|>"
38
  messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
39
+ print(messages)
40
  input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
41
  streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
42
  generate_kwargs = dict(