File size: 4,033 Bytes
c193104
d76943e
 
 
c193104
 
d76943e
 
 
93e4eee
 
d76943e
 
 
 
 
93e4eee
 
 
 
 
 
 
 
c193104
d76943e
93e4eee
 
 
 
 
d76943e
 
c193104
93e4eee
 
 
d76943e
c193104
93e4eee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d76943e
93e4eee
c193104
93e4eee
 
c193104
93e4eee
d76943e
93e4eee
d76943e
93e4eee
d76943e
 
 
c193104
d76943e
 
 
 
 
 
 
c193104
 
d76943e
 
 
 
 
 
 
 
 
93e4eee
c193104
d76943e
 
93e4eee
 
 
d76943e
 
93e4eee
 
 
d76943e
 
c193104
93e4eee
d76943e
 
c193104
d76943e
93e4eee
c193104
d76943e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import gradio as gr
import os
from typing import Iterator
import sambanova


def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_message,
    max_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    
    conversation = [{"role": "system", "content": system_message}]

    for val in chat_history:
        if val[0]:
            conversation.append({"role": "user", "content": val[0]})
        if val[1]:
            conversation.append({"role": "assistant", "content": val[1]})

    outputs = []
    for text in sambanova.Streamer(conversation,
                                   new_tokens=max_tokens,
                                   temperature=temperature,
                                   top_k=top_k,
                                   top_p=top_p):
        outputs.append(text)
        yield "".join(outputs)


MAX_MAX_TOKENS = 2048
DEFAULT_MAX_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))

# chat_interface = gr.ChatInterface(
#     fn=generate,
#     additional_inputs=[
#         gr.Slider(
#             label="Max new tokens",
#             minimum=1,
#             maximum=MAX_MAX_NEW_TOKENS,
#             step=1,
#             value=DEFAULT_MAX_NEW_TOKENS,
#         ),
#         gr.Slider(
#             label="Temperature",
#             minimum=0.1,
#             maximum=4.0,
#             step=0.1,
#             value=0.6,
#         ),
#         gr.Slider(
#             label="Top-p (nucleus sampling)",
#             minimum=0.05,
#             maximum=1.0,
#             step=0.05,
#             value=0.9,
#         ),
#         gr.Slider(
#             label="Top-k",
#             minimum=1,
#             maximum=1000,
#             step=1,
#             value=50,
#         ),
#         gr.Slider(
#             label="Repetition penalty",
#             minimum=1.0,
#             maximum=2.0,
#             step=0.05,
#             value=1.2,
#         ),
#     ],
#     stop_btn=None,
#     fill_height=True,
#     examples=[
#         ["Which one is bigger? 4.9 or 4.11"],
#         [
#             "Can you explain briefly to me what is the Python programming language?"
#         ],
#         ["Explain the plot of Cinderella in a sentence."],
#         ["How many hours does it take a man to eat a Helicopter?"],
#         [
#             "Write a 100-word article on 'Benefits of Open-Source in AI research'"
#         ],
#     ],
#     cache_examples=False,
# )

chat_interface = gr.ChatInterface(
    generate,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.",
                   label="System message"),
        gr.Slider(
            label="Max tokens",
            minimum=1,
            maximum=MAX_MAX_TOKENS,
            step=1,
            value=DEFAULT_MAX_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),

    ],
    examples=[
        ["Which one is bigger? 4.9 or 4.11"],
        [
            "Can you explain briefly to me what is the Python programming language?"
        ],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        [
            "Write a 100-word article on 'Benefits of Open-Source in AI research'"
        ],
    ],
    cache_examples=False,
)

with gr.Blocks() as demo:
    gr.Markdown('# Sambanova model inference LLAMA 405B')

    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()