-
Notifications
You must be signed in to change notification settings - Fork 25.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transformer models are not deterministic when using Flash Attention 2 #31787
Comments
you’re running a script to test the consistency of different attention implementations using PyTorch and Flash Attention 2. While sdpa and eager implementations work as expected, flash_attention_2 is giving inconsistent results despite following the PyTorch reproducibility guidelines. Here’s what you can do next:
|
Is this AI-generated @Varma0604 ? Thanks for your report @YunfanZhang42! @ArthurZucker will be able to help once he's back from his holiday (next week). Thank you for your patience! |
@LysandreJik Thank you for your reply. No need to rush here as this is not an urgent issue. And yes, I think @Varma0604 is spam posting using LLMs. I did a few experiments on decoder only models for both the forward and the backward pass and the results are interesting. Here are the steps to reproduce:
import torch
import random
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
def test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="flash_attention_2"):
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation, device_map="cuda:0")
# Define the prompt
prompt = "My favourite condiment is"
# Tokenize the input and send to appropriate device
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda:0")
# Perform multiple runs and check for consistency
output_consistency = True
gradient_consistency = True
first_output = None
first_gradients = None
# Run the model 10 times
for _ in range(10):
# Ensure model is in training mode
model.train()
# Clear previous gradients
model.zero_grad()
# Get the model's output (logits)
outputs = model(**model_inputs)
# Use the sum of logits as a simple loss
loss = outputs.logits.sum()
# Compute gradients
loss.backward()
# Extract gradients and store them on CPU to avoid OOM
current_gradients = {name: param.grad.cpu() for name, param in model.named_parameters() if param.grad is not None}
# If it's the first run, store the output logits
if first_output is None:
first_output = outputs.logits
else:
# Compare current output with the first output
if not torch.equal(first_output, outputs.logits):
output_consistency = False
# If it's the first run, store the gradients
if first_gradients is None:
first_gradients = current_gradients
else:
# Compare current gradients with the first gradients
for name, grad in current_gradients.items():
if not torch.equal(grad, first_gradients[name]):
gradient_consistency = False
# Print the results
print(f"Model: {model_name}, Attention Implementation: {attn_implementation}, Output Consistency: {output_consistency}, Gradient Consistency: {gradient_consistency}")
if __name__ == "__main__":
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
test_consistency(model_name="gpt2", attn_implementation="flash_attention_2")
# GPT2 does not support sdpa implementations
# test_consistency(model_name="gpt2", attn_implementation="sdpa")
test_consistency(model_name="gpt2", attn_implementation="eager")
test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="flash_attention_2")
test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="sdpa")
test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="eager")
test_consistency(model_name="meta-llama/Meta-Llama-3-8B", attn_implementation="flash_attention_2")
test_consistency(model_name="meta-llama/Meta-Llama-3-8B", attn_implementation="sdpa")
test_consistency(model_name="meta-llama/Meta-Llama-3-8B", attn_implementation="eager") What I get is
So GPT-2 is not deterministic under both Flash Attention 2 and the default attention implementation. Mistral and Llama 3 demostrate deterministic behavior regardless of the type of the attention that is used. I also checked the Flash Attention 2 documentation, and it seems to me that after Flash Attention v2.5.0, the foward pass should always be deterministic, and the backward could be made deterministic by passing So my suspicions are either (1) BART and GPT-2 are FP32 models, but we are using BF16 mode, which could make the numerical stability issues more pronounced or (2) transformers' attention implementation for certain models might be problematic and could trigger unexpected behavior under some conditions. Let me know how I can help further, and thanks a lot! |
System Info
transformers
version: 4.41.2Who can help?
@ArthurZucker @stevhliu
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
export CUBLAS_WORKSPACE_CONFIG=:4096:8
as required by the PyTorch Reproducibility Guide.Expected behavior
Since we have closely followed the PyTorch reproducibility guide, we expect to see:
In reality, we would get:
Based on this experiment, it seems Flash Attention 2 is not deterministic in the forward pass, and according to Dao-AILab/flash-attention#414, Flash Attention 2 would not be deterministic for the backward pass as well, so this also affects training.
It's worth noting that PyTorch
sdpa
implementation may also select a non-deterministic execution path depending on the input dimension, but it would throw the following error/warning whentorch.use_deterministic_algorithms
is set toTrue
, so it would not fail silently.Therefore, I think we should throw a similar error when the
attn_implementation
is set toflash_attention_2
and
use_deterministic_algorithms
is set toTrue
. At least, this behavior should be documented in Transformer docs.Finally, thank you for your contribution to the deep learning and open source community. Please let me know how I can contribute here.
The text was updated successfully, but these errors were encountered: