from huggingface_hub import InferenceClient import gradio as gr inference_client = InferenceClient("google/gemma-7b-it") # format prompt as per the chat template on the official model page: https://huggingface.co/google/gemma-7b-it def format_prompt(input_text, history): prompt = "" if history: for previous_prompt, response in history: prompt += f"""user {previous_prompt} model {response}""" prompt += f"""user {input_text} model""" return prompt def generate(prompt, history, temperature=0.95, max_new_tokens=512, top_p=0.9, repetition_penalty=1.0): if not history: history = [] temperature = float(temperature) top_p = float(top_p) kwargs = dict( temperature=temperature, max_new_tokens=max_new_tokens, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, ) formatted_prompt = format_prompt(prompt, history) response = inference_client.text_generation(formatted_prompt, **kwargs, stream=True, details=True, return_full_text=False) output = "" for chunk in response: output += chunk.token.text yield output return output additional_inputs=[ gr.Slider( label="Temperature", value=0.85, minimum=0.1, maximum=1.0, step=0.05, interactive=True, info="A higher value (> 1) will generate randomness and variability in the model response", ), gr.Slider( label="Max new tokens", value=512, minimum=128, maximum=1048, step=64, interactive=True, info="The maximum numbers of new tokens generated in the model response", ), gr.Slider( label="Top-p (random sampling)", value=0.80, minimum=0.1, maximum=1, step=0.05, interactive=True, info="A smaller value generates the highest probability tokens, a higher value (~ 1) allows low-probability tokens", ), gr.Slider( label="Repetition penalty", value=1.0, minimum=0.5, maximum=2.0, step=0.05, interactive=True, info="Penalizes repeated tokens in model response", ) ] chatbot = gr.Chatbot(height=500) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.HTML("

Google Gemma 7B IT

") gr.ChatInterface( generate, chatbot=chatbot, retry_btn=None, undo_btn=None, clear_btn="Clear", description="This chatbot is using a Hugging Face Inference Client for the google/gemma-7b-it model.", additional_inputs=additional_inputs, examples=[["Explain artificial intelligence in a few lines."]] ) demo.queue().launch()