nisten commited on
Commit
a622fef
1 Parent(s): 5598c41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -5,10 +5,11 @@ import subprocess
5
  import sys
6
 
7
  # Install required packages
8
- subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "einops", "accelerate", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
9
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
- from transformers import OlmoeForCausalLM, AutoTokenizer
 
12
 
13
  model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
14
 
@@ -51,7 +52,7 @@ def generate_response(message, history, temperature, max_new_tokens):
51
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
52
 
53
  try:
54
- streamer = gr.TextIteratorStreamer(tokenizer, skip_special_tokens=True)
55
  generation_kwargs = dict(
56
  inputs=inputs,
57
  max_new_tokens=max_new_tokens,
@@ -61,7 +62,7 @@ def generate_response(message, history, temperature, max_new_tokens):
61
  streamer=streamer
62
  )
63
 
64
- thread = torch.multiprocessing.Process(target=model.generate, kwargs=generation_kwargs)
65
  thread.start()
66
 
67
  generated_text = ""
 
5
  import sys
6
 
7
  # Install required packages
8
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--force-reinstall", "--no-deps", "einops", "accelerate", "torch", "git+https://github.com/Muennighoff/transformers.git@olmoe"])
9
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
10
 
11
+ from transformers import OlmoeForCausalLM, AutoTokenizer, TextIteratorStreamer
12
+ from threading import Thread
13
 
14
  model_name = "allenai/OLMoE-1B-7B-0924-Instruct"
15
 
 
52
  inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(DEVICE)
53
 
54
  try:
55
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
56
  generation_kwargs = dict(
57
  inputs=inputs,
58
  max_new_tokens=max_new_tokens,
 
62
  streamer=streamer
63
  )
64
 
65
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
66
  thread.start()
67
 
68
  generated_text = ""