Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformer models are not deterministic when using Flash Attention 2 #31787

Open
2 of 4 tasks
YunfanZhang42 opened this issue Jul 4, 2024 · 3 comments
Open
2 of 4 tasks

Comments

@YunfanZhang42
Copy link

YunfanZhang42 commented Jul 4, 2024

System Info

  • transformers version: 4.41.2
  • Platform: Linux-6.5.0-27-generic-x86_64-with-glibc2.35
  • Python version: 3.10.12
  • Huggingface_hub version: 0.23.4
  • Safetensors version: 0.4.2
  • Accelerate version: 0.29.3
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.2.2+cu118 (True)
  • Tensorflow version (GPU?): 2.16.1 (False)
  • Flax version (CPU?/GPU?/TPU?): 0.7.0 (gpu)
  • Jax version: 0.4.18
  • JaxLib version: 0.4.18
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help?

@ArthurZucker @stevhliu

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Install Flash Attention 2. I am using version 2.5.6 but this should not matter.
  2. Run export CUBLAS_WORKSPACE_CONFIG=:4096:8 as required by the PyTorch Reproducibility Guide.
  3. Run the following script
import random
import torch
import numpy as np
from transformers import BartForConditionalGeneration, BartTokenizer


def test_consistency(attn_implementation="flash_attention_2"):
    # Load the model and tokenizer
    tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
    model = BartForConditionalGeneration.from_pretrained("facebook/bart-base", torch_dtype=torch.bfloat16, attn_implementation=attn_implementation, device_map="cuda:0")

    # Define the prompt
    prompt = "My favourite condiment is"

    # Tokenize the input and send to appropriate device
    model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda:0")

    # Disable gradient calculations
    with torch.no_grad():
        # Store the first output to compare with subsequent outputs
        first_output = None
        consistent = True

        # Perform inference 1000 times
        for i in range(1000):
            # Get the model's output (logits)
            outputs = model(**model_inputs)

            # Get logits of the last token from the output
            logits = outputs.logits[:, -1, :]

            # If it's the first run, store the output logits
            if first_output is None:
                first_output = logits
            else:
                # Compare current output with the first output
                if not torch.equal(first_output, logits):
                    consistent = False
                    break

        # Output whether all runs produced the same probabilities
        print(f"Using attention implementation {attn_implementation}, Consistent: {consistent}")


if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    torch.use_deterministic_algorithms(True, warn_only=True)
    torch.backends.cudnn.benchmark = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    test_consistency("flash_attention_2")
    test_consistency("sdpa")
    test_consistency("eager")

Expected behavior

Since we have closely followed the PyTorch reproducibility guide, we expect to see:

Using attention implementation flash_attention_2, Consistent: True
Using attention implementation sdpa, Consistent: True
Using attention implementation eager, Consistent: True

In reality, we would get:

Using attention implementation flash_attention_2, Consistent: False
Using attention implementation sdpa, Consistent: True
Using attention implementation eager, Consistent: True

Based on this experiment, it seems Flash Attention 2 is not deterministic in the forward pass, and according to Dao-AILab/flash-attention#414, Flash Attention 2 would not be deterministic for the backward pass as well, so this also affects training.

It's worth noting that PyTorch sdpa implementation may also select a non-deterministic execution path depending on the input dimension, but it would throw the following error/warning when torch.use_deterministic_algorithms is set to True, so it would not fail silently.

UserWarning: Memory Efficient attention defaults to a non-deterministic algorithm. To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False). (Triggered internally at ../aten/src/ATen/native/transformers/cuda/attention_backward.cu:449.)

Therefore, I think we should throw a similar error when the attn_implementation is set to flash_attention_2
and use_deterministic_algorithms is set to True. At least, this behavior should be documented in Transformer docs.

Finally, thank you for your contribution to the deep learning and open source community. Please let me know how I can contribute here.

@YunfanZhang42 YunfanZhang42 changed the title Transformer models are not deterministic when using Flash Atttention 2 Transformer models are not deterministic when using Flash Attention 2 Jul 4, 2024
@Varma0604
Copy link

you’re running a script to test the consistency of different attention implementations using PyTorch and Flash Attention 2. While sdpa and eager implementations work as expected, flash_attention_2 is giving inconsistent results despite following the PyTorch reproducibility guidelines.

Here’s what you can do next:

1.	Double-check your settings: Ensure all flags and settings for determinism are correctly applied. Small misconfigurations can cause issues.
2.	Review documentation: Check the Flash Attention 2 documentation for any notes on determinism and if there are specific configurations needed.
3.	Update your tools: Make sure you’re using the latest versions of PyTorch and Flash Attention 2, as updates may resolve your issue.
4.	Community support: Post your issue on PyTorch forums or Flash Attention 2’s GitHub repository. The community might have encountered and solved similar issues.
5.	Contribute: If you find that Flash Attention 2 is inherently non-deterministic, consider suggesting a change or adding a warning in the documentation or code to help others.

@LysandreJik
Copy link
Member

Is this AI-generated @Varma0604 ?


Thanks for your report @YunfanZhang42! @ArthurZucker will be able to help once he's back from his holiday (next week). Thank you for your patience!

@YunfanZhang42
Copy link
Author

@LysandreJik Thank you for your reply. No need to rush here as this is not an urgent issue. And yes, I think @Varma0604 is spam posting using LLMs.

I did a few experiments on decoder only models for both the forward and the backward pass and the results are interesting. Here are the steps to reproduce:

  1. Install Flash Attention 2. I am using version 2.5.6
  2. Run export CUBLAS_WORKSPACE_CONFIG=:4096:8 as required by the PyTorch Reproducibility Guide.
  3. Run the following script
import torch
import random
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer


def test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="flash_attention_2"):
    # Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation, device_map="cuda:0")

    # Define the prompt
    prompt = "My favourite condiment is"

    # Tokenize the input and send to appropriate device
    model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda:0")

    # Perform multiple runs and check for consistency
    output_consistency = True
    gradient_consistency = True
    first_output = None
    first_gradients = None

    # Run the model 10 times
    for _ in range(10):
        # Ensure model is in training mode
        model.train()

        # Clear previous gradients
        model.zero_grad()

        # Get the model's output (logits)
        outputs = model(**model_inputs)

        # Use the sum of logits as a simple loss
        loss = outputs.logits.sum()

        # Compute gradients
        loss.backward()

        # Extract gradients and store them on CPU to avoid OOM
        current_gradients = {name: param.grad.cpu() for name, param in model.named_parameters() if param.grad is not None}

        # If it's the first run, store the output logits
        if first_output is None:
            first_output = outputs.logits
        else:
            # Compare current output with the first output
            if not torch.equal(first_output, outputs.logits):
                output_consistency = False

        # If it's the first run, store the gradients
        if first_gradients is None:
            first_gradients = current_gradients
        else:
            # Compare current gradients with the first gradients
            for name, grad in current_gradients.items():
                if not torch.equal(grad, first_gradients[name]):
                    gradient_consistency = False
                
    # Print the results
    print(f"Model: {model_name}, Attention Implementation: {attn_implementation}, Output Consistency: {output_consistency}, Gradient Consistency: {gradient_consistency}")


if __name__ == "__main__":
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    torch.use_deterministic_algorithms(True, warn_only=True)
    torch.backends.cudnn.benchmark = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    test_consistency(model_name="gpt2", attn_implementation="flash_attention_2")
    # GPT2 does not support sdpa implementations
    # test_consistency(model_name="gpt2", attn_implementation="sdpa")
    test_consistency(model_name="gpt2", attn_implementation="eager")

    test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="flash_attention_2")
    test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="sdpa")
    test_consistency(model_name="mistralai/Mistral-7B-v0.1", attn_implementation="eager")

    test_consistency(model_name="meta-llama/Meta-Llama-3-8B", attn_implementation="flash_attention_2")
    test_consistency(model_name="meta-llama/Meta-Llama-3-8B", attn_implementation="sdpa")
    test_consistency(model_name="meta-llama/Meta-Llama-3-8B", attn_implementation="eager")

What I get is

Model: gpt2, Attention Implementation: flash_attention_2, Output Consistency: False, Gradient Consistency: False
Model: gpt2, Attention Implementation: eager, Output Consistency: False, Gradient Consistency: False
Model: mistralai/Mistral-7B-v0.1, Attention Implementation: flash_attention_2, Output Consistency: True, Gradient Consistency: True
Model: mistralai/Mistral-7B-v0.1, Attention Implementation: sdpa, Output Consistency: True, Gradient Consistency: True
Model: mistralai/Mistral-7B-v0.1, Attention Implementation: eager, Output Consistency: True, Gradient Consistency: True
Model: meta-llama/Meta-Llama-3-8B, Attention Implementation: flash_attention_2, Output Consistency: True, Gradient Consistency: True
Model: meta-llama/Meta-Llama-3-8B, Attention Implementation: sdpa, Output Consistency: True, Gradient Consistency: True
Model: meta-llama/Meta-Llama-3-8B, Attention Implementation: eager, Output Consistency: True, Gradient Consistency: True

So GPT-2 is not deterministic under both Flash Attention 2 and the default attention implementation. Mistral and Llama 3 demostrate deterministic behavior regardless of the type of the attention that is used.

I also checked the Flash Attention 2 documentation, and it seems to me that after Flash Attention v2.5.0, the foward pass should always be deterministic, and the backward could be made deterministic by passing deterministic=True, which we did not do for either Mistral or Llama IMHO.

So my suspicions are either (1) BART and GPT-2 are FP32 models, but we are using BF16 mode, which could make the numerical stability issues more pronounced or (2) transformers' attention implementation for certain models might be problematic and could trigger unexpected behavior under some conditions.

Let me know how I can help further, and thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants