RAGOndevice / app.py
cutechicken's picture
Update app.py
d6a3ccb verified
raw
history blame
28.9 kB
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import random
from datasets import load_dataset
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
import pandas as pd
from typing import List, Tuple
import json
from datetime import datetime
import pyarrow.parquet as pq
import pypdf
import io
import pyarrow.parquet as pq
from pdfminer.high_level import extract_text
from pdfminer.layout import LAParams
from tabulate import tabulate # tabulate ์ถ”๊ฐ€
import platform
import subprocess
import pytesseract
from pdf2image import convert_from_path
# ์ „์—ญ ๋ณ€์ˆ˜ ์ถ”๊ฐ€
current_file_context = None
# ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์„ค์ •
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
MODELS = os.environ.get("MODELS")
MODEL_NAME = MODEL_ID.split("/")[-1]
model = None # ์ „์—ญ ๋ณ€์ˆ˜๋กœ ์„ ์–ธ
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# ์œ„ํ‚คํ”ผ๋””์•„ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
print("Wikipedia dataset loaded:", wiki_dataset)
# TF-IDF ๋ฒกํ„ฐ๋ผ์ด์ € ์ดˆ๊ธฐํ™” ๋ฐ ํ•™์Šต
print("TF-IDF ๋ฒกํ„ฐํ™” ์‹œ์ž‘...")
questions = wiki_dataset['train']['question'][:10000] # ์ฒ˜์Œ 10000๊ฐœ๋งŒ ์‚ฌ์šฉ
vectorizer = TfidfVectorizer(max_features=1000)
question_vectors = vectorizer.fit_transform(questions)
print("TF-IDF ๋ฒกํ„ฐํ™” ์™„๋ฃŒ")
class ChatHistory:
def __init__(self):
self.history = []
self.history_file = "/tmp/chat_history.json"
self.load_history()
def add_conversation(self, user_msg: str, assistant_msg: str):
conversation = {
"timestamp": datetime.now().isoformat(),
"messages": [
{"role": "user", "content": user_msg},
{"role": "assistant", "content": assistant_msg}
]
}
self.history.append(conversation)
self.save_history()
def format_for_display(self):
formatted = []
for conv in self.history:
formatted.append([
conv["messages"][0]["content"],
conv["messages"][1]["content"]
])
return formatted
def get_messages_for_api(self):
messages = []
for conv in self.history:
messages.extend([
{"role": "user", "content": conv["messages"][0]["content"]},
{"role": "assistant", "content": conv["messages"][1]["content"]}
])
return messages
def clear_history(self):
self.history = []
self.save_history()
def save_history(self):
try:
with open(self.history_file, 'w', encoding='utf-8') as f:
json.dump(self.history, f, ensure_ascii=False, indent=2)
except Exception as e:
print(f"ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ ์‹คํŒจ: {e}")
def load_history(self):
try:
if os.path.exists(self.history_file):
with open(self.history_file, 'r', encoding='utf-8') as f:
self.history = json.load(f)
except Exception as e:
print(f"ํžˆ์Šคํ† ๋ฆฌ ๋กœ๋“œ ์‹คํŒจ: {e}")
self.history = []
# ์ „์—ญ ChatHistory ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ
chat_history = ChatHistory()
def find_relevant_context(query, top_k=3):
# ์ฟผ๋ฆฌ ๋ฒกํ„ฐํ™”
query_vector = vectorizer.transform([query])
# ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
similarities = (query_vector * question_vectors.T).toarray()[0]
# ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์งˆ๋ฌธ๋“ค์˜ ์ธ๋ฑ์Šค
top_indices = np.argsort(similarities)[-top_k:][::-1]
# ๊ด€๋ จ ์ปจํ…์ŠคํŠธ ์ถ”์ถœ
relevant_contexts = []
for idx in top_indices:
if similarities[idx] > 0:
relevant_contexts.append({
'question': questions[idx],
'answer': wiki_dataset['train']['answer'][idx],
'similarity': similarities[idx]
})
return relevant_contexts
def analyze_file_content(content, file_type):
"""Analyze file content and return structural summary"""
if file_type in ['parquet', 'csv']:
try:
lines = content.split('\n')
header = lines[0]
columns = header.count('|') - 1
rows = len(lines) - 3
return f"๐Ÿ“Š ๋ฐ์ดํ„ฐ์…‹ ๊ตฌ์กฐ: {columns}๊ฐœ ์ปฌ๋Ÿผ, {rows}๊ฐœ ๋ฐ์ดํ„ฐ"
except:
return "โŒ ๋ฐ์ดํ„ฐ์…‹ ๊ตฌ์กฐ ๋ถ„์„ ์‹คํŒจ"
lines = content.split('\n')
total_lines = len(lines)
non_empty_lines = len([line for line in lines if line.strip()])
if any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function']):
functions = len([line for line in lines if 'def ' in line])
classes = len([line for line in lines if 'class ' in line])
imports = len([line for line in lines if 'import ' in line or 'from ' in line])
return f"๐Ÿ’ป ์ฝ”๋“œ ๊ตฌ์กฐ: {total_lines}์ค„ (ํ•จ์ˆ˜: {functions}, ํด๋ž˜์Šค: {classes}, ์ž„ํฌํŠธ: {imports})"
paragraphs = content.count('\n\n') + 1
words = len(content.split())
return f"๐Ÿ“ ๋ฌธ์„œ ๊ตฌ์กฐ: {total_lines}์ค„, {paragraphs}๋‹จ๋ฝ, ์•ฝ {words}๋‹จ์–ด"
def extract_pdf_text_with_ocr(file_path):
try:
# Poppler ๊ฒฝ๋กœ ์„ค์ •
if platform.system() == 'Windows':
poppler_path = r"C:\Program Files\poppler-0.68.0\bin"
else:
poppler_path = None # Linux์˜ ๊ฒฝ์šฐ ๊ธฐ๋ณธ ๊ฒฝ๋กœ ์‚ฌ์šฉ
# PDF๋ฅผ ์ด๋ฏธ์ง€๋กœ ๋ณ€ํ™˜
images = convert_from_path(
file_path,
poppler_path=poppler_path,
fmt='jpeg',
grayscale=False,
size=(1700, None) # ํ•ด์ƒ๋„ ํ–ฅ์ƒ
)
# ์ „์ฒด ํ…์ŠคํŠธ ์ €์žฅ
text = ""
# ๊ฐ ํŽ˜์ด์ง€์— ๋Œ€ํ•ด OCR ์ˆ˜ํ–‰
for i, image in enumerate(images):
try:
# OCR ์„ค์ •
custom_config = r'--oem 3 --psm 6 -l kor+eng'
# OCR ์ˆ˜ํ–‰
page_text = pytesseract.image_to_string(
image,
config=custom_config
)
text += f"\n--- ํŽ˜์ด์ง€ {i+1} ---\n{page_text}\n"
except Exception as e:
print(f"ํŽ˜์ด์ง€ {i+1} OCR ์˜ค๋ฅ˜: {str(e)}")
continue
return text
except Exception as e:
return f"PDF ํ…์ŠคํŠธ ์ถ”์ถœ ์˜ค๋ฅ˜: {str(e)}"
def read_uploaded_file(file):
if file is None:
return "", ""
try:
file_ext = os.path.splitext(file.name)[1].lower()
# Parquet ํŒŒ์ผ ์ฒ˜๋ฆฌ
if file_ext == '.parquet':
try:
table = pq.read_table(file.name)
df = table.to_pandas()
content = f"๐Ÿ“Š Parquet ํŒŒ์ผ ๋ถ„์„:\n\n"
content += f"1. ๊ธฐ๋ณธ ์ •๋ณด:\n"
content += f"- ์ „์ฒด ํ–‰ ์ˆ˜: {len(df):,}๊ฐœ\n"
content += f"- ์ „์ฒด ์—ด ์ˆ˜: {len(df.columns)}๊ฐœ\n"
content += f"- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB\n\n"
content += f"2. ์ปฌ๋Ÿผ ์ •๋ณด:\n"
for col in df.columns:
content += f"- {col} ({df[col].dtype})\n"
content += f"\n3. ๋ฐ์ดํ„ฐ ๋ฏธ๋ฆฌ๋ณด๊ธฐ:\n"
# tabulate ์‚ฌ์šฉํ•˜์—ฌ ํ…Œ์ด๋ธ” ํ˜•์‹์œผ๋กœ ์ถœ๋ ฅ
content += tabulate(df.head(5), headers='keys', tablefmt='pipe', showindex=False)
content += f"\n\n4. ๊ฒฐ์ธก์น˜ ์ •๋ณด:\n"
null_counts = df.isnull().sum()
for col, count in null_counts[null_counts > 0].items():
content += f"- {col}: {count:,}๊ฐœ ({count/len(df)*100:.1f}%)\n"
# ์ˆ˜์น˜ํ˜• ์ปฌ๋Ÿผ์— ๋Œ€ํ•œ ๊ธฐ๋ณธ ํ†ต๊ณ„
numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
if len(numeric_cols) > 0:
content += f"\n5. ์ˆ˜์น˜ํ˜• ์ปฌ๋Ÿผ ํ†ต๊ณ„:\n"
stats_df = df[numeric_cols].describe()
content += tabulate(stats_df, headers='keys', tablefmt='pipe')
return content, "parquet"
except Exception as e:
return f"Parquet ํŒŒ์ผ ์ฝ๊ธฐ ์˜ค๋ฅ˜: {str(e)}", "error"
# PDF ํŒŒ์ผ ์ฒ˜๋ฆฌ
if file_ext == '.pdf':
try:
pdf_reader = pypdf.PdfReader(file.name)
total_pages = len(pdf_reader.pages)
content = f"๐Ÿ“‘ PDF ๋ฌธ์„œ ๋ถ„์„:\n\n"
content += f"1. ๊ธฐ๋ณธ ์ •๋ณด:\n"
content += f"- ์ด ํŽ˜์ด์ง€ ์ˆ˜: {total_pages}ํŽ˜์ด์ง€\n"
# ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์ถ”์ถœ
if pdf_reader.metadata:
content += "\n2. ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ:\n"
for key, value in pdf_reader.metadata.items():
if value and str(key).startswith('/'):
content += f"- {key[1:]}: {value}\n"
# ๋จผ์ € pdfminer๋กœ ํ…์ŠคํŠธ ์ถ”์ถœ ์‹œ๋„
try:
text = extract_text(
file.name,
laparams=LAParams(
line_margin=0.5,
word_margin=0.1,
char_margin=2.0,
all_texts=True
)
)
except:
text = ""
# pdfminer๋กœ ์ถ”์ถœ ์‹คํŒจ์‹œ OCR ์‹œ๋„
if not text.strip():
text = extract_pdf_text_with_ocr(file.name)
# ํ…์ŠคํŠธ ๋ถ„์„
if text:
words = text.split()
lines = text.split('\n')
content += f"\n3. ํ…์ŠคํŠธ ๋ถ„์„:\n"
content += f"- ์ด ๋‹จ์–ด ์ˆ˜: {len(words):,}๊ฐœ\n"
content += f"- ๊ณ ์œ  ๋‹จ์–ด ์ˆ˜: {len(set(words)):,}๊ฐœ\n"
content += f"- ์ด ๋ผ์ธ ์ˆ˜: {len(lines):,}๊ฐœ\n"
# ๋ณธ๋ฌธ ๋‚ด์šฉ
content += f"\n4. ๋ณธ๋ฌธ ๋‚ด์šฉ:\n"
preview_length = min(2000, len(text)) # ๋ฏธ๋ฆฌ๋ณด๊ธฐ ๊ธธ์ด ์ฆ๊ฐ€
content += f"--- ์ฒ˜์Œ {preview_length}์ž ---\n"
content += text[:preview_length]
if len(text) > preview_length:
content += f"\n... (์ด {len(text):,}์ž ์ค‘ ์ผ๋ถ€ ํ‘œ์‹œ)\n"
else:
content += "\nโš ๏ธ ํ…์ŠคํŠธ ์ถ”์ถœ ์‹คํŒจ"
return content, "pdf"
except Exception as e:
return f"PDF ํŒŒ์ผ ์ฝ๊ธฐ ์˜ค๋ฅ˜: {str(e)}", "error"
# CSV ํŒŒ์ผ ์ฒ˜๋ฆฌ
elif file_ext == '.csv':
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
for encoding in encodings:
try:
df = pd.read_csv(file.name, encoding=encoding)
content = f"๐Ÿ“Š CSV ํŒŒ์ผ ๋ถ„์„:\n\n"
content += f"1. ๊ธฐ๋ณธ ์ •๋ณด:\n"
content += f"- ์ „์ฒด ํ–‰ ์ˆ˜: {len(df):,}๊ฐœ\n"
content += f"- ์ „์ฒด ์—ด ์ˆ˜: {len(df.columns)}๊ฐœ\n"
content += f"- ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {df.memory_usage(deep=True).sum() / 1024 / 1024:.2f} MB\n\n"
content += f"2. ์ปฌ๋Ÿผ ์ •๋ณด:\n"
for col in df.columns:
content += f"- {col} ({df[col].dtype})\n"
content += f"\n3. ๋ฐ์ดํ„ฐ ๋ฏธ๋ฆฌ๋ณด๊ธฐ:\n"
content += df.head(5).to_markdown(index=False)
content += f"\n\n4. ๊ฒฐ์ธก์น˜ ์ •๋ณด:\n"
null_counts = df.isnull().sum()
for col, count in null_counts[null_counts > 0].items():
content += f"- {col}: {count:,}๊ฐœ ({count/len(df)*100:.1f}%)\n"
return content, "csv"
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(f"์ง€์›๋˜๋Š” ์ธ์ฝ”๋”ฉ์œผ๋กœ ํŒŒ์ผ์„ ์ฝ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค ({', '.join(encodings)})")
# ํ…์ŠคํŠธ ํŒŒ์ผ ์ฒ˜๋ฆฌ
else:
encodings = ['utf-8', 'cp949', 'euc-kr', 'latin1']
for encoding in encodings:
try:
with open(file.name, 'r', encoding=encoding) as f:
content = f.read()
# ํŒŒ์ผ ๋‚ด์šฉ ๋ถ„์„
lines = content.split('\n')
total_lines = len(lines)
non_empty_lines = len([line for line in lines if line.strip()])
# ์ฝ”๋“œ ํŒŒ์ผ ์—ฌ๋ถ€ ํ™•์ธ
is_code = any(keyword in content.lower() for keyword in ['def ', 'class ', 'import ', 'function'])
analysis = f"\n๐Ÿ“ ํŒŒ์ผ ๋ถ„์„:\n"
if is_code:
# ์ฝ”๋“œ ํŒŒ์ผ ๋ถ„์„
functions = len([line for line in lines if 'def ' in line])
classes = len([line for line in lines if 'class ' in line])
imports = len([line for line in lines if 'import ' in line or 'from ' in line])
analysis += f"- ํŒŒ์ผ ์œ ํ˜•: ์ฝ”๋“œ\n"
analysis += f"- ์ „์ฒด ๋ผ์ธ ์ˆ˜: {total_lines:,}์ค„\n"
analysis += f"- ํ•จ์ˆ˜ ์ˆ˜: {functions}๊ฐœ\n"
analysis += f"- ํด๋ž˜์Šค ์ˆ˜: {classes}๊ฐœ\n"
analysis += f"- import ๋ฌธ ์ˆ˜: {imports}๊ฐœ\n"
else:
# ์ผ๋ฐ˜ ํ…์ŠคํŠธ ํŒŒ์ผ ๋ถ„์„
words = len(content.split())
chars = len(content)
analysis += f"- ํŒŒ์ผ ์œ ํ˜•: ํ…์ŠคํŠธ\n"
analysis += f"- ์ „์ฒด ๋ผ์ธ ์ˆ˜: {total_lines:,}์ค„\n"
analysis += f"- ์‹ค์ œ ๋‚ด์šฉ์ด ์žˆ๋Š” ๋ผ์ธ ์ˆ˜: {non_empty_lines:,}์ค„\n"
analysis += f"- ๋‹จ์–ด ์ˆ˜: {words:,}๊ฐœ\n"
analysis += f"- ๋ฌธ์ž ์ˆ˜: {chars:,}๊ฐœ\n"
return content + analysis, "text"
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(f"์ง€์›๋˜๋Š” ์ธ์ฝ”๋”ฉ์œผ๋กœ ํŒŒ์ผ์„ ์ฝ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค ({', '.join(encodings)})")
except Exception as e:
return f"ํŒŒ์ผ ์ฝ๊ธฐ ์˜ค๋ฅ˜: {str(e)}", "error"
# ํŒŒ์ผ ์—…๋กœ๋“œ ์ด๋ฒคํŠธ ํ•ธ๋“ค๋ง ์ˆ˜์ •
def init_msg():
return "ํŒŒ์ผ์„ ๋ถ„์„ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค..."
CSS = """
/* 3D ์Šคํƒ€์ผ CSS */
:root {
--primary-color: #2196f3;
--secondary-color: #1976d2;
--background-color: #f0f2f5;
--card-background: #ffffff;
--text-color: #333333;
--shadow-color: rgba(0, 0, 0, 0.1);
}
body {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
min-height: 100vh;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
.container {
transform-style: preserve-3d;
perspective: 1000px;
}
.chatbot {
background: var(--card-background);
border-radius: 20px;
box-shadow:
0 10px 20px var(--shadow-color),
0 6px 6px var(--shadow-color);
transform: translateZ(0);
transition: transform 0.3s ease;
backdrop-filter: blur(10px);
}
.chatbot:hover {
transform: translateZ(10px);
}
/* ๋ฉ”์‹œ์ง€ ์ž…๋ ฅ ์˜์—ญ */
.input-area {
background: var(--card-background);
border-radius: 15px;
padding: 15px;
margin-top: 20px;
box-shadow:
0 5px 15px var(--shadow-color),
0 3px 3px var(--shadow-color);
transform: translateZ(0);
transition: all 0.3s ease;
display: flex;
align-items: center;
gap: 10px;
}
.input-area:hover {
transform: translateZ(5px);
}
/* ๋ฒ„ํŠผ ์Šคํƒ€์ผ */
.custom-button {
background: linear-gradient(145deg, var(--primary-color), var(--secondary-color));
color: white;
border: none;
border-radius: 10px;
padding: 10px 20px;
font-weight: 600;
cursor: pointer;
transform: translateZ(0);
transition: all 0.3s ease;
box-shadow:
0 4px 6px var(--shadow-color),
0 1px 3px var(--shadow-color);
}
.custom-button:hover {
transform: translateZ(5px) translateY(-2px);
box-shadow:
0 7px 14px var(--shadow-color),
0 3px 6px var(--shadow-color);
}
/* ํŒŒ์ผ ์—…๋กœ๋“œ ๋ฒ„ํŠผ */
.file-upload-icon {
background: linear-gradient(145deg, #64b5f6, #42a5f5);
color: white;
border-radius: 8px;
font-size: 2em;
cursor: pointer;
display: flex;
align-items: center;
justify-content: center;
height: 70px;
width: 70px;
transition: all 0.3s ease;
box-shadow: 0 2px 5px rgba(0,0,0,0.1);
}
.file-upload-icon:hover {
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0,0,0,0.2);
}
/* ํŒŒ์ผ ์—…๋กœ๋“œ ๋ฒ„ํŠผ ๋‚ด๋ถ€ ์š”์†Œ ์Šคํƒ€์ผ๋ง */
.file-upload-icon > .wrap {
display: flex !important;
align-items: center;
justify-content: center;
width: 100%;
height: 100%;
}
.file-upload-icon > .wrap > p {
display: none !important;
}
.file-upload-icon > .wrap::before {
content: "๐Ÿ“";
font-size: 2em;
display: block;
}
/* ๋ฉ”์‹œ์ง€ ์Šคํƒ€์ผ */
.message {
background: var(--card-background);
border-radius: 15px;
padding: 15px;
margin: 10px 0;
box-shadow:
0 4px 6px var(--shadow-color),
0 1px 3px var(--shadow-color);
transform: translateZ(0);
transition: all 0.3s ease;
}
.message:hover {
transform: translateZ(5px);
}
.chat-container {
height: 600px !important;
margin-bottom: 10px;
}
.input-container {
height: 70px !important;
display: flex;
align-items: center;
gap: 10px;
margin-top: 5px;
}
.input-textbox {
height: 70px !important;
border-radius: 8px !important;
font-size: 1.1em !important;
padding: 10px 15px !important;
display: flex !important;
align-items: flex-start !important; /* ํ…์ŠคํŠธ ์ž…๋ ฅ ์œ„์น˜๋ฅผ ์œ„๋กœ ์กฐ์ • */
}
.input-textbox textarea {
padding-top: 5px !important; /* ํ…์ŠคํŠธ ์ƒ๋‹จ ์—ฌ๋ฐฑ ์กฐ์ • */
}
.send-button {
height: 70px !important;
min-width: 70px !important;
font-size: 1.1em !important;
}
/* ์„ค์ • ํŒจ๋„ ๊ธฐ๋ณธ ์Šคํƒ€์ผ */
.settings-panel {
padding: 20px;
margin-top: 20px;
}
"""
# GPU ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ํ•จ์ˆ˜ ์ˆ˜์ •
def clear_cuda_memory():
if hasattr(torch.cuda, 'empty_cache'):
with torch.cuda.device('cuda'):
torch.cuda.empty_cache()
# ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜ ์ˆ˜์ •
@spaces.GPU
def load_model():
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
return model
except Exception as e:
print(f"๋ชจ๋ธ ๋กœ๋“œ ์˜ค๋ฅ˜: {str(e)}")
raise
@spaces.GPU
def stream_chat(message: str, history: list, uploaded_file, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
global model, current_file_context
try:
if model is None:
model = load_model()
print(f'message is - {message}')
print(f'history is - {history}')
# ํŒŒ์ผ ์—…๋กœ๋“œ ์ฒ˜๋ฆฌ
file_context = ""
if uploaded_file and message == "ํŒŒ์ผ์„ ๋ถ„์„ํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค...":
try:
content, file_type = read_uploaded_file(uploaded_file)
if content:
file_analysis = analyze_file_content(content, file_type)
file_context = f"\n\n๐Ÿ“„ ํŒŒ์ผ ๋ถ„์„ ๊ฒฐ๊ณผ:\n{file_analysis}\n\nํŒŒ์ผ ๋‚ด์šฉ:\n```\n{content}\n```"
current_file_context = file_context # ํŒŒ์ผ ์ปจํ…์ŠคํŠธ ์ €์žฅ
message = "์—…๋กœ๋“œ๋œ ํŒŒ์ผ์„ ๋ถ„์„ํ•ด์ฃผ์„ธ์š”."
except Exception as e:
print(f"ํŒŒ์ผ ๋ถ„์„ ์˜ค๋ฅ˜: {str(e)}")
file_context = f"\n\nโŒ ํŒŒ์ผ ๋ถ„์„ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
elif current_file_context: # ์ €์žฅ๋œ ํŒŒ์ผ ์ปจํ…์ŠคํŠธ๊ฐ€ ์žˆ์œผ๋ฉด ์‚ฌ์šฉ
file_context = current_file_context
# ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ๋ชจ๋‹ˆํ„ฐ๋ง
if torch.cuda.is_available():
print(f"CUDA ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
# ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ๊ฐ€ ๋„ˆ๋ฌด ๊ธธ๋ฉด ์ž˜๋ผ๋‚ด๊ธฐ
max_history_length = 10 # ์ตœ๋Œ€ ํžˆ์Šคํ† ๋ฆฌ ๊ธธ์ด ์„ค์ •
if len(history) > max_history_length:
history = history[-max_history_length:]
# ๊ด€๋ จ ์ปจํ…์ŠคํŠธ ์ฐพ๊ธฐ
try:
relevant_contexts = find_relevant_context(message)
wiki_context = "\n\n๊ด€๋ จ ์œ„ํ‚คํ”ผ๋””์•„ ์ •๋ณด:\n"
for ctx in relevant_contexts:
wiki_context += f"Q: {ctx['question']}\nA: {ctx['answer']}\n์œ ์‚ฌ๋„: {ctx['similarity']:.3f}\n\n"
except Exception as e:
print(f"์ปจํ…์ŠคํŠธ ๊ฒ€์ƒ‰ ์˜ค๋ฅ˜: {str(e)}")
wiki_context = ""
# ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ๊ตฌ์„ฑ
conversation = []
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer}
])
# ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
final_message = file_context + wiki_context + "\nํ˜„์žฌ ์งˆ๋ฌธ: " + message
conversation.append({"role": "user", "content": final_message})
# ํ† ํฐ ์ˆ˜ ์ œํ•œ
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
max_length = 4096 # ๋˜๋Š” ๋ชจ๋ธ์˜ ์ตœ๋Œ€ ์ปจํ…์ŠคํŠธ ๊ธธ์ด
if len(input_ids.split()) > max_length:
# ์ปจํ…์ŠคํŠธ๊ฐ€ ๋„ˆ๋ฌด ๊ธธ๋ฉด ์ž˜๋ผ๋‚ด๊ธฐ
input_ids = " ".join(input_ids.split()[-max_length:])
inputs = tokenizer(input_ids, return_tensors="pt").to("cuda")
# ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰ ์ฒดํฌ
if torch.cuda.is_available():
print(f"์ž…๋ ฅ ํ…์„œ ์ƒ์„ฑ ํ›„ CUDA ๋ฉ”๋ชจ๋ฆฌ: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=min(max_new_tokens, 2048), # ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜ ์ œํ•œ
do_sample=True,
temperature=temperature,
eos_token_id=[255001],
)
# ์ƒ์„ฑ ์‹œ์ž‘ ์ „ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
clear_cuda_memory()
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield "", history + [[message, buffer]]
# ์ƒ์„ฑ ์™„๋ฃŒ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
clear_cuda_memory()
except Exception as e:
error_message = f"์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
print(f"Stream chat ์˜ค๋ฅ˜: {error_message}")
# ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ์—๋„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
clear_cuda_memory()
yield "", history + [[message, error_message]]
def create_demo():
with gr.Blocks(css=CSS) as demo:
with gr.Column(elem_classes="markdown-style"):
gr.Markdown("""
# ๐Ÿค– OnDevice AI RAG
#### ๐Ÿ“Š RAG: Upload and Analyze Files (TXT, CSV, PDF, Parquet files)
Upload your files for data analysis and learning
""")
chatbot = gr.Chatbot(
value=[],
height=600,
label="GiniGEN AI Assistant",
elem_classes="chat-container"
)
with gr.Row(elem_classes="input-container"):
with gr.Column(scale=1, min_width=70):
file_upload = gr.File(
type="filepath",
elem_classes="file-upload-icon",
scale=1,
container=True,
interactive=True,
show_label=False
)
with gr.Column(scale=3):
msg = gr.Textbox(
show_label=False,
placeholder="Type your message here... ๐Ÿ’ญ",
container=False,
elem_classes="input-textbox",
scale=1
)
with gr.Column(scale=1, min_width=70):
send = gr.Button(
"Send",
elem_classes="send-button custom-button",
scale=1
)
with gr.Column(scale=1, min_width=70):
clear = gr.Button(
"Clear",
elem_classes="clear-button custom-button",
scale=1
)
with gr.Accordion("๐ŸŽฎ Advanced Settings", open=False):
with gr.Row():
with gr.Column(scale=1):
temperature = gr.Slider(
minimum=0, maximum=1, step=0.1, value=0.8,
label="Creativity Level ๐ŸŽจ"
)
max_new_tokens = gr.Slider(
minimum=128, maximum=8000, step=1, value=4000,
label="Maximum Token Count ๐Ÿ“"
)
with gr.Column(scale=1):
top_p = gr.Slider(
minimum=0.0, maximum=1.0, step=0.1, value=0.8,
label="Diversity Control ๐ŸŽฏ"
)
top_k = gr.Slider(
minimum=1, maximum=20, step=1, value=20,
label="Selection Range ๐Ÿ“Š"
)
penalty = gr.Slider(
minimum=0.0, maximum=2.0, step=0.1, value=1.0,
label="Repetition Penalty ๐Ÿ”„"
)
gr.Examples(
examples=[
["Please analyze this code and suggest improvements:\ndef fibonacci(n):\n if n <= 1: return n\n return fibonacci(n-1) + fibonacci(n-2)"],
["Please analyze this data and provide insights:\nAnnual Revenue (Million)\n2019: 1200\n2020: 980\n2021: 1450\n2022: 2100\n2023: 1890"],
["Please solve this math problem step by step: 'When a circle's area is twice that of its inscribed square, find the relationship between the circle's radius and the square's side length.'"],
["Please analyze this marketing campaign's ROI and suggest improvements:\nTotal Cost: $50,000\nReach: 1M users\nClick Rate: 2.3%\nConversion Rate: 0.8%\nAverage Purchase: $35"],
],
inputs=msg
)
def clear_conversation():
global current_file_context
current_file_context = None
return [], None, "Start a new conversation..."
# Event bindings
msg.submit(
stream_chat,
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
outputs=[msg, chatbot]
)
send.click(
stream_chat,
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
outputs=[msg, chatbot]
)
file_upload.change(
fn=init_msg,
outputs=msg,
queue=False
).then(
fn=stream_chat,
inputs=[msg, chatbot, file_upload, temperature, max_new_tokens, top_p, top_k, penalty],
outputs=[msg, chatbot],
queue=True
)
# Clear button event binding
clear.click(
fn=clear_conversation,
outputs=[chatbot, file_upload, msg],
queue=False
)
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch()