Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -21,7 +21,7 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
21 |
# Defining a custom stopping criteria class for the model's text generation.
|
22 |
class StopOnTokens(StoppingCriteria):
|
23 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
24 |
-
stop_ids = [
|
25 |
for stop_id in stop_ids:
|
26 |
if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
|
27 |
return True
|
@@ -36,6 +36,7 @@ def predict(message, history):
|
|
36 |
# Formatting the input for the model.
|
37 |
system_prompt = "<|im_start|>system\nYou are Phixtral, a helpful AI assistant.<|im_end|>"
|
38 |
messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
|
|
|
39 |
input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
|
40 |
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
|
41 |
generate_kwargs = dict(
|
|
|
21 |
# Defining a custom stopping criteria class for the model's text generation.
|
22 |
class StopOnTokens(StoppingCriteria):
|
23 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
24 |
+
stop_ids = [50256, 50295] # IDs of tokens where the generation should stop.
|
25 |
for stop_id in stop_ids:
|
26 |
if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
|
27 |
return True
|
|
|
36 |
# Formatting the input for the model.
|
37 |
system_prompt = "<|im_start|>system\nYou are Phixtral, a helpful AI assistant.<|im_end|>"
|
38 |
messages = system_prompt + "".join(["".join(["\n<|im_start|>user\n" + item[0], "<|im_end|>\n<|im_start|>assistant\n" + item[1]]) for item in history_transformer_format])
|
39 |
+
print(messages)
|
40 |
input_ids = tokenizer([messages], return_tensors="pt").to('cuda')
|
41 |
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
|
42 |
generate_kwargs = dict(
|