RAGOndevice / app.py
cutechicken's picture
Update app.py
6360699 verified
raw
history blame
8.87 kB
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread
import random
from datasets import load_dataset
import gc
# GPU ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ
torch.cuda.empty_cache()
gc.collect()
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024"
MODELS = os.environ.get("MODELS")
MODEL_NAME = MODEL_ID.split("/")[-1]
TITLE = "<h1><center>์˜จ๋””๋ฐ”์ด์Šค AI(Open LLM ๋ชจ๋ธ)</center></h1>"
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
.chatbox .messages .message.user {
background-color: #e1f5fe;
}
.chatbox .messages .message.bot {
background-color: #eeeeee;
}
"""
# ๋””๋ฐ”์ด์Šค ์„ค์ •
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ with ์—๋Ÿฌ ์ฒ˜๋ฆฌ
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
except Exception as e:
print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
raise
# ๋ฐ์ดํ„ฐ์…‹ ๋กœ๋“œ with ์—๋Ÿฌ ์ฒ˜๋ฆฌ
try:
dataset = load_dataset("elyza/ELYZA-tasks-100")
print(dataset)
split_name = "train" if "train" in dataset else "test"
examples_list = list(dataset[split_name])
examples = random.sample(examples_list, 50)
example_inputs = [[example['input']] for example in examples]
except Exception as e:
print(f"๋ฐ์ดํ„ฐ์…‹ ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
examples = []
example_inputs = []
def error_handler(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"Error in {func.__name__}: {str(e)}")
return "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค. ์ž ์‹œ ํ›„ ๋‹ค์‹œ ์‹œ๋„ํ•ด์ฃผ์„ธ์š”."
return wrapper
@error_handler
@spaces.GPU
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
try:
print(f'message is - {message}')
print(f'history is - {history}')
# GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
torch.cuda.empty_cache()
conversation = []
for prompt, answer in history:
conversation.extend([
{"role": "user", "content": prompt},
{"role": "assistant", "content": answer}
])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(input_ids, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
inputs,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=[255001],
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
except Exception as e:
print(f"Stream chat error: {str(e)}")
yield "์ฃ„์†กํ•ฉ๋‹ˆ๋‹ค. ์‘๋‹ต ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค."
finally:
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
torch.cuda.empty_cache()
gc.collect()
chatbot = gr.Chatbot(height=500)
CSS = """
/* ์ „์ฒด ํŽ˜์ด์ง€ ์Šคํƒ€์ผ๋ง */
body {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
min-height: 100vh;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
/* ๋ฉ”์ธ ์ปจํ…Œ์ด๋„ˆ */
.container {
max-width: 1200px;
margin: 0 auto;
padding: 2rem;
background: rgba(255, 255, 255, 0.95);
border-radius: 20px;
box-shadow: 0 20px 40px rgba(0, 0, 0, 0.1);
backdrop-filter: blur(10px);
transform: perspective(1000px) translateZ(0);
transition: all 0.3s ease;
}
/* ์ œ๋ชฉ ์Šคํƒ€์ผ๋ง */
h1 {
color: #2d3436;
font-size: 2.5rem;
text-align: center;
margin-bottom: 2rem;
text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.1);
transform: perspective(1000px) translateZ(20px);
}
h3 {
text-align: center;
color: #2d3436;
font-size: 1.5rem;
margin: 1rem 0;
}
/* ์ฑ„ํŒ…๋ฐ•์Šค ์Šคํƒ€์ผ๋ง */
.chatbox {
background: white;
border-radius: 15px;
box-shadow: 0 8px 32px rgba(31, 38, 135, 0.15);
backdrop-filter: blur(4px);
border: 1px solid rgba(255, 255, 255, 0.18);
padding: 1rem;
margin: 1rem 0;
transform: translateZ(0);
transition: all 0.3s ease;
}
/* ๋ฉ”์‹œ์ง€ ์Šคํƒ€์ผ๋ง */
.chatbox .messages .message.user {
background: linear-gradient(145deg, #e1f5fe, #bbdefb);
border-radius: 15px;
padding: 1rem;
margin: 0.5rem;
box-shadow: 5px 5px 15px rgba(0, 0, 0, 0.05);
transform: translateZ(10px);
animation: messageIn 0.3s ease-out;
}
.chatbox .messages .message.bot {
background: linear-gradient(145deg, #f5f5f5, #eeeeee);
border-radius: 15px;
padding: 1rem;
margin: 0.5rem;
box-shadow: 5px 5px 15px rgba(0, 0, 0, 0.05);
transform: translateZ(10px);
animation: messageIn 0.3s ease-out;
}
/* ๋ฒ„ํŠผ ์Šคํƒ€์ผ๋ง */
.duplicate-button {
background: linear-gradient(145deg, #24292e, #1a1e22) !important;
color: white !important;
border-radius: 100vh !important;
padding: 0.8rem 1.5rem !important;
box-shadow: 3px 3px 10px rgba(0, 0, 0, 0.2) !important;
transition: all 0.3s ease !important;
border: none !important;
cursor: pointer !important;
}
.duplicate-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.3) !important;
}
/* ์ž…๋ ฅ ํ•„๋“œ ์Šคํƒ€์ผ๋ง */
"""
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
theme="soft",
additional_inputs_accordion=gr.Accordion(label="โš™๏ธ ์˜ต์…˜", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.3,
label="์˜จ๋„",
render=False,
),
gr.Slider(
minimum=128,
maximum=8000,
step=1,
value=4000,
label="์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.8,
label="์ƒ์œ„ ํ™•๋ฅ ",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="์ƒ์œ„ K",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.0,
label="๋ฐ˜๋ณต ํŒจ๋„ํ‹ฐ",
render=False,
),
],
examples=[
["์•„์ด์˜ ์—ฌ๋ฆ„๋ฐฉํ•™ ๊ณผํ•™ ํ”„๋กœ์ ํŠธ๋ฅผ ์œ„ํ•œ 5๊ฐ€์ง€ ์•„์ด๋””์–ด๋ฅผ ์ฃผ์„ธ์š”."],
["๋งˆํฌ๋‹ค์šด์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ธŒ๋ ˆ์ดํฌ์•„์›ƒ ๊ฒŒ์ž„ ๋งŒ๋“ค๊ธฐ ํŠœํ† ๋ฆฌ์–ผ์„ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”."],
["์ดˆ๋Šฅ๋ ฅ์„ ๊ฐ€์ง„ ์ฃผ์ธ๊ณต์˜ SF ์ด์•ผ๊ธฐ ์‹œ๋‚˜๋ฆฌ์˜ค๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”. ๋ณต์„  ์„ค์ •, ํ…Œ๋งˆ์™€ ๋กœ๊ทธ๋ผ์ธ์„ ๋…ผ๋ฆฌ์ ์œผ๋กœ ์‚ฌ์šฉํ•ด์ฃผ์„ธ์š”"],
["์•„์ด์˜ ์—ฌ๋ฆ„๋ฐฉํ•™ ์ž์œ ์—ฐ๊ตฌ๋ฅผ ์œ„ํ•œ 5๊ฐ€์ง€ ์•„์ด๋””์–ด์™€ ๊ทธ ๋ฐฉ๋ฒ•์„ ๊ฐ„๋‹จํžˆ ์•Œ๋ ค์ฃผ์„ธ์š”."],
["ํผ์ฆ ๊ฒŒ์ž„ ์Šคํฌ๋ฆฝํŠธ ์ž‘์„ฑ์„ ์œ„ํ•œ ์กฐ์–ธ ๋ถ€ํƒ๋“œ๋ฆฝ๋‹ˆ๋‹ค"],
["๋งˆํฌ๋‹ค์šด ํ˜•์‹์œผ๋กœ ๋ธ”๋ก ๊นจ๊ธฐ ๊ฒŒ์ž„ ์ œ์ž‘ ๊ต๊ณผ์„œ๋ฅผ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”"],
["์‹ค๋ฒ„ ๅทๆŸณ๋ฅผ ์ƒ๊ฐํ•ด์ฃผ์„ธ์š”"],
["์ผ๋ณธ์–ด ๊ด€์šฉ๊ตฌ, ์†๋‹ด์— ๊ด€ํ•œ ์‹œํ—˜ ๋ฌธ์ œ๋ฅผ ๋งŒ๋“ค์–ด์ฃผ์„ธ์š”"],
["๋„๋ผ์—๋ชฝ์˜ ๋“ฑ์žฅ์ธ๋ฌผ์„ ์•Œ๋ ค์ฃผ์„ธ์š”"],
["์˜ค์ฝ”๋…ธ๋ฏธ์•ผํ‚ค ๋งŒ๋“œ๋Š” ๋ฐฉ๋ฒ•์„ ์•Œ๋ ค์ฃผ์„ธ์š”"],
["๋ฌธ์ œ 9.11๊ณผ 9.9 ์ค‘ ์–ด๋Š ๊ฒƒ์ด ๋” ํฐ๊ฐ€์š”? step by step์œผ๋กœ ๋…ผ๋ฆฌ์ ์œผ๋กœ ์ƒ๊ฐํ•ด์ฃผ์„ธ์š”."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()