Spaces:
Paused
Paused
File size: 5,085 Bytes
60b53a6 72dbc28 9787d82 892a160 72dbc28 2759f98 892a160 60b53a6 892a160 60b53a6 1acabf9 4f04d17 892a160 4f04d17 1acabf9 60b53a6 892a160 60b53a6 1acabf9 60b53a6 1acabf9 60b53a6 892a160 1acabf9 60b53a6 892a160 1acabf9 9787d82 892a160 9787d82 892a160 1acabf9 9787d82 1acabf9 9787d82 60b53a6 5efe227 9787d82 1acabf9 9787d82 60b53a6 5efe227 60b53a6 9787d82 5efe227 9787d82 5efe227 0db8079 60b53a6 892a160 60b53a6 9787d82 60b53a6 892a160 60b53a6 9787d82 60b53a6 9787d82 60b53a6 892a160 756f692 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import numpy as np
from huggingface_hub import login
import os
# Authentification Hugging Face avec ton token d'accès
login(token=os.environ["HF_TOKEN"])
# Liste des modèles disponibles
models = [
"meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b",
"meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B",
"mistralai/Mistral-7B-v0.1", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mistral-7B-v0.3",
"google/gemma-2-2b", "google/gemma-2-9b", "google/gemma-2-27b",
"croissantllm/CroissantLLMBase"
]
# Variables pour le modèle et le tokenizer
model = None
tokenizer = None
def load_model(model_name):
global model, tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
# Assurer que le token de padding est défini si nécessaire
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
return f"Modèle {model_name} chargé avec succès sur GPU."
def generate_text(input_text, temperature, top_p, top_k):
global model, tokenizer
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
temperature=temperature,
top_p=top_p,
top_k=top_k,
output_attentions=True,
return_dict_in_generate=True
)
generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
# Logits et probabilités du dernier token généré
last_token_logits = outputs.scores[-1][0]
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
# Top 5 des mots les plus probables
top_probs, top_indices = torch.topk(probabilities, 5)
top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
# Extraction des attentions
attentions = torch.cat([att[-1].mean(dim=1) for att in outputs.attentions], dim=0).cpu().numpy()
attention_data = {
'attention': attentions,
'tokens': tokenizer.convert_ids_to_tokens(outputs.sequences[0])
}
return generated_text, plot_attention(attention_data), plot_probabilities(prob_data)
def plot_attention(attention_data):
attention = attention_data['attention']
tokens = attention_data['tokens']
fig, ax = plt.subplots(figsize=(10, 10))
im = ax.imshow(attention, cmap='viridis')
plt.colorbar(im)
ax.set_xticks(range(len(tokens)))
ax.set_yticks(range(len(tokens)))
ax.set_xticklabels(tokens, rotation=90)
ax.set_yticklabels(tokens)
ax.set_title("Carte d'attention")
plt.tight_layout()
return fig
def plot_probabilities(prob_data):
words = list(prob_data.keys())
probs = list(prob_data.values())
fig, ax = plt.subplots(figsize=(10, 5))
ax.bar(words, probs)
ax.set_title("Probabilités des tokens suivants les plus probables")
ax.set_xlabel("Tokens")
ax.set_ylabel("Probabilité")
plt.xticks(rotation=45)
plt.tight_layout()
return fig
def reset():
return "", 1.0, 1.0, 50, None, None, None
# Interface Gradio
with gr.Blocks() as demo:
gr.Markdown("# Générateur de texte avec visualisation d'attention")
with gr.Accordion("Sélection du modèle"):
model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
load_button = gr.Button("Charger le modèle")
load_output = gr.Textbox(label="Statut du chargement")
with gr.Row():
temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k")
input_text = gr.Textbox(label="Texte d'entrée")
generate_button = gr.Button("Générer")
output_text = gr.Textbox(label="Texte généré")
with gr.Row():
attention_plot = gr.Plot(label="Visualisation de l'attention")
prob_plot = gr.Plot(label="Probabilités des tokens suivants")
reset_button = gr.Button("Réinitialiser")
# Association des actions avec les boutons
load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
generate_button.click(generate_text,
inputs=[input_text, temperature, top_p, top_k],
outputs=[output_text, attention_plot, prob_plot])
reset_button.click(reset,
outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
# Lancement de l'application
demo.launch() |