File size: 4,526 Bytes
51a7d9e 9eefdf9 51a7d9e edb9e8a 51a7d9e 063316d 99a7a45 51a7d9e 99a7a45 51a7d9e 27d1730 51a7d9e 3bc2ef0 063316d 22f5f54 f2cc9dc 22f5f54 f2cc9dc 22f5f54 3bc2ef0 51a7d9e 9eefdf9 f663115 9a43acc 5312535 fd6304d 51a7d9e 6f1ee3e 51a7d9e fd6304d 487032c 99a7a45 6f1ee3e 030c23d 9eefdf9 639e063 edb9e8a 9eefdf9 5312535 030c23d f663115 51a7d9e 5312535 9eefdf9 51a7d9e 9eefdf9 0961bc7 9eefdf9 9a43acc 50b348c 9a43acc 9eefdf9 99a7a45 51a7d9e 063316d 51a7d9e 5312535 063316d 51a7d9e 063316d 5312535 51a7d9e 99a7a45 51a7d9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteriaList, StoppingCriteria
import os
from threading import Thread
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_LIST = "THUDM/LongWriter-glm4-9b"
#MODELS = os.environ.get("MODELS")
#MODEL_NAME = MODELS.split("/")[-1]
TITLE = "<h1><center>GLM SPACE</center></h1>"
PLACEHOLDER = f'<h3><center>LongWriter-glm4-9b is trained based on glm-4-9b, and is capable of generating 10,000+ words at once.</center></h3>'
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
model = AutoModelForCausalLM.from_pretrained(
"THUDM/LongWriter-glm4-9b",
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
).eval()
tokenizer = AutoTokenizer.from_pretrained("THUDM/LongWriter-glm4-9b",trust_remote_code=True)
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
# stop_ids = model.config.eos_token_id
stop_ids = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
@spaces.GPU()
def stream_chat(message: str, history: list, temperature: float, max_new_tokens: int):
print(f'message is - {message}')
print(f'history is - {history}')
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
#conversation.append({"role": "user", "content": message})
print(f"Conversation is -\n{conversation}")
stop = StopOnTokens()
input_ids = tokenizer.build_chat_input(message, history=conversation, role='user').input_ids.to(model.device)
#input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
tokenizer.get_command("<|observation|>")]
generate_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
streamer=streamer,
do_sample=True,
top_k=1,
temperature=temperature,
repetition_penalty=1,
stopping_criteria=StoppingCriteriaList([stop]),
eos_token_id=eos_token_id,
)
#gen_kwargs = {**input_ids, **generate_kwargs}
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_token in streamer:
if new_token and '<|user|>' not in new_token:
buffer += new_token
yield buffer
chatbot = gr.Chatbot(height=600, placeholder = PLACEHOLDER)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.5,
label="Temperature",
render=False,
),
gr.Slider(
minimum=1024,
maximum=32768,
step=1,
value=4096,
label="Max New Tokens",
render=False,
),
],
examples=[
["Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."],
["What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."],
["Tell me a random fun fact about the Roman Empire."],
["Show me a code snippet of a website's sticky header in CSS and JavaScript."],
],
cache_examples=False,
)
if __name__ == "__main__":
demo.launch()
|