|
--- |
|
tags: |
|
- pytorch |
|
- transformers |
|
- masked-lm |
|
- persian |
|
- modernbert |
|
- flash-attention |
|
library_name: transformers |
|
datasets: |
|
- custom |
|
license: apache-2.0 |
|
language: |
|
- fa |
|
base_model: |
|
- answerdotai/ModernBERT-base |
|
pipeline_tag: fill-mask |
|
--- |
|
|
|
# ModernBERT Fine-Tuned on Persian Data |
|
|
|
Persian ModernBERT is a Persian-language Masked Language Model (MLM) fine-tuned with a custom tokenizer on a massive corpus of **2.5 billion tokens**, exceeding the **1.3 billion tokens** ParsBERT is trained on. This model leverages state-of-the-art attention mechanisms. |
|
|
|
## Model Details |
|
|
|
- **Base Model**: [`answerdotai/ModernBERT-base`](https://huggingface.co/answerdotai/ModernBERT-base) |
|
- **Tokenizer**: Custom, optimized for Persian |
|
- **Corpus**: 2.5 billion Persian tokens from diverse sources |
|
- **Objective**: Masked Language Modeling (MLM) |
|
- **Attention Mechanism**: Flash Attention v2 |
|
- **Precision**: `torch.bfloat16` for efficient computation on modern hardware |
|
|
|
## Usage |
|
|
|
You can use these models directly with the `transformers` library. Until the next `transformers` release, doing so requires installing transformers from main: |
|
|
|
```sh |
|
pip install git+https://github.com/huggingface/transformers.git |
|
``` |
|
|
|
Since ModernBERT is a Masked Language Model (MLM), you can use the `fill-mask` pipeline or load it via `AutoModelForMaskedLM`. To use ModernBERT for downstream tasks like classification, retrieval, or QA, fine-tune it following standard BERT fine-tuning recipes. |
|
|
|
**⚠️ If your GPU supports it, we recommend using ModernBERT with Flash Attention 2 to reach the highest efficiency. To do so, install Flash Attention as follows, then use the model as normal:** |
|
|
|
```bash |
|
pip install flash-attn |
|
``` |
|
### Inference on CPU |
|
|
|
#### Load the Model and Tokenizer |
|
|
|
```python |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
|
|
# Load custom tokenizer and fine-tuned model |
|
tokenizer = AutoTokenizer.from_pretrained("myrkur/Persian-ModernBert-base") |
|
model = AutoModelForMaskedLM.from_pretrained("myrkur/Persian-ModernBert-base", attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="cpu") |
|
``` |
|
|
|
#### Example: Masked Token Prediction |
|
|
|
```python |
|
text = "حال و [MASK] مردم خوب است." |
|
inputs = tokenizer(text, return_tensors="pt") |
|
inputs = {k:v.cpu() for k, v in inputs.items()} |
|
token_logits = model(**inputs).logits |
|
|
|
# Find the [MASK] token and decode top predictions |
|
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
|
mask_token_logits = token_logits[0, mask_token_index, :] |
|
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist() |
|
|
|
for token in top_5_tokens: |
|
print(f"Prediction: {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}") |
|
``` |
|
|
|
### Inference on GPU |
|
|
|
#### Load the Model and Tokenizer |
|
|
|
```python |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
|
|
# Load custom tokenizer and fine-tuned model |
|
tokenizer = AutoTokenizer.from_pretrained("myrkur/Persian-ModernBert-base") |
|
model = AutoModelForMaskedLM.from_pretrained("myrkur/Persian-ModernBert-base", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="cuda") |
|
``` |
|
|
|
#### Example: Masked Token Prediction |
|
|
|
```python |
|
text = "حال و [MASK] مردم خوب است." |
|
inputs = tokenizer(text, return_tensors="pt") |
|
inputs = {k:v.cuda() for k, v in inputs.items()} |
|
token_logits = model(**inputs).logits |
|
|
|
# Find the [MASK] token and decode top predictions |
|
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1] |
|
mask_token_logits = token_logits[0, mask_token_index, :] |
|
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist() |
|
|
|
for token in top_5_tokens: |
|
print(f"Prediction: {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}") |
|
``` |
|
|
|
## Training Details |
|
|
|
### Dataset |
|
|
|
The model was fine-tuned on a custom dataset with **2.5 billion Persian tokens**. The dataset was preprocessed and tokenized using a custom tokenizer designed to maximize efficiency and coverage for Persian. |
|
|
|
### Training Configuration |
|
|
|
- **Optimizer**: AdamW |
|
- **Learning Rate**: 6e-4 |
|
- **Batch Size**: 32 |
|
- **Epochs**: 2 |
|
- **Scheduler**: Inverse square root |
|
- **Precision**: bfloat16 for faster computation and lower memory usage |
|
- **Masking Strategy**: Whole Word Masking (WWM) with a probability of 30% |
|
|
|
### Efficient Training with Flash Attention |
|
|
|
The model uses the `flash_attention_2` implementation, significantly reducing memory overhead while accelerating training on large datasets. |