import gradio as gr import pandas as pd from datasets import Dataset from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments import torch import os import matplotlib.pyplot as plt import json import io # Variables globales pour stocker les colonnes détectées columns = [] # Fonction pour lire le fichier et détecter les colonnes def read_file(data_file): global columns try: # Charger les données file_extension = os.path.splitext(data_file.name)[1] if file_extension == '.csv': df = pd.read_csv(data_file.name) elif file_extension == '.json': df = pd.read_json(data_file.name) elif file_extension == '.xlsx': df = pd.read_excel(data_file.name) else: return "Invalid file format. Please upload a CSV, JSON, or Excel file." # Détecter les colonnes columns = df.columns.tolist() return columns except Exception as e: return f"An error occurred: {str(e)}" # Fonction pour entraîner le modèle def train_model(data_file, model_name, epochs, batch_size, learning_rate, output_dir, prompt_col, description_col): try: # Charger les données file_extension = os.path.splitext(data_file.name)[1] if file_extension == '.csv': df = pd.read_csv(data_file.name) elif file_extension == '.json': df = pd.read_json(data_file.name) elif file_extension == '.xlsx': df = pd.read_excel(data_file.name) # Prévisualisation des données preview = df.head().to_string(index=False) # Préparer le texte d'entraînement df['text'] = df[prompt_col] + ': ' + df[description_col] dataset = Dataset.from_pandas(df[['text']]) # Initialiser le tokenizer et le modèle GPT-2 tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) # Ajouter un token de padding si nécessaire if tokenizer.pad_token is None: tokenizer.add_special_tokens({'pad_token': '[PAD]'}) model.resize_token_embeddings(len(tokenizer)) # Tokenizer les données def tokenize_function(examples): tokens = tokenizer(examples['text'], padding="max_length", truncation=True, max_length=128) tokens['labels'] = tokens['input_ids'].copy() return tokens tokenized_datasets = dataset.map(tokenize_function, batched=True) # Ajustement des hyperparamètres training_args = TrainingArguments( output_dir=output_dir, overwrite_output_dir=True, num_train_epochs=int(epochs), per_device_train_batch_size=int(batch_size), per_device_eval_batch_size=int(batch_size), warmup_steps=1000, weight_decay=0.01, learning_rate=float(learning_rate), logging_dir="./logs", logging_steps=10, save_steps=500, save_total_limit=2, evaluation_strategy="steps", eval_steps=500, load_best_model_at_end=True, metric_for_best_model="eval_loss" ) # Configuration du Trainer trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets, eval_dataset=tokenized_datasets, ) # Entraînement et évaluation trainer.train() eval_results = trainer.evaluate() # Sauvegarder le modèle fine-tuné model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) # Générer un graphique des pertes d'entraînement et de validation train_loss = [x['loss'] for x in trainer.state.log_history if 'loss' in x] eval_loss = [x['eval_loss'] for x in trainer.state.log_history if 'eval_loss' in x] plt.plot(train_loss, label='Training Loss') plt.plot(eval_loss, label='Validation Loss') plt.xlabel('Steps') plt.ylabel('Loss') plt.title('Training and Validation Loss') plt.legend() plt.savefig(os.path.join(output_dir, 'training_eval_loss.png')) return f"Training completed successfully.\nPreview of data:\n{preview}", eval_results except Exception as e: return f"An error occurred: {str(e)}" # Fonction de génération de texte def generate_text(prompt, temperature, top_k, max_length, repetition_penalty, use_comma): try: model_name = "./fine-tuned-gpt2" tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) if use_comma: prompt = prompt.replace('.', ',') inputs = tokenizer(prompt, return_tensors="pt", padding=True) attention_mask = inputs.attention_mask outputs = model.generate( inputs.input_ids, attention_mask=attention_mask, max_length=int(max_length), temperature=float(temperature), top_k=int(top_k), repetition_penalty=float(repetition_penalty), num_return_sequences=1, pad_token_id=tokenizer.eos_token_id ) return tokenizer.decode(outputs[0], skip_special_tokens=True) except Exception as e: return f"An error occurred: {str(e)}" # Fonction pour configurer les presets def set_preset(preset): if preset == "Default": return 5, 8, 3e-5 elif preset == "Fast Training": return 3, 16, 5e-5 elif preset == "High Accuracy": return 10, 4, 1e-5 # Interface Gradio with gr.Blocks() as ui: gr.Markdown("# Model-Fine-Tuner | by Dimonapatrick243") with gr.Tab("Train Model"): with gr.Row(): data_file = gr.File(label="Upload Data File (CSV, JSON, Excel)") model_name = gr.Textbox(label="Model Name", value="gpt2") output_dir = gr.Textbox(label="Output Directory", value="./fine-tuned-gpt2") with gr.Row(): preset = gr.Radio(["Default", "Fast Training", "High Accuracy"], label="Preset") epochs = gr.Number(label="Epochs", value=5) batch_size = gr.Number(label="Batch Size", value=8) learning_rate = gr.Number(label="Learning Rate", value=3e-5) preset.change(set_preset, preset, [epochs, batch_size, learning_rate]) # Champs pour sélectionner les colonnes with gr.Row(): design_col = gr.Dropdown(label="Design Column") description_col = gr.Dropdown(label="Description Column") # Détection des colonnes lors du téléchargement du fichier data_file.upload(read_file, inputs=data_file, outputs=[design_col, description_col]) train_button = gr.Button("Train Model") train_output = gr.Textbox(label="Training Output") train_graph = gr.Image(label="Training and Validation Loss Graph") train_button.click(train_model, inputs=[data_file, model_name, epochs, batch_size, learning_rate, output_dir, design_col, description_col], outputs=[train_output, train_graph]) with gr.Tab("Generate Text"): with gr.Row(): with gr.Column(): temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, value=0.7) top_k = gr.Slider(label="Top K", minimum=1, maximum=100, value=50) max_length = gr.Slider(label="Max Length", minimum=10, maximum=1024, value=128) repetition_penalty = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.2) use_comma = gr.Checkbox(label="Use Comma", value=True) with gr.Column(): prompt = gr.Textbox(label="Prompt") generate_button = gr.Button("Generate Text") generated_text = gr.Textbox(label="Generated Text") generate_button.click(generate_text, inputs=[prompt, temperature, top_k, max_length, repetition_penalty, use_comma], outputs=generated_text) ui.launch()