davidberenstein1957 HF staff commited on
Commit
05c0b89
1 Parent(s): 2c31d1f

Update generate function

Browse files
Files changed (1) hide show
  1. app.py +34 -41
app.py CHANGED
@@ -1,5 +1,6 @@
1
  #!/usr/bin/env python
2
  import os
 
3
  from typing import Iterator
4
 
5
  import gradio as gr
@@ -14,14 +15,9 @@ DEFAULT_MAX_NEW_TOKENS = 1024
14
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
15
 
16
  if torch.cuda.is_available():
17
- # model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
18
- # model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
19
- # tokenizer = AutoTokenizer.from_pretrained(model_id)
20
- pass
21
-
22
- style = None
23
-
24
- AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
25
 
26
 
27
  @spaces.GPU
@@ -34,39 +30,36 @@ def generate(
34
  top_k: int = 40,
35
  repetition_penalty: float = 1.2,
36
  ) -> Iterator[str]:
37
- # conversation = []
38
- # for user, assistant in chat_history:
39
- # conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
40
- # conversation.append({"role": "user", "content": message})
41
-
42
- # input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
43
- # if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
44
- # input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
45
- # gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
46
- # input_ids = input_ids.to(model.device)
47
-
48
- # streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
49
- # generate_kwargs = dict(
50
- # {"input_ids": input_ids},
51
- # streamer=streamer,
52
- # max_new_tokens=max_new_tokens,
53
- # do_sample=True,
54
- # top_p=top_p,
55
- # top_k=top_k,
56
- # temperature=temperature,
57
- # num_beams=1,
58
- # repetition_penalty=repetition_penalty,
59
- # )
60
- # t = Thread(target=model.generate, kwargs=generate_kwargs)
61
- # t.start()
62
-
63
- # outputs = []
64
- # for text in streamer:
65
- # outputs.append(text)
66
- # yield "".join(outputs)
67
-
68
- for char in "help":
69
- yield "help"
70
 
71
 
72
  chat_interface = ChatInterface(
 
1
  #!/usr/bin/env python
2
  import os
3
+ from threading import Thread
4
  from typing import Iterator
5
 
6
  import gradio as gr
 
15
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
16
 
17
  if torch.cuda.is_available():
18
+ model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
19
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
20
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
 
 
21
 
22
 
23
  @spaces.GPU
 
30
  top_k: int = 40,
31
  repetition_penalty: float = 1.2,
32
  ) -> Iterator[str]:
33
+ conversation = []
34
+ for user, assistant in chat_history:
35
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
36
+ conversation.append({"role": "user", "content": message})
37
+
38
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
39
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
40
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
41
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
42
+ input_ids = input_ids.to(model.device)
43
+
44
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
45
+ generate_kwargs = dict(
46
+ {"input_ids": input_ids},
47
+ streamer=streamer,
48
+ max_new_tokens=max_new_tokens,
49
+ do_sample=True,
50
+ top_p=top_p,
51
+ top_k=top_k,
52
+ temperature=temperature,
53
+ num_beams=1,
54
+ repetition_penalty=repetition_penalty,
55
+ )
56
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
57
+ t.start()
58
+
59
+ outputs = []
60
+ for text in streamer:
61
+ outputs.append(text)
62
+ yield "".join(outputs)
 
 
 
63
 
64
 
65
  chat_interface = ChatInterface(