|
--- |
|
library_name: transformers |
|
language: |
|
- en |
|
license: mit |
|
base_model: FacebookAI/roberta-base |
|
tags: |
|
- generated_from_trainer |
|
datasets: |
|
- swag |
|
metrics: |
|
- accuracy |
|
model-index: |
|
- name: swag_base |
|
results: |
|
- task: |
|
name: Multiple Choice |
|
type: multiple-choice |
|
dataset: |
|
name: SWAG |
|
type: swag |
|
args: regular |
|
metrics: |
|
- name: Accuracy |
|
type: accuracy |
|
value: 0.7521243691444397 |
|
--- |
|
|
|
# swag_base |
|
|
|
This model is a fine-tuned version of [FacebookAI/roberta-base](https://huggingface.co/FacebookAI/roberta-base) on the SWAG (Situations With Adversarial Generations) dataset. |
|
|
|
## Model description |
|
|
|
The model is designed to perform multiple-choice reasoning about real-world situations. Given a context and four possible continuations, it predicts the most plausible ending based on common sense understanding. |
|
|
|
Key Features: |
|
- Base model: RoBERTa-base |
|
- Task: Multiple Choice Prediction |
|
- Training dataset: SWAG |
|
- Performance: 75.21% accuracy on evaluation set |
|
|
|
## Training Procedure |
|
|
|
### Training hyperparameters |
|
- Learning rate: 5e-05 |
|
- Batch size: 16 |
|
- Number of epochs: 3 |
|
- Optimizer: AdamW |
|
- Learning rate scheduler: Linear |
|
- Training samples: 73,546 |
|
- Training time: 17m 53s |
|
|
|
### Training Results |
|
- Training loss: 0.73 |
|
- Evaluation loss: 0.7362 |
|
- Evaluation accuracy: 0.7521 |
|
- Training samples/second: 205.623 |
|
- Training steps/second: 12.852 |
|
|
|
## Usage Example |
|
|
|
Here's how to use the model: |
|
|
|
```python |
|
from transformers import AutoTokenizer, AutoModelForMultipleChoice |
|
import torch |
|
|
|
# Load model and tokenizer |
|
model_path = "real-jiakai/roberta-base-uncased-finetuned-swag" |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
model = AutoModelForMultipleChoice.from_pretrained(model_path) |
|
|
|
def predict_swag(context, endings, model, tokenizer): |
|
encoding = tokenizer( |
|
[context] * 4, |
|
endings, |
|
truncation=True, |
|
max_length=128, |
|
padding="max_length", |
|
return_tensors="pt" |
|
) |
|
|
|
input_ids = encoding['input_ids'].unsqueeze(0) |
|
attention_mask = encoding['attention_mask'].unsqueeze(0) |
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
|
|
predicted_idx = torch.argmax(logits).item() |
|
|
|
return { |
|
'context': context, |
|
'predicted_ending': endings[predicted_idx], |
|
'probabilities': torch.softmax(logits, dim=1)[0].tolist() |
|
} |
|
|
|
# Example scenarios |
|
test_examples = [ |
|
{ |
|
'context': "Stephen Curry dribbles the ball at the three-point line", |
|
'endings': [ |
|
"He quickly releases a perfect shot that swishes through the net", # Most plausible |
|
"He suddenly starts dancing ballet on the court", |
|
"He transforms the basketball into a pizza", |
|
"He flies to the moon with the basketball" |
|
] |
|
}, |
|
{ |
|
'context': "Elon Musk walks into a SpaceX facility and looks at a rocket", |
|
'endings': [ |
|
"He discusses technical details with the engineering team", # Most plausible |
|
"He turns the rocket into a giant chocolate bar", |
|
"He starts playing basketball with the rocket", |
|
"He teaches the rocket to speak French" |
|
] |
|
} |
|
] |
|
|
|
for i, example in enumerate(test_examples, 1): |
|
result = predict_swag( |
|
example['context'], |
|
example['endings'], |
|
model, |
|
tokenizer |
|
) |
|
|
|
print(f"\n=== Test Scenario {i} ===") |
|
print(f"Initial Context: {result['context']}") |
|
print(f"\nPredicted Most Likely Ending: {result['predicted_ending']}") |
|
print("\nProbabilities for All Options:") |
|
for idx, (ending, prob) in enumerate(zip(result['all_endings'], result['probabilities'])): |
|
print(f"Option {idx}: {ending}") |
|
print(f"Probability: {prob:.3f}") |
|
print("\n" + "="*50) |
|
``` |
|
|
|
## Limitations and Biases |
|
|
|
The model's performance is limited by its training data and may not generalize well to all domains |
|
Performance might vary depending on the complexity and domain of the input scenarios |
|
The model may exhibit biases present in the training data |
|
|
|
## Framework versions |
|
|
|
Transformers 4.47.0.dev0 |
|
PyTorch 2.5.1+cu124 |
|
Datasets 3.1.0 |
|
Tokenizers 0.20.3 |
|
|
|
## Citation |
|
|
|
If you use this model, please cite: |
|
|
|
``` |
|
@inproceedings{zellers2018swagaf, |
|
title={SWAG: A Large-Scale Adversarial Dataset for Grounded Commonsense Inference}, |
|
author={Zellers, Rowan and Bisk, Yonatan and Schwartz, Roy and Choi, Yejin}, |
|
booktitle = "Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing (EMNLP)", |
|
year={2018} |
|
} |
|
``` |