File size: 1,653 Bytes
ef831a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from huggingface_hub import InferenceClient
import gradio as gr

inference_client = InferenceClient("google/gemma-7b-it")

# format prompt as per the chat template on the official model page: https://huggingface.co/google/gemma-7b-it 
def format_prompt(input_text, history):
    prompt = ""
    if history:
        for previous_prompt, response in history:
            prompt += f"""<start_of_turn>user
            {previous_prompt}<end_of_turn>
            <start_of_turn>model
            {response}<end_of_turn>"""
    prompt += f"""<start_of_turn>user
    {input_text}<end_of_turn>
    <start_of_turn>model"""    
    return prompt

def generate(prompt, history):
    if not history:
        history = []

    kwargs = dict(
        temperature=1.0,
        max_new_tokens=512,
        top_p=0.9,
        repetition_penalty=1,
        do_sample=True,
    )

    formatted_prompt = format_prompt(prompt, history)

    response = inference_client.text_generation(formatted_prompt, **kwargs, stream=True, details=True, return_full_text=True)
    output = ""

    for chunk in response:
        output += chunk.token.text
        yield output
    return output


chatbot = gr.Chatbot(height=500)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.HTML("<center><h1>Google Gemma 7B IT</h1><center>")
    gr.ChatInterface(
        generate,
        chatbot=chatbot,  
        retry_btn=None,
        undo_btn=None,
        clear_btn="Clear",
        description="This chatbot is using a Hugging Face Inference Client for the google/gemma-7b-it model.",
        examples=[["Explain artificial intelligence in a few lines."]]
    )
demo.queue().launch()