nan or 0.0 loss when training with flash attention

#59
by roadtoagi - opened

, attn_implementation="sdpa"

{'loss': 0.3681, 'grad_norm': 5.589271545410156, 'learning_rate': 4e-05, 'epoch': 1.0}
{'eval_loss': 0.2998541593551636, 'eval_runtime': 2.4136, 'eval_samples_per_second': 20.716, 'eval_steps_per_second': 5.386, 'epoch': 1.0}
{'loss': 0.1703, 'grad_norm': 1.1856054067611694, 'learning_rate': 0.0, 'epoch': 2.0}
{'eval_loss': 0.21692198514938354, 'eval_runtime': 1.0645, 'eval_samples_per_second': 46.97, 'eval_steps_per_second': 12.212, 'epoch': 2.0}
{'train_runtime': 86.2749, 'train_samples_per_second': 10.432, 'train_steps_per_second': 2.62, 'train_loss': 0.2692194153777266, 'epoch': 2.0}
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 226/226 [01:26<00:00, 2.62it/s]
Model saved successfully.

C:\Users\Admin\Desktop\bert\train_fin.py:152: FutureWarning: tokenizer is deprecated and will be removed in version 5.0.0 for Trainer.__init__. Use processing_class instead.
trainer = Trainer(
{'loss': 0.4619, 'grad_norm': nan, 'learning_rate': 4e-05, 'epoch': 1.0}
{'eval_loss': nan, 'eval_runtime': 1.04, 'eval_samples_per_second': 48.079, 'eval_steps_per_second': 12.501, 'epoch': 1.0}
{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 0.0, 'epoch': 2.0}
{'eval_loss': nan, 'eval_runtime': 0.7291, 'eval_samples_per_second': 68.58, 'eval_steps_per_second': 17.831, 'epoch': 2.0}
{'train_runtime': 74.6236, 'train_samples_per_second': 12.061, 'train_steps_per_second': 3.029, 'train_loss': 0.23094445625237658, 'epoch': 2.0}
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 226/226 [01:14<00:00, 3.03it/s]
Model saved successfully.


# Split dataset
train_texts, test_texts, train_labels, test_labels = train_test_split(
    df_sampled["prompt"].tolist(), 
    binary_labels,
    test_size=0.1,
    random_state=42
)

# Create Hugging Face datasets
def create_hf_dataset(texts, labels):
    return Dataset.from_dict({
        "text": texts,
        "labels": labels.astype(np.float32)   # Now contains float32 values
    })
dataset_train = create_hf_dataset(train_texts, train_labels)
print(dataset_train[0]["labels"])  # Should show [1.0, 0.0, ...] instead of [1, 0, ...]
dataset_test = create_hf_dataset(test_texts, test_labels)

# Load tokenizer
checkpoint = "answerdotai/ModernBERT-large"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# Tokenization function
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding=True)

# Tokenize datasets
tokenized_train = dataset_train.map(preprocess_function, batched=True)
tokenized_test = dataset_test.map(preprocess_function, batched=True)

# Load model for multi-label classification
model = AutoModelForSequenceClassification.from_pretrained(
    checkpoint,
    num_labels=len(unique_labels),
    id2label={i: label for i, label in enumerate(unique_labels)},
    label2id={label: i for i, label in enumerate(unique_labels)},
    problem_type="multi_label_classification"
)

train_bsz, val_bsz = 4,4
lr = 8e-5
betas = (0.9, 0.98)
n_epochs = 2
eps = 1e-6
wd = 8e-6

# Training setup (optimized for memory efficiency)
training_args = TrainingArguments(
    output_dir="modernbert_finetuned",
    learning_rate=lr,
    per_device_train_batch_size=train_bsz,
    per_device_eval_batch_size=val_bsz,
    num_train_epochs=n_epochs,
    lr_scheduler_type="linear",
    optim="adamw_torch",
    adam_beta1=betas[0],
    adam_beta2=betas[1],
    adam_epsilon=eps,
    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    bf16=True,
    bf16_full_eval=True,
    push_to_hub=False,
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
)

# Train and save the model
trainer.train()
model.save_pretrained("modernbert_finetuned_model")
tokenizer.save_pretrained("modernbert_finetuned_model")
print("Model saved successfully.")

When I enable flash attention on windows, speed is greatly improved, but resulting model is broken (nan values during eval/inference).

torch 2.5.1+cu124
transformers 4.48.2
triton 3.1.0

Any suggestions on how to fix?

Same problem on latest transformers (4.49 built from github)

I also got a problem running train_st on google collab:

{'loss': 0.0, 'grad_norm': nan, 'learning_rate': 6.69943941354032e-05, 'epoch': 0.2}
20% 500/2442 [1:33:18<6:00:25, 11.14s/it]

loss 0.0 and grad norm nan, doesnt seem to be working...

@WoutDeRijck seems like i'm not the only one having this problem. If you want a quick workaround though, you can switch to sdpa attention.

model = AutoModelForSequenceClassification.from_pretrained(model_name, attn_implementation="sdpa")

But for me its 10x times slower than without fa since it drastically reduces memory use and allows for larger batches.

@WoutDeRijck

1:33:18<6:00:25

Also your training time seems off. It takes for me a hour~two to finetune 1 000 000 examples dataset and 2 minutes for 10 000 examples with flashattention enabled. Maybe you have too small or too high batch size? (i use 32)

I was using mini batch size 64 and train batch size 512. (L4 GPU on Google Collab)

I get this error when just using code on git:
ValueError: Input contains NaN.

@WoutDeRijck

I was using mini batch size 64 and train batch size 512. (L4 GPU on Google Collab)

Maybe batch size is too high, so it uses offloading (Have you tried with batch size 4 and then increase it until it slows down?). I'm not sure since i haven't tested it on L4. But this training time seems very long.

Yes that was the problem indeed for my long training time thanks!
But still, its not learning this way

@WoutDeRijck

But still, its not learning this way

.from_pretrained(model_name, attn_implementation="sdpa")

Will probably train if you enable sdpa attention, but slower. I'm reminding if you need to do training now and don't want to wait for fa2 fix.

Im running the example code. However adding this still resulted in the value error that there are NaNs in the input.
model = SentenceTransformer(model_name, model_kwargs={"attn_implementation": "sdpa"})

It was using sdpa by default.
model = SentenceTransformer(model_name, model_kwargs={"attn_implementation": "flash_attention_2"})

Seems better already
{'loss': 3.8848, 'grad_norm': 4.208026885986328, 'learning_rate': 3.2520325203252037e-05, 'epoch': 0.02}
2% 56/2442 [05:34<3:07:58, 4.73s/it]

I'm not using Flash attention but I still get loss 0 and gradients = NaN , FP16 enabling and disabling does not change anything. Any suggestions why it could be happening ?

Answer.AI org

Some users experience nan when their torch is a bit dated. Updating it might help.

  • Tom Aarsen

Dear @tomaarsen ,

I cannot thank you enough. Simply updating torch fixed the issue.

Thank you so much !!

Edit :

Dear @tomaarsen ,

I cannot thank you enough. Simply updating torch fixed the issue.

Thank you so much !!

Edit:

It seems that I ran the training once after updating torch to 2.6.0 from 2.5.1 and the training seemed to be going fine but as soon as I re-ran it multiple times , the loss is again 0 with gradient norms = NaN

Sign up or log in to comment