crang commited on
Commit
1d896f1
·
1 Parent(s): 34fad06

add application file

Browse files
Files changed (1) hide show
  1. app.py +115 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ import os
5
+ from threading import Thread
6
+ import spaces
7
+ import time
8
+ import subprocess
9
+
10
+ MIN_TOKENS=128
11
+ MAX_TOKENS=8192
12
+ DEFAULT_TOKENS=2048
13
+ DURATION=60
14
+
15
+ # Install flash attention
16
+ subprocess.run(
17
+ "pip install flash-attn --no-build-isolation",
18
+ env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
19
+ shell=True,
20
+ )
21
+
22
+ # Load model and tokenizer once when the app starts
23
+ model_token = os.environ["HF_TOKEN"]
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ "microsoft/Phi-3-mini-128k-instruct",
26
+ token=model_token,
27
+ trust_remote_code=True,
28
+ )
29
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", token=model_token)
30
+
31
+ # Set device (GPU or CPU)
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ model.to(device)
34
+
35
+ # Define error handling function
36
+ def handle_error(error):
37
+ return {"error": str(error)}
38
+
39
+ # Define chat function with input validation and error handling
40
+ @spaces.GPU(duration=DURATION)
41
+ def chat(message, history, temperature, do_sample, max_tokens):
42
+ try:
43
+ # Validate input
44
+ if not message:
45
+ raise ValueError("Please enter a message")
46
+ if temperature < 0 or temperature > 1:
47
+ raise ValueError("Temperature must be between 0 and 1")
48
+ if max_tokens < MIN_TOKENS or max_tokens > MAX_TOKENS:
49
+ raise ValueError(f"Max tokens must be between {MIN_TOKENS} and {MAX_TOKENS}")
50
+
51
+ # Prepare chat history
52
+ chat = []
53
+ for item in history:
54
+ chat.append({"role": "user", "content": item[0]})
55
+ if item[1] is not None:
56
+ chat.append({"role": "assistant", "content": item[1]})
57
+ chat.append({"role": "user", "content": message})
58
+
59
+ # Generate response
60
+ messages = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
61
+ model_inputs = tokenizer([messages], return_tensors="pt").to(device)
62
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
63
+ generate_kwargs = dict(
64
+ model_inputs,
65
+ streamer=streamer,
66
+ max_new_tokens=max_tokens,
67
+ do_sample=do_sample,
68
+ temperature=temperature,
69
+ eos_token_id=[tokenizer.eos_token_id],
70
+ )
71
+
72
+ # Generate response
73
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
74
+ t.start()
75
+
76
+ # Yield partial responses
77
+ partial_text = ""
78
+ for new_text in streamer:
79
+ partial_text += new_text
80
+ yield partial_text
81
+
82
+ # Yield final response
83
+ yield partial_text
84
+
85
+ except Exception as e:
86
+ yield handle_error(e)
87
+
88
+ # Create Gradio interface
89
+ demo = gr.ChatInterface(
90
+ fn=chat,
91
+ examples=[["Write me a poem about Machine Learning."]],
92
+ additional_inputs_accordion=gr.Accordion(
93
+ label="⚙️ Parameters", open=False, render=False
94
+ ),
95
+ additional_inputs=[
96
+ gr.Slider(
97
+ minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
98
+ ),
99
+ gr.Checkbox(label="Sampling", value=True),
100
+ gr.Slider(
101
+ minimum=MIN_TOKENS,
102
+ maximum=MAX_TOKENS,
103
+ step=1,
104
+ value=DEFAULT_TOKENS,
105
+ label="Max new tokens",
106
+ render=False,
107
+ ),
108
+ ],
109
+ stop_btn="Stop Generation",
110
+ title="Chat With LLMs",
111
+ description="Now Running [microsoft/Phi-3-mini-128k-instruct](https://huggingface.co/microsoft/Phi-3-mini-128k-instruct)",
112
+ )
113
+
114
+ # Launch Gradio app
115
+ demo.launch()