Woziii commited on
Commit
892a160
1 Parent(s): 756f692

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -6,9 +6,10 @@ import numpy as np
6
  from huggingface_hub import login
7
  import os
8
 
 
9
  login(token=os.environ["HF_TOKEN"])
10
 
11
- # Liste des modèles
12
  models = [
13
  "meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b",
14
  "meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B",
@@ -17,7 +18,7 @@ models = [
17
  "croissantllm/CroissantLLMBase"
18
  ]
19
 
20
- # Variables globales pour stocker le modèle et le tokenizer
21
  model = None
22
  tokenizer = None
23
 
@@ -26,14 +27,14 @@ def load_model(model_name):
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
27
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
28
 
29
- # Définir le token de padding s'il n'existe pas
30
  if tokenizer.pad_token is None:
31
  tokenizer.pad_token = tokenizer.eos_token
32
  model.config.pad_token_id = model.config.eos_token_id
33
 
34
  return f"Modèle {model_name} chargé avec succès sur GPU."
35
 
36
- def generate_text(input_text, temperature, top_p, top_k_value):
37
  global model, tokenizer
38
 
39
  inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
@@ -44,31 +45,25 @@ def generate_text(input_text, temperature, top_p, top_k_value):
44
  max_new_tokens=50,
45
  temperature=temperature,
46
  top_p=top_p,
47
- top_k=top_k_value,
48
  output_attentions=True,
49
- output_scores=True, # Activer les scores pour obtenir les logits
50
  return_dict_in_generate=True
51
  )
52
 
53
  generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
54
 
55
- # Obtenir les logits pour le dernier token généré
56
  last_token_logits = outputs.scores[-1][0]
57
-
58
- # Appliquer softmax pour obtenir les probabilités
59
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
60
 
61
- # Obtenir les top 5 tokens les plus probables
62
- top_k = 5
63
- top_probs, top_indices = torch.topk(probabilities, top_k)
64
  top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
65
 
66
- # Préparer les données pour le graphique des probabilités
67
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
68
 
69
- # Extraire les attentions
70
  attentions = torch.cat([att[-1].mean(dim=1) for att in outputs.attentions], dim=0).cpu().numpy()
71
-
72
  attention_data = {
73
  'attention': attentions,
74
  'tokens': tokenizer.convert_ids_to_tokens(outputs.sequences[0])
@@ -107,6 +102,7 @@ def plot_probabilities(prob_data):
107
  def reset():
108
  return "", 1.0, 1.0, 50, None, None, None
109
 
 
110
  with gr.Blocks() as demo:
111
  gr.Markdown("# Générateur de texte avec visualisation d'attention")
112
 
@@ -131,6 +127,7 @@ with gr.Blocks() as demo:
131
 
132
  reset_button = gr.Button("Réinitialiser")
133
 
 
134
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
135
  generate_button.click(generate_text,
136
  inputs=[input_text, temperature, top_p, top_k],
@@ -138,4 +135,5 @@ with gr.Blocks() as demo:
138
  reset_button.click(reset,
139
  outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
140
 
 
141
  demo.launch()
 
6
  from huggingface_hub import login
7
  import os
8
 
9
+ # Authentification Hugging Face avec ton token d'accès
10
  login(token=os.environ["HF_TOKEN"])
11
 
12
+ # Liste des modèles disponibles
13
  models = [
14
  "meta-llama/Llama-2-13b", "meta-llama/Llama-2-7b", "meta-llama/Llama-2-70b",
15
  "meta-llama/Meta-Llama-3-8B", "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.1-8B",
 
18
  "croissantllm/CroissantLLMBase"
19
  ]
20
 
21
+ # Variables pour le modèle et le tokenizer
22
  model = None
23
  tokenizer = None
24
 
 
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
29
 
30
+ # Assurer que le token de padding est défini si nécessaire
31
  if tokenizer.pad_token is None:
32
  tokenizer.pad_token = tokenizer.eos_token
33
  model.config.pad_token_id = model.config.eos_token_id
34
 
35
  return f"Modèle {model_name} chargé avec succès sur GPU."
36
 
37
+ def generate_text(input_text, temperature, top_p, top_k):
38
  global model, tokenizer
39
 
40
  inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
 
45
  max_new_tokens=50,
46
  temperature=temperature,
47
  top_p=top_p,
48
+ top_k=top_k,
49
  output_attentions=True,
 
50
  return_dict_in_generate=True
51
  )
52
 
53
  generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
54
 
55
+ # Logits et probabilités du dernier token généré
56
  last_token_logits = outputs.scores[-1][0]
 
 
57
  probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
58
 
59
+ # Top 5 des mots les plus probables
60
+ top_probs, top_indices = torch.topk(probabilities, 5)
 
61
  top_words = [tokenizer.decode([idx.item()]) for idx in top_indices]
62
 
 
63
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
64
 
65
+ # Extraction des attentions
66
  attentions = torch.cat([att[-1].mean(dim=1) for att in outputs.attentions], dim=0).cpu().numpy()
 
67
  attention_data = {
68
  'attention': attentions,
69
  'tokens': tokenizer.convert_ids_to_tokens(outputs.sequences[0])
 
102
  def reset():
103
  return "", 1.0, 1.0, 50, None, None, None
104
 
105
+ # Interface Gradio
106
  with gr.Blocks() as demo:
107
  gr.Markdown("# Générateur de texte avec visualisation d'attention")
108
 
 
127
 
128
  reset_button = gr.Button("Réinitialiser")
129
 
130
+ # Association des actions avec les boutons
131
  load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
132
  generate_button.click(generate_text,
133
  inputs=[input_text, temperature, top_p, top_k],
 
135
  reset_button.click(reset,
136
  outputs=[input_text, temperature, top_p, top_k, output_text, attention_plot, prob_plot])
137
 
138
+ # Lancement de l'application
139
  demo.launch()