#!/usr/bin/env python import os from threading import Thread from typing import Iterator import gradio as gr import spaces import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer HF_TOKEN = os.environ['HF_TOKEN'] DESCRIPTION = """# đ GEITje-7B-chat đ ## Een groot open Nederlands taalmodel [_Coming soon_](https://github.com/Rijgersberg/GEITje)""" if not torch.cuda.is_available(): DESCRIPTION += "\n
Running on CPU đĨļ This demo does not work on CPU.
" MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096")) if torch.cuda.is_available(): model_id = "Rijgersberg/GEITje-7B-chat" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto", token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(model_id, token=HF_TOKEN) @spaces.GPU def generate( message: str, chat_history: list[tuple[str, str]], max_new_tokens: int = 1024, temperature: float = 0.06, top_p: float = 0.95, top_k: int = 40, repetition_penalty: float = 1.2, ) -> Iterator[str]: conversation = [] for user, assistant in chat_history: conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}]) conversation.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) chat_interface = gr.ChatInterface( fn=generate, chatbot=gr.Chatbot(height=400), additional_inputs=[ gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ), gr.Slider( label="Temperature", minimum=0., maximum=1.2, step=0.05, value=0.2, ), gr.Slider( label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, ), gr.Slider( label="Top-k", minimum=1, maximum=1000, step=1, value=50, ), gr.Slider( label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2, ), ], examples=[ ["""Welk woord hoort er niet in dit rijtje thuis: "auto, vliegtuig, geit, bus"?"""], ["Schrijf een nieuwsbericht voor De Speld over de inzet van een kudde geiten door het Nederlands Forensisch Instituut"], ["Wat zijn leuke dingen om te doen als ik een weekendje naar Friesland ga?"], ["Kan je naar de maan fietsen?"], ["Wat is het belang van open source taalmodellen?"], ], title="đ GEITje 7B Chat", description="""Een eerste chatbot op basis van GEITje 7B: een groot open Nederlands taalmodel. Dit is een chatbot gebaseerd op GEITje 7B, gemaakt voor demonstratiedoeleinden. Generatieve taalmodellen maken fouten, controleer daarom feiten voordat je ze overneemt. GEITJje Chat is niet uitgebreid getraind om _gealigned_ te zijn met menselijke waarden. Het is daarom mogelijk dat het problematische output genereert, zeker als het daartoe ge_prompt_ wordt. Voor meer info over GEITJje: zie de đ README op GitHub.""", submit_btn="Genereer", stop_btn="Stop", retry_btn="đ Opnieuw", undo_btn="âŠī¸ Ongedaan maken", clear_btn="đī¸ Wissen", ) with gr.Blocks(css="style.css") as demo: chat_interface.render() if __name__ == "__main__": demo.queue(max_size=20).launch()