Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from huggingface_hub import login | |
import os | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
import asyncio | |
import gc | |
# Authentification | |
login(token=os.environ["HF_TOKEN"]) | |
# Restructuration des modèles et de leurs informations | |
models_info = { | |
"Meta-llama": { | |
"Llama 2": { | |
"7B": {"name": "meta-llama/Llama-2-7b-hf", "languages": ["en"]}, | |
"13B": {"name": "meta-llama/Llama-2-13b-hf", "languages": ["en"]}, | |
}, | |
"Llama 3": { | |
"8B": {"name": "meta-llama/Llama-3-8B", "languages": ["en"]}, | |
"3.2-3B": {"name": "meta-llama/Llama-3.2-3B", "languages": ["en", "de", "fr", "it", "pt", "hi", "es", "th"]}, | |
}, | |
}, | |
"Mistral AI": { | |
"Mistral": { | |
"7B-v0.1": {"name": "mistralai/Mistral-7B-v0.1", "languages": ["en"]}, | |
"7B-v0.3": {"name": "mistralai/Mistral-7B-v0.3", "languages": ["en"]}, | |
}, | |
"Mixtral": { | |
"8x7B-v0.1": {"name": "mistralai/Mixtral-8x7B-v0.1", "languages": ["en", "fr", "it", "de", "es"]}, | |
}, | |
}, | |
"Google": { | |
"Gemma": { | |
"2B": {"name": "google/gemma-2-2b", "languages": ["en"]}, | |
"7B": {"name": "google/gemma-2-7b", "languages": ["en"]}, | |
}, | |
}, | |
"CroissantLLM": { | |
"CroissantLLMBase": { | |
"Base": {"name": "croissantllm/CroissantLLMBase", "languages": ["en", "fr"]}, | |
}, | |
}, | |
} | |
# Paramètres recommandés pour chaque modèle | |
model_parameters = { | |
"meta-llama/Llama-2-7b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40}, | |
"meta-llama/Llama-2-13b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40}, | |
"meta-llama/Llama-3-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50}, | |
"meta-llama/Llama-3.2-3B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50}, | |
"mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50}, | |
"mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50}, | |
"mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50}, | |
"google/gemma-2-2b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40}, | |
"google/gemma-2-7b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40}, | |
"croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50} | |
} | |
# Variables globales | |
model_cache = {} | |
# Fonctions utilitaires | |
def update_model_type(family): | |
return gr.Dropdown(choices=list(models_info[family].keys()), value=None, interactive=True) | |
def update_model_variation(family, model_type): | |
if family and model_type: | |
return gr.Dropdown(choices=list(models_info[family][model_type].keys()), value=None, interactive=True) | |
return gr.Dropdown(choices=[], value=None, interactive=False) | |
def update_selected_model(family, model_type, variation): | |
if family and model_type and variation: | |
model_name = models_info[family][model_type][variation]["name"] | |
return model_name, gr.Dropdown(choices=models_info[family][model_type][variation]["languages"], value=models_info[family][model_type][variation]["languages"][0], visible=True, interactive=True) | |
return "", gr.Dropdown(visible=False) | |
async def load_model_async(model_name, progress=gr.Progress()): | |
try: | |
if model_name not in model_cache: | |
progress(0.1, f"Chargement du tokenizer pour {model_name}...") | |
tokenizer = await asyncio.to_thread(AutoTokenizer.from_pretrained, model_name) | |
progress(0.4, f"Chargement du modèle {model_name}...") | |
model = await asyncio.to_thread(AutoModelForCausalLM.from_pretrained, model_name, | |
torch_dtype=torch.float16, device_map="auto", low_cpu_mem_usage=True) | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model_cache[model_name] = (model, tokenizer) | |
progress(1.0, f"Modèle {model_name} chargé avec succès") | |
return f"Modèle {model_name} chargé avec succès" | |
except Exception as e: | |
return f"Erreur lors du chargement du modèle {model_name} : {str(e)}" | |
def set_language(lang): | |
return f"Langue sélectionnée : {lang}" | |
def ensure_token_display(token, tokenizer): | |
if token.isdigit() or (token.startswith('-') and token[1:].isdigit()): | |
return tokenizer.decode([int(token)]) | |
return token | |
async def analyze_next_token(model_name, input_text, temperature, top_p, top_k, progress=gr.Progress()): | |
if model_name not in model_cache: | |
return "Veuillez d'abord charger le modèle", None, None | |
model, tokenizer = model_cache[model_name] | |
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device) | |
try: | |
progress(0.5, "Analyse en cours...") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
last_token_logits = outputs.logits[0, -1, :] | |
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1) | |
top_k = min(10, top_k) | |
top_probs, top_indices = torch.topk(probabilities, top_k) | |
top_words = [ensure_token_display(tokenizer.decode([idx.item()]), tokenizer) for idx in top_indices] | |
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)} | |
prob_text = "Prochains tokens les plus probables :\n\n" | |
for word, prob in prob_data.items(): | |
prob_text += f"{word}: {prob:.2%}\n" | |
prob_plot = plot_probabilities(prob_data) | |
attention_plot = plot_attention(inputs["input_ids"][0].cpu(), last_token_logits.cpu(), tokenizer) | |
progress(1.0, "Analyse terminée") | |
return prob_text, attention_plot, prob_plot | |
except Exception as e: | |
return f"Erreur lors de l'analyse : {str(e)}", None, None | |
async def generate_text(model_name, input_text, temperature, top_p, top_k, progress=gr.Progress()): | |
if model_name not in model_cache: | |
return "Veuillez d'abord charger le modèle" | |
model, tokenizer = model_cache[model_name] | |
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device) | |
try: | |
progress(0.5, "Génération en cours...") | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=50, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k | |
) | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
progress(1.0, "Génération terminée") | |
return generated_text | |
except Exception as e: | |
return f"Erreur lors de la génération : {str(e)}" | |
def plot_probabilities(prob_data): | |
try: | |
words = list(prob_data.keys()) | |
probs = list(prob_data.values()) | |
fig, ax = plt.subplots(figsize=(12, 6)) | |
bars = ax.bar(range(len(words)), probs, color='lightgreen') | |
ax.set_title("Probabilités des tokens suivants les plus probables") | |
ax.set_xlabel("Tokens") | |
ax.set_ylabel("Probabilité") | |
ax.set_xticks(range(len(words))) | |
ax.set_xticklabels(words, rotation=45, ha='right') | |
for i, (bar, word) in enumerate(zip(bars, words)): | |
height = bar.get_height() | |
ax.text(i, height, f'{height:.2%}', | |
ha='center', va='bottom', rotation=0) | |
plt.tight_layout() | |
return fig | |
except Exception as e: | |
print(f"Erreur lors de la création du graphique : {str(e)}") | |
return None | |
def plot_attention(input_ids, last_token_logits, tokenizer): | |
try: | |
input_tokens = [ensure_token_display(tokenizer.decode([id]), tokenizer) for id in input_ids] | |
attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1) | |
top_k = min(len(input_tokens), 10) | |
top_attention_scores, _ = torch.topk(attention_scores, top_k) | |
fig, ax = plt.subplots(figsize=(14, 7)) | |
sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%') | |
ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10) | |
ax.set_yticklabels(["Attention"], rotation=0, fontsize=10) | |
ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16) | |
cbar = ax.collections[0].colorbar | |
cbar.set_label("Score d'attention", fontsize=12) | |
cbar.ax.tick_params(labelsize=10) | |
plt.tight_layout() | |
return fig | |
except Exception as e: | |
print(f"Erreur lors de la création du graphique d'attention : {str(e)}") | |
return None | |
def reset(): | |
global model_cache | |
for model in model_cache.values(): | |
del model | |
model_cache.clear() | |
torch.cuda.empty_cache() | |
gc.collect() | |
return ( | |
"", 1.0, 1.0, 50, None, None, None, None, | |
gr.Dropdown(choices=list(models_info.keys()), value=None, interactive=True), | |
gr.Dropdown(choices=[], value=None, interactive=False), | |
gr.Dropdown(choices=[], value=None, interactive=False), | |
"", gr.Dropdown(visible=False), "" | |
) | |
def reset_comparison(): | |
return [gr.Dropdown(choices=[], value=None) for _ in range(4)] + ["", "", gr.Dropdown(choices=[], value=None), 1.0, 1.0, 50, "", "", None, None, None, None] | |
async def compare_models(model1, model2, input_text, temp, top_p, top_k, progress=gr.Progress()): | |
if model1 not in model_cache or model2 not in model_cache: | |
return "Veuillez d'abord charger les deux modèles", "", None, None, None, None | |
progress(0.1, "Analyse du premier modèle...") | |
results1 = await analyze_next_token(model1, input_text, temp, top_p, top_k) | |
progress(0.4, "Analyse du second modèle...") | |
results2 = await analyze_next_token(model2, input_text, temp, top_p, top_k) | |
progress(1.0, "Comparaison terminée") | |
return ( | |
results1[0], results2[0], # Probabilités du prochain token | |
results1[2], results2[2], # Graphiques de probabilités | |
results1[1], results2[1] # Graphiques d'attention | |
) | |
with gr.Blocks() as demo: | |
gr.Markdown("# LLM&BIAS") | |
with gr.Tabs(): | |
with gr.Tab("Analyse individuelle"): | |
with gr.Accordion("Sélection du modèle", open=True): | |
with gr.Row(): | |
model_family = gr.Dropdown(choices=list(models_info.keys()), label="Famille de modèle", interactive=True) | |
model_type = gr.Dropdown(choices=[], label="Type de modèle", interactive=False) | |
model_variation = gr.Dropdown(choices=[], label="Variation du modèle", interactive=False) | |
selected_model = gr.Textbox(label="Modèle sélectionné", interactive=False) | |
load_button = gr.Button("Charger le modèle") | |
load_output = gr.Textbox(label="Statut du chargement") | |
language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False) | |
language_output = gr.Textbox(label="Langue sélectionnée") | |
with gr.Row(): | |
temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température") | |
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p") | |
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k") | |
input_text = gr.Textbox(label="Texte d'entrée", lines=3) | |
analyze_button = gr.Button("Analyser le prochain token") | |
next_token_probs = gr.Textbox(label="Probabilités du prochain token") | |
with gr.Row(): | |
attention_plot = gr.Plot(label="Visualisation de l'attention") | |
prob_plot = gr.Plot(label="Probabilités des tokens suivants") | |
generate_button = gr.Button("Générer le texte") | |
generated_text = gr.Textbox(label="Texte généré") | |
reset_button = gr.Button("Réinitialiser") | |
with gr.Tab("Comparaison de modèles"): | |
with gr.Row(): | |
model1_family = gr.Dropdown(choices=list(models_info.keys()), label="Famille du modèle 1", interactive=True) | |
model1_type = gr.Dropdown(choices=[], label="Type du modèle 1", interactive=False) | |
model1_variation = gr.Dropdown(choices=[], label="Variation du modèle 1", interactive=False) | |
with gr.Row(): | |
model2_family = gr.Dropdown(choices=list(models_info.keys()), label="Famille du modèle 2", interactive=True) | |
model2_type = gr.Dropdown(choices=[], label="Type du modèle 2", interactive=False) | |
model2_variation = gr.Dropdown(choices=[], label="Variation du modèle 2", interactive=False) | |
model1_selected = gr.Textbox(label="Modèle 1 sélectionné", interactive=False) | |
model2_selected = gr.Textbox(label="Modèle 2 sélectionné", interactive=False) | |
load_models_button = gr.Button("Charger les modèles") | |
load_models_output = gr.Textbox(label="Statut du chargement des modèles") | |
comparison_language = gr.Dropdown(label="Langue pour la comparaison", choices=[], interactive=False) | |
with gr.Row(): | |
comp_temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température") | |
comp_top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p") | |
comp_top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k") | |
comp_input_text = gr.Textbox(label="Texte d'entrée pour la comparaison", lines=3) | |
compare_button = gr.Button("Comparer les modèles") | |
with gr.Row(): | |
model1_output = gr.Textbox(label="Probabilités du Modèle 1", lines=10) | |
model2_output = gr.Textbox(label="Probabilités du Modèle 2", lines=10) | |
with gr.Row(): | |
model1_prob_plot = gr.Plot(label="Probabilités des tokens (Modèle 1)") | |
model2_prob_plot = gr.Plot(label="Probabilités des tokens (Modèle 2)") | |
with gr.Row(): | |
model1_attention_plot = gr.Plot(label="Attention (Modèle 1)") | |
model2_attention_plot = gr.Plot(label="Attention (Modèle 2)") | |
comp_reset_button = gr.Button("Réinitialiser la comparaison") | |
# Événements pour l'onglet d'analyse individuelle | |
model_family.change(update_model_type, inputs=[model_family], outputs=[model_type]) | |
model_type.change(update_model_variation, inputs=[model_family, model_type], outputs=[model_variation]) | |
model_variation.change(update_selected_model, inputs=[model_family, model_type, model_variation], outputs=[selected_model, language_dropdown]) | |
load_button.click(load_model_async, inputs=[selected_model], outputs=[load_output]) | |
language_dropdown.change(set_language, inputs=[language_dropdown], outputs=[language_output]) | |
analyze_button.click(analyze_next_token, inputs=[selected_model, input_text, temperature, top_p, top_k], outputs=[next_token_probs, attention_plot, prob_plot]) | |
generate_button.click(generate_text, inputs=[selected_model, input_text, temperature, top_p, top_k], outputs=[generated_text]) | |
reset_button.click(reset, outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text, model_family, model_type, model_variation, selected_model, language_dropdown, language_output]) | |
# Événements pour l'onglet de comparaison | |
model1_family.change(update_model_type, inputs=[model1_family], outputs=[model1_type]) | |
model1_type.change(update_model_variation, inputs=[model1_family, model1_type], outputs=[model1_variation]) | |
model1_variation.change(update_selected_model, inputs=[model1_family, model1_type, model1_variation], outputs=[model1_selected, comparison_language]) | |
model2_family.change(update_model_type, inputs=[model2_family], outputs=[model2_type]) | |
model2_type.change(update_model_variation, inputs=[model2_family, model2_type], outputs=[model2_variation]) | |
model2_variation.change(update_selected_model, inputs=[model2_family, model2_type, model2_variation], outputs=[model2_selected, comparison_language]) | |
async def load_both_models(model1, model2): | |
result1 = await load_model_async(model1) | |
result2 = await load_model_async(model2) | |
return f"Modèle 1: {result1}\nModèle 2: {result2}" | |
load_models_button.click(load_both_models, inputs=[model1_selected, model2_selected], outputs=[load_models_output]) | |
compare_button.click( | |
compare_models, | |
inputs=[model1_selected, model2_selected, comp_input_text, comp_temperature, comp_top_p, comp_top_k], | |
outputs=[model1_output, model2_output, model1_prob_plot, model2_prob_plot, model1_attention_plot, model2_attention_plot] | |
) | |
comp_reset_button.click( | |
reset_comparison, | |
outputs=[model1_type, model1_variation, model2_type, model2_variation, model1_selected, model2_selected, comparison_language, | |
comp_temperature, comp_top_p, comp_top_k, comp_input_text, model1_output, model2_output, | |
model1_prob_plot, model2_prob_plot, model1_attention_plot, model2_attention_plot] | |
) | |
if __name__ == "__main__": | |
demo.launch() |