Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel | |
from argparse import ArgumentParser | |
def get_args(): | |
parser = ArgumentParser() | |
parser.add_argument("--port", type=int, default=7860) | |
parser.add_argument("--device", type=str, default='cuda', help='Device to run the model on') | |
parser.add_argument("--model", type=str, default='havenhq/mamba-chat', help='Model to use') | |
parser.add_argument( | |
"--share", | |
action="store_true", | |
default=False, | |
help="share your instance publicly through gradio", | |
) | |
try: | |
args = parser.parse_args() | |
except: | |
parser.print_help() | |
exit(0) | |
return args | |
if __name__ == "__main__": | |
args = get_args() | |
device = args.device | |
model_name = args.model | |
eos = "<|endoftext|>" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.eos_token = eos | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.chat_template = AutoTokenizer.from_pretrained( | |
"HuggingFaceH4/zephyr-7b-beta" | |
).chat_template | |
model = MambaLMHeadModel.from_pretrained( | |
model_name, device=device, dtype=torch.float16 | |
) | |
def chat_with_mamba( | |
user_message, | |
history: list[list[str]], | |
temperature: float = 0.9, | |
top_p: float = 0.7, | |
max_length: int = 2000, | |
): | |
history_dict: list[dict[str, str]] = [] | |
for user_m, assistant_m in history: | |
history_dict.append(dict(role="user", content=user_m)) | |
history_dict.append(dict(role="assistant", content=assistant_m)) | |
history_dict.append(dict(role="user", content=user_message)) | |
input_ids = tokenizer.apply_chat_template( | |
history_dict, return_tensors="pt", add_generation_prompt=True | |
).to(device) | |
out = model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
temperature=temperature, | |
top_p=top_p, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
decoded = tokenizer.batch_decode(out) | |
assistant_message = ( | |
decoded[0].split("<|assistant|>\n")[-1].replace(eos, "") | |
) | |
return assistant_message | |
demo = gr.ChatInterface( | |
fn=chat_with_mamba, | |
# examples=[ | |
# "Explain what is state space model", | |
# "Nice to meet you!", | |
# "'Mamba is way better than ChatGPT.' Is this statement correct?", | |
# ], | |
additional_inputs=[ | |
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.9, label="temperature"), | |
gr.Slider(minimum=0, maximum=1, step=0.1, value=0.7, label="top_p"), | |
gr.Number(value=2000, label="max_length"), | |
], | |
title="Mamba Chat", | |
) | |
demo.launch(server_port=args.port, share=args.share) | |