myrkur's picture
Update README.md
31a6a90 verified
|
raw
history blame
2.57 kB
---
tags:
- pytorch
- transformers
- masked-lm
- persian
- modernbert
- flash-attention
library_name: transformers
datasets:
- custom
license: apache-2.0
language:
- fa
base_model:
- answerdotai/ModernBERT-base
pipeline_tag: fill-mask
---
# ModernBERT Fine-Tuned on Persian Data
ModernBERT is a Persian-language Masked Language Model (MLM) fine-tuned with a custom tokenizer on a massive corpus of **2.5 billion tokens**, exceeding the **1.3 billion tokens** ParsBERT is trained on. This model leverages state-of-the-art attention mechanisms.
## Model Details
- **Base Model**: [`answerdotai/ModernBERT-base`](https://huggingface.co/answerdotai/ModernBERT-base)
- **Tokenizer**: Custom, optimized for Persian
- **Corpus**: 2.5 billion Persian tokens from diverse sources
- **Objective**: Masked Language Modeling (MLM)
- **Attention Mechanism**: Flash Attention v2
- **Precision**: `torch.bfloat16` for efficient computation on modern hardware
## Usage
### Load the Model and Tokenizer
```python
from transformers import AutoTokenizer, AutoModelForMaskedLM
# Load custom tokenizer and fine-tuned model
tokenizer = AutoTokenizer.from_pretrained("myrkur/Persian-ModernBert-base")
model = AutoModelForMaskedLM.from_pretrained("myrkur/Persian-ModernBert-base")
```
### Example: Masked Token Prediction
```python
text = "حال و [MASK] مردم خوب است."
inputs = tokenizer(text, return_tensors="pt")
token_logits = model(**inputs).logits
# Find the [MASK] token and decode top predictions
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
for token in top_5_tokens:
print(f"Prediction: {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}")
```
## Training Details
### Dataset
The model was fine-tuned on a custom dataset with **2.5 billion Persian tokens**. The dataset was preprocessed and tokenized using a custom tokenizer designed to maximize efficiency and coverage for Persian.
### Training Configuration
- **Optimizer**: AdamW
- **Learning Rate**: 6e-4
- **Batch Size**: 32
- **Epochs**: 2
- **Scheduler**: Inverse square root
- **Precision**: bfloat16 for faster computation and lower memory usage
- **Masking Strategy**: Whole Word Masking (WWM) with a probability of 30%
### Efficient Training with Flash Attention
The model uses the `flash_attention_2` implementation, significantly reducing memory overhead while accelerating training on large datasets.