Spaces:
Runtime error
Runtime error
import gradio as gr | |
import transformers | |
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM | |
import torch | |
import unicodedata | |
import re | |
# Default variables | |
default_max_new_tokens = 100 | |
default_temperature = 1.0 | |
default_top_k = 10 | |
default_top_p = 0.99 | |
default_repetition_penalty = 1.0 | |
model_name = "OpenLLM-France/Claire-7B-0.1" | |
print("Loading model...") | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
load_in_4bit=True, | |
) | |
print("Optimizing model...") | |
import optimum | |
from optimum.bettertransformer import BetterTransformer | |
model = BetterTransformer.transform(model) | |
print("Setup chat...") | |
eos_token_id = tokenizer.eos_token_id | |
newspk_token_id = tokenizer.encode("[") | |
assert len(newspk_token_id) == 1 | |
newspk_token_id = newspk_token_id[0] | |
# Class to encapsulate the Claire chatbot | |
class ClaireChatBot: | |
def __init__( | |
self, | |
# Chat will display... | |
user_name="VOUS:", | |
bot_name="CHATBOT:", | |
other_name_regex_in=r"AUTRE (\d+):", | |
other_name_regex_out=r"AUTRE \1:", | |
# but Claire was trained on... | |
user_internal_tag="[Intervenant 1:]", | |
bot_internal_tag="[Intervenant 2:]", | |
other_internal_tag_regex_in=r"\[Intervenant (\d+):\]", | |
other_internal_tag_regex_out=r"\[Intervenant \1:\]", | |
): | |
self.user_name = user_name | |
self.bot_name = bot_name | |
self.other_name_regex_in = other_name_regex_in | |
self.other_name_regex_out = other_name_regex_out | |
self.user_internal_tag = user_internal_tag | |
self.bot_internal_tag = bot_internal_tag | |
self.other_internal_tag_regex_in = other_internal_tag_regex_in | |
self.other_internal_tag_regex_out = other_internal_tag_regex_out | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.has_started_bracket = False | |
self.history = "" | |
self.history_raw = "" | |
self.reinject_history = False | |
self.reshow_history = False | |
def predict( | |
self, | |
user_message, | |
bot_message_start="", | |
conversation_history="", | |
generate_several_turns=False, | |
max_new_tokens=default_max_new_tokens, | |
temperature=default_temperature, | |
top_k=default_top_k, | |
top_p=default_top_p, | |
repetition_penalty=default_repetition_penalty, | |
): | |
user_message = claire_text_preproc_message(user_message) | |
bot_message_start = claire_text_preproc_message(bot_message_start) | |
if conversation_history: | |
# Format conversation history | |
for spk_in, spk_out in [ | |
(self.user_name, self.user_internal_tag), | |
(self.bot_name, self.bot_internal_tag), | |
]: | |
conversation_history = conversation_history.replace(spk_in, spk_out) | |
conversation_history = re.sub(self.other_name_regex_in, self.other_internal_tag_regex_out, conversation_history) | |
conversation_history = claire_text_preproc_conversation(conversation_history) | |
conversation_history = conversation_history.rstrip() + "\n" | |
else: | |
conversation_history = self.history_raw | |
# (Only relevant if self.reinject_history is True) | |
user_internal_tag = self.user_internal_tag | |
if self.has_started_bracket: | |
user_internal_tag = user_internal_tag[1:] | |
# Combine the user and bot messages into a conversation | |
conversation = f"{conversation_history}{user_internal_tag} {user_message}\n{self.bot_internal_tag} {bot_message_start if bot_message_start else ''}".strip() | |
# Encode the conversation using the tokenizer | |
input_ids = tokenizer.encode( | |
conversation, return_tensors="pt", add_special_tokens=False | |
) | |
input_ids = input_ids.to(self.device) | |
# Generate a response using Claire | |
response = model.generate( | |
input_ids=input_ids, | |
use_cache=False, | |
early_stopping=False, | |
temperature=temperature, | |
do_sample=True, | |
max_new_tokens=max_new_tokens, | |
top_k=top_k, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
pad_token_id=eos_token_id, | |
eos_token_id=eos_token_id if generate_several_turns else newspk_token_id, | |
) | |
# Decode the generated response to text | |
response_text = tokenizer.decode(response[0], skip_special_tokens=True) | |
# Remove last unfinished speech turn/sentence/phrase | |
line_breaks = [u.span(0)[0] for u in re.finditer("\n", response_text)] | |
remove_last_sentence = True | |
if generate_several_turns: | |
if len(line_breaks) >= 2: | |
response_text = response_text[: line_breaks[-1]] | |
line_breaks.pop(-1) | |
remove_last_sentence = False | |
if remove_last_sentence and len(line_breaks) == 1: | |
sentence_ends = [ | |
u.span(0)[0] for u in re.finditer(r"[\.!?]", response_text) | |
] | |
sentence_ends = [p for p in sentence_ends if p > line_breaks[-1]] | |
if sentence_ends: | |
response_text = response_text[: sentence_ends[-1] + 1] | |
else: | |
phrase_ends = [ | |
u.span(0)[0] for u in re.finditer(r"[,;]", response_text) | |
] | |
phrase_ends = [p for p in phrase_ends if p > line_breaks[-1]] | |
if phrase_ends: | |
response_text = response_text[: phrase_ends[-1] + 1] | |
ended_with_bracket = response_text.endswith("[") | |
if self.reinject_history: | |
self.history_raw = response_text | |
self.has_started_bracket = ended_with_bracket | |
if ended_with_bracket: | |
response_text = response_text[:-1] | |
for spk_in, spk_out in [ | |
(self.user_internal_tag, self.user_name), | |
(self.user_internal_tag[1:], self.user_name), # Starting bracket may be missing | |
(self.bot_internal_tag, self.bot_name), | |
]: | |
response_text = response_text.replace(spk_in, spk_out) | |
response_text = re.sub(self.other_internal_tag_regex_in, self.other_name_regex_out, response_text) | |
if self.reshow_history: | |
previous_history = self.history | |
self.history = previous_history + response_text + "\n" | |
else: | |
previous_history = "" | |
return previous_history + response_text | |
def claire_text_preproc_conversation(text): | |
text = format_special_characters(text) | |
text = collapse_whitespaces_conversations(text) | |
return text | |
def claire_text_preproc_message(text): | |
text = format_special_characters(text) | |
text = collapse_whitespaces_message(text) | |
text = replace_brackets(text) | |
return text | |
def collapse_whitespaces_conversations(text): | |
text = re.sub(r"\n+", "\n", text) | |
text = re.sub(r"[ \t]+", " ", text) | |
text = re.sub(r"\n ", "\n", text) | |
text = re.sub(r" ([\.,])", r"\1", text) | |
return text.lstrip().rstrip(" ") | |
def collapse_whitespaces_message(text): | |
text = re.sub(r"\s+", " ", text) | |
text = re.sub(r" ([\.,])", r"\1", text) | |
return text.lstrip().rstrip(" ") | |
def replace_brackets(text): | |
text = re.sub(r"[\[\{]", "(", text) | |
text = re.sub(r"[\]\}]", ")", text) | |
return text | |
def format_special_characters(text): | |
text = unicodedata.normalize("NFC", text) | |
for before, after in [ | |
("…", "..."), | |
(r"[«“][^\S\r\n]*", '"'), | |
(r"[^\S\r\n]*[»”″„]", '"'), | |
(r"(``|'')", '"'), | |
(r"[’‘‛ʿ]", "'"), | |
("‚", ","), | |
(r"–", "-"), | |
("[ ]", " "), # unbreakable spaces | |
(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]", ""), # non-printable characters | |
# ("·", "."), | |
(r"ᵉʳ", "er"), | |
(r"ᵉ", "e"), | |
]: | |
text = re.sub(before, after, text) | |
return text | |
# Create the Claire chatbot instance | |
chatbot = ClaireChatBot() | |
# Define the Gradio interface | |
title = "Démo de conversation avec Claire" | |
description = "Simulation de conversations en Français avec [Claire](https://huggingface.co/OpenLLM-France/Claire-7B-0.1), sans recherche de vérité, et avec potentiellement un peu d'humour." | |
default_parameters = [ | |
default_temperature, | |
default_top_k, | |
default_top_p, | |
default_repetition_penalty, | |
] | |
examples = [ | |
[ | |
"Nous allons commencer cette interview avec une question un peu classique. Quel est votre sport préféré?", # user_message | |
"", # bot_message_start | |
"", # conversation_history | |
True, # generate_several_turns | |
200, # max_new_tokens | |
*default_parameters, | |
], | |
[ | |
"Que vas-tu nous cuisiner aujourd'hui?", # user_message | |
"Alors, nous allons voir la recette de", # bot_message_start | |
"VOUS: Bonjour Claire.\nCHATBOT: Bonjour Dominique.", # conversation_history | |
False, # generate_several_turns | |
default_max_new_tokens, # max_new_tokens | |
*default_parameters, | |
], | |
] | |
# # Test | |
# chatbot.predict(*examples[0]) | |
inputs = [ | |
gr.Textbox( | |
"", | |
label="Prompt", | |
info="Tapez ce que vous voulez dire au ChatBot", | |
type="text", | |
lines=2, | |
), | |
gr.Textbox( | |
"", | |
label="Début de réponse", | |
info="Vous pouvez taper ici ce que commence à vous répondre le ChatBot", | |
type="text", | |
), | |
gr.Textbox( | |
"", | |
label="Historique", | |
info="Vous pouvez copier-coller (et modifier?) ici votre historique de conversation, pour continuer cette conversation", | |
type="text", | |
lines=3, | |
), | |
gr.Checkbox( | |
False, | |
label="Plus qu'un tour de parole", | |
info="Générer aussi comment pourrait continuer la conversation (plusieurs tours de parole incluant le vôtre)", | |
), | |
gr.Slider( | |
label="Longueur max", | |
info="Longueur maximale du texte généré (en nombre de 'tokens' ~ mots et ponctuations)", | |
value=default_max_new_tokens, | |
minimum=25, | |
maximum=1000, | |
step=25, | |
interactive=True, | |
), | |
gr.Slider( | |
label="Température", | |
info="Une valeur élevée augmente la diversité du texte généré, mais peut aussi produire des résultats incohérents", | |
value=default_temperature, | |
minimum=0.1, | |
maximum=1.9, | |
step=0.1, | |
interactive=True, | |
), | |
gr.Slider( | |
label="Top-k", | |
info="Une valeur élevée permet d'explorer plus d'alternatives, mais augmente les temps de calcul", | |
value=default_top_k, | |
minimum=1, | |
maximum=50, | |
step=1, | |
interactive=True, | |
), | |
gr.Slider( | |
label="Top-p", | |
info="Une valeur élevée permet d'explorer des alternatives moins probables", | |
value=default_top_p, | |
minimum=0.9, | |
maximum=1.0, | |
step=0.01, | |
interactive=True, | |
), | |
gr.Slider( | |
label="Pénalité de répétition", | |
info="Pénalisation des répétitions", | |
value=default_repetition_penalty, | |
minimum=1.0, | |
maximum=2.0, | |
step=0.05, | |
interactive=True, | |
), | |
] | |
theme = gr.themes.Monochrome( | |
secondary_hue="emerald", | |
neutral_hue="teal", | |
).set( | |
body_background_fill="*primary_950", | |
body_background_fill_dark="*secondary_950", | |
body_text_color="*primary_50", | |
body_text_color_dark="*secondary_100", | |
body_text_color_subdued="*primary_300", | |
body_text_color_subdued_dark="*primary_300", | |
background_fill_primary="*primary_600", | |
background_fill_primary_dark="*primary_400", | |
background_fill_secondary="*primary_950", | |
background_fill_secondary_dark="*primary_950", | |
border_color_accent="*secondary_600", | |
border_color_primary="*secondary_50", | |
border_color_primary_dark="*secondary_50", | |
color_accent="*secondary_50", | |
color_accent_soft="*primary_500", | |
color_accent_soft_dark="*primary_500", | |
link_text_color="*secondary_950", | |
link_text_color_dark="*primary_50", | |
link_text_color_active="*primary_50", | |
link_text_color_active_dark="*primary_50", | |
link_text_color_hover="*primary_50", | |
link_text_color_hover_dark="*primary_50", | |
link_text_color_visited="*primary_50", | |
block_background_fill="*primary_950", | |
block_background_fill_dark="*primary_950", | |
block_border_color="*secondary_500", | |
block_border_color_dark="*secondary_500", | |
block_info_text_color="*primary_50", | |
block_info_text_color_dark="*primary_50", | |
block_label_background_fill="*primary_950", | |
block_label_background_fill_dark="*secondary_950", | |
block_label_border_color="*secondary_500", | |
block_label_border_color_dark="*secondary_500", | |
block_label_text_color="*secondary_500", | |
block_label_text_color_dark="*secondary_500", | |
block_title_background_fill="*primary_950", | |
panel_background_fill="*primary_950", | |
panel_border_color="*primary_950", | |
checkbox_background_color="*primary_950", | |
checkbox_background_color_dark="*primary_950", | |
checkbox_background_color_focus="*primary_950", | |
checkbox_border_color="*secondary_500", | |
input_background_fill="*primary_800", | |
input_background_fill_focus="*primary_950", | |
input_background_fill_hover="*secondary_950", | |
input_placeholder_color="*secondary_950", | |
slider_color="*primary_950", | |
slider_color_dark="*primary_950", | |
table_even_background_fill="*primary_800", | |
table_odd_background_fill="*primary_600", | |
button_primary_background_fill="*primary_800", | |
button_primary_background_fill_dark="*primary_800", | |
) | |
iface = gr.Interface( | |
fn=chatbot.predict, | |
title=title, | |
description=description, | |
examples=examples, | |
inputs=inputs, | |
outputs="text", | |
theme=theme, | |
) | |
print("Launching chat...") | |
# Launch the Gradio interface for the model | |
iface.launch(share=True) | |