Spaces:
Paused
Paused
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import login | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
# Login to Hugging Face with token | |
login(token=os.environ["HF_TOKEN"]) | |
MODEL_LIST = [ | |
"meta-llama/Llama-2-13b-hf", | |
"meta-llama/Llama-2-7b-hf", | |
"meta-llama/Llama-2-70b-hf", | |
"meta-llama/Meta-Llama-3-8B", | |
"meta-llama/Llama-3.2-3B", | |
"meta-llama/Llama-3.1-8B", | |
"mistralai/Mistral-7B-v0.1", | |
"mistralai/Mixtral-8x7B-v0.1", | |
"mistralai/Mistral-7B-v0.3", | |
"google/gemma-2-2b", | |
"google/gemma-2-9b", | |
"google/gemma-2-27b", | |
"croissantllm/CroissantLLMBase" | |
] | |
# Dictionnaire pour stocker les modèles et tokenizers déjà chargés | |
loaded_models = {} | |
# Charger le modèle | |
def load_model(model_name): | |
if model_name not in loaded_models: | |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") | |
loaded_models[model_name] = (model, tokenizer) | |
return loaded_models[model_name] | |
# Génération de texte et attention | |
def generate_text(model_name, input_text, temperature, top_p, top_k): | |
model, tokenizer = load_model(model_name) | |
inputs = tokenizer(input_text, return_tensors="pt").to("cuda") | |
# Génération du texte | |
output = model.generate(**inputs, max_new_tokens=50, temperature=temperature, top_p=top_p, top_k=top_k, output_attentions=True) | |
# Décodage de la sortie | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
# Affichage des mots les plus probables | |
last_token_logits = output.scores[-1][0] | |
probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1) | |
top_tokens = torch.topk(probabilities, k=5) | |
probable_words = [tokenizer.decode([token]) for token in top_tokens.indices] | |
return generated_text, probable_words | |
# Interface utilisateur Gradio | |
def reset_interface(): | |
return "", "", "", "" | |
def main(): | |
with gr.Blocks() as app: | |
with gr.Accordion("Choix du modèle", open=True): | |
model_name = gr.Dropdown(choices=MODEL_LIST, label="Modèles disponibles", value=MODEL_LIST[0]) | |
with gr.Row(): | |
input_text = gr.Textbox(label="Texte d'entrée", placeholder="Saisissez votre texte ici...") | |
with gr.Accordion("Paramètres", open=True): | |
temperature = gr.Slider(minimum=0, maximum=1, value=0.7, step=0.01, label="Température") | |
top_p = gr.Slider(minimum=0, maximum=1, value=0.9, step=0.01, label="Top_p") | |
top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top_k") | |
with gr.Row(): | |
generate_button = gr.Button("Lancer la génération") | |
reset_button = gr.Button("Réinitialiser") | |
generated_text_output = gr.Textbox(label="Texte généré", placeholder="Le texte généré s'affichera ici...") | |
probable_words_output = gr.Textbox(label="Mots les plus probables", placeholder="Les mots les plus probables apparaîtront ici...") | |
# Lancer la génération | |
generate_button.click(generate_text, inputs=[model_name, input_text, temperature, top_p, top_k], outputs=[generated_text_output, probable_words_output]) | |
# Réinitialiser | |
reset_button.click(reset_interface, outputs=[input_text, generated_text_output, probable_words_output]) | |
app.launch() | |
if __name__ == "__main__": | |
main() |