cutechicken commited on
Commit
85ff42c
ยท
verified ยท
1 Parent(s): 6360699

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -100
app.py CHANGED
@@ -6,11 +6,6 @@ import os
6
  from threading import Thread
7
  import random
8
  from datasets import load_dataset
9
- import gc
10
-
11
- # GPU ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
12
- torch.cuda.empty_cache()
13
- gc.collect()
14
 
15
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
16
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
@@ -37,97 +32,58 @@ h3 {
37
  }
38
  """
39
 
40
- # ๋””๋ฐ”์ด์Šค ์„ค์ •
41
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
-
43
- # ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ with ์—๋Ÿฌ ์ฒ˜๋ฆฌ
44
- try:
45
- model = AutoModelForCausalLM.from_pretrained(
46
- MODEL_ID,
47
- torch_dtype=torch.bfloat16,
48
- device_map="auto",
49
- low_cpu_mem_usage=True,
50
- )
51
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
52
- except Exception as e:
53
- print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
54
- raise
55
 
56
- # ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ with ์—๋Ÿฌ ์ฒ˜๋ฆฌ
57
- try:
58
- dataset = load_dataset("elyza/ELYZA-tasks-100")
59
- print(dataset)
60
 
61
- split_name = "train" if "train" in dataset else "test"
62
- examples_list = list(dataset[split_name])
63
- examples = random.sample(examples_list, 50)
64
- example_inputs = [[example['input']] for example in examples]
65
- except Exception as e:
66
- print(f"๋ฐ์ดํ„ฐ์…‹ ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
67
- examples = []
68
- example_inputs = []
69
 
70
- def error_handler(func):
71
- def wrapper(*args, **kwargs):
72
- try:
73
- return func(*args, **kwargs)
74
- except Exception as e:
75
- print(f"Error in {func.__name__}: {str(e)}")
76
- return "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค. ์ž ์‹œ ํ›„ ๋‹ค์‹œ ์‹œ๋„ํ•ด์ฃผ์„ธ์š”."
77
- return wrapper
78
-
79
- @error_handler
80
  @spaces.GPU
81
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
82
- try:
83
- print(f'message is - {message}')
84
- print(f'history is - {history}')
85
-
86
- # GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
87
- torch.cuda.empty_cache()
88
-
89
- conversation = []
90
- for prompt, answer in history:
91
- conversation.extend([
92
- {"role": "user", "content": prompt},
93
- {"role": "assistant", "content": answer}
94
- ])
95
- conversation.append({"role": "user", "content": message})
96
-
97
- input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
98
- inputs = tokenizer(input_ids, return_tensors="pt").to(device)
99
-
100
- streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
101
-
102
- generate_kwargs = dict(
103
- inputs,
104
- streamer=streamer,
105
- top_k=top_k,
106
- top_p=top_p,
107
- repetition_penalty=penalty,
108
- max_new_tokens=max_new_tokens,
109
- do_sample=True,
110
- temperature=temperature,
111
- eos_token_id=[255001],
112
- )
113
-
114
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
115
- thread.start()
116
 
117
- buffer = ""
118
- for new_text in streamer:
119
- buffer += new_text
120
- yield buffer
121
-
122
- except Exception as e:
123
- print(f"Stream chat error: {str(e)}")
124
- yield "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์‘๋‹ต ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค."
125
- finally:
126
- # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
127
- torch.cuda.empty_cache()
128
- gc.collect()
129
 
130
- chatbot = gr.Chatbot(height=500)
131
 
132
  CSS = """
133
  /* ์ „์ฒด ํŽ˜์ด์ง€ ์Šคํƒ€์ผ๋ง */
@@ -136,7 +92,6 @@ body {
136
  min-height: 100vh;
137
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
138
  }
139
-
140
  /* ๋ฉ”์ธ ์ปจํ…Œ์ด๋„ˆ */
141
  .container {
142
  max-width: 1200px;
@@ -149,7 +104,6 @@ body {
149
  transform: perspective(1000px) translateZ(0);
150
  transition: all 0.3s ease;
151
  }
152
-
153
  /* ์ œ๋ชฉ ์Šคํƒ€์ผ๋ง */
154
  h1 {
155
  color: #2d3436;
@@ -159,14 +113,12 @@ h1 {
159
  text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.1);
160
  transform: perspective(1000px) translateZ(20px);
161
  }
162
-
163
  h3 {
164
  text-align: center;
165
  color: #2d3436;
166
  font-size: 1.5rem;
167
  margin: 1rem 0;
168
  }
169
-
170
  /* ์ฑ„ํŒ…๋ฐ•์Šค ์Šคํƒ€์ผ๋ง */
171
  .chatbox {
172
  background: white;
@@ -179,7 +131,6 @@ h3 {
179
  transform: translateZ(0);
180
  transition: all 0.3s ease;
181
  }
182
-
183
  /* ๋ฉ”์‹œ์ง€ ์Šคํƒ€์ผ๋ง */
184
  .chatbox .messages .message.user {
185
  background: linear-gradient(145deg, #e1f5fe, #bbdefb);
@@ -190,7 +141,6 @@ h3 {
190
  transform: translateZ(10px);
191
  animation: messageIn 0.3s ease-out;
192
  }
193
-
194
  .chatbox .messages .message.bot {
195
  background: linear-gradient(145deg, #f5f5f5, #eeeeee);
196
  border-radius: 15px;
@@ -200,7 +150,6 @@ h3 {
200
  transform: translateZ(10px);
201
  animation: messageIn 0.3s ease-out;
202
  }
203
-
204
  /* ๋ฒ„ํŠผ ์Šคํƒ€์ผ๋ง */
205
  .duplicate-button {
206
  background: linear-gradient(145deg, #24292e, #1a1e22) !important;
@@ -212,12 +161,10 @@ h3 {
212
  border: none !important;
213
  cursor: pointer !important;
214
  }
215
-
216
  .duplicate-button:hover {
217
  transform: translateY(-2px) !important;
218
  box-shadow: 0 5px 15px rgba(0, 0, 0, 0.3) !important;
219
  }
220
-
221
  /* ์ž…๋ ฅ ํ•„๋“œ ์Šคํƒ€์ผ๋ง */
222
  """
223
 
@@ -228,13 +175,13 @@ with gr.Blocks(css=CSS) as demo:
228
  chatbot=chatbot,
229
  fill_height=True,
230
  theme="soft",
231
- additional_inputs_accordion=gr.Accordion(label="โš™๏ธ ์˜ต์…˜", open=False, render=False),
232
  additional_inputs=[
233
  gr.Slider(
234
  minimum=0,
235
  maximum=1,
236
  step=0.1,
237
- value=0.3,
238
  label="์˜จ๋„",
239
  render=False,
240
  ),
 
6
  from threading import Thread
7
  import random
8
  from datasets import load_dataset
 
 
 
 
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
  MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
 
32
  }
33
  """
34
 
35
+ # ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
36
+ model = AutoModelForCausalLM.from_pretrained(
37
+ MODEL_ID,
38
+ torch_dtype=torch.bfloat16,
39
+ device_map="auto",
40
+ )
41
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
 
42
 
43
+ # ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
44
+ dataset = load_dataset("elyza/ELYZA-tasks-100")
45
+ print(dataset)
 
46
 
47
+ split_name = "train" if "train" in dataset else "test"
48
+ examples_list = list(dataset[split_name])
49
+ examples = random.sample(examples_list, 50)
50
+ example_inputs = [[example['input']] for example in examples]
 
 
 
 
51
 
 
 
 
 
 
 
 
 
 
 
52
  @spaces.GPU
53
  def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
54
+ print(f'message is - {message}')
55
+ print(f'history is - {history}')
56
+ conversation = []
57
+ for prompt, answer in history:
58
+ conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
59
+ conversation.append({"role": "user", "content": message})
60
+
61
+ input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
62
+ inputs = tokenizer(input_ids, return_tensors="pt").to(0)
63
+
64
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
65
+
66
+ generate_kwargs = dict(
67
+ inputs,
68
+ streamer=streamer,
69
+ top_k=top_k,
70
+ top_p=top_p,
71
+ repetition_penalty=penalty,
72
+ max_new_tokens=max_new_tokens,
73
+ do_sample=True,
74
+ temperature=temperature,
75
+ eos_token_id=[255001],
76
+ )
77
+
78
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
79
+ thread.start()
 
 
 
 
 
 
 
 
80
 
81
+ buffer = ""
82
+ for new_text in streamer:
83
+ buffer += new_text
84
+ yield buffer
 
 
 
 
 
 
 
 
85
 
86
+ chatbot = gr.Chatbot(height=500)
87
 
88
  CSS = """
89
  /* ์ „์ฒด ํŽ˜์ด์ง€ ์Šคํƒ€์ผ๋ง */
 
92
  min-height: 100vh;
93
  font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
94
  }
 
95
  /* ๋ฉ”์ธ ์ปจํ…Œ์ด๋„ˆ */
96
  .container {
97
  max-width: 1200px;
 
104
  transform: perspective(1000px) translateZ(0);
105
  transition: all 0.3s ease;
106
  }
 
107
  /* ์ œ๋ชฉ ์Šคํƒ€์ผ๋ง */
108
  h1 {
109
  color: #2d3436;
 
113
  text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.1);
114
  transform: perspective(1000px) translateZ(20px);
115
  }
 
116
  h3 {
117
  text-align: center;
118
  color: #2d3436;
119
  font-size: 1.5rem;
120
  margin: 1rem 0;
121
  }
 
122
  /* ์ฑ„ํŒ…๋ฐ•์Šค ์Šคํƒ€์ผ๋ง */
123
  .chatbox {
124
  background: white;
 
131
  transform: translateZ(0);
132
  transition: all 0.3s ease;
133
  }
 
134
  /* ๋ฉ”์‹œ์ง€ ์Šคํƒ€์ผ๋ง */
135
  .chatbox .messages .message.user {
136
  background: linear-gradient(145deg, #e1f5fe, #bbdefb);
 
141
  transform: translateZ(10px);
142
  animation: messageIn 0.3s ease-out;
143
  }
 
144
  .chatbox .messages .message.bot {
145
  background: linear-gradient(145deg, #f5f5f5, #eeeeee);
146
  border-radius: 15px;
 
150
  transform: translateZ(10px);
151
  animation: messageIn 0.3s ease-out;
152
  }
 
153
  /* ๋ฒ„ํŠผ ์Šคํƒ€์ผ๋ง */
154
  .duplicate-button {
155
  background: linear-gradient(145deg, #24292e, #1a1e22) !important;
 
161
  border: none !important;
162
  cursor: pointer !important;
163
  }
 
164
  .duplicate-button:hover {
165
  transform: translateY(-2px) !important;
166
  box-shadow: 0 5px 15px rgba(0, 0, 0, 0.3) !important;
167
  }
 
168
  /* ์ž…๋ ฅ ํ•„๋“œ ์Šคํƒ€์ผ๋ง */
169
  """
170
 
 
175
  chatbot=chatbot,
176
  fill_height=True,
177
  theme="soft",
178
+ additional_inputs_accordion=gr.Accordion(label="โš™๏ธ ์˜ต์…˜์…˜", open=False, render=False),
179
  additional_inputs=[
180
  gr.Slider(
181
  minimum=0,
182
  maximum=1,
183
  step=0.1,
184
+ value=0.8,
185
  label="์˜จ๋„",
186
  render=False,
187
  ),