File size: 2,930 Bytes
9c317f9
 
 
 
 
ede06bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c317f9
ede06bd
 
 
7eeefc1
 
9c317f9
7eeefc1
 
 
 
 
 
9c317f9
 
ede06bd
9c317f9
 
ede06bd
9c317f9
 
7eeefc1
ede06bd
7eeefc1
 
 
 
 
 
ede06bd
 
7eeefc1
 
ede06bd
 
9c317f9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# Define models as None to delay loading
model, model_instruct = None, None
tokenizer, tokenizer_instruct = None, None

# Define the response function with lazy loading
def generate_response(input_text, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_beams, length_penalty, model_choice):
    global model, model_instruct, tokenizer, tokenizer_instruct

    # Lazy loading of the selected model
    if model_choice == "Zamba2-7B":
        if model is None:  # Load only if not already loaded
            tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
            model = AutoModelForCausalLM.from_pretrained(
                "Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16
            )
        selected_model = model
        selected_tokenizer = tokenizer
    else:
        if model_instruct is None:  # Load only if not already loaded
            tokenizer_instruct = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B-instruct")
            model_instruct = AutoModelForCausalLM.from_pretrained(
                "Zyphra/Zamba2-7B-instruct", device_map="cuda", torch_dtype=torch.bfloat16
            )
        selected_model = model_instruct
        selected_tokenizer = tokenizer_instruct

    # Tokenize and generate response
    input_ids = selected_tokenizer(input_text, return_tensors="pt").input_ids.to(selected_model.device)
    outputs = selected_model.generate(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        num_beams=num_beams,
        length_penalty=length_penalty,
        num_return_sequences=1
    )
    response = selected_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

# Gradio interface with model selection
demo = gr.Interface(
    fn=generate_response,
    inputs=[
        gr.Textbox(lines=1, placeholder="Enter your input text...", label="Input Text"),
        gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens"),
        gr.Slider(0.1, 1.5, step=0.1, value=0.7, label="Temperature"),
        gr.Slider(1, 100, step=1, value=50, label="Top K"),
        gr.Slider(0.1, 1.0, step=0.1, value=0.9, label="Top P"),
        gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty"),
        gr.Slider(1, 10, step=1, value=5, label="Number of Beams"),
        gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty"),
        gr.Dropdown(["Zamba2-7B", "Zamba2-7B-instruct"], label="Model Choice")
    ],
    outputs=gr.Textbox(label="Generated Response"),
    title="Zamba2-7B Model Selector",
    description="Choose a model and ask a question with customizable parameters."
)

if __name__ == "__main__":
    demo.launch()