PhotographerAlpha7 commited on
Commit
5fb8b32
1 Parent(s): e3778e7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from datasets import Dataset
4
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
5
+ import torch
6
+ import os
7
+ import matplotlib.pyplot as plt
8
+ import json
9
+ import io
10
+
11
+ # Variables globales pour stocker les colonnes détectées
12
+ columns = []
13
+
14
+ # Fonction pour lire le fichier et détecter les colonnes
15
+ def read_file(data_file):
16
+ global columns
17
+ try:
18
+ # Charger les données
19
+ file_extension = os.path.splitext(data_file.name)[1]
20
+ if file_extension == '.csv':
21
+ df = pd.read_csv(data_file.name)
22
+ elif file_extension == '.json':
23
+ df = pd.read_json(data_file.name)
24
+ elif file_extension == '.xlsx':
25
+ df = pd.read_excel(data_file.name)
26
+ else:
27
+ return "Invalid file format. Please upload a CSV, JSON, or Excel file."
28
+
29
+ # Détecter les colonnes
30
+ columns = df.columns.tolist()
31
+ return columns
32
+ except Exception as e:
33
+ return f"An error occurred: {str(e)}"
34
+
35
+ # Fonction pour entraîner le modèle
36
+ def train_model(data_file, model_name, epochs, batch_size, learning_rate, output_dir, prompt_col, description_col):
37
+ try:
38
+ # Charger les données
39
+ file_extension = os.path.splitext(data_file.name)[1]
40
+ if file_extension == '.csv':
41
+ df = pd.read_csv(data_file.name)
42
+ elif file_extension == '.json':
43
+ df = pd.read_json(data_file.name)
44
+ elif file_extension == '.xlsx':
45
+ df = pd.read_excel(data_file.name)
46
+
47
+ # Prévisualisation des données
48
+ preview = df.head().to_string(index=False)
49
+
50
+ # Préparer le texte d'entraînement
51
+ df['text'] = df[prompt_col] + ': ' + df[description_col]
52
+ dataset = Dataset.from_pandas(df[['text']])
53
+
54
+ # Initialiser le tokenizer et le modèle GPT-2
55
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
56
+ model = GPT2LMHeadModel.from_pretrained(model_name)
57
+
58
+ # Ajouter un token de padding si nécessaire
59
+ if tokenizer.pad_token is None:
60
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
61
+ model.resize_token_embeddings(len(tokenizer))
62
+
63
+ # Tokenizer les données
64
+ def tokenize_function(examples):
65
+ tokens = tokenizer(examples['text'], padding="max_length", truncation=True, max_length=128)
66
+ tokens['labels'] = tokens['input_ids'].copy()
67
+ return tokens
68
+
69
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
70
+
71
+ # Ajustement des hyperparamètres
72
+ training_args = TrainingArguments(
73
+ output_dir=output_dir,
74
+ overwrite_output_dir=True,
75
+ num_train_epochs=int(epochs),
76
+ per_device_train_batch_size=int(batch_size),
77
+ per_device_eval_batch_size=int(batch_size),
78
+ warmup_steps=1000,
79
+ weight_decay=0.01,
80
+ learning_rate=float(learning_rate),
81
+ logging_dir="./logs",
82
+ logging_steps=10,
83
+ save_steps=500,
84
+ save_total_limit=2,
85
+ evaluation_strategy="steps",
86
+ eval_steps=500,
87
+ load_best_model_at_end=True,
88
+ metric_for_best_model="eval_loss"
89
+ )
90
+
91
+ # Configuration du Trainer
92
+ trainer = Trainer(
93
+ model=model,
94
+ args=training_args,
95
+ train_dataset=tokenized_datasets,
96
+ eval_dataset=tokenized_datasets,
97
+ )
98
+
99
+ # Entraînement et évaluation
100
+ trainer.train()
101
+ eval_results = trainer.evaluate()
102
+
103
+ # Sauvegarder le modèle fine-tuné
104
+ model.save_pretrained(output_dir)
105
+ tokenizer.save_pretrained(output_dir)
106
+
107
+ # Générer un graphique des pertes d'entraînement et de validation
108
+ train_loss = [x['loss'] for x in trainer.state.log_history if 'loss' in x]
109
+ eval_loss = [x['eval_loss'] for x in trainer.state.log_history if 'eval_loss' in x]
110
+ plt.plot(train_loss, label='Training Loss')
111
+ plt.plot(eval_loss, label='Validation Loss')
112
+ plt.xlabel('Steps')
113
+ plt.ylabel('Loss')
114
+ plt.title('Training and Validation Loss')
115
+ plt.legend()
116
+ plt.savefig(os.path.join(output_dir, 'training_eval_loss.png'))
117
+
118
+ return f"Training completed successfully.\nPreview of data:\n{preview}", eval_results
119
+ except Exception as e:
120
+ return f"An error occurred: {str(e)}"
121
+
122
+ # Fonction de génération de texte
123
+ def generate_text(prompt, temperature, top_k, max_length, repetition_penalty, use_comma):
124
+ try:
125
+ model_name = "./fine-tuned-gpt2"
126
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
127
+ model = GPT2LMHeadModel.from_pretrained(model_name)
128
+
129
+ if use_comma:
130
+ prompt = prompt.replace('.', ',')
131
+
132
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True)
133
+ attention_mask = inputs.attention_mask
134
+ outputs = model.generate(
135
+ inputs.input_ids,
136
+ attention_mask=attention_mask,
137
+ max_length=int(max_length),
138
+ temperature=float(temperature),
139
+ top_k=int(top_k),
140
+ repetition_penalty=float(repetition_penalty),
141
+ num_return_sequences=1,
142
+ pad_token_id=tokenizer.eos_token_id
143
+ )
144
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
145
+ except Exception as e:
146
+ return f"An error occurred: {str(e)}"
147
+
148
+ # Fonction pour configurer les presets
149
+ def set_preset(preset):
150
+ if preset == "Default":
151
+ return 5, 8, 3e-5
152
+ elif preset == "Fast Training":
153
+ return 3, 16, 5e-5
154
+ elif preset == "High Accuracy":
155
+ return 10, 4, 1e-5
156
+
157
+ # Interface Gradio
158
+ with gr.Blocks() as ui:
159
+ gr.Markdown("# Model-Fine-Tuner | by Dimonapatrick243")
160
+
161
+ with gr.Tab("Train Model"):
162
+ with gr.Row():
163
+ data_file = gr.File(label="Upload Data File (CSV, JSON, Excel)")
164
+ model_name = gr.Textbox(label="Model Name", value="gpt2")
165
+ output_dir = gr.Textbox(label="Output Directory", value="./fine-tuned-gpt2")
166
+
167
+ with gr.Row():
168
+ preset = gr.Radio(["Default", "Fast Training", "High Accuracy"], label="Preset")
169
+ epochs = gr.Number(label="Epochs", value=5)
170
+ batch_size = gr.Number(label="Batch Size", value=8)
171
+ learning_rate = gr.Number(label="Learning Rate", value=3e-5)
172
+
173
+ preset.change(set_preset, preset, [epochs, batch_size, learning_rate])
174
+
175
+ # Champs pour sélectionner les colonnes
176
+ with gr.Row():
177
+ design_col = gr.Dropdown(label="Design Column")
178
+ description_col = gr.Dropdown(label="Description Column")
179
+
180
+ # Détection des colonnes lors du téléchargement du fichier
181
+ data_file.upload(read_file, inputs=data_file, outputs=[design_col, description_col])
182
+
183
+ train_button = gr.Button("Train Model")
184
+ train_output = gr.Textbox(label="Training Output")
185
+ train_graph = gr.Image(label="Training and Validation Loss Graph")
186
+
187
+ train_button.click(train_model, inputs=[data_file, model_name, epochs, batch_size, learning_rate, output_dir, design_col, description_col], outputs=[train_output, train_graph])
188
+
189
+ with gr.Tab("Generate Text"):
190
+ with gr.Row():
191
+ with gr.Column():
192
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7)
193
+ top_k = gr.Slider(label="Top K", minimum=1, maximum=100, value=50)
194
+ max_length = gr.Slider(label="Max Length", minimum=10, maximum=1024, value=128)
195
+ repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2)
196
+ use_comma = gr.Checkbox(label="Use Comma", value=True)
197
+
198
+ with gr.Column():
199
+ prompt = gr.Textbox(label="Prompt")
200
+ generate_button = gr.Button("Generate Text")
201
+ generated_text = gr.Textbox(label="Generated Text")
202
+
203
+ generate_button.click(generate_text, inputs=[prompt, temperature, top_k, max_length, repetition_penalty, use_comma], outputs=generated_text)
204
+
205
+ ui.launch()