File size: 4,527 Bytes
82429cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d4b2b6
82429cc
 
 
 
 
 
 
 
 
 
 
 
87c5c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f05dd97
87c5c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82429cc
 
87c5c2d
82429cc
 
 
 
87c5c2d
82429cc
 
f05dd97
82429cc
 
 
 
87c5c2d
82429cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31a6a90
82429cc
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
---
tags:
- pytorch
- transformers
- masked-lm
- persian
- modernbert
- flash-attention
library_name: transformers
datasets:
- custom
license: apache-2.0
language:
- fa
base_model:
- answerdotai/ModernBERT-base
pipeline_tag: fill-mask
---

# ModernBERT Fine-Tuned on Persian Data

Persian ModernBERT is a Persian-language Masked Language Model (MLM) fine-tuned with a custom tokenizer on a massive corpus of **2.5 billion tokens**, exceeding the **1.3 billion tokens** ParsBERT is trained on. This model leverages state-of-the-art attention mechanisms.

## Model Details

- **Base Model**: [`answerdotai/ModernBERT-base`](https://huggingface.co/answerdotai/ModernBERT-base)
- **Tokenizer**: Custom, optimized for Persian
- **Corpus**: 2.5 billion Persian tokens from diverse sources
- **Objective**: Masked Language Modeling (MLM)
- **Attention Mechanism**: Flash Attention v2
- **Precision**: `torch.bfloat16` for efficient computation on modern hardware

## Usage

You can use these models directly with the `transformers` library. Until the next `transformers` release, doing so requires installing transformers from main:

```sh
pip install git+https://github.com/huggingface/transformers.git
```

Since ModernBERT is a Masked Language Model (MLM), you can use the `fill-mask` pipeline or load it via `AutoModelForMaskedLM`. To use ModernBERT for downstream tasks like classification, retrieval, or QA, fine-tune it following standard BERT fine-tuning recipes.

**⚠️ If your GPU supports it, we recommend using ModernBERT with Flash Attention 2 to reach the highest efficiency. To do so, install Flash Attention as follows, then use the model as normal:**

```bash
pip install flash-attn
```
### Inference on CPU

#### Load the Model and Tokenizer 

```python
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

# Load custom tokenizer and fine-tuned model
tokenizer = AutoTokenizer.from_pretrained("myrkur/Persian-ModernBert-base")
model = AutoModelForMaskedLM.from_pretrained("myrkur/Persian-ModernBert-base", attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="cpu")
```

#### Example: Masked Token Prediction

```python
text = "حال و [MASK] مردم خوب است."
inputs = tokenizer(text, return_tensors="pt")
inputs = {k:v.cpu() for k, v in inputs.items()}
token_logits = model(**inputs).logits

# Find the [MASK] token and decode top predictions
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"Prediction: {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}")
```

### Inference on GPU

#### Load the Model and Tokenizer 

```python
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

# Load custom tokenizer and fine-tuned model
tokenizer = AutoTokenizer.from_pretrained("myrkur/Persian-ModernBert-base")
model = AutoModelForMaskedLM.from_pretrained("myrkur/Persian-ModernBert-base", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="cuda")
```

#### Example: Masked Token Prediction

```python
text = "حال و [MASK] مردم خوب است."
inputs = tokenizer(text, return_tensors="pt")
inputs = {k:v.cuda() for k, v in inputs.items()}
token_logits = model(**inputs).logits

# Find the [MASK] token and decode top predictions
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"Prediction: {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}")
```

## Training Details

### Dataset

The model was fine-tuned on a custom dataset with **2.5 billion Persian tokens**. The dataset was preprocessed and tokenized using a custom tokenizer designed to maximize efficiency and coverage for Persian.

### Training Configuration

- **Optimizer**: AdamW
- **Learning Rate**: 6e-4
- **Batch Size**: 32
- **Epochs**: 2
- **Scheduler**: Inverse square root
- **Precision**: bfloat16 for faster computation and lower memory usage
- **Masking Strategy**: Whole Word Masking (WWM) with a probability of 30%

### Efficient Training with Flash Attention

The model uses the `flash_attention_2` implementation, significantly reducing memory overhead while accelerating training on large datasets.