Claire-Chat-0.1 / app.py
Jeronymous's picture
initial commit
0dcddb0
raw
history blame
14.1 kB
import gradio as gr
import transformers
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM
import torch
import unicodedata
import re
# Default variables
default_max_new_tokens = 100
default_temperature = 1.0
default_top_k = 10
default_top_p = 0.99
default_repetition_penalty = 1.0
model_name = "OpenLLM-France/Claire-7B-0.1"
print("Loading model...")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
load_in_4bit=True,
)
print("Optimizing model...")
import optimum
from optimum.bettertransformer import BetterTransformer
model = BetterTransformer.transform(model)
print("Setup chat...")
eos_token_id = tokenizer.eos_token_id
newspk_token_id = tokenizer.encode("[")
assert len(newspk_token_id) == 1
newspk_token_id = newspk_token_id[0]
# Class to encapsulate the Claire chatbot
class ClaireChatBot:
def __init__(
self,
# Chat will display...
user_name="VOUS:",
bot_name="CHATBOT:",
other_name_regex_in=r"AUTRE (\d+):",
other_name_regex_out=r"AUTRE \1:",
# but Claire was trained on...
user_internal_tag="[Intervenant 1:]",
bot_internal_tag="[Intervenant 2:]",
other_internal_tag_regex_in=r"\[Intervenant (\d+):\]",
other_internal_tag_regex_out=r"\[Intervenant \1:\]",
):
self.user_name = user_name
self.bot_name = bot_name
self.other_name_regex_in = other_name_regex_in
self.other_name_regex_out = other_name_regex_out
self.user_internal_tag = user_internal_tag
self.bot_internal_tag = bot_internal_tag
self.other_internal_tag_regex_in = other_internal_tag_regex_in
self.other_internal_tag_regex_out = other_internal_tag_regex_out
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.has_started_bracket = False
self.history = ""
self.history_raw = ""
self.reinject_history = False
self.reshow_history = False
def predict(
self,
user_message,
bot_message_start="",
conversation_history="",
generate_several_turns=False,
max_new_tokens=default_max_new_tokens,
temperature=default_temperature,
top_k=default_top_k,
top_p=default_top_p,
repetition_penalty=default_repetition_penalty,
):
user_message = claire_text_preproc_message(user_message)
bot_message_start = claire_text_preproc_message(bot_message_start)
if conversation_history:
# Format conversation history
for spk_in, spk_out in [
(self.user_name, self.user_internal_tag),
(self.bot_name, self.bot_internal_tag),
]:
conversation_history = conversation_history.replace(spk_in, spk_out)
conversation_history = re.sub(self.other_name_regex_in, self.other_internal_tag_regex_out, conversation_history)
conversation_history = claire_text_preproc_conversation(conversation_history)
conversation_history = conversation_history.rstrip() + "\n"
else:
conversation_history = self.history_raw
# (Only relevant if self.reinject_history is True)
user_internal_tag = self.user_internal_tag
if self.has_started_bracket:
user_internal_tag = user_internal_tag[1:]
# Combine the user and bot messages into a conversation
conversation = f"{conversation_history}{user_internal_tag} {user_message}\n{self.bot_internal_tag} {bot_message_start if bot_message_start else ''}".strip()
# Encode the conversation using the tokenizer
input_ids = tokenizer.encode(
conversation, return_tensors="pt", add_special_tokens=False
)
input_ids = input_ids.to(self.device)
# Generate a response using Claire
response = model.generate(
input_ids=input_ids,
use_cache=False,
early_stopping=False,
temperature=temperature,
do_sample=True,
max_new_tokens=max_new_tokens,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id if generate_several_turns else newspk_token_id,
)
# Decode the generated response to text
response_text = tokenizer.decode(response[0], skip_special_tokens=True)
# Remove last unfinished speech turn/sentence/phrase
line_breaks = [u.span(0)[0] for u in re.finditer("\n", response_text)]
remove_last_sentence = True
if generate_several_turns:
if len(line_breaks) >= 2:
response_text = response_text[: line_breaks[-1]]
line_breaks.pop(-1)
remove_last_sentence = False
if remove_last_sentence and len(line_breaks) == 1:
sentence_ends = [
u.span(0)[0] for u in re.finditer(r"[\.!?]", response_text)
]
sentence_ends = [p for p in sentence_ends if p > line_breaks[-1]]
if sentence_ends:
response_text = response_text[: sentence_ends[-1] + 1]
else:
phrase_ends = [
u.span(0)[0] for u in re.finditer(r"[,;]", response_text)
]
phrase_ends = [p for p in phrase_ends if p > line_breaks[-1]]
if phrase_ends:
response_text = response_text[: phrase_ends[-1] + 1]
ended_with_bracket = response_text.endswith("[")
if self.reinject_history:
self.history_raw = response_text
self.has_started_bracket = ended_with_bracket
if ended_with_bracket:
response_text = response_text[:-1]
for spk_in, spk_out in [
(self.user_internal_tag, self.user_name),
(self.user_internal_tag[1:], self.user_name), # Starting bracket may be missing
(self.bot_internal_tag, self.bot_name),
]:
response_text = response_text.replace(spk_in, spk_out)
response_text = re.sub(self.other_internal_tag_regex_in, self.other_name_regex_out, response_text)
if self.reshow_history:
previous_history = self.history
self.history = previous_history + response_text + "\n"
else:
previous_history = ""
return previous_history + response_text
def claire_text_preproc_conversation(text):
text = format_special_characters(text)
text = collapse_whitespaces_conversations(text)
return text
def claire_text_preproc_message(text):
text = format_special_characters(text)
text = collapse_whitespaces_message(text)
text = replace_brackets(text)
return text
def collapse_whitespaces_conversations(text):
text = re.sub(r"\n+", "\n", text)
text = re.sub(r"[ \t]+", " ", text)
text = re.sub(r"\n ", "\n", text)
text = re.sub(r" ([\.,])", r"\1", text)
return text.lstrip().rstrip(" ")
def collapse_whitespaces_message(text):
text = re.sub(r"\s+", " ", text)
text = re.sub(r" ([\.,])", r"\1", text)
return text.lstrip().rstrip(" ")
def replace_brackets(text):
text = re.sub(r"[\[\{]", "(", text)
text = re.sub(r"[\]\}]", ")", text)
return text
def format_special_characters(text):
text = unicodedata.normalize("NFC", text)
for before, after in [
("…", "..."),
(r"[«“][^\S\r\n]*", '"'),
(r"[^\S\r\n]*[»”″„]", '"'),
(r"(``|'')", '"'),
(r"[’‘‛ʿ]", "'"),
("‚", ","),
(r"–", "-"),
("[  ]", " "), # unbreakable spaces
(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F-\x9F]", ""), # non-printable characters
# ("·", "."),
(r"ᵉʳ", "er"),
(r"ᵉ", "e"),
]:
text = re.sub(before, after, text)
return text
# Create the Claire chatbot instance
chatbot = ClaireChatBot()
# Define the Gradio interface
title = "Démo de conversation avec Claire"
description = "Simulation de conversations en Français avec [Claire](https://huggingface.co/OpenLLM-France/Claire-7B-0.1), sans recherche de vérité, et avec potentiellement un peu d'humour."
default_parameters = [
default_temperature,
default_top_k,
default_top_p,
default_repetition_penalty,
]
examples = [
[
"Nous allons commencer cette interview avec une question un peu classique. Quel est votre sport préféré?", # user_message
"", # bot_message_start
"", # conversation_history
True, # generate_several_turns
200, # max_new_tokens
*default_parameters,
],
[
"Que vas-tu nous cuisiner aujourd'hui?", # user_message
"Alors, nous allons voir la recette de", # bot_message_start
"VOUS: Bonjour Claire.\nCHATBOT: Bonjour Dominique.", # conversation_history
False, # generate_several_turns
default_max_new_tokens, # max_new_tokens
*default_parameters,
],
]
# # Test
# chatbot.predict(*examples[0])
inputs = [
gr.Textbox(
"",
label="Prompt",
info="Tapez ce que vous voulez dire au ChatBot",
type="text",
lines=2,
),
gr.Textbox(
"",
label="Début de réponse",
info="Vous pouvez taper ici ce que commence à vous répondre le ChatBot",
type="text",
),
gr.Textbox(
"",
label="Historique",
info="Vous pouvez copier-coller (et modifier?) ici votre historique de conversation, pour continuer cette conversation",
type="text",
lines=3,
),
gr.Checkbox(
False,
label="Plus qu'un tour de parole",
info="Générer aussi comment pourrait continuer la conversation (plusieurs tours de parole incluant le vôtre)",
),
gr.Slider(
label="Longueur max",
info="Longueur maximale du texte généré (en nombre de 'tokens' ~ mots et ponctuations)",
value=default_max_new_tokens,
minimum=25,
maximum=1000,
step=25,
interactive=True,
),
gr.Slider(
label="Température",
info="Une valeur élevée augmente la diversité du texte généré, mais peut aussi produire des résultats incohérents",
value=default_temperature,
minimum=0.1,
maximum=1.9,
step=0.1,
interactive=True,
),
gr.Slider(
label="Top-k",
info="Une valeur élevée permet d'explorer plus d'alternatives, mais augmente les temps de calcul",
value=default_top_k,
minimum=1,
maximum=50,
step=1,
interactive=True,
),
gr.Slider(
label="Top-p",
info="Une valeur élevée permet d'explorer des alternatives moins probables",
value=default_top_p,
minimum=0.9,
maximum=1.0,
step=0.01,
interactive=True,
),
gr.Slider(
label="Pénalité de répétition",
info="Pénalisation des répétitions",
value=default_repetition_penalty,
minimum=1.0,
maximum=2.0,
step=0.05,
interactive=True,
),
]
theme = gr.themes.Monochrome(
secondary_hue="emerald",
neutral_hue="teal",
).set(
body_background_fill="*primary_950",
body_background_fill_dark="*secondary_950",
body_text_color="*primary_50",
body_text_color_dark="*secondary_100",
body_text_color_subdued="*primary_300",
body_text_color_subdued_dark="*primary_300",
background_fill_primary="*primary_600",
background_fill_primary_dark="*primary_400",
background_fill_secondary="*primary_950",
background_fill_secondary_dark="*primary_950",
border_color_accent="*secondary_600",
border_color_primary="*secondary_50",
border_color_primary_dark="*secondary_50",
color_accent="*secondary_50",
color_accent_soft="*primary_500",
color_accent_soft_dark="*primary_500",
link_text_color="*secondary_950",
link_text_color_dark="*primary_50",
link_text_color_active="*primary_50",
link_text_color_active_dark="*primary_50",
link_text_color_hover="*primary_50",
link_text_color_hover_dark="*primary_50",
link_text_color_visited="*primary_50",
block_background_fill="*primary_950",
block_background_fill_dark="*primary_950",
block_border_color="*secondary_500",
block_border_color_dark="*secondary_500",
block_info_text_color="*primary_50",
block_info_text_color_dark="*primary_50",
block_label_background_fill="*primary_950",
block_label_background_fill_dark="*secondary_950",
block_label_border_color="*secondary_500",
block_label_border_color_dark="*secondary_500",
block_label_text_color="*secondary_500",
block_label_text_color_dark="*secondary_500",
block_title_background_fill="*primary_950",
panel_background_fill="*primary_950",
panel_border_color="*primary_950",
checkbox_background_color="*primary_950",
checkbox_background_color_dark="*primary_950",
checkbox_background_color_focus="*primary_950",
checkbox_border_color="*secondary_500",
input_background_fill="*primary_800",
input_background_fill_focus="*primary_950",
input_background_fill_hover="*secondary_950",
input_placeholder_color="*secondary_950",
slider_color="*primary_950",
slider_color_dark="*primary_950",
table_even_background_fill="*primary_800",
table_odd_background_fill="*primary_600",
button_primary_background_fill="*primary_800",
button_primary_background_fill_dark="*primary_800",
)
iface = gr.Interface(
fn=chatbot.predict,
title=title,
description=description,
examples=examples,
inputs=inputs,
outputs="text",
theme=theme,
)
print("Launching chat...")
# Launch the Gradio interface for the model
iface.launch(share=True)