myrkur's picture
Update README.md
9d4b2b6 verified
metadata
tags:
  - pytorch
  - transformers
  - masked-lm
  - persian
  - modernbert
  - flash-attention
library_name: transformers
datasets:
  - custom
license: apache-2.0
language:
  - fa
base_model:
  - answerdotai/ModernBERT-base
pipeline_tag: fill-mask

ModernBERT Fine-Tuned on Persian Data

Persian ModernBERT is a Persian-language Masked Language Model (MLM) fine-tuned with a custom tokenizer on a massive corpus of 2.5 billion tokens, exceeding the 1.3 billion tokens ParsBERT is trained on. This model leverages state-of-the-art attention mechanisms.

Model Details

  • Base Model: answerdotai/ModernBERT-base
  • Tokenizer: Custom, optimized for Persian
  • Corpus: 2.5 billion Persian tokens from diverse sources
  • Objective: Masked Language Modeling (MLM)
  • Attention Mechanism: Flash Attention v2
  • Precision: torch.bfloat16 for efficient computation on modern hardware

Usage

You can use these models directly with the transformers library. Until the next transformers release, doing so requires installing transformers from main:

pip install git+https://github.com/huggingface/transformers.git

Since ModernBERT is a Masked Language Model (MLM), you can use the fill-mask pipeline or load it via AutoModelForMaskedLM. To use ModernBERT for downstream tasks like classification, retrieval, or QA, fine-tune it following standard BERT fine-tuning recipes.

⚠️ If your GPU supports it, we recommend using ModernBERT with Flash Attention 2 to reach the highest efficiency. To do so, install Flash Attention as follows, then use the model as normal:

pip install flash-attn

Inference on CPU

Load the Model and Tokenizer

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

# Load custom tokenizer and fine-tuned model
tokenizer = AutoTokenizer.from_pretrained("myrkur/Persian-ModernBert-base")
model = AutoModelForMaskedLM.from_pretrained("myrkur/Persian-ModernBert-base", attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="cpu")

Example: Masked Token Prediction

text = "حال و [MASK] مردم خوب است."
inputs = tokenizer(text, return_tensors="pt")
inputs = {k:v.cpu() for k, v in inputs.items()}
token_logits = model(**inputs).logits

# Find the [MASK] token and decode top predictions
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"Prediction: {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}")

Inference on GPU

Load the Model and Tokenizer

import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

# Load custom tokenizer and fine-tuned model
tokenizer = AutoTokenizer.from_pretrained("myrkur/Persian-ModernBert-base")
model = AutoModelForMaskedLM.from_pretrained("myrkur/Persian-ModernBert-base", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="cuda")

Example: Masked Token Prediction

text = "حال و [MASK] مردم خوب است."
inputs = tokenizer(text, return_tensors="pt")
inputs = {k:v.cuda() for k, v in inputs.items()}
token_logits = model(**inputs).logits

# Find the [MASK] token and decode top predictions
mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
mask_token_logits = token_logits[0, mask_token_index, :]
top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()

for token in top_5_tokens:
    print(f"Prediction: {text.replace(tokenizer.mask_token, tokenizer.decode([token]))}")

Training Details

Dataset

The model was fine-tuned on a custom dataset with 2.5 billion Persian tokens. The dataset was preprocessed and tokenized using a custom tokenizer designed to maximize efficiency and coverage for Persian.

Training Configuration

  • Optimizer: AdamW
  • Learning Rate: 6e-4
  • Batch Size: 32
  • Epochs: 2
  • Scheduler: Inverse square root
  • Precision: bfloat16 for faster computation and lower memory usage
  • Masking Strategy: Whole Word Masking (WWM) with a probability of 30%

Efficient Training with Flash Attention

The model uses the flash_attention_2 implementation, significantly reducing memory overhead while accelerating training on large datasets.