Spaces:
Runtime error
Runtime error
File size: 1,912 Bytes
5ca90cd 5a974a6 a31291c 9277ff4 4aae3f6 99d2684 4aae3f6 7bc6dcb 4aae3f6 91149d9 1a67d0c 7623f06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import datasets
from datasets import load_dataset
import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments, pipeline
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) # 2 classes : positif et négatif
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) # 2 classes : positif et négatif
ds = load_dataset("stanfordnlp/sst2")
sst2_dataset = load_dataset("glue", "sst2", split="train")
def encode(examples):
return tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, padding="max_length")
sst2_dataset = sst2_dataset.map(encode, batched=True)
sst2_dataset = sst2_dataset.map(lambda examples: {"labels": examples["label"]}, batched=True)
training_args = TrainingArguments(
per_device_train_batch_size=8,
evaluation_strategy="epoch",
logging_dir="./logs",
output_dir="./results",
num_train_epochs=3,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=encoded_dataset["train"],
eval_dataset=encoded_dataset["test"],
)
import os
if not os.path.exists("./fine_tuned_model"):
trainer.train()
# Sauvegarder le modèle fine-tuné et le tokenizer
model.save_pretrained("./fine_tuned_model")
tokenizer.save_pretrained("./fine_tuned_model")
else:
# Charger le modèle fine-tuné
model = BertForSequenceClassification.from_pretrained("./fine_tuned_model")
tokenizer = BertTokenizer.from_pretrained("./fine_tuned_model")
sentiment_analysis = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)
def generate_response(message):
result = sentiment_analysis(message)[0]
return f"Label: {result['label']}, Score: {result['score']}"
gr.ChatInterface(fn=generate_response).launch()
|