File size: 14,083 Bytes
0dcddb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
import gradio as gr
import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
import unicodedata
import re

# Default variables
default_max_new_tokens = 100
default_temperature = 1.0
default_top_k = 10
default_top_p = 0.99
default_repetition_penalty = 1.0

model_name = "OpenLLM-France/Claire-7B-0.1"

print("Loading model...")

tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    load_in_4bit=True,
)

print("Optimizing model...")

import optimum
from optimum.bettertransformer import BetterTransformer

model = BetterTransformer.transform(model)

print("Setup chat...")

eos_token_id = tokenizer.eos_token_id
newspk_token_id = tokenizer.encode("[")
assert len(newspk_token_id) == 1
newspk_token_id = newspk_token_id[0]


# Class to encapsulate the Claire chatbot
class ClaireChatBot:
    def __init__(
        self,
        # Chat will display...
        user_name="VOUS:",
        bot_name="CHATBOT:",
        other_name_regex_in=r"AUTRE (\d+):",
        other_name_regex_out=r"AUTRE \1:",
        # but Claire was trained on...
        user_internal_tag="[Intervenant 1:]",
        bot_internal_tag="[Intervenant 2:]",
        other_internal_tag_regex_in=r"\[Intervenant (\d+):\]",
        other_internal_tag_regex_out=r"\[Intervenant \1:\]",
    ):
        self.user_name = user_name
        self.bot_name = bot_name
        self.other_name_regex_in = other_name_regex_in
        self.other_name_regex_out = other_name_regex_out

        self.user_internal_tag = user_internal_tag
        self.bot_internal_tag = bot_internal_tag
        self.other_internal_tag_regex_in = other_internal_tag_regex_in
        self.other_internal_tag_regex_out = other_internal_tag_regex_out

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.has_started_bracket = False
        self.history = ""
        self.history_raw = ""
        self.reinject_history = False
        self.reshow_history = False

    def predict(
        self,
        user_message,
        bot_message_start="",
        conversation_history="",
        generate_several_turns=False,
        max_new_tokens=default_max_new_tokens,
        temperature=default_temperature,
        top_k=default_top_k,
        top_p=default_top_p,
        repetition_penalty=default_repetition_penalty,
    ):
        user_message = claire_text_preproc_message(user_message)
        bot_message_start = claire_text_preproc_message(bot_message_start)

        if conversation_history:
            # Format conversation history
            for spk_in, spk_out in [
                (self.user_name, self.user_internal_tag),
                (self.bot_name, self.bot_internal_tag),
            ]:
                conversation_history = conversation_history.replace(spk_in, spk_out)
            conversation_history = re.sub(self.other_name_regex_in, self.other_internal_tag_regex_out, conversation_history)
            conversation_history = claire_text_preproc_conversation(conversation_history)
            conversation_history = conversation_history.rstrip() + "\n"
        else:
            conversation_history = self.history_raw

        # (Only relevant if self.reinject_history is True)
        user_internal_tag = self.user_internal_tag
        if self.has_started_bracket:
            user_internal_tag = user_internal_tag[1:]

        # Combine the user and bot messages into a conversation
        conversation = f"{conversation_history}{user_internal_tag} {user_message}\n{self.bot_internal_tag} {bot_message_start if bot_message_start else ''}".strip()

        # Encode the conversation using the tokenizer
        input_ids = tokenizer.encode(
            conversation, return_tensors="pt", add_special_tokens=False
        )
        input_ids = input_ids.to(self.device)

        # Generate a response using Claire
        response = model.generate(
            input_ids=input_ids,
            use_cache=False,
            early_stopping=False,
            temperature=temperature,
            do_sample=True,
            max_new_tokens=max_new_tokens,
            top_k=top_k,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
            pad_token_id=eos_token_id,
            eos_token_id=eos_token_id if generate_several_turns else newspk_token_id,
        )

        # Decode the generated response to text
        response_text = tokenizer.decode(response[0], skip_special_tokens=True)

        # Remove last unfinished speech turn/sentence/phrase
        line_breaks = [u.span(0)[0] for u in re.finditer("\n", response_text)]
        remove_last_sentence = True
        if generate_several_turns:
            if len(line_breaks) >= 2:
                response_text = response_text[: line_breaks[-1]]
                line_breaks.pop(-1)
                remove_last_sentence = False
        if remove_last_sentence and len(line_breaks) == 1:
            sentence_ends = [
                u.span(0)[0] for u in re.finditer(r"[\.!?]", response_text)
            ]
            sentence_ends = [p for p in sentence_ends if p > line_breaks[-1]]
            if sentence_ends:
                response_text = response_text[: sentence_ends[-1] + 1]
            else:
                phrase_ends = [
                    u.span(0)[0] for u in re.finditer(r"[,;]", response_text)
                ]
                phrase_ends = [p for p in phrase_ends if p > line_breaks[-1]]
                if phrase_ends:
                    response_text = response_text[: phrase_ends[-1] + 1]

        ended_with_bracket = response_text.endswith("[")

        if self.reinject_history:
            self.history_raw = response_text
            self.has_started_bracket = ended_with_bracket

        if ended_with_bracket:
            response_text = response_text[:-1]

        for spk_in, spk_out in [
            (self.user_internal_tag, self.user_name),
            (self.user_internal_tag[1:], self.user_name), # Starting bracket may be missing  
            (self.bot_internal_tag, self.bot_name),
        ]:
            response_text = response_text.replace(spk_in, spk_out)
        response_text = re.sub(self.other_internal_tag_regex_in, self.other_name_regex_out, response_text)

        if self.reshow_history:
            previous_history = self.history
            self.history = previous_history + response_text + "\n"
        else:
            previous_history = ""

        return previous_history + response_text


def claire_text_preproc_conversation(text):
    text = format_special_characters(text)
    text = collapse_whitespaces_conversations(text)
    return text


def claire_text_preproc_message(text):
    text = format_special_characters(text)
    text = collapse_whitespaces_message(text)
    text = replace_brackets(text)
    return text


def collapse_whitespaces_conversations(text):
    text = re.sub(r"\n+", "\n", text)
    text = re.sub(r"[ \t]+", " ", text)
    text = re.sub(r"\n ", "\n", text)
    text = re.sub(r" ([\.,])", r"\1", text)
    return text.lstrip().rstrip(" ")


def collapse_whitespaces_message(text):
    text = re.sub(r"\s+", " ", text)
    text = re.sub(r" ([\.,])", r"\1", text)
    return text.lstrip().rstrip(" ")


def replace_brackets(text):
    text = re.sub(r"[\[\{]", "(", text)
    text = re.sub(r"[\]\}]", ")", text)
    return text


def format_special_characters(text):
    text = unicodedata.normalize("NFC", text)
    for before, after in [
        ("…", "..."),
        (r"[«“][^\S\r\n]*", '"'),
        (r"[^\S\r\n]*[»”″„]", '"'),
        (r"(``|'')", '"'),
        (r"[’‘‛ʿ]", "'"),
        ("‚", ","),
        (r"–", "-"),
        ("[  ]", " "),  # unbreakable spaces
        (r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]", ""),  # non-printable characters
        # ("·", "."),
        (r"ᵉʳ", "er"),
        (r"ᵉ", "e"),
    ]:
        text = re.sub(before, after, text)

    return text


# Create the Claire chatbot instance
chatbot = ClaireChatBot()

# Define the Gradio interface
title = "Démo de conversation avec Claire"
description = "Simulation de conversations en Français avec [Claire](https://huggingface.co/OpenLLM-France/Claire-7B-0.1), sans recherche de vérité, et avec potentiellement un peu d'humour."

default_parameters = [
    default_temperature,
    default_top_k,
    default_top_p,
    default_repetition_penalty,
]

examples = [
    [
        "Nous allons commencer cette interview avec une question un peu classique. Quel est votre sport préféré?",  # user_message
        "",  # bot_message_start
        "",  # conversation_history
        True,  # generate_several_turns
        200,  # max_new_tokens
        *default_parameters,
    ],
    [
        "Que vas-tu nous cuisiner aujourd'hui?",  # user_message
        "Alors, nous allons voir la recette de",  # bot_message_start
        "VOUS: Bonjour Claire.\nCHATBOT: Bonjour Dominique.",  # conversation_history
        False,  # generate_several_turns
        default_max_new_tokens,  # max_new_tokens
        *default_parameters,
    ],
]

# # Test
# chatbot.predict(*examples[0])

inputs = [
    gr.Textbox(
        "",
        label="Prompt",
        info="Tapez ce que vous voulez dire au ChatBot",
        type="text",
        lines=2,
    ),
    gr.Textbox(
        "",
        label="Début de réponse",
        info="Vous pouvez taper ici ce que commence à vous répondre le ChatBot",
        type="text",
    ),
    gr.Textbox(
        "",
        label="Historique",
        info="Vous pouvez copier-coller (et modifier?) ici votre historique de conversation, pour continuer cette conversation",
        type="text",
        lines=3,
    ),
    gr.Checkbox(
        False,
        label="Plus qu'un tour de parole",
        info="Générer aussi comment pourrait continuer la conversation (plusieurs tours de parole incluant le vôtre)",
    ),
    gr.Slider(
        label="Longueur max",
        info="Longueur maximale du texte généré (en nombre de 'tokens' ~ mots et ponctuations)",
        value=default_max_new_tokens,
        minimum=25,
        maximum=1000,
        step=25,
        interactive=True,
    ),
    gr.Slider(
        label="Température",
        info="Une valeur élevée augmente la diversité du texte généré, mais peut aussi produire des résultats incohérents",
        value=default_temperature,
        minimum=0.1,
        maximum=1.9,
        step=0.1,
        interactive=True,
    ),
    gr.Slider(
        label="Top-k",
        info="Une valeur élevée permet d'explorer plus d'alternatives, mais augmente les temps de calcul",
        value=default_top_k,
        minimum=1,
        maximum=50,
        step=1,
        interactive=True,
    ),
    gr.Slider(
        label="Top-p",
        info="Une valeur élevée permet d'explorer des alternatives moins probables",
        value=default_top_p,
        minimum=0.9,
        maximum=1.0,
        step=0.01,
        interactive=True,
    ),
    gr.Slider(
        label="Pénalité de répétition",
        info="Pénalisation des répétitions",
        value=default_repetition_penalty,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
    ),
]

theme = gr.themes.Monochrome(
    secondary_hue="emerald",
    neutral_hue="teal",
).set(
    body_background_fill="*primary_950",
    body_background_fill_dark="*secondary_950",
    body_text_color="*primary_50",
    body_text_color_dark="*secondary_100",
    body_text_color_subdued="*primary_300",
    body_text_color_subdued_dark="*primary_300",
    background_fill_primary="*primary_600",
    background_fill_primary_dark="*primary_400",
    background_fill_secondary="*primary_950",
    background_fill_secondary_dark="*primary_950",
    border_color_accent="*secondary_600",
    border_color_primary="*secondary_50",
    border_color_primary_dark="*secondary_50",
    color_accent="*secondary_50",
    color_accent_soft="*primary_500",
    color_accent_soft_dark="*primary_500",
    link_text_color="*secondary_950",
    link_text_color_dark="*primary_50",
    link_text_color_active="*primary_50",
    link_text_color_active_dark="*primary_50",
    link_text_color_hover="*primary_50",
    link_text_color_hover_dark="*primary_50",
    link_text_color_visited="*primary_50",
    block_background_fill="*primary_950",
    block_background_fill_dark="*primary_950",
    block_border_color="*secondary_500",
    block_border_color_dark="*secondary_500",
    block_info_text_color="*primary_50",
    block_info_text_color_dark="*primary_50",
    block_label_background_fill="*primary_950",
    block_label_background_fill_dark="*secondary_950",
    block_label_border_color="*secondary_500",
    block_label_border_color_dark="*secondary_500",
    block_label_text_color="*secondary_500",
    block_label_text_color_dark="*secondary_500",
    block_title_background_fill="*primary_950",
    panel_background_fill="*primary_950",
    panel_border_color="*primary_950",
    checkbox_background_color="*primary_950",
    checkbox_background_color_dark="*primary_950",
    checkbox_background_color_focus="*primary_950",
    checkbox_border_color="*secondary_500",
    input_background_fill="*primary_800",
    input_background_fill_focus="*primary_950",
    input_background_fill_hover="*secondary_950",
    input_placeholder_color="*secondary_950",
    slider_color="*primary_950",
    slider_color_dark="*primary_950",
    table_even_background_fill="*primary_800",
    table_odd_background_fill="*primary_600",
    button_primary_background_fill="*primary_800",
    button_primary_background_fill_dark="*primary_800",
)

iface = gr.Interface(
    fn=chatbot.predict,
    title=title,
    description=description,
    examples=examples,
    inputs=inputs,
    outputs="text",
    theme=theme,
)

print("Launching chat...")

# Launch the Gradio interface for the model
iface.launch(share=True)