Soratz commited on
Commit
6749904
1 Parent(s): 871520e

added gradio files

Browse files
Files changed (4) hide show
  1. README.md +5 -5
  2. app.py +244 -0
  3. model.py +57 -0
  4. style.css +16 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Lmm Use
3
- emoji: 🌍
4
- colorFrom: green
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 4.7.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Mistralai Mistral 7B V0.1
3
+ emoji:
4
+ colorFrom: gray
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.4.1
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Iterator
3
+
4
+ import gradio as gr
5
+
6
+ from model import run
7
+
8
+ DEFAULT_SYSTEM_PROMPT = "You are Mistral. You are AI-assistant, you are polite, give only truthful information and are based on the Mistral-7B model from Mistral AI. You can communicate in different languages equally well."
9
+ MAX_MAX_NEW_TOKENS = 4096
10
+ DEFAULT_MAX_NEW_TOKENS = 256
11
+ MAX_INPUT_TOKEN_LENGTH = 4000
12
+
13
+ DESCRIPTION = """
14
+ # [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
15
+ """
16
+
17
+ def clear_and_save_textbox(message: str) -> tuple[str, str]:
18
+ return '', message
19
+
20
+
21
+ def display_input(message: str,
22
+ history: list[tuple[str, str]]) -> list[tuple[str, str]]:
23
+ history.append((message, ''))
24
+ return history
25
+
26
+
27
+ def delete_prev_fn(
28
+ history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
29
+ try:
30
+ message, _ = history.pop()
31
+ except IndexError:
32
+ message = ''
33
+ return history, message or ''
34
+
35
+
36
+ def generate(
37
+ message: str,
38
+ history_with_input: list[tuple[str, str]],
39
+ system_prompt: str,
40
+ max_new_tokens: int,
41
+ temperature: float,
42
+ top_p: float,
43
+ top_k: int,
44
+ ) -> Iterator[list[tuple[str, str]]]:
45
+ if max_new_tokens > MAX_MAX_NEW_TOKENS:
46
+ raise ValueError
47
+
48
+ history = history_with_input[:-1]
49
+ generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
50
+ try:
51
+ first_response = next(generator)
52
+ yield history + [(message, first_response)]
53
+ except StopIteration:
54
+ yield history + [(message, '')]
55
+ for response in generator:
56
+ yield history + [(message, response)]
57
+
58
+
59
+ def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
60
+ generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
61
+ for x in generator:
62
+ pass
63
+ return '', x
64
+
65
+
66
+ def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
67
+ input_token_length = len(message) + len(chat_history)
68
+ if input_token_length > MAX_INPUT_TOKEN_LENGTH:
69
+ raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
70
+
71
+
72
+ with gr.Blocks(css='style.css') as demo:
73
+ gr.Markdown(DESCRIPTION)
74
+
75
+ with gr.Group():
76
+ chatbot = gr.Chatbot(label='Playground')
77
+ with gr.Row():
78
+ textbox = gr.Textbox(
79
+ container=False,
80
+ show_label=False,
81
+ placeholder='Hello there!',
82
+ scale=10,
83
+ lines=5
84
+ )
85
+ submit_button = gr.Button('Submit',
86
+ variant='primary',
87
+ scale=1,
88
+ min_width=0)
89
+ with gr.Row():
90
+ retry_button = gr.Button('🔄 Retry', variant='secondary')
91
+ undo_button = gr.Button('↩️ Undo', variant='secondary')
92
+ clear_button = gr.Button('🗑️ Clear', variant='secondary')
93
+
94
+ saved_input = gr.State()
95
+
96
+ with gr.Accordion(label='⚙️ Advanced options', open=False):
97
+ system_prompt = gr.Textbox(label='System prompt',
98
+ value=DEFAULT_SYSTEM_PROMPT,
99
+ lines=5,
100
+ interactive=False)
101
+ max_new_tokens = gr.Slider(
102
+ label='Max new tokens',
103
+ minimum=1,
104
+ maximum=MAX_MAX_NEW_TOKENS,
105
+ step=1,
106
+ value=DEFAULT_MAX_NEW_TOKENS,
107
+ )
108
+ temperature = gr.Slider(
109
+ label='Temperature',
110
+ minimum=0.1,
111
+ maximum=4.0,
112
+ step=0.1,
113
+ value=0.1,
114
+ )
115
+ top_p = gr.Slider(
116
+ label='Top-p (nucleus sampling)',
117
+ minimum=0.05,
118
+ maximum=1.0,
119
+ step=0.05,
120
+ value=0.9,
121
+ )
122
+ top_k = gr.Slider(
123
+ label='Top-k',
124
+ minimum=1,
125
+ maximum=1000,
126
+ step=1,
127
+ value=10,
128
+ )
129
+
130
+
131
+
132
+ textbox.submit(
133
+ fn=clear_and_save_textbox,
134
+ inputs=textbox,
135
+ outputs=[textbox, saved_input],
136
+ api_name=False,
137
+ queue=False,
138
+ ).then(
139
+ fn=display_input,
140
+ inputs=[saved_input, chatbot],
141
+ outputs=chatbot,
142
+ api_name=False,
143
+ queue=False,
144
+ ).then(
145
+ fn=check_input_token_length,
146
+ inputs=[saved_input, chatbot, system_prompt],
147
+ api_name=False,
148
+ queue=False,
149
+ ).success(
150
+ fn=generate,
151
+ inputs=[
152
+ saved_input,
153
+ chatbot,
154
+ system_prompt,
155
+ max_new_tokens,
156
+ temperature,
157
+ top_p,
158
+ top_k,
159
+ ],
160
+ outputs=chatbot,
161
+ api_name=False,
162
+ )
163
+
164
+ button_event_preprocess = submit_button.click(
165
+ fn=clear_and_save_textbox,
166
+ inputs=textbox,
167
+ outputs=[textbox, saved_input],
168
+ api_name=False,
169
+ queue=False,
170
+ ).then(
171
+ fn=display_input,
172
+ inputs=[saved_input, chatbot],
173
+ outputs=chatbot,
174
+ api_name=False,
175
+ queue=False,
176
+ ).then(
177
+ fn=check_input_token_length,
178
+ inputs=[saved_input, chatbot, system_prompt],
179
+ api_name=False,
180
+ queue=False,
181
+ ).success(
182
+ fn=generate,
183
+ inputs=[
184
+ saved_input,
185
+ chatbot,
186
+ system_prompt,
187
+ max_new_tokens,
188
+ temperature,
189
+ top_p,
190
+ top_k,
191
+ ],
192
+ outputs=chatbot,
193
+ api_name=False,
194
+ )
195
+
196
+ retry_button.click(
197
+ fn=delete_prev_fn,
198
+ inputs=chatbot,
199
+ outputs=[chatbot, saved_input],
200
+ api_name=False,
201
+ queue=False,
202
+ ).then(
203
+ fn=display_input,
204
+ inputs=[saved_input, chatbot],
205
+ outputs=chatbot,
206
+ api_name=False,
207
+ queue=False,
208
+ ).then(
209
+ fn=generate,
210
+ inputs=[
211
+ saved_input,
212
+ chatbot,
213
+ system_prompt,
214
+ max_new_tokens,
215
+ temperature,
216
+ top_p,
217
+ top_k,
218
+ ],
219
+ outputs=chatbot,
220
+ api_name=False,
221
+ )
222
+
223
+ undo_button.click(
224
+ fn=delete_prev_fn,
225
+ inputs=chatbot,
226
+ outputs=[chatbot, saved_input],
227
+ api_name=False,
228
+ queue=False,
229
+ ).then(
230
+ fn=lambda x: x,
231
+ inputs=[saved_input],
232
+ outputs=textbox,
233
+ api_name=False,
234
+ queue=False,
235
+ )
236
+
237
+ clear_button.click(
238
+ fn=lambda: ([], ''),
239
+ outputs=[chatbot, saved_input],
240
+ queue=False,
241
+ api_name=False,
242
+ )
243
+
244
+ demo.queue(max_size=32).launch(share=False, show_api=False)
model.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Iterator
3
+
4
+ from text_generation import Client
5
+
6
+ model_id = 'mistralai/Mistral-7B-Instruct-v0.1'
7
+
8
+ API_URL = "https://api-inference.huggingface.co/models/" + model_id
9
+ HF_TOKEN = os.environ.get("HF_READ_TOKEN", None)
10
+
11
+ client = Client(
12
+ API_URL,
13
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
14
+ )
15
+ EOS_STRING = "</s>"
16
+ EOT_STRING = "<EOT>"
17
+
18
+
19
+ def get_prompt(message: str, chat_history: list[tuple[str, str]],
20
+ system_prompt: str) -> str:
21
+ texts = [f'<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n']
22
+ # The first user input is _not_ stripped
23
+ do_strip = False
24
+ for user_input, response in chat_history:
25
+ user_input = user_input.strip() if do_strip else user_input
26
+ do_strip = True
27
+ texts.append(f'{user_input} [/INST] {response.strip()} </s><s>[INST] ')
28
+ message = message.strip() if do_strip else message
29
+ texts.append(f'{message} [/INST]')
30
+ return ''.join(texts)
31
+
32
+
33
+ def run(message: str,
34
+ chat_history: list[tuple[str, str]],
35
+ system_prompt: str,
36
+ max_new_tokens: int = 1024,
37
+ temperature: float = 0.1,
38
+ top_p: float = 0.9,
39
+ top_k: int = 50) -> Iterator[str]:
40
+ prompt = get_prompt(message, chat_history, system_prompt)
41
+
42
+ generate_kwargs = dict(
43
+ max_new_tokens=max_new_tokens,
44
+ do_sample=True,
45
+ top_p=top_p,
46
+ top_k=top_k,
47
+ temperature=temperature,
48
+ )
49
+ stream = client.generate_stream(prompt, **generate_kwargs)
50
+ output = ""
51
+ for response in stream:
52
+ if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
53
+ return output
54
+ else:
55
+ output += response.token.text
56
+ yield output
57
+ return output
style.css ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: white;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }
11
+
12
+ #component-0 {
13
+ max-width: 900px;
14
+ margin: auto;
15
+ padding-top: 1.5rem;
16
+ }