mgoin commited on
Commit
f51b330
1 Parent(s): e3233e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
 
2
 
3
  import gradio as gr
4
- import spaces
5
  import torch
6
  from transformers import AutoTokenizer
7
  from vllm import LLM, SamplingParams
@@ -15,18 +16,15 @@ DESCRIPTION = """\
15
  """
16
 
17
  if not torch.cuda.is_available():
18
- DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
19
 
 
 
 
 
20
 
21
- if torch.cuda.is_available():
22
- model_id = "nm-testing/OpenHermes-2.5-Mistral-7B-pruned50"
23
- model = LLM(model_id, max_model_len=MAX_INPUT_TOKEN_LENGTH)
24
- tokenizer = AutoTokenizer.from_pretrained(model_id)
25
- tokenizer.use_default_system_prompt = False
26
-
27
-
28
- @spaces.GPU
29
- def generate(
30
  message: str,
31
  chat_history: list[tuple[str, str]],
32
  system_prompt: str,
@@ -35,7 +33,7 @@ def generate(
35
  top_p: float = 0.9,
36
  top_k: int = 50,
37
  repetition_penalty: float = 1.2,
38
- ) -> str:
39
  conversation = []
40
  if system_prompt:
41
  conversation.append({"role": "system", "content": system_prompt})
@@ -53,11 +51,11 @@ def generate(
53
  repetition_penalty=repetition_penalty,
54
  )
55
 
56
- outputs = model.generate(formatted_conversation, sampling_params)
57
-
58
- for output in outputs:
59
- generated_text = output.outputs[0].text
60
- return generated_text
61
 
62
 
63
  chat_interface = gr.ChatInterface(
 
1
  import os
2
+ import uuid
3
 
4
  import gradio as gr
5
+ # import spaces
6
  import torch
7
  from transformers import AutoTokenizer
8
  from vllm import LLM, SamplingParams
 
16
  """
17
 
18
  if not torch.cuda.is_available():
19
+ raise ValueError("Running on CPU 🥶 This demo does not work on CPU.")
20
 
21
+ model_id = "neuralmagic/OpenHermes-2.5-Mistral-7B-pruned50"
22
+ model = LLM(model_id, max_model_len=MAX_INPUT_TOKEN_LENGTH)
23
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
24
+ tokenizer.use_default_system_prompt = False
25
 
26
+ # @spaces.GPU
27
+ async def generate(
 
 
 
 
 
 
 
28
  message: str,
29
  chat_history: list[tuple[str, str]],
30
  system_prompt: str,
 
33
  top_p: float = 0.9,
34
  top_k: int = 50,
35
  repetition_penalty: float = 1.2,
36
+ ):
37
  conversation = []
38
  if system_prompt:
39
  conversation.append({"role": "system", "content": system_prompt})
 
51
  repetition_penalty=repetition_penalty,
52
  )
53
 
54
+ stream = await model.add_request(uuid.uuid4().hex, formatted_conversation, sampling_params)
55
+
56
+ async for request_output in stream:
57
+ text = request_output.outputs[0].text
58
+ yield text
59
 
60
 
61
  chat_interface = gr.ChatInterface(