Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from huggingface_hub import login | |
import os | |
# Authentification Hugging Face avec ton token d'accès | |
login(token=os.environ["HF_TOKEN"]) | |
# Liste des modèles disponibles | |
models = [ | |
"meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b", | |
"meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B", | |
"mistralai/Mistral-7B-v0.1", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mistral-7B-v0.3", | |
"google/gemma-2-2b", "google/gemma-2-9b", "google/gemma-2-27b", | |
"croissantllm/CroissantLLMBase" | |
] | |
# Variables pour le modèle et le tokenizer | |
model = None | |
tokenizer = None | |
def load_model(model_name): | |
global model, tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) | |
# Assurer que le token de padding est défini si nécessaire | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
model.config.pad_token_id = model.config.eos_token_id | |
return f"Modèle {model_name} chargé avec succès sur GPU." | |
def generate_text(input_text, temperature, top_p, top_k): | |
global model, tokenizer | |
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=50, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
output_attentions=True, | |
return_dict_in_generate=True | |
) | |
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) | |
# Logits et probabilités du dernier token généré | |
last_token_logits = outputs.scores[-1][0] | |
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1) | |
# Top 5 des mots les plus probables | |
top_probs, top_indices = torch.topk(probabilities, 5) | |
top_words = [tokenizer.decode([idx.item()]) for idx in top_indices] | |
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)} | |
# Extraction des attentions | |
attentions = torch.cat([att[-1].mean(dim=1) for att in outputs.attentions], dim=0).cpu().numpy() | |
attention_data = { | |
'attention': attentions, | |
'tokens': tokenizer.convert_ids_to_tokens(outputs.sequences[0]) | |
} | |
return generated_text, plot_attention(attention_data), plot_probabilities(prob_data) | |
def plot_attention(attention_data): | |
attention = attention_data['attention'] | |
tokens = attention_data['tokens'] | |
fig, ax = plt.subplots(figsize=(10, 10)) | |
im = ax.imshow(attention, cmap='viridis') | |
plt.colorbar(im) | |
ax.set_xticks(range(len(tokens))) | |
ax.set_yticks(range(len(tokens))) | |
ax.set_xticklabels(tokens, rotation=90) | |
ax.set_yticklabels(tokens) | |
ax.set_title("Carte d'attention") | |
plt.tight_layout() | |
return fig | |
def plot_probabilities(prob_data): | |
words = list(prob_data.keys()) | |
probs = list(prob_data.values()) | |
fig, ax = plt.subplots(figsize=(10, 5)) | |
ax.bar(words, probs) | |
ax.set_title("Probabilités des tokens suivants les plus probables") | |
ax.set_xlabel("Tokens") | |
ax.set_ylabel("Probabilité") | |
plt.xticks(rotation=45) | |
plt.tight_layout() | |
return fig | |
def reset(): | |
return "", 1.0, 1.0, 50, None, None, None | |
# Interface Gradio | |
with gr.Blocks() as demo: | |
gr.Markdown("# Générateur de texte avec visualisation d'attention") | |
with gr.Accordion("Sélection du modèle"): | |
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle") | |
load_button = gr.Button("Charger le modèle") | |
load_output = gr.Textbox(label="Statut du chargement") | |
with gr.Row(): | |
temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température") | |
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p") | |
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k") | |
input_text = gr.Textbox(label="Texte d'entrée") | |
generate_button = gr.Button("Générer") | |
output_text = gr.Textbox(label="Texte généré") | |
with gr.Row(): | |
attention_plot = gr.Plot(label="Visualisation de l'attention") | |
prob_plot = gr.Plot(label="Probabilités des tokens suivants") | |
reset_button = gr.Button("Réinitialiser") | |
# Association des actions avec les boutons | |
load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output]) | |
generate_button.click(generate_text, | |
inputs=[input_text, temperature, top_p, top_k], | |
outputs=[output_text, attention_plot, prob_plot]) | |
reset_button.click(reset, | |
outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot]) | |
# Lancement de l'application | |
demo.launch() |