Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -5,10 +5,11 @@ import subprocess
|
|
5 |
import sys
|
6 |
|
7 |
# Install required packages
|
8 |
-
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "einops", "accelerate", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
|
9 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
10 |
|
11 |
-
from transformers import OlmoeForCausalLM, AutoTokenizer
|
|
|
12 |
|
13 |
model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
|
14 |
|
@@ -51,7 +52,7 @@ def generate_response(message, history, temperature, max_new_tokens):
|
|
51 |
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
|
52 |
|
53 |
try:
|
54 |
-
streamer =
|
55 |
generation_kwargs = dict(
|
56 |
inputs=inputs,
|
57 |
max_new_tokens=max_new_tokens,
|
@@ -61,7 +62,7 @@ def generate_response(message, history, temperature, max_new_tokens):
|
|
61 |
streamer=streamer
|
62 |
)
|
63 |
|
64 |
-
thread =
|
65 |
thread.start()
|
66 |
|
67 |
generated_text = ""
|
|
|
5 |
import sys
|
6 |
|
7 |
# Install required packages
|
8 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops", "accelerate", "torch", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
|
9 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
10 |
|
11 |
+
from transformers import OlmoeForCausalLM, AutoTokenizer, TextIteratorStreamer
|
12 |
+
from threading import Thread
|
13 |
|
14 |
model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
|
15 |
|
|
|
52 |
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
|
53 |
|
54 |
try:
|
55 |
+
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
|
56 |
generation_kwargs = dict(
|
57 |
inputs=inputs,
|
58 |
max_new_tokens=max_new_tokens,
|
|
|
62 |
streamer=streamer
|
63 |
)
|
64 |
|
65 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
66 |
thread.start()
|
67 |
|
68 |
generated_text = ""
|