textclassfication / README.md
dheeraj1019's picture
Create README.md
d2a1ee9 verified
metadata
license: afl-3.0
datasets:
  - HuggingFaceTB/cosmopedia
metrics:
  - accuracy
library_name: adapter-transformers
pipeline_tag: text-classification
tags:
  - code

Install the necessary libraries

!pip install transformers !pip install torch

import torch from transformers import RobertaTokenizer, RobertaForSequenceClassification, XLNetTokenizer, XLNetForSequenceClassification from transformers import Trainer, TrainingArguments from sklearn.model_selection import train_test_split import numpy as np from sklearn.metrics import accuracy_score, precision_recall_fscore_support

Example dataset for text classification (replace with your own dataset)

texts = [...] # List of input texts labels = [...] # List of corresponding labels (0 or 1 for binary classification)

Split the dataset into training and testing sets

train_texts, test_texts, train_labels, test_labels = train_test_split(texts, labels, test_size=0.2, random_state=42)

Define the tokenizer and model for RoBERTa

roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base") roberta_model = RobertaForSequenceClassification.from_pretrained("roberta-base")

Define the tokenizer and model for XLNet

xlnet_tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased") xlnet_model = XLNetForSequenceClassification.from_pretrained("xlnet-base-cased")

Tokenize and encode the training and testing sets

train_encodings_roberta = roberta_tokenizer(train_texts, truncation=True, padding=True) test_encodings_roberta = roberta_tokenizer(test_texts, truncation=True, padding=True)

train_encodings_xlnet = xlnet_tokenizer(train_texts, truncation=True, padding=True) test_encodings_xlnet = xlnet_tokenizer(test_texts, truncation=True, padding=True)

class MyDataset(torch.utils.data.Dataset): def init(self, encodings, labels): self.encodings = encodings self.labels = labels

def __getitem__(self, idx):
    item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    item['labels'] = torch.tensor(self.labels[idx])
    return item

def __len__(self):
    return len(self.labels)

train_dataset_roberta = MyDataset(train_encodings_roberta, train_labels) test_dataset_roberta = MyDataset(test_encodings_roberta, test_labels)

train_dataset_xlnet = MyDataset(train_encodings_xlnet, train_labels) test_dataset_xlnet = MyDataset(test_encodings_xlnet, test_labels)

Fine-tune RoBERTa model

training_args = TrainingArguments( per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=3, logging_dir='./logs', logging_steps=10, )

trainer_roberta = Trainer( model=roberta_model, args=training_args, train_dataset=train_dataset_roberta, eval_dataset=test_dataset_roberta, )

trainer_roberta.train()

Fine-tune XLNet model

trainer_xlnet = Trainer( model=xlnet_model, args=training_args, train_dataset=train_dataset_xlnet, eval_dataset=test_dataset_xlnet, )

trainer_xlnet.train()

Evaluate models

def evaluate_model(model, test_dataset): predictions = [] labels = [] for batch in test_dataset: input_ids = batch['input_ids'].to(model.device) attention_mask = batch['attention_mask'].to(model.device) labels.extend(batch['labels'].tolist()) with torch.no_grad(): outputs = model(input_ids, attention_mask=attention_mask) logits = outputs.logits predictions.extend(torch.argmax(logits, axis=1).tolist()) accuracy = accuracy_score(labels, predictions) precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='binary') return accuracy, precision, recall, f1

accuracy_roberta, precision_roberta, recall_roberta, f1_roberta = evaluate_model(roberta_model, test_dataset_roberta) accuracy_xlnet, precision_xlnet, recall_xlnet, f1_xlnet = evaluate_model(xlnet_model, test_dataset_xlnet)

print("RoBERTa Model Evaluation:") print(f"Accuracy: {accuracy_roberta}") print(f"Precision: {precision_roberta}") print(f"Recall: {recall_roberta}") print(f"F1 Score: {f1_roberta}")

print("\nXLNet Model Evaluation:") print(f"Accuracy: {accuracy_xlnet}") print(f"Precision: {precision_xlnet}") print(f"Recall: {recall_xlnet}") print(f"F1 Score: {f1_xlnet}")