Woziii commited on
Commit
65194e4
1 Parent(s): 63dd69c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -32
app.py CHANGED
@@ -12,9 +12,9 @@ import time
12
  login(token=os.environ["HF_TOKEN"])
13
 
14
  # Structure hiérarchique des modèles
15
- models_hierarchy = {
16
  "meta-llama": {
17
- "Llama-2": ["7b", "13b", "70b"],
18
  "Llama-3": ["8B", "3.2-3B", "3.1-8B"]
19
  },
20
  "mistralai": {
@@ -22,7 +22,7 @@ models_hierarchy = {
22
  "Mixtral": ["8x7B-v0.1"]
23
  },
24
  "google": {
25
- "gemma": ["2b", "9b", "27b"]
26
  },
27
  "croissantllm": {
28
  "CroissantLLM": ["Base"]
@@ -31,35 +31,35 @@ models_hierarchy = {
31
 
32
  # Langues supportées par modèle
33
  models_and_languages = {
34
- "meta-llama/Llama-2-7b-hf": ["en"],
35
- "meta-llama/Llama-2-13b-hf": ["en"],
36
- "meta-llama/Llama-2-70b-hf": ["en"],
37
- "meta-llama/Meta-Llama-3-8B": ["en"],
38
  "meta-llama/Llama-3.2-3B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
39
  "meta-llama/Llama-3.1-8B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
40
  "mistralai/Mistral-7B-v0.1": ["en"],
41
- "mistralai/Mixtral-8x7B-v0.1": ["en", "fr", "it", "de", "es"],
42
  "mistralai/Mistral-7B-v0.3": ["en"],
43
- "google/gemma-2-2b": ["en"],
44
- "google/gemma-2-9b": ["en"],
45
- "google/gemma-2-27b": ["en"],
 
46
  "croissantllm/CroissantLLMBase": ["en", "fr"]
47
  }
48
 
49
  # Paramètres recommandés pour chaque modèle
50
  model_parameters = {
51
- "meta-llama/Llama-2-13b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
52
- "meta-llama/Llama-2-7b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
53
- "meta-llama/Llama-2-70b-hf": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
54
- "meta-llama/Meta-Llama-3-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
55
  "meta-llama/Llama-3.2-3B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
56
  "meta-llama/Llama-3.1-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
57
  "mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
58
- "mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50},
59
  "mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
60
- "google/gemma-2-2b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
61
- "google/gemma-2-9b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
62
- "google/gemma-2-27b": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
 
63
  "croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50}
64
  }
65
 
@@ -69,24 +69,20 @@ tokenizer = None
69
  selected_language = None
70
 
71
  def update_model_choices(company):
72
- return gr.Dropdown(choices=list(models_hierarchy[company].keys()), value=None)
73
 
74
  def update_variation_choices(company, model_name):
75
- return gr.Dropdown(choices=models_hierarchy[company][model_name], value=None)
76
 
77
  def load_model(company, model_name, variation, progress=gr.Progress()):
78
  global model, tokenizer
79
-
80
  full_model_name = f"{company}/{model_name}-{variation}"
81
- if full_model_name not in models_and_languages:
82
- full_model_name = f"{company}/{model_name}{variation}"
83
 
84
  try:
85
  progress(0, desc="Chargement du tokenizer")
86
  tokenizer = AutoTokenizer.from_pretrained(full_model_name)
87
  progress(0.5, desc="Chargement du modèle")
88
 
89
- # Configurations spécifiques par modèle
90
  if "mixtral" in full_model_name.lower():
91
  model = AutoModelForCausalLM.from_pretrained(
92
  full_model_name,
@@ -106,9 +102,8 @@ def load_model(company, model_name, variation, progress=gr.Progress()):
106
 
107
  progress(1.0, desc="Modèle chargé")
108
  available_languages = models_and_languages[full_model_name]
109
-
110
- # Mise à jour des sliders avec les valeurs recommandées
111
  params = model_parameters[full_model_name]
 
112
  return (
113
  f"Modèle {full_model_name} chargé avec succès. Langues disponibles : {', '.join(available_languages)}",
114
  gr.Dropdown(choices=available_languages, value=available_languages[0], visible=True, interactive=True),
@@ -119,15 +114,129 @@ def load_model(company, model_name, variation, progress=gr.Progress()):
119
  except Exception as e:
120
  return f"Erreur lors du chargement du modèle : {str(e)}", gr.Dropdown(visible=False), None, None, None
121
 
122
- # Le reste du code reste inchangé...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  with gr.Blocks() as demo:
125
  gr.Markdown("# LLM&BIAS")
126
 
127
  with gr.Accordion("Sélection du modèle"):
128
- company_dropdown = gr.Dropdown(choices=list(models_hierarchy.keys()), label="Choisissez une société")
129
- model_dropdown = gr.Dropdown(label="Choisissez un modèle", choices=[])
130
- variation_dropdown = gr.Dropdown(label="Choisissez une variation", choices=[])
131
  load_button = gr.Button("Charger le modèle")
132
  load_output = gr.Textbox(label="Statut du chargement")
133
  language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False)
@@ -156,7 +265,7 @@ with gr.Blocks() as demo:
156
  model_dropdown.change(update_variation_choices, inputs=[company_dropdown, model_dropdown], outputs=[variation_dropdown])
157
  load_button.click(load_model,
158
  inputs=[company_dropdown, model_dropdown, variation_dropdown],
159
- outputs=[load_output, language_dropdown, temperature, top_p, top_k])
160
  language_dropdown.change(set_language, inputs=[language_dropdown], outputs=[language_output])
161
  analyze_button.click(analyze_next_token,
162
  inputs=[input_text, temperature, top_p, top_k],
 
12
  login(token=os.environ["HF_TOKEN"])
13
 
14
  # Structure hiérarchique des modèles
15
+ model_hierarchy = {
16
  "meta-llama": {
17
+ "Llama-2": ["7B", "13B", "70B"],
18
  "Llama-3": ["8B", "3.2-3B", "3.1-8B"]
19
  },
20
  "mistralai": {
 
22
  "Mixtral": ["8x7B-v0.1"]
23
  },
24
  "google": {
25
+ "Gemma": ["2B", "9B", "27B"]
26
  },
27
  "croissantllm": {
28
  "CroissantLLM": ["Base"]
 
31
 
32
  # Langues supportées par modèle
33
  models_and_languages = {
34
+ "meta-llama/Llama-2-7B": ["en"],
35
+ "meta-llama/Llama-2-13B": ["en"],
36
+ "meta-llama/Llama-2-70B": ["en"],
37
+ "meta-llama/Llama-3-8B": ["en"],
38
  "meta-llama/Llama-3.2-3B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
39
  "meta-llama/Llama-3.1-8B": ["en", "de", "fr", "it", "pt", "hi", "es", "th"],
40
  "mistralai/Mistral-7B-v0.1": ["en"],
 
41
  "mistralai/Mistral-7B-v0.3": ["en"],
42
+ "mistralai/Mixtral-8x7B-v0.1": ["en", "fr", "it", "de", "es"],
43
+ "google/Gemma-2B": ["en"],
44
+ "google/Gemma-9B": ["en"],
45
+ "google/Gemma-27B": ["en"],
46
  "croissantllm/CroissantLLMBase": ["en", "fr"]
47
  }
48
 
49
  # Paramètres recommandés pour chaque modèle
50
  model_parameters = {
51
+ "meta-llama/Llama-2-7B": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
52
+ "meta-llama/Llama-2-13B": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
53
+ "meta-llama/Llama-2-70B": {"temperature": 0.8, "top_p": 0.9, "top_k": 40},
54
+ "meta-llama/Llama-3-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
55
  "meta-llama/Llama-3.2-3B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
56
  "meta-llama/Llama-3.1-8B": {"temperature": 0.75, "top_p": 0.9, "top_k": 50},
57
  "mistralai/Mistral-7B-v0.1": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
 
58
  "mistralai/Mistral-7B-v0.3": {"temperature": 0.7, "top_p": 0.9, "top_k": 50},
59
+ "mistralai/Mixtral-8x7B-v0.1": {"temperature": 0.8, "top_p": 0.95, "top_k": 50},
60
+ "google/Gemma-2B": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
61
+ "google/Gemma-9B": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
62
+ "google/Gemma-27B": {"temperature": 0.7, "top_p": 0.95, "top_k": 40},
63
  "croissantllm/CroissantLLMBase": {"temperature": 0.8, "top_p": 0.92, "top_k": 50}
64
  }
65
 
 
69
  selected_language = None
70
 
71
  def update_model_choices(company):
72
+ return gr.Dropdown(choices=list(model_hierarchy[company].keys()), value=None)
73
 
74
  def update_variation_choices(company, model_name):
75
+ return gr.Dropdown(choices=model_hierarchy[company][model_name], value=None)
76
 
77
  def load_model(company, model_name, variation, progress=gr.Progress()):
78
  global model, tokenizer
 
79
  full_model_name = f"{company}/{model_name}-{variation}"
 
 
80
 
81
  try:
82
  progress(0, desc="Chargement du tokenizer")
83
  tokenizer = AutoTokenizer.from_pretrained(full_model_name)
84
  progress(0.5, desc="Chargement du modèle")
85
 
 
86
  if "mixtral" in full_model_name.lower():
87
  model = AutoModelForCausalLM.from_pretrained(
88
  full_model_name,
 
102
 
103
  progress(1.0, desc="Modèle chargé")
104
  available_languages = models_and_languages[full_model_name]
 
 
105
  params = model_parameters[full_model_name]
106
+
107
  return (
108
  f"Modèle {full_model_name} chargé avec succès. Langues disponibles : {', '.join(available_languages)}",
109
  gr.Dropdown(choices=available_languages, value=available_languages[0], visible=True, interactive=True),
 
114
  except Exception as e:
115
  return f"Erreur lors du chargement du modèle : {str(e)}", gr.Dropdown(visible=False), None, None, None
116
 
117
+ def set_language(lang):
118
+ global selected_language
119
+ selected_language = lang
120
+ return f"Langue sélectionnée : {lang}"
121
+
122
+ def ensure_token_display(token):
123
+ if token.isdigit() or (token.startswith('-') and token[1:].isdigit()):
124
+ return tokenizer.decode([int(token)])
125
+ return token
126
+
127
+ def analyze_next_token(input_text, temperature, top_p, top_k):
128
+ global model, tokenizer, selected_language
129
+
130
+ if model is None or tokenizer is None:
131
+ return "Veuillez d'abord charger un modèle.", None, None
132
+
133
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
134
+
135
+ try:
136
+ with torch.no_grad():
137
+ outputs = model(**inputs)
138
+
139
+ last_token_logits = outputs.logits[0, -1, :]
140
+ probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
141
+
142
+ top_k = 10
143
+ top_probs, top_indices = torch.topk(probabilities, top_k)
144
+ top_words = [ensure_token_display(tokenizer.decode([idx.item()])) for idx in top_indices]
145
+ prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
146
+
147
+ prob_text = "Prochains tokens les plus probables :\n\n"
148
+ for word, prob in prob_data.items():
149
+ prob_text += f"{word}: {prob:.2%}\n"
150
+
151
+ prob_plot = plot_probabilities(prob_data)
152
+ attention_plot = plot_attention(inputs["input_ids"][0].cpu(), last_token_logits.cpu())
153
+
154
+ return prob_text, attention_plot, prob_plot
155
+ except Exception as e:
156
+ return f"Erreur lors de l'analyse : {str(e)}", None, None
157
+
158
+ def generate_text(input_text, temperature, top_p, top_k):
159
+ global model, tokenizer, selected_language
160
+
161
+ if model is None or tokenizer is None:
162
+ return "Veuillez d'abord charger un modèle."
163
+
164
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
165
+
166
+ try:
167
+ with torch.no_grad():
168
+ outputs = model.generate(
169
+ **inputs,
170
+ max_new_tokens=10,
171
+ temperature=temperature,
172
+ top_p=top_p,
173
+ top_k=top_k
174
+ )
175
+
176
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
177
+ return generated_text
178
+ except Exception as e:
179
+ return f"Erreur lors de la génération : {str(e)}"
180
+
181
+ def plot_probabilities(prob_data):
182
+ words = list(prob_data.keys())
183
+ probs = list(prob_data.values())
184
+
185
+ fig, ax = plt.subplots(figsize=(12, 6))
186
+ bars = ax.bar(range(len(words)), probs, color='lightgreen')
187
+ ax.set_title("Probabilités des tokens suivants les plus probables")
188
+ ax.set_xlabel("Tokens")
189
+ ax.set_ylabel("Probabilité")
190
+
191
+ ax.set_xticks(range(len(words)))
192
+ ax.set_xticklabels(words, rotation=45, ha='right')
193
+
194
+ for i, (bar, word) in enumerate(zip(bars, words)):
195
+ height = bar.get_height()
196
+ ax.text(i, height, f'{height:.2%}',
197
+ ha='center', va='bottom', rotation=0)
198
+
199
+ plt.tight_layout()
200
+ return fig
201
+
202
+ def plot_attention(input_ids, last_token_logits):
203
+ input_tokens = [ensure_token_display(tokenizer.decode([id])) for id in input_ids]
204
+ attention_scores = torch.nn.functional.softmax(last_token_logits, dim=-1)
205
+ top_k = min(len(input_tokens), 10)
206
+ top_attention_scores, _ = torch.topk(attention_scores, top_k)
207
+
208
+ fig, ax = plt.subplots(figsize=(14, 7))
209
+ sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
210
+ ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
211
+ ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
212
+ ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
213
+
214
+ cbar = ax.collections[0].colorbar
215
+ cbar.set_label("Score d'attention", fontsize=12)
216
+ cbar.ax.tick_params(labelsize=10)
217
+
218
+ plt.tight_layout()
219
+ return fig
220
+
221
+ def reset():
222
+ global model, tokenizer, selected_language
223
+ model = None
224
+ tokenizer = None
225
+ selected_language = None
226
+ return (
227
+ gr.Dropdown(choices=list(model_hierarchy.keys()), value=None),
228
+ gr.Dropdown(visible=False),
229
+ gr.Dropdown(visible=False),
230
+ "", 1.0, 1.0, 50, None, None, None, None, gr.Dropdown(visible=False), ""
231
+ )
232
 
233
  with gr.Blocks() as demo:
234
  gr.Markdown("# LLM&BIAS")
235
 
236
  with gr.Accordion("Sélection du modèle"):
237
+ company_dropdown = gr.Dropdown(choices=list(model_hierarchy.keys()), label="Choisissez une société")
238
+ model_dropdown = gr.Dropdown(label="Choisissez un modèle", visible=False)
239
+ variation_dropdown = gr.Dropdown(label="Choisissez une variation", visible=False)
240
  load_button = gr.Button("Charger le modèle")
241
  load_output = gr.Textbox(label="Statut du chargement")
242
  language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False)
 
265
  model_dropdown.change(update_variation_choices, inputs=[company_dropdown, model_dropdown], outputs=[variation_dropdown])
266
  load_button.click(load_model,
267
  inputs=[company_dropdown, model_dropdown, variation_dropdown],
268
+ outputs=[load_output, language_dropdown])
269
  language_dropdown.change(set_language, inputs=[language_dropdown], outputs=[language_output])
270
  analyze_button.click(analyze_next_token,
271
  inputs=[input_text, temperature, top_p, top_k],