cutechicken commited on
Commit
5c6c33c
Β·
verified Β·
1 Parent(s): 6d02d4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +324 -10
app.py CHANGED
@@ -5,7 +5,7 @@ import pandas as pd
5
  import json
6
  from datetime import datetime
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
  import spaces
10
  from threading import Thread
11
 
@@ -17,8 +17,12 @@ class ModelManager:
17
  def __init__(self):
18
  self.tokenizer = None
19
  self.model = None
20
- self.setup_model()
21
 
 
 
 
 
22
  @spaces.GPU
23
  def setup_model(self):
24
  try:
@@ -42,13 +46,320 @@ class ModelManager:
42
  trust_remote_code=True,
43
  low_cpu_mem_usage=True
44
  )
45
- self.model.eval() # 평가 λͺ¨λ“œλ‘œ μ„€μ •
46
  print("λͺ¨λΈ λ‘œλ”© μ™„λ£Œ")
47
 
48
- # λͺ¨λΈκ³Ό ν† ν¬λ‚˜μ΄μ €κ°€ μ œλŒ€λ‘œ λ‘œλ“œλ˜μ—ˆλŠ”μ§€ 확인
49
- if self.model is None or self.tokenizer is None:
50
- raise Exception("λͺ¨λΈ λ˜λŠ” ν† ν¬λ‚˜μ΄μ €κ°€ μ œλŒ€λ‘œ μ΄ˆκΈ°ν™”λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  except Exception as e:
53
  print(f"λͺ¨λΈ λ‘œλ”© 쀑 였λ₯˜ λ°œμƒ: {e}")
54
  raise Exception(f"λͺ¨λΈ λ‘œλ”© μ‹€νŒ¨: {e}")
@@ -56,8 +367,8 @@ class ModelManager:
56
  @spaces.GPU
57
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
58
  try:
59
- if self.model is None or self.tokenizer is None:
60
- raise Exception("λͺ¨λΈμ΄ μ΄ˆκΈ°ν™”λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€.")
61
 
62
  # μž…λ ₯ ν…μŠ€νŠΈ μ€€λΉ„
63
  prompt = ""
@@ -70,7 +381,7 @@ class ModelManager:
70
  prompt += f"Human: {content}\n"
71
  elif role == "assistant":
72
  prompt += f"Assistant: {content}\n"
73
- prompt += "Assistant: " # 응닡 μ‹œμž‘ ν”„λ‘¬ν”„νŠΈ
74
 
75
  # μž…λ ₯ 인코딩
76
  input_ids = self.tokenizer.encode(
@@ -92,7 +403,7 @@ class ModelManager:
92
  num_return_sequences=1
93
  )
94
 
95
- # 응닡 λ””μ½”λ”© 및 슀트리밍
96
  generated_text = self.tokenizer.decode(
97
  output_ids[0][input_ids.shape[1]:],
98
  skip_special_tokens=True
@@ -255,6 +566,9 @@ def chat(message, history, uploaded_file, system_message="", max_tokens=4000, te
255
  이해λ₯Ό λ•κ² μŠ΅λ‹ˆλ‹€."""
256
 
257
  try:
 
 
 
258
  if uploaded_file:
259
  content, file_type = read_uploaded_file(uploaded_file)
260
  if file_type == "error":
 
5
  import json
6
  from datetime import datetime
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
  import spaces
10
  from threading import Thread
11
 
 
17
  def __init__(self):
18
  self.tokenizer = None
19
  self.model = None
20
+ # μ΄ˆκΈ°ν™”λŠ” 첫 μš”μ²­ μ‹œμ— μˆ˜ν–‰
21
 
22
+ def ensure_model_loaded(self):
23
+ if self.model is None or self.tokenizer is None:
24
+ self.setup_model()
25
+
26
  @spaces.GPU
27
  def setup_model(self):
28
  try:
 
46
  trust_remote_code=True,
47
  low_cpu_mem_usage=True
48
  )
49
+ self.model.eval()
50
  print("λͺ¨λΈ λ‘œλ”© μ™„λ£Œ")
51
 
52
+ except Exception as e:
53
+ print(f"λͺ¨λΈ λ‘œλ”© 쀑 였λ₯˜ λ°œμƒ: {e}")
54
+ raise Exception(f"λͺ¨λΈ λ‘œλ”© μ‹€νŒ¨: {e}")
55
+
56
+ @spaces.GPU
57
+ def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
58
+ try:
59
+ # λͺ¨λΈμ΄ λ‘œλ“œλ˜μ–΄ μžˆλŠ”μ§€ 확인
60
+ self.ensure_model_loaded()
61
+
62
+ # μž…λ ₯ ν…μŠ€νŠΈ μ€€λΉ„
63
+ prompt = ""
64
+ for msg in messages:
65
+ role = msg["role"]
66
+ content = msg["content"]
67
+ if role == "system":
68
+ prompt += f"System: {content}\n"
69
+ elif role == "user":
70
+ prompt += f"Human: {content}\n"
71
+ elif role == "assistant":
72
+ prompt += f"Assistant: {content}\n"
73
+ prompt += "Assistant: "
74
+
75
+ # μž…λ ₯ 인코딩
76
+ input_ids = self.tokenizer.encode(
77
+ prompt,
78
+ return_tensors="pt",
79
+ add_special_tokens=True
80
+ ).to(self.model.device)
81
+
82
+ # 응닡 생성
83
+ with torch.no_grad():
84
+ output_ids = self.model.generate(
85
+ input_ids,
86
+ max_new_tokens=max_tokens,
87
+ do_sample=True,
88
+ temperature=temperature,
89
+ top_p=top_p,
90
+ pad_token_id=self.tokenizer.pad_token_id,
91
+ eos_token_id=self.tokenizer.eos_token_id,
92
+ num_return_sequences=1
93
+ )
94
+
95
+ # 응닡 λ””μ½”λ”©
96
+ generated_text = self.tokenizer.decode(
97
+ output_ids[0][input_ids.shape[1]:],
98
+ skip_special_tokens=True
99
+ )
100
+
101
+ # 단어 λ‹¨μœ„λ‘œ 슀트리밍
102
+ words = generated_text.split()
103
+ for word in words:
104
+ yield type('Response', (), {
105
+ 'choices': [type('Choice', (), {
106
+ 'delta': {'content': word + " "}
107
+ })()]
108
+ })()
109
+
110
+ except Exception as e:
111
+ print(f"응닡 생성 쀑 였λ₯˜ λ°œμƒ: {e}")
112
+ raise Exception(f"응닡 생성 μ‹€νŒ¨: {e}")
113
+
114
+ class ChatHistory:
115
+ def __init__(self):
116
+ self.history = []
117
+ self.history_file = "/tmp/chat_history.json"
118
+ self.load_history()
119
+
120
+ def add_conversation(self, user_msg: str, assistant_msg: str):
121
+ conversation = {
122
+ "timestamp": datetime.now().isoformat(),
123
+ "messages": [
124
+ {"role": "user", "content": user_msg},
125
+ {"role": "assistant", "content": assistant_msg}
126
+ ]
127
+ }
128
+ self.history.append(conversation)
129
+ self.save_history()
130
+
131
+ def format_for_display(self):
132
+ formatted = []
133
+ for conv in self.history:
134
+ formatted.append([
135
+ conv["messages"][0]["content"],
136
+ conv["messages"][1]["content"]
137
+ ])
138
+ return formatted
139
+
140
+ def get_messages_for_api(self):
141
+ messages = []
142
+ for conv in self.history:
143
+ messages.extend([
144
+ {"role": "user", "content": conv["messages"][0]["content"]},
145
+ {"role": "assistant", "content": conv["messages"][1]["content"]}
146
+ ])
147
+ return messages
148
+
149
+ def clear_history(self):
150
+ self.history = []
151
+ self.save_history()
152
+
153
+ def save_history(self):
154
+ try:
155
+ with open(self.history_file, 'w', encoding='utf-8') as f:
156
+ json.dump(self.history, f, ensure_ascii=False, indent=2)
157
+ except Exception as e:
158
+ print(f"νžˆμŠ€ν† λ¦¬ μ €μž₯ μ‹€νŒ¨: {e}")
159
+
160
+ def load_history(self):
161
+ try:
162
+ if os.path.exists(self.history_file):
163
+ with open(self.history_file, 'r', encoding='utf-8') as f:
164
+ self.history = json.load(f)
165
+ except Exception as e:
166
+ print(f"νžˆμŠ€ν† λ¦¬ λ‘œλ“œ μ‹€νŒ¨: {e}")
167
+ self.history = []
168
+
169
+ # μ „μ—­ μΈμŠ€ν„΄μŠ€ 생성
170
+ chat_history = ChatHistory()
171
+ model_manager = ModelManager()
172
+
173
+ def analyze_file_content(content, file_type):
174
+ """Analyze file content and return structural summary"""
175
+ if file_type in ['parquet', 'csv']:
176
+ try:
177
+ lines = content.split('\n')
178
+ header = lines[0]
179
+ columns = header.count('|') - 1
180
+ rows = len(lines) - 3
181
+ return f"πŸ“Š 데이터셋 ꡬ쑰: {columns}개 컬럼, {rows}개 데이터"
182
+ except:
183
+ return "❌ 데이터셋 ꡬ쑰 뢄석 μ‹€νŒ¨"
184
+
185
+ lines = content.split('\n')
186
+ total_lines = len(lines)
187
+ non_empty_lines = len([line for line in lines if line.strip()])
188
+
189
+ if any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function']):
190
+ functions = len([line for line in lines if 'def ' in line])
191
+ classes = len([line for line in lines if 'class ' in line])
192
+ imports = len([line for line in lines if 'import ' in line or 'from ' in line])
193
+ return f"πŸ’» μ½”λ“œ ꡬ쑰: {total_lines}쀄 (ν•¨μˆ˜: {functions}, 클래슀: {classes}, μž„ν¬νŠΈ: {imports})"
194
+
195
+ paragraphs = content.count('\n\n') + 1
196
+ words = len(content.split())
197
+ return f"πŸ“ λ¬Έμ„œ ꡬ쑰: {total_lines}쀄, {paragraphs}단락, μ•½ {words}단어"
198
+
199
+ def read_uploaded_file(file):
200
+ if file is None:
201
+ return "", ""
202
+ try:
203
+ file_ext = os.path.splitext(file.name)[1].lower()
204
+
205
+ if file_ext == '.parquet':
206
+ df = pd.read_parquet(file.name, engine='pyarrow')
207
+ content = df.head(10).to_markdown(index=False)
208
+ return content, "parquet"
209
+ elif file_ext == '.csv':
210
+ encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
211
+ for encoding in encodings:
212
+ try:
213
+ df = pd.read_csv(file.name, encoding=encoding)
214
+ content = f"πŸ“Š 데이터 미리보기:\n{df.head(10).to_markdown(index=False)}\n\n"
215
+ content += f"\nπŸ“ˆ 데이터 정보:\n"
216
+ content += f"- 전체 ν–‰ 수: {len(df)}\n"
217
+ content += f"- 전체 μ—΄ 수: {len(df.columns)}\n"
218
+ content += f"- 컬럼 λͺ©λ‘: {', '.join(df.columns)}\n"
219
+ content += f"\nπŸ“‹ 컬럼 데이터 νƒ€μž…:\n"
220
+ for col, dtype in df.dtypes.items():
221
+ content += f"- {col}: {dtype}\n"
222
+ null_counts = df.isnull().sum()
223
+ if null_counts.any():
224
+ content += f"\n⚠️ 결츑치:\n"
225
+ for col, null_count in null_counts[null_counts > 0].items():
226
+ content += f"- {col}: {null_count}개 λˆ„λ½\n"
227
+ return content, "csv"
228
+ except UnicodeDecodeError:
229
+ continue
230
+ raise UnicodeDecodeError(f"❌ μ§€μ›λ˜λŠ” μΈμ½”λ”©μœΌλ‘œ νŒŒμΌμ„ 읽을 수 μ—†μŠ΅λ‹ˆλ‹€ ({', '.join(encodings)})")
231
+ else:
232
+ encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
233
+ for encoding in encodings:
234
+ try:
235
+ with open(file.name, 'r', encoding=encoding) as f:
236
+ content = f.read()
237
+ return content, "text"
238
+ except UnicodeDecodeError:
239
+ continue
240
+ raise UnicodeDecodeError(f"❌ μ§€μ›λ˜λŠ” μΈμ½”λ”©μœΌλ‘œ νŒŒμΌμ„ 읽을 수 μ—†μŠ΅λ‹ˆλ‹€ ({', '.join(encodings)})")
241
+ except Exception as e:
242
+ return f"❌ 파일 읽기 였λ₯˜: {str(e)}", "error"
243
+
244
+ def chat(message, history, uploaded_file, system_message="", max_tokens=4000, temperature=0.7, top_p=0.9):
245
+ if not message:
246
+ return "", history
247
+
248
+ system_prefix = """μ €λŠ” μ—¬λŸ¬λΆ„μ˜ μΉœκ·Όν•˜κ³  지적인 AI μ–΄μ‹œμŠ€ν„΄νŠΈ 'GiniGEN'μž…λ‹ˆλ‹€.. λ‹€μŒκ³Ό 같은 μ›μΉ™μœΌλ‘œ μ†Œν†΅ν•˜κ² μŠ΅λ‹ˆλ‹€:
249
+ 1. 🀝 μΉœκ·Όν•˜κ³  곡감적인 νƒœλ„λ‘œ λŒ€ν™”
250
+ 2. πŸ’‘ λͺ…ν™•ν•˜κ³  μ΄ν•΄ν•˜κΈ° μ‰¬μš΄ μ„€λͺ… 제곡
251
+ 3. 🎯 질문의 μ˜λ„λ₯Ό μ •ν™•νžˆ νŒŒμ•…ν•˜μ—¬ λ§žμΆ€ν˜• λ‹΅λ³€
252
+ 4. πŸ“š ν•„μš”ν•œ 경우 μ—…λ‘œλ“œλœ 파일 λ‚΄μš©μ„ μ°Έκ³ ν•˜μ—¬ ꡬ체적인 도움 제곡
253
+ 5. ✨ 좔가적인 톡찰과 μ œμ•ˆμ„ ν†΅ν•œ κ°€μΉ˜ μžˆλŠ” λŒ€ν™”
254
+ 항상 예의 λ°”λ₯΄κ³  μΉœμ ˆν•˜κ²Œ μ‘λ‹΅ν•˜λ©°, ν•„μš”ν•œ 경우 ꡬ체적인 μ˜ˆμ‹œλ‚˜ μ„€λͺ…을 μΆ”κ°€ν•˜μ—¬
255
+ 이해λ₯Ό λ•κ² μŠ΅λ‹ˆλ‹€."""
256
+
257
+ try:
258
+ # 첫 λ©”μ‹œμ§€μΌ λ•Œ λͺ¨λΈ λ‘œλ”©
259
+ model_manager.ensure_model_loaded()
260
+
261
+ if uploaded_file:
262
+ content, file_type = read_uploaded_file(uploaded_file)
263
+ if file_type == "error":
264
+ error_message = content
265
+ chat_history.add_conversation(message, error_message)
266
+ return "", history + [[message, error_message]]
267
+
268
+ file_summary = analyze_file_content(content, file_type)
269
+
270
+ if file_type in ['parquet', 'csv']:
271
+ system_message += f"\n\n파일 λ‚΄μš©:\n```markdown\n{content}\n```"
272
+ else:
273
+ system_message += f"\n\n파일 λ‚΄μš©:\n```\n{content}\n```"
274
 
275
+ if message == "파일 뢄석을 μ‹œμž‘ν•©λ‹ˆλ‹€...":
276
+ message = f"""[파일 ꡬ쑰 뢄석] {file_summary}
277
+ λ‹€μŒ κ΄€μ μ—μ„œ 도움을 λ“œλ¦¬κ² μŠ΅λ‹ˆλ‹€:
278
+ 1. πŸ“‹ μ „λ°˜μ μΈ λ‚΄μš© νŒŒμ•…
279
+ 2. πŸ’‘ μ£Όμš” νŠΉμ§• μ„€λͺ…
280
+ 3. 🎯 μ‹€μš©μ μΈ ν™œμš© λ°©μ•ˆ
281
+ 4. ✨ κ°œμ„  μ œμ•ˆ
282
+ 5. πŸ’¬ μΆ”κ°€ μ§ˆλ¬Έμ΄λ‚˜ ν•„μš”ν•œ μ„€λͺ…"""
283
+
284
+ messages = [{"role": "system", "content": system_prefix + system_message}]
285
+
286
+ if history:
287
+ for user_msg, assistant_msg in history:
288
+ messages.append({"role": "user", "content": user_msg})
289
+ messages.append({"role": "assistant", "content": assistant_msg})
290
+
291
+ messages.append({"role": "user", "content": message})
292
+
293
+ partial_message = ""
294
+
295
+ for response in model_manager.generate_response(
296
+ messages,
297
+ max_tokens=max_tokens,
298
+ temperature=temperature,
299
+ top_p=top_p
300
+ ):
301
+ token = response.choices[0].delta.get('content', '')
302
+ if token:
303
+ partial_message += token
304
+ current_history = history + [[message, partial_message]]
305
+ yield "", current_history
306
+
307
+ chat_history.add_conversation(message, partial_message)
308
+
309
+ except Exception as e:
310
+ error_msg = f"❌ 였λ₯˜κ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€: {str(e)}"
311
+ chat_history.add_conversation(message, error_msg)
312
+ yield "", history + [[message, error_msg]]import os
313
+ from dotenv import load_dotenv
314
+ import gradio as gr
315
+ import pandas as pd
316
+ import json
317
+ from datetime import datetime
318
+ import torch
319
+ from transformers import AutoModelForCausalLM, AutoTokenizer
320
+ import spaces
321
+ from threading import Thread
322
+
323
+ # ν™˜κ²½ λ³€μˆ˜ μ„€μ •
324
+ HF_TOKEN = os.getenv("HF_TOKEN")
325
+ MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
326
+
327
+ class ModelManager:
328
+ def __init__(self):
329
+ self.tokenizer = None
330
+ self.model = None
331
+ # μ΄ˆκΈ°ν™”λŠ” 첫 μš”μ²­ μ‹œμ— μˆ˜ν–‰
332
+
333
+ def ensure_model_loaded(self):
334
+ if self.model is None or self.tokenizer is None:
335
+ self.setup_model()
336
+
337
+ @spaces.GPU
338
+ def setup_model(self):
339
+ try:
340
+ print("ν† ν¬λ‚˜μ΄μ € λ‘œλ”© μ‹œμž‘...")
341
+ self.tokenizer = AutoTokenizer.from_pretrained(
342
+ MODEL_ID,
343
+ use_fast=True,
344
+ token=HF_TOKEN,
345
+ trust_remote_code=True
346
+ )
347
+ if not self.tokenizer.pad_token:
348
+ self.tokenizer.pad_token = self.tokenizer.eos_token
349
+ print("ν† ν¬λ‚˜μ΄μ € λ‘œλ”© μ™„λ£Œ")
350
+
351
+ print("λͺ¨λΈ λ‘œλ”© μ‹œμž‘...")
352
+ self.model = AutoModelForCausalLM.from_pretrained(
353
+ MODEL_ID,
354
+ token=HF_TOKEN,
355
+ torch_dtype=torch.bfloat16,
356
+ device_map="auto",
357
+ trust_remote_code=True,
358
+ low_cpu_mem_usage=True
359
+ )
360
+ self.model.eval()
361
+ print("λͺ¨λΈ λ‘œλ”© μ™„λ£Œ")
362
+
363
  except Exception as e:
364
  print(f"λͺ¨λΈ λ‘œλ”© 쀑 였λ₯˜ λ°œμƒ: {e}")
365
  raise Exception(f"λͺ¨λΈ λ‘œλ”© μ‹€νŒ¨: {e}")
 
367
  @spaces.GPU
368
  def generate_response(self, messages, max_tokens=4000, temperature=0.7, top_p=0.9):
369
  try:
370
+ # λͺ¨λΈμ΄ λ‘œλ“œλ˜μ–΄ μžˆλŠ”μ§€ 확인
371
+ self.ensure_model_loaded()
372
 
373
  # μž…λ ₯ ν…μŠ€νŠΈ μ€€λΉ„
374
  prompt = ""
 
381
  prompt += f"Human: {content}\n"
382
  elif role == "assistant":
383
  prompt += f"Assistant: {content}\n"
384
+ prompt += "Assistant: "
385
 
386
  # μž…λ ₯ 인코딩
387
  input_ids = self.tokenizer.encode(
 
403
  num_return_sequences=1
404
  )
405
 
406
+ # 응닡 λ””μ½”λ”©
407
  generated_text = self.tokenizer.decode(
408
  output_ids[0][input_ids.shape[1]:],
409
  skip_special_tokens=True
 
566
  이해λ₯Ό λ•κ² μŠ΅λ‹ˆλ‹€."""
567
 
568
  try:
569
+ # 첫 λ©”μ‹œμ§€μΌ λ•Œ λͺ¨λΈ λ‘œλ”©
570
+ model_manager.ensure_model_loaded()
571
+
572
  if uploaded_file:
573
  content, file_type = read_uploaded_file(uploaded_file)
574
  if file_type == "error":