Spaces:
Runtime error
Runtime error
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, pipeline | |
from datasets import load_dataset | |
import gradio as gr | |
import os | |
# Charger le jeu de données SST-2 | |
dataset = load_dataset("glue", "sst2") | |
# Charger le modèle BERT pré-entraîné et le tokenizer associé | |
model_name = "bert-base-uncased" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) # 2 classes : positif et négatif | |
# Prétraitement des données | |
def preprocess_function(examples): | |
return tokenizer(examples["sentence"], padding="max_length", truncation=True) | |
encoded_dataset = dataset.map(preprocess_function, batched=True) | |
# Configuration des arguments d'entraînement | |
training_args = TrainingArguments( | |
per_device_train_batch_size=8, | |
evaluation_strategy="epoch", | |
logging_dir="./logs", | |
output_dir="./results", | |
num_train_epochs=3, | |
) | |
# Entraînement du modèle | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=encoded_dataset["train"], | |
eval_dataset=encoded_dataset["validation"], | |
) | |
# Vérifiez si le modèle a déjà été entraîné et sauvegardé | |
if not os.path.exists("./fine_tuned_model"): | |
trainer.train() | |
# Sauvegarder le modèle fine-tuné et le tokenizer | |
model.save_pretrained("./fine_tuned_model") | |
tokenizer.save_pretrained("./fine_tuned_model") | |
else: | |
# Charger le modèle fine-tuné | |
model = AutoModelForSequenceClassification.from_pretrained("./fine_tuned_model") | |
tokenizer = AutoTokenizer.from_pretrained("./fine_tuned_model") | |
# Créer un pipeline de classification des sentiments | |
sentiment_analysis = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer) | |
# Fonction pour générer une réponse à partir du message de l'utilisateur | |
def generate_response(message): | |
result = sentiment_analysis(message)[0] | |
return f"Label: {result['label']}, Score: {result['score']}" | |
# Configurer et lancer l'interface de chat avec Gradio | |
gr.ChatInterface(fn=generate_response).launch() | |