import gradio as gr import transformers from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM import torch import unicodedata import re # Default variables default_max_new_tokens = 100 default_temperature = 1.0 default_top_k = 10 default_top_p = 0.99 default_repetition_penalty = 1.0 model_name = "OpenLLM-France/Claire-7B-0.1" print("Loading model...") tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) model = transformers.AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", torch_dtype=torch.bfloat16, load_in_4bit=True, ) print("Optimizing model...") import optimum from optimum.bettertransformer import BetterTransformer model = BetterTransformer.transform(model) print("Setup chat...") eos_token_id = tokenizer.eos_token_id newspk_token_id = tokenizer.encode("[") assert len(newspk_token_id) == 1 newspk_token_id = newspk_token_id[0] # Class to encapsulate the Claire chatbot class ClaireChatBot: def __init__( self, # Chat will display... user_name="VOUS:", bot_name="CHATBOT:", other_name_regex_in=r"AUTRE (\d+):", other_name_regex_out=r"AUTRE \1:", # but Claire was trained on... user_internal_tag="[Intervenant 1:]", bot_internal_tag="[Intervenant 2:]", other_internal_tag_regex_in=r"\[Intervenant (\d+):\]", other_internal_tag_regex_out=r"\[Intervenant \1:\]", ): self.user_name = user_name self.bot_name = bot_name self.other_name_regex_in = other_name_regex_in self.other_name_regex_out = other_name_regex_out self.user_internal_tag = user_internal_tag self.bot_internal_tag = bot_internal_tag self.other_internal_tag_regex_in = other_internal_tag_regex_in self.other_internal_tag_regex_out = other_internal_tag_regex_out self.device = "cuda" if torch.cuda.is_available() else "cpu" self.has_started_bracket = False self.history = "" self.history_raw = "" self.reinject_history = False self.reshow_history = False def predict( self, user_message, bot_message_start="", conversation_history="", generate_several_turns=False, max_new_tokens=default_max_new_tokens, temperature=default_temperature, top_k=default_top_k, top_p=default_top_p, repetition_penalty=default_repetition_penalty, ): user_message = claire_text_preproc_message(user_message) bot_message_start = claire_text_preproc_message(bot_message_start) if conversation_history: # Format conversation history for spk_in, spk_out in [ (self.user_name, self.user_internal_tag), (self.bot_name, self.bot_internal_tag), ]: conversation_history = conversation_history.replace(spk_in, spk_out) conversation_history = re.sub(self.other_name_regex_in, self.other_internal_tag_regex_out, conversation_history) conversation_history = claire_text_preproc_conversation(conversation_history) conversation_history = conversation_history.rstrip() + "\n" else: conversation_history = self.history_raw # (Only relevant if self.reinject_history is True) user_internal_tag = self.user_internal_tag if self.has_started_bracket: user_internal_tag = user_internal_tag[1:] # Combine the user and bot messages into a conversation conversation = f"{conversation_history}{user_internal_tag} {user_message}\n{self.bot_internal_tag} {bot_message_start if bot_message_start else ''}".strip() # Encode the conversation using the tokenizer input_ids = tokenizer.encode( conversation, return_tensors="pt", add_special_tokens=False ) input_ids = input_ids.to(self.device) # Generate a response using Claire response = model.generate( input_ids=input_ids, use_cache=False, early_stopping=False, temperature=temperature, do_sample=True, max_new_tokens=max_new_tokens, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, pad_token_id=eos_token_id, eos_token_id=eos_token_id if generate_several_turns else newspk_token_id, ) # Decode the generated response to text response_text = tokenizer.decode(response[0], skip_special_tokens=True) # Remove last unfinished speech turn/sentence/phrase line_breaks = [u.span(0)[0] for u in re.finditer("\n", response_text)] remove_last_sentence = True if generate_several_turns: if len(line_breaks) >= 2: response_text = response_text[: line_breaks[-1]] line_breaks.pop(-1) remove_last_sentence = False if remove_last_sentence and len(line_breaks) == 1: sentence_ends = [ u.span(0)[0] for u in re.finditer(r"[\.!?]", response_text) ] sentence_ends = [p for p in sentence_ends if p > line_breaks[-1]] if sentence_ends: response_text = response_text[: sentence_ends[-1] + 1] else: phrase_ends = [ u.span(0)[0] for u in re.finditer(r"[,;]", response_text) ] phrase_ends = [p for p in phrase_ends if p > line_breaks[-1]] if phrase_ends: response_text = response_text[: phrase_ends[-1] + 1] ended_with_bracket = response_text.endswith("[") if self.reinject_history: self.history_raw = response_text self.has_started_bracket = ended_with_bracket if ended_with_bracket: response_text = response_text[:-1] for spk_in, spk_out in [ (self.user_internal_tag, self.user_name), (self.user_internal_tag[1:], self.user_name), # Starting bracket may be missing (self.bot_internal_tag, self.bot_name), ]: response_text = response_text.replace(spk_in, spk_out) response_text = re.sub(self.other_internal_tag_regex_in, self.other_name_regex_out, response_text) if self.reshow_history: previous_history = self.history self.history = previous_history + response_text + "\n" else: previous_history = "" return previous_history + response_text def claire_text_preproc_conversation(text): text = format_special_characters(text) text = collapse_whitespaces_conversations(text) return text def claire_text_preproc_message(text): text = format_special_characters(text) text = collapse_whitespaces_message(text) text = replace_brackets(text) return text def collapse_whitespaces_conversations(text): text = re.sub(r"\n+", "\n", text) text = re.sub(r"[ \t]+", " ", text) text = re.sub(r"\n ", "\n", text) text = re.sub(r" ([\.,])", r"\1", text) return text.lstrip().rstrip(" ") def collapse_whitespaces_message(text): text = re.sub(r"\s+", " ", text) text = re.sub(r" ([\.,])", r"\1", text) return text.lstrip().rstrip(" ") def replace_brackets(text): text = re.sub(r"[\[\{]", "(", text) text = re.sub(r"[\]\}]", ")", text) return text def format_special_characters(text): text = unicodedata.normalize("NFC", text) for before, after in [ ("…", "..."), (r"[«“][^\S\r\n]*", '"'), (r"[^\S\r\n]*[»”″„]", '"'), (r"(``|'')", '"'), (r"[’‘‛ʿ]", "'"), ("‚", ","), (r"–", "-"), ("[  ]", " "), # unbreakable spaces (r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]", ""), # non-printable characters # ("·", "."), (r"ᵉʳ", "er"), (r"ᵉ", "e"), ]: text = re.sub(before, after, text) return text # Create the Claire chatbot instance chatbot = ClaireChatBot() # Define the Gradio interface title = "Démo de conversation avec Claire" description = "Simulation de conversations en Français avec [Claire](https://huggingface.co/OpenLLM-France/Claire-7B-0.1), sans recherche de vérité, et avec potentiellement un peu d'humour." default_parameters = [ default_temperature, default_top_k, default_top_p, default_repetition_penalty, ] examples = [ [ "Nous allons commencer cette interview avec une question un peu classique. Quel est votre sport préféré?", # user_message "", # bot_message_start "", # conversation_history True, # generate_several_turns 200, # max_new_tokens *default_parameters, ], [ "Que vas-tu nous cuisiner aujourd'hui?", # user_message "Alors, nous allons voir la recette de", # bot_message_start "VOUS: Bonjour Claire.\nCHATBOT: Bonjour Dominique.", # conversation_history False, # generate_several_turns default_max_new_tokens, # max_new_tokens *default_parameters, ], ] # # Test # chatbot.predict(*examples[0]) inputs = [ gr.Textbox( "", label="Prompt", info="Tapez ce que vous voulez dire au ChatBot", type="text", lines=2, ), gr.Textbox( "", label="Début de réponse", info="Vous pouvez taper ici ce que commence à vous répondre le ChatBot", type="text", ), gr.Textbox( "", label="Historique", info="Vous pouvez copier-coller (et modifier?) ici votre historique de conversation, pour continuer cette conversation", type="text", lines=3, ), gr.Checkbox( False, label="Plus qu'un tour de parole", info="Générer aussi comment pourrait continuer la conversation (plusieurs tours de parole incluant le vôtre)", ), gr.Slider( label="Longueur max", info="Longueur maximale du texte généré (en nombre de 'tokens' ~ mots et ponctuations)", value=default_max_new_tokens, minimum=25, maximum=1000, step=25, interactive=True, ), gr.Slider( label="Température", info="Une valeur élevée augmente la diversité du texte généré, mais peut aussi produire des résultats incohérents", value=default_temperature, minimum=0.1, maximum=1.9, step=0.1, interactive=True, ), gr.Slider( label="Top-k", info="Une valeur élevée permet d'explorer plus d'alternatives, mais augmente les temps de calcul", value=default_top_k, minimum=1, maximum=50, step=1, interactive=True, ), gr.Slider( label="Top-p", info="Une valeur élevée permet d'explorer des alternatives moins probables", value=default_top_p, minimum=0.9, maximum=1.0, step=0.01, interactive=True, ), gr.Slider( label="Pénalité de répétition", info="Pénalisation des répétitions", value=default_repetition_penalty, minimum=1.0, maximum=2.0, step=0.05, interactive=True, ), ] theme = gr.themes.Monochrome( secondary_hue="emerald", neutral_hue="teal", ).set( body_background_fill="*primary_950", body_background_fill_dark="*secondary_950", body_text_color="*primary_50", body_text_color_dark="*secondary_100", body_text_color_subdued="*primary_300", body_text_color_subdued_dark="*primary_300", background_fill_primary="*primary_600", background_fill_primary_dark="*primary_400", background_fill_secondary="*primary_950", background_fill_secondary_dark="*primary_950", border_color_accent="*secondary_600", border_color_primary="*secondary_50", border_color_primary_dark="*secondary_50", color_accent="*secondary_50", color_accent_soft="*primary_500", color_accent_soft_dark="*primary_500", link_text_color="*secondary_950", link_text_color_dark="*primary_50", link_text_color_active="*primary_50", link_text_color_active_dark="*primary_50", link_text_color_hover="*primary_50", link_text_color_hover_dark="*primary_50", link_text_color_visited="*primary_50", block_background_fill="*primary_950", block_background_fill_dark="*primary_950", block_border_color="*secondary_500", block_border_color_dark="*secondary_500", block_info_text_color="*primary_50", block_info_text_color_dark="*primary_50", block_label_background_fill="*primary_950", block_label_background_fill_dark="*secondary_950", block_label_border_color="*secondary_500", block_label_border_color_dark="*secondary_500", block_label_text_color="*secondary_500", block_label_text_color_dark="*secondary_500", block_title_background_fill="*primary_950", panel_background_fill="*primary_950", panel_border_color="*primary_950", checkbox_background_color="*primary_950", checkbox_background_color_dark="*primary_950", checkbox_background_color_focus="*primary_950", checkbox_border_color="*secondary_500", input_background_fill="*primary_800", input_background_fill_focus="*primary_950", input_background_fill_hover="*secondary_950", input_placeholder_color="*secondary_950", slider_color="*primary_950", slider_color_dark="*primary_950", table_even_background_fill="*primary_800", table_odd_background_fill="*primary_600", button_primary_background_fill="*primary_800", button_primary_background_fill_dark="*primary_800", ) iface = gr.Interface( fn=chatbot.predict, title=title, description=description, examples=examples, inputs=inputs, outputs="text", theme=theme, ) print("Launching chat...") # Launch the Gradio interface for the model iface.launch(share=True)