RAGOndevice / app.py
cutechicken's picture
Update app.py
0de5bb6 verified
raw
history blame
8.95 kB
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import random
from datasets import load_dataset
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
# GPU ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
torch.cuda.empty_cache()
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
MODELS = os.environ.get("MODELS")
MODEL_NAME = MODEL_ID.split("/")[-1]
# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# ์œ„ํ‚คํ”ผ๋””์•„ ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ
wiki_dataset = load_dataset("lcw99/wikipedia-korean-20240501-1million-qna")
print("Wikipedia dataset loaded:", wiki_dataset)
# TF-IDF ๋ฒกํ„ฐ๋ผ์ด์ € ์ดˆ๊ธฐํ™” ๋ฐ ํ•™์Šต
print("TF-IDF ๋ฒกํ„ฐํ™” ์‹œ์ž‘...")
questions = wiki_dataset['train']['question'][:10000] # ์ฒ˜์Œ 10000๊ฐœ๋งŒ ์‚ฌ์šฉ
vectorizer = TfidfVectorizer(max_features=1000)
question_vectors = vectorizer.fit_transform(questions)
print("TF-IDF ๋ฒกํ„ฐํ™” ์™„๋ฃŒ")
def find_relevant_context(query, top_k=3):
# ์ฟผ๋ฆฌ ๋ฒกํ„ฐํ™”
query_vector = vectorizer.transform([query])
# ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
similarities = (query_vector * question_vectors.T).toarray()[0]
# ๊ฐ€์žฅ ์œ ์‚ฌํ•œ ์งˆ๋ฌธ๋“ค์˜ ์ธ๋ฑ์Šค
top_indices = np.argsort(similarities)[-top_k:][::-1]
# ๊ด€๋ จ ์ปจํ…์ŠคํŠธ ์ถ”์ถœ
relevant_contexts = []
for idx in top_indices:
if similarities[idx] > 0: # ์œ ์‚ฌ๋„๊ฐ€ 0๋ณด๋‹ค ํฐ ๊ฒฝ์šฐ๋งŒ ํฌํ•จ
relevant_contexts.append({
'question': questions[idx],
'answer': wiki_dataset['train']['answer'][idx],
'similarity': similarities[idx]
})
return relevant_contexts
@spaces.GPU
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
print(f'message is - {message}')
print(f'history is - {history}')
# ๊ด€๋ จ ์ปจํ…์ŠคํŠธ ์ฐพ๊ธฐ
relevant_contexts = find_relevant_context(message)
context_prompt = "\n\n๊ด€๋ จ ์ฐธ๊ณ  ์ •๋ณด:\n"
for ctx in relevant_contexts:
context_prompt += f"Q: {ctx['question']}\nA: {ctx['answer']}\n์œ ์‚ฌ๋„: {ctx['similarity']:.3f}\n\n"
# ๋Œ€ํ™” ํžˆ์Šคํ† ๋ฆฌ ๊ตฌ์„ฑ
conversation = []
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer}
])
# ์ปจํ…์ŠคํŠธ๋ฅผ ํฌํ•จํ•œ ์ตœ์ข… ํ”„๋กฌํ”„ํŠธ ๊ตฌ์„ฑ
final_message = context_prompt + "\nํ˜„์žฌ ์งˆ๋ฌธ: " + message
conversation.append({"role": "user", "content": final_message})
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_ids, return_tensors="pt").to(0)
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=[255001],
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=500)
CSS = """
/* ์ „์ฒด ํŽ˜์ด์ง€ ์Šคํƒ€์ผ๋ง */
body {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
min-height: 100vh;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
/* ๋ฉ”์ธ ์ปจํ…Œ์ด๋„ˆ */
.container {
max-width: 1200px;
margin: 0 auto;
padding: 2rem;
background: rgba(255, 255, 255, 0.95);
border-radius: 20px;
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
backdrop-filter: blur(10px);
transform: perspective(1000px) translateZ(0);
transition: all 0.3s ease;
}
/* ์ œ๋ชฉ ์Šคํƒ€์ผ๋ง */
h1 {
color: #2d3436;
font-size: 2.5rem;
text-align: center;
margin-bottom: 2rem;
text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.1);
transform: perspective(1000px) translateZ(20px);
}
h3 {
text-align: center;
color: #2d3436;
font-size: 1.5rem;
margin: 1rem 0;
}
/* ์ฑ„ํŒ…๋ฐ•์Šค ์Šคํƒ€์ผ๋ง */
.chatbox {
background: white;
border-radius: 15px;
box-shadow: 0 8px 32px rgba(31, 38, 135, 0.15);
backdrop-filter: blur(4px);
border: 1px solid rgba(255, 255, 255, 0.18);
padding: 1rem;
margin: 1rem 0;
transform: translateZ(0);
transition: all 0.3s ease;
}
/* ๋ฉ”์‹œ์ง€ ์Šคํƒ€์ผ๋ง */
.chatbox .messages .message.user {
background: linear-gradient(145deg, #e1f5fe, #bbdefb);
border-radius: 15px;
padding: 1rem;
margin: 0.5rem;
box-shadow: 5px 5px 15px rgba(0, 0, 0, 0.05);
transform: translateZ(10px);
animation: messageIn 0.3s ease-out;
}
.chatbox .messages .message.bot {
background: linear-gradient(145deg, #f5f5f5, #eeeeee);
border-radius: 15px;
padding: 1rem;
margin: 0.5rem;
box-shadow: 5px 5px 15px rgba(0, 0, 0, 0.05);
transform: translateZ(10px);
animation: messageIn 0.3s ease-out;
}
/* ๋ฒ„ํŠผ ์Šคํƒ€์ผ๋ง */
.duplicate-button {
background: linear-gradient(145deg, #24292e, #1a1e22) !important;
color: white !important;
border-radius: 100vh !important;
padding: 0.8rem 1.5rem !important;
box-shadow: 3px 3px 10px rgba(0, 0, 0, 0.2) !important;
transition: all 0.3s ease !important;
border: none !important;
cursor: pointer !important;
}
.duplicate-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.3) !important;
}
/* ์ž…๋ ฅ ํ•„๋“œ ์Šคํƒ€์ผ๋ง */
"""
with gr.Blocks(css=CSS) as demo:
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
theme="soft",
additional_inputs_accordion=gr.Accordion(label="โš™๏ธ ์˜ต์…˜", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="์˜จ๋„",
render=False,
),
gr.Slider(
minimum=128,
maximum=8000,
step=1,
value=4000,
label="์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.8,
label="์ƒ์œ„ ํ™•๋ฅ ",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="์ƒ์œ„ K",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="๋ฐ˜๋ณต ํŒจ๋„ํ‹ฐ",
render=False,
),
],
examples=[
["ํ•œ๊ตญ์˜ ์ „ํ†ต ์ ˆ๊ธฐ์™€ 24์ ˆ๊ธฐ์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
["์šฐ๋ฆฌ๋‚˜๋ผ ์ „ํ†ต ์Œ์‹ ์ค‘ ๊ฑด๊ฐ•์— ์ข‹์€ ๋ฐœํšจ์Œ์‹ 5๊ฐ€์ง€๋ฅผ ์ถ”์ฒœํ•˜๊ณ  ๊ทธ ํšจ๋Šฅ์„ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ ๋Œ€ํ‘œ์ ์ธ ์‚ฐ๋“ค์„ ์†Œ๊ฐœํ•˜๊ณ , ๊ฐ ์‚ฐ์˜ ํŠน์ง•๊ณผ ๋“ฑ์‚ฐ ์ฝ”์Šค๋ฅผ ์ถ”์ฒœํ•ด์ฃผ์„ธ์š”."],
["์‚ฌ๋ฌผ๋†€์ด์˜ ์•…๊ธฐ ๊ตฌ์„ฑ๊ณผ ์žฅ๋‹จ์— ๋Œ€ํ•ด ์ดˆ๋ณด์ž๋„ ์ดํ•ดํ•˜๊ธฐ ์‰ฝ๊ฒŒ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ ์ „ํ†ต ๊ฑด์ถ•๋ฌผ์— ๋‹ด๊ธด ๊ณผํ•™์  ์›๋ฆฌ๋ฅผ ํ˜„๋Œ€์  ๊ด€์ ์—์„œ ๋ถ„์„ํ•ด์ฃผ์„ธ์š”."],
["์กฐ์„ ์‹œ๋Œ€ ๊ณผ๊ฑฐ ์‹œํ—˜ ์ œ๋„๋ฅผ ํ˜„๋Œ€์˜ ์ž…์‹œ ์ œ๋„์™€ ๋น„๊ตํ•˜์—ฌ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ 4๋Œ€ ๊ถ๊ถ์„ ๋น„๊ตํ•˜์—ฌ ๊ฐ๊ฐ์˜ ํŠน์ง•๊ณผ ์—ญ์‚ฌ์  ์˜๋ฏธ๋ฅผ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ ์ „ํ†ต ๋†€์ด๋ฅผ ํ˜„๋Œ€์ ์œผ๋กœ ์žฌํ•ด์„ํ•˜์—ฌ ์‹ค๋‚ด์—์„œ ํ•  ์ˆ˜ ์žˆ๋Š” ๋ฐฉ๋ฒ•์„ ์ œ์•ˆํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ธ€ ์ฐฝ์ œ ๊ณผ์ •๊ณผ ํ›ˆ๋ฏผ์ •์Œ์˜ ๊ณผํ•™์  ์›๋ฆฌ๋ฅผ ์ƒ์„ธํžˆ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ ์ „ํ†ต ์ฐจ ๋ฌธํ™”์— ๋Œ€ํ•ด ์„ค๋ช…ํ•˜๊ณ , ๊ณ„์ ˆ๋ณ„๋กœ ์–ด์šธ๋ฆฌ๋Š” ์ „ํ†ต์ฐจ๋ฅผ ์ถ”์ฒœํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ ์ „ํ†ต ์˜๋ณต์ธ ํ•œ๋ณต์˜ ๊ตฌ์กฐ์™€ ํŠน์ง•์„ ๊ณผํ•™์ , ๋ฏธํ•™์  ๊ด€์ ์—์„œ ๋ถ„์„ํ•ด์ฃผ์„ธ์š”."],
["ํ•œ๊ตญ์˜ ์ „ํ†ต ๊ฐ€์˜ฅ ๊ตฌ์กฐ๋ฅผ ๊ธฐํ›„์™€ ํ™˜๊ฒฝ ๊ด€์ ์—์„œ ๋ถ„์„ํ•˜๊ณ , ํ˜„๋Œ€ ๊ฑด์ถ•์— ์ ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ์š”์†Œ๋ฅผ ์ œ์•ˆํ•ด์ฃผ์„ธ์š”."]
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()