File size: 4,974 Bytes
44936b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from datetime import datetime

description = """
[🦎Salamandra-7b-instruct](https://huggingface.co/BSC-LT/salamandra-7b-instruct) is a Transformer-based decoder-only language model that has been pre-trained on 7.8 trillion tokens of highly curated data. 
The pre-training corpus contains text in 35 European languages and code. This model has been instruction-tuned and can be used as a general-purpose assistant.
"""

join_us = """
## Join us:
🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 
[![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) 
On 🤗Huggingface: [MultiTransformer](https://huggingface.co/MultiTransformer) 
On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)
🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""

model_id = "BSC-LT/salamandra-7b-instruct"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

def generate_text(system_prompt, prompt, temperature, max_new_tokens, top_p, repetition_penalty):
    date_string = datetime.today().strftime('%Y-%m-%d')
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt}
    ]
    
    chat_prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        date_string=date_string
    )
    
    inputs = tokenizer(chat_prompt, return_tensors="pt").to(device)
    
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text.split("assistant\n")[-1].strip()

def update_output(system_prompt, prompt, temperature, max_new_tokens, top_p, repetition_penalty):
    return generate_text(system_prompt, prompt, temperature, max_new_tokens, top_p, repetition_penalty)

with gr.Blocks() as demo:
    gr.Markdown("# 🦎 Welcome to Tonic's Salamandra-7b-instruct Demo")
    
    with gr.Row():
        with gr.Column(scale=1):
            with gr.Group():
                gr.Markdown(description)
        with gr.Column(scale=1):
            with gr.Group():
                gr.Markdown(join_us)
    
    with gr.Row():
        with gr.Column(scale=1):
            system_prompt = gr.Textbox(
                lines=3, 
                label="🖥️ System Prompt", 
                value="You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."
            )
            prompt = gr.Textbox(lines=5, label="🙋‍♂️ User Prompt")
            generate_button = gr.Button("Generate with 🦎 Salamandra-7b-instruct")
            
            with gr.Accordion("🧪 Parameters", open=False):
                temperature = gr.Slider(0.0, 1.0, value=0.7, label="🌡️ Temperature")
                max_new_tokens = gr.Slider(1, 1000, value=200, step=1, label="🔢 Max New Tokens")
                top_p = gr.Slider(0.0, 1.0, value=0.95, label="⚛️ Top P")
                repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, label="🔁 Repetition Penalty")
        
        with gr.Column(scale=1):
            output = gr.Textbox(lines=10, label="🦎 Salamandra-7b-instruct Output")
    
    generate_button.click(
        update_output,
        inputs=[system_prompt, prompt, temperature, max_new_tokens, top_p, repetition_penalty],
        outputs=output
    )
    
    gr.Examples(
        examples=[
            ["At what temperature does water boil?"],
            ["Explain the concept of artificial intelligence in simple terms."],
            ["Write a short poem about the beauty of nature."],
            ["What are the main differences between Python and JavaScript?"],
            ["Describe the process of photosynthesis in plants."]
        ],
        inputs=prompt,
        outputs=prompt,
        label="Example Prompts"
    )

if __name__ == "__main__":
    demo.launch()