Woziii commited on
Commit
6cca076
1 Parent(s): 7e2f9cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -62
app.py CHANGED
@@ -7,30 +7,12 @@ import matplotlib.pyplot as plt
7
  import seaborn as sns
8
  import numpy as np
9
  import time
10
- from langdetect import detect
11
 
12
  # Authentification
13
  login(token=os.environ["HF_TOKEN"])
14
 
15
- # Liste des modèles
16
- models = [
17
- "meta-llama/Llama-2-13b-hf",
18
- "meta-llama/Llama-2-7b-hf",
19
- "meta-llama/Llama-2-70b-hf",
20
- "meta-llama/Meta-Llama-3-8B",
21
- "meta-llama/Llama-3.2-3B",
22
- "meta-llama/Llama-3.1-8B",
23
- "mistralai/Mistral-7B-v0.1",
24
- "mistralai/Mixtral-8x7B-v0.1",
25
- "mistralai/Mistral-7B-v0.3",
26
- "google/gemma-2-2b",
27
- "google/gemma-2-9b",
28
- "google/gemma-2-27b",
29
- "croissantllm/CroissantLLMBase"
30
- ]
31
-
32
- # Dictionnaire des langues supportées par modèle
33
- model_languages = {
34
  "meta-llama/Llama-2-13b-hf": ["en"],
35
  "meta-llama/Llama-2-7b-hf": ["en"],
36
  "meta-llama/Llama-2-70b-hf": ["en"],
@@ -49,6 +31,7 @@ model_languages = {
49
  # Variables globales
50
  model = None
51
  tokenizer = None
 
52
 
53
  def load_model(model_name, progress=gr.Progress()):
54
  global model, tokenizer
@@ -57,21 +40,40 @@ def load_model(model_name, progress=gr.Progress()):
57
  tokenizer = AutoTokenizer.from_pretrained(model_name)
58
  progress(0.5, desc="Chargement du modèle")
59
 
60
- # Configuration générique pour tous les modèles
61
- model = AutoModelForCausalLM.from_pretrained(
62
- model_name,
63
- torch_dtype=torch.float16,
64
- device_map="auto",
65
- low_cpu_mem_usage=True
66
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  if tokenizer.pad_token is None:
69
  tokenizer.pad_token = tokenizer.eos_token
70
 
71
  progress(1.0, desc="Modèle chargé")
72
- return f"Modèle {model_name} chargé avec succès."
 
73
  except Exception as e:
74
- return f"Erreur lors du chargement du modèle : {str(e)}"
 
 
 
 
 
75
 
76
  def ensure_token_display(token):
77
  """Assure que le token est affiché correctement."""
@@ -80,29 +82,23 @@ def ensure_token_display(token):
80
  return token
81
 
82
  def analyze_next_token(input_text, temperature, top_p, top_k):
83
- global model, tokenizer
84
 
85
  if model is None or tokenizer is None:
86
  return "Veuillez d'abord charger un modèle.", None, None
87
 
88
- # Détection de la langue
89
- detected_lang = detect(input_text)
90
- if detected_lang not in model_languages.get(model.config._name_or_path, []):
91
- return f"Langue détectée ({detected_lang}) non supportée par ce modèle.", None, None
92
-
93
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
94
 
95
  try:
96
  with torch.no_grad():
97
  outputs = model(**inputs)
98
 
99
  last_token_logits = outputs.logits[0, -1, :]
100
- probabilities = torch.nn.functional.softmax(last_token_logits / temperature, dim=-1)
101
 
102
- top_k = min(top_k, probabilities.size(-1))
103
  top_probs, top_indices = torch.topk(probabilities, top_k)
104
  top_words = [ensure_token_display(tokenizer.decode([idx.item()])) for idx in top_indices]
105
-
106
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
107
 
108
  prob_text = "Prochains tokens les plus probables :\n\n"
@@ -117,27 +113,22 @@ def analyze_next_token(input_text, temperature, top_p, top_k):
117
  return f"Erreur lors de l'analyse : {str(e)}", None, None
118
 
119
  def generate_text(input_text, temperature, top_p, top_k):
120
- global model, tokenizer
121
 
122
  if model is None or tokenizer is None:
123
  return "Veuillez d'abord charger un modèle."
124
 
125
- # Détection de la langue
126
- detected_lang = detect(input_text)
127
- if detected_lang not in model_languages.get(model.config._name_or_path, []):
128
- return f"Langue détectée ({detected_lang}) non supportée par ce modèle."
129
-
130
- inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device)
131
 
132
  try:
133
- outputs = model.generate(
134
- **inputs,
135
- max_new_tokens=50,
136
- do_sample=True,
137
- temperature=temperature,
138
- top_p=top_p,
139
- top_k=top_k
140
- )
141
 
142
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
143
  return generated_text
@@ -172,7 +163,7 @@ def plot_attention(input_ids, last_token_logits):
172
  top_attention_scores, _ = torch.topk(attention_scores, top_k)
173
 
174
  fig, ax = plt.subplots(figsize=(14, 7))
175
- sns.heatmap(top_attention_scores.unsqueeze(0).cpu().numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
176
  ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
177
  ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
178
  ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
@@ -185,18 +176,21 @@ def plot_attention(input_ids, last_token_logits):
185
  return fig
186
 
187
  def reset():
188
- global model, tokenizer
189
  model = None
190
  tokenizer = None
191
- return "", 1.0, 1.0, 50, None, None, None, None
 
192
 
193
  with gr.Blocks() as demo:
194
- gr.Markdown("# Analyse et génération de texte avec LLM")
195
 
196
  with gr.Accordion("Sélection du modèle"):
197
- model_dropdown = gr.Dropdown(choices=models, label="Choisissez un modèle")
198
  load_button = gr.Button("Charger le modèle")
199
  load_output = gr.Textbox(label="Statut du chargement")
 
 
200
 
201
  with gr.Row():
202
  temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
@@ -212,12 +206,13 @@ with gr.Blocks() as demo:
212
  attention_plot = gr.Plot(label="Visualisation de l'attention")
213
  prob_plot = gr.Plot(label="Probabilités des tokens suivants")
214
 
215
- generate_button = gr.Button("Générer la suite du texte")
216
  generated_text = gr.Textbox(label="Texte généré")
217
 
218
  reset_button = gr.Button("Réinitialiser")
219
 
220
- load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output])
 
221
  analyze_button.click(analyze_next_token,
222
  inputs=[input_text, temperature, top_p, top_k],
223
  outputs=[next_token_probs, attention_plot, prob_plot])
@@ -225,7 +220,7 @@ with gr.Blocks() as demo:
225
  inputs=[input_text, temperature, top_p, top_k],
226
  outputs=[generated_text])
227
  reset_button.click(reset,
228
- outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text])
229
 
230
  if __name__ == "__main__":
231
  demo.launch()
 
7
  import seaborn as sns
8
  import numpy as np
9
  import time
 
10
 
11
  # Authentification
12
  login(token=os.environ["HF_TOKEN"])
13
 
14
+ # Liste des modèles et leurs langues supportées
15
+ models_and_languages = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  "meta-llama/Llama-2-13b-hf": ["en"],
17
  "meta-llama/Llama-2-7b-hf": ["en"],
18
  "meta-llama/Llama-2-70b-hf": ["en"],
 
31
  # Variables globales
32
  model = None
33
  tokenizer = None
34
+ selected_language = None
35
 
36
  def load_model(model_name, progress=gr.Progress()):
37
  global model, tokenizer
 
40
  tokenizer = AutoTokenizer.from_pretrained(model_name)
41
  progress(0.5, desc="Chargement du modèle")
42
 
43
+ # Configurations spécifiques par modèle
44
+ if "mixtral" in model_name.lower():
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ model_name,
47
+ torch_dtype=torch.float16,
48
+ device_map="auto",
49
+ load_in_8bit=True
50
+ )
51
+ elif "llama" in model_name.lower() or "mistral" in model_name.lower():
52
+ model = AutoModelForCausalLM.from_pretrained(
53
+ model_name,
54
+ torch_dtype=torch.float16,
55
+ device_map="auto"
56
+ )
57
+ else:
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ model_name,
60
+ torch_dtype=torch.float16,
61
+ device_map="auto"
62
+ )
63
 
64
  if tokenizer.pad_token is None:
65
  tokenizer.pad_token = tokenizer.eos_token
66
 
67
  progress(1.0, desc="Modèle chargé")
68
+ available_languages = models_and_languages[model_name]
69
+ return f"Modèle {model_name} chargé avec succès. Langues disponibles : {', '.join(available_languages)}", gr.Dropdown.update(choices=available_languages, value=available_languages[0], visible=True)
70
  except Exception as e:
71
+ return f"Erreur lors du chargement du modèle : {str(e)}", gr.Dropdown.update(visible=False)
72
+
73
+ def set_language(lang):
74
+ global selected_language
75
+ selected_language = lang
76
+ return f"Langue sélectionnée : {lang}"
77
 
78
  def ensure_token_display(token):
79
  """Assure que le token est affiché correctement."""
 
82
  return token
83
 
84
  def analyze_next_token(input_text, temperature, top_p, top_k):
85
+ global model, tokenizer, selected_language
86
 
87
  if model is None or tokenizer is None:
88
  return "Veuillez d'abord charger un modèle.", None, None
89
 
90
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
 
 
 
 
91
 
92
  try:
93
  with torch.no_grad():
94
  outputs = model(**inputs)
95
 
96
  last_token_logits = outputs.logits[0, -1, :]
97
+ probabilities = torch.nn.functional.softmax(last_token_logits, dim=-1)
98
 
99
+ top_k = 10
100
  top_probs, top_indices = torch.topk(probabilities, top_k)
101
  top_words = [ensure_token_display(tokenizer.decode([idx.item()])) for idx in top_indices]
 
102
  prob_data = {word: prob.item() for word, prob in zip(top_words, top_probs)}
103
 
104
  prob_text = "Prochains tokens les plus probables :\n\n"
 
113
  return f"Erreur lors de l'analyse : {str(e)}", None, None
114
 
115
  def generate_text(input_text, temperature, top_p, top_k):
116
+ global model, tokenizer, selected_language
117
 
118
  if model is None or tokenizer is None:
119
  return "Veuillez d'abord charger un modèle."
120
 
121
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
 
 
 
 
 
122
 
123
  try:
124
+ with torch.no_grad():
125
+ outputs = model.generate(
126
+ **inputs,
127
+ max_new_tokens=1,
128
+ temperature=temperature,
129
+ top_p=top_p,
130
+ top_k=top_k
131
+ )
132
 
133
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
134
  return generated_text
 
163
  top_attention_scores, _ = torch.topk(attention_scores, top_k)
164
 
165
  fig, ax = plt.subplots(figsize=(14, 7))
166
+ sns.heatmap(top_attention_scores.unsqueeze(0).numpy(), annot=True, cmap="YlOrRd", cbar=True, ax=ax, fmt='.2%')
167
  ax.set_xticklabels(input_tokens[-top_k:], rotation=45, ha="right", fontsize=10)
168
  ax.set_yticklabels(["Attention"], rotation=0, fontsize=10)
169
  ax.set_title("Scores d'attention pour les derniers tokens", fontsize=16)
 
176
  return fig
177
 
178
  def reset():
179
+ global model, tokenizer, selected_language
180
  model = None
181
  tokenizer = None
182
+ selected_language = None
183
+ return "", 1.0, 1.0, 50, None, None, None, None, gr.Dropdown.update(visible=False), ""
184
 
185
  with gr.Blocks() as demo:
186
+ gr.Markdown("# Analyse et génération de texte")
187
 
188
  with gr.Accordion("Sélection du modèle"):
189
+ model_dropdown = gr.Dropdown(choices=list(models_and_languages.keys()), label="Choisissez un modèle")
190
  load_button = gr.Button("Charger le modèle")
191
  load_output = gr.Textbox(label="Statut du chargement")
192
+ language_dropdown = gr.Dropdown(label="Choisissez une langue", visible=False)
193
+ language_output = gr.Textbox(label="Langue sélectionnée")
194
 
195
  with gr.Row():
196
  temperature = gr.Slider(0.1, 2.0, value=1.0, label="Température")
 
206
  attention_plot = gr.Plot(label="Visualisation de l'attention")
207
  prob_plot = gr.Plot(label="Probabilités des tokens suivants")
208
 
209
+ generate_button = gr.Button("Générer le prochain mot")
210
  generated_text = gr.Textbox(label="Texte généré")
211
 
212
  reset_button = gr.Button("Réinitialiser")
213
 
214
+ load_button.click(load_model, inputs=[model_dropdown], outputs=[load_output, language_dropdown])
215
+ language_dropdown.change(set_language, inputs=[language_dropdown], outputs=[language_output])
216
  analyze_button.click(analyze_next_token,
217
  inputs=[input_text, temperature, top_p, top_k],
218
  outputs=[next_token_probs, attention_plot, prob_plot])
 
220
  inputs=[input_text, temperature, top_p, top_k],
221
  outputs=[generated_text])
222
  reset_button.click(reset,
223
+ outputs=[input_text, temperature, top_p, top_k, next_token_probs, attention_plot, prob_plot, generated_text, language_dropdown, language_output])
224
 
225
  if __name__ == "__main__":
226
  demo.launch()