crang commited on
Commit
c4bbe12
·
verified ·
1 Parent(s): 1970af8
Files changed (1) hide show
  1. app.py +55 -76
app.py CHANGED
@@ -1,94 +1,75 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
 
 
 
4
  import os
5
  from threading import Thread
6
  import spaces
7
  import time
8
- import subprocess
9
 
10
- MIN_TOKENS=128
11
- MAX_TOKENS=8192
12
- DEFAULT_TOKENS=2048
13
- DURATION=60
14
 
15
- # Install flash attention
16
- subprocess.run(
17
- "pip install flash-attn --no-build-isolation",
18
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
19
- shell=True,
20
  )
21
 
22
- # Load model and tokenizer once when the app starts
23
- model_token = os.environ["HF_TOKEN"]
24
  model = AutoModelForCausalLM.from_pretrained(
25
- "microsoft/Phi-3-mini-128k-instruct",
26
- token=model_token,
27
- trust_remote_code=True,
28
  )
29
- tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", token=model_token)
30
-
31
- # Set device (GPU or CPU)
32
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
- model.to(device)
34
-
35
- # Define error handling function
36
- def handle_error(error):
37
- return {"error": str(error)}
38
-
39
- # Define chat function with input validation and error handling
40
- @spaces.GPU(duration=DURATION)
41
- def chat(message, history, temperature, do_sample, max_tokens):
42
- try:
43
- # Validate input
44
- if not message:
45
- raise ValueError("Please enter a message")
46
- if temperature < 0 or temperature > 1:
47
- raise ValueError("Temperature must be between 0 and 1")
48
- if max_tokens < MIN_TOKENS or max_tokens > MAX_TOKENS:
49
- raise ValueError(f"Max tokens must be between {MIN_TOKENS} and {MAX_TOKENS}")
50
 
51
- # Prepare chat history
52
- chat = []
53
- for item in history:
54
- chat.append({"role": "user", "content": item[0]})
55
- if item[1] is not None:
56
- chat.append({"role": "assistant", "content": item[1]})
57
- chat.append({"role": "user", "content": message})
58
 
59
- # Generate response
60
- messages = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
61
- model_inputs = tokenizer([messages], return_tensors="pt").to(device)
62
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
63
- generate_kwargs = dict(
64
- model_inputs,
65
- streamer=streamer,
66
- max_new_tokens=max_tokens,
67
- do_sample=do_sample,
68
- temperature=temperature,
69
- eos_token_id=[tokenizer.eos_token_id],
70
- )
71
 
72
- # Generate response
73
- t = Thread(target=model.generate, kwargs=generate_kwargs)
74
- t.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- # Yield partial responses
77
- partial_text = ""
78
- for new_text in streamer:
79
- partial_text += new_text
80
- yield partial_text
81
-
82
- # Yield final response
83
  yield partial_text
84
 
85
- except Exception as e:
86
- yield handle_error(e)
 
87
 
88
- # Create Gradio interface
89
  demo = gr.ChatInterface(
90
  fn=chat,
91
  examples=[["Write me a poem about Machine Learning."]],
 
92
  additional_inputs_accordion=gr.Accordion(
93
  label="⚙️ Parameters", open=False, render=False
94
  ),
@@ -96,20 +77,18 @@ demo = gr.ChatInterface(
96
  gr.Slider(
97
  minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
98
  ),
99
- gr.Checkbox(label="Sampling", value=True),
100
  gr.Slider(
101
- minimum=MIN_TOKENS,
102
- maximum=MAX_TOKENS,
103
  step=1,
104
- value=DEFAULT_TOKENS,
105
  label="Max new tokens",
106
  render=False,
107
  ),
108
  ],
109
  stop_btn="Stop Generation",
110
  title="Chat With LLMs",
111
- description="Now Running [microsoft/Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct)",
112
  )
113
-
114
- # Launch Gradio app
115
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import (
4
+ AutoModelForCausalLM,
5
+ AutoTokenizer,
6
+ TextIteratorStreamer,
7
+ BitsAndBytesConfig,
8
+ )
9
  import os
10
  from threading import Thread
11
  import spaces
12
  import time
 
13
 
14
+ token = os.environ["HF_TOKEN"]
 
 
 
15
 
16
+ quantization_config = BitsAndBytesConfig(
17
+ load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
 
 
 
18
  )
19
 
 
 
20
  model = AutoModelForCausalLM.from_pretrained(
21
+ "microsoft/Phi-3-mini-128k-instruct", quantization_config=quantization_config, token=token
 
 
22
  )
23
+ tok = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", token=token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
25
 
26
+ if torch.cuda.is_available():
27
+ device = torch.device("cuda")
28
+ print(f"Using GPU: {torch.cuda.get_device_name(device)}")
29
+ else:
30
+ device = torch.device("cpu")
31
+ print("Using CPU")
 
 
 
 
 
 
32
 
33
+ @spaces.GPU(duration=150)
34
+ def chat(message, history, temperature,do_sample, max_tokens):
35
+ chat = []
36
+ for item in history:
37
+ chat.append({"role": "user", "content": item[0]})
38
+ if item[1] is not None:
39
+ chat.append({"role": "assistant", "content": item[1]})
40
+ chat.append({"role": "user", "content": message})
41
+ messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
42
+ model_inputs = tok([messages], return_tensors="pt").to(device)
43
+ streamer = TextIteratorStreamer(
44
+ tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
45
+ )
46
+ generate_kwargs = dict(
47
+ model_inputs,
48
+ streamer=streamer,
49
+ max_new_tokens=max_tokens,
50
+ do_sample=True,
51
+ temperature=temperature,
52
+ )
53
+
54
+ if temperature == 0:
55
+ generate_kwargs['do_sample'] = False
56
+
57
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
58
+ t.start()
59
 
60
+ partial_text = ""
61
+ for new_text in streamer:
62
+ partial_text += new_text
 
 
 
 
63
  yield partial_text
64
 
65
+ tokens = len(tok.tokenize(partial_text))
66
+ yield partial_text
67
+
68
 
 
69
  demo = gr.ChatInterface(
70
  fn=chat,
71
  examples=[["Write me a poem about Machine Learning."]],
72
+ # multimodal=False,
73
  additional_inputs_accordion=gr.Accordion(
74
  label="⚙️ Parameters", open=False, render=False
75
  ),
 
77
  gr.Slider(
78
  minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
79
  ),
80
+ gr.Checkbox(label="Sampling",value=True),
81
  gr.Slider(
82
+ minimum=128,
83
+ maximum=4096,
84
  step=1,
85
+ value=512,
86
  label="Max new tokens",
87
  render=False,
88
  ),
89
  ],
90
  stop_btn="Stop Generation",
91
  title="Chat With LLMs",
92
+ description="Now Running [Microsoft Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct) in 4bit"
93
  )
 
 
94
  demo.launch()