Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig | |
import torch | |
from string import Template | |
from huggingface_hub import login | |
# Hugging Face์ ๋ก๊ทธ์ธ (ํ๊ฒฝ ๋ณ์์์ Access Token ๊ฐ์ ธ์ค๊ธฐ) | |
login(os.getenv("ACCESS_TOKEN")) # ACCESS_TOKEN์ ํ๊ฒฝ ๋ณ์์์ ๋ถ๋ฌ์ด | |
# ํ๋กฌํํธ ํ ํ๋ฆฟ ์ค์ | |
prompt_template = Template("Human: ${inst} </s> Assistant: ") | |
# ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ก๋ | |
model_name = "meta-llama/Llama-3.2-1b-instruct" # ๋ชจ๋ธ ๊ฒฝ๋ก | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="cpu").eval() | |
# ์์ฑ ์ค์ (Gradio UI์์ ์ ์ดํ ์ ์๋ ๋ณ์๋ค) | |
default_generation_config = GenerationConfig( | |
temperature=0.1, | |
top_k=30, | |
top_p=0.5, | |
do_sample=True, | |
num_beams=1, | |
repetition_penalty=1.1, | |
min_new_tokens=10, | |
max_new_tokens=30 | |
) | |
# ์๋ต ์์ฑ ํจ์ | |
def respond(message, history, system_message, max_tokens, temperature, top_p): | |
# ์์ฑ ์ค์ | |
generation_config = GenerationConfig( | |
**default_generation_config.to_dict() # ๊ธฐ๋ณธ ์ค์ ๊ณผ ๋ณํฉ | |
) | |
generation_config.max_new_tokens = max_tokens # max_tokens ๋ฐ๋ก ์ค์ | |
generation_config.temperature = temperature # temperature ๋ฐ๋ก ์ค์ | |
generation_config.top_p = top_p | |
# ๋ํ ํ์คํ ๋ฆฌ์ ์์คํ ๋ฉ์์ง๋ฅผ ํฌํจํ ํ๋กฌํํธ ๊ตฌ์ฑ | |
prompt = prompt_template.safe_substitute({"inst": system_message}) | |
for val in history: | |
if val[0]: | |
prompt += f"Human: {val[0]} </s> Assistant: {val[1]} </s> " | |
prompt += f"Human: {message} </s> Assistant: " | |
# ๋ชจ๋ธ ์ ๋ ฅ ์์ฑ | |
inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device) | |
response_ids = model.generate( | |
**inputs, | |
generation_config=generation_config, | |
eos_token_id=tokenizer.eos_token_id, # ์ข ๋ฃ ํ ํฐ ์ค์ | |
pad_token_id=tokenizer.eos_token_id # pad_token_id๋ ์ข ๋ฃ ํ ํฐ์ผ๋ก ์ค์ | |
) | |
# ๋ชจ๋ธ ์๋ต ๋์ฝ๋ฉ | |
response_text = tokenizer.decode(response_ids[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
# ์ค์๊ฐ ์๋ต์ ์ํ ๋ถ๋ถ์ ํ ์คํธ ๋ฐํ | |
response = "" | |
for token in response_text: | |
response += token | |
yield response | |
# Gradio Chat Interface ์ค์ | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox(value="You are a friendly and knowledgeable assistant who can discuss a wide range of topics related to music, including genres, artists, albums, instruments, and music history.", label="System message"), | |
gr.Slider(minimum=1, maximum=2048, value=30, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=4.0, value=0.1, step=0.1, label="Temperature"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Top-p (nucleus sampling)"), | |
], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |