File size: 5,193 Bytes
691f3d7 51a7d9e 9e9c8af bd34f0b 9e9c8af edb9e8a 51a7d9e 15f4d32 9e9c8af da8a347 5ba2e07 248b38e 9e9c8af 51a7d9e beb5c13 51a7d9e 9e9c8af beb5c13 bd34f0b 9e9c8af bd34f0b 51a7d9e 2024746 8830af9 133be07 bd34f0b 9e9c8af 51a7d9e 3b9cb87 92e7c12 3b9cb87 bd34f0b 639e063 edb9e8a 92e7c12 edb9e8a bd34f0b 51a7d9e 922f584 51a7d9e edb9e8a 51a7d9e edb9e8a 51a7d9e a3e36c2 51a7d9e 781217c 51a7d9e 579ca70 51a7d9e ef2eb9e 51a7d9e bd34f0b 28514c1 bd34f0b 51a7d9e 9e9c8af d217d72 57f7053 51a7d9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread
model_id = "team-hatakeyama-phase2/Tanuki-8x8B-dpo-v1.0-AWQ"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="sequential",
trust_remote_code=True,
#offload_folder="offload", # オフロードフォルダの指定
#offload_state_dict=True # 必要に応じてstate_dictをオフロード
)
TITLE = "<h1><center>Tanuki-8x8B-dpo-v1.0-AWQ Chat webui</center></h1>"
DESCRIPTION = """
<h3>MODEL: <a href="https://huggingface.co/weblab-GENIAC/Tanuki-8x8B-dpo-v1.0">Tanuki-8x8B-dpo-v1.0</a></h3>
<center>
<p>This model is designed for conversational interactions.</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
.chatbox .messages .message.user {
background-color: #e1f5fe;
}
.chatbox .messages .message.bot {
background-color: #eeeeee;
}
"""
@spaces.GPU(duration=120)
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int, top_p: float, top_k: int, penalty: float):
print(f'Message: {message}')
print(f'History: {history}')
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message})
inputs = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt")
input_ids = inputs['input_ids'].to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
top_k=top_k,
top_p=top_p,
repetition_penalty=penalty,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
eos_token_id=[2],
)
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=500)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
theme="soft",
retry_btn=None,
undo_btn="Delete Previous",
clear_btn="Clear",
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=1.0,
step=0.1,
value=0.8,
label="top_p",
render=False,
),
gr.Slider(
minimum=1,
maximum=20,
step=1,
value=20,
label="top_k",
render=False,
),
gr.Slider(
minimum=0.0,
maximum=2.0,
step=0.1,
value=1.2,
label="Repetition penalty",
render=False,
),
],
examples=[
["Explain Deep Learning as a pirate."],
["Give me five ideas for a child's summer science project."],
["Provide advice for writing a script for a puzzle game."],
["Create a tutorial for building a breakout game using markdown."],
["超能力を持つ主人公のSF物語のシナリオを考えてください。伏線の設定、テーマやログラインを理論的に使用してください"],
["子供の夏休みの自由研究のための、5つのアイデアと、その手法を簡潔に教えてください。"],
["パズルゲームのスクリプト作成のためにアドバイスお願いします"],
["マークダウン記法にて、ブロック崩しのゲーム作成の教科書作成してください"],
["お笑いのトンチ大会のお題を考えてください"],
["日本語の慣用句、ことわざについての試験問題を考えてください"],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()
|