myrkur commited on
Commit
82429cc
·
verified ·
1 Parent(s): 9923f9f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +79 -3
README.md CHANGED
@@ -1,3 +1,79 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - pytorch
4
+ - transformers
5
+ - masked-lm
6
+ - persian
7
+ - modernbert
8
+ - flash-attention
9
+ library_name: transformers
10
+ datasets:
11
+ - custom
12
+ license: apache-2.0
13
+ language:
14
+ - fa
15
+ base_model:
16
+ - answerdotai/ModernBERT-base
17
+ pipeline_tag: fill-mask
18
+ ---
19
+
20
+ # ModernBERT Fine-Tuned on Persian Data
21
+
22
+ ModernBERT is a Persian-language Masked Language Model (MLM) fine-tuned with a custom tokenizer on a massive corpus of **2.5 billion tokens**, exceeding the **1.3 billion tokens** ParsBERT is trained on. This model leverages state-of-the-art attention mechanisms.
23
+
24
+ ## Model Details
25
+
26
+ - **Base Model**: [`answerdotai/ModernBERT-base`](https://huggingface.co/answerdotai/ModernBERT-base)
27
+ - **Tokenizer**: Custom, optimized for Persian
28
+ - **Corpus**: 2.5 billion Persian tokens from diverse sources
29
+ - **Objective**: Masked Language Modeling (MLM)
30
+ - **Attention Mechanism**: Flash Attention v2
31
+ - **Precision**: `torch.bfloat16` for efficient computation on modern hardware
32
+
33
+ ## Usage
34
+
35
+ ### Load the Model and Tokenizer
36
+
37
+ ```python
38
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
39
+
40
+ # Load custom tokenizer and fine-tuned model
41
+ tokenizer = AutoTokenizer.from_pretrained("myrkur/Persian-ModernBert-base")
42
+ model = AutoModelForMaskedLM.from_pretrained("myrkur/Persian-ModernBert-base")
43
+ ```
44
+
45
+ ### Example: Masked Token Prediction
46
+
47
+ ```python
48
+ text = "حال و [MASK] مردم خوب است."
49
+ inputs = tokenizer(text, return_tensors="pt")
50
+ token_logits = model(**inputs).logits
51
+
52
+ # Find the [MASK] token and decode top predictions
53
+ mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
54
+ mask_token_logits = token_logits[0, mask_token_index, :]
55
+ top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()
56
+
57
+ for token in top_5_tokens:
58
+ print(f"Prediction: {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}")
59
+ ```
60
+
61
+ ## Training Details
62
+
63
+ ### Dataset
64
+
65
+ The model was fine-tuned on a custom dataset with **2.5 billion Persian tokens**. The dataset was preprocessed and tokenized using a custom tokenizer designed to maximize efficiency and coverage for Persian.
66
+
67
+ ### Training Configuration
68
+
69
+ - **Optimizer**: AdamW
70
+ - **Learning Rate**: 6e-4
71
+ - **Batch Size**: 32
72
+ - **Epochs**: 3
73
+ - **Scheduler**: Inverse square root
74
+ - **Precision**: bfloat16 for faster computation and lower memory usage
75
+ - **Masking Strategy**: Whole Word Masking (WWM) with a probability of 30%
76
+
77
+ ### Efficient Training with Flash Attention
78
+
79
+ The model uses the `flash_attention_2` implementation, significantly reducing memory overhead while accelerating training on large datasets.