Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flash_attn ImportError breaking model loading (Florence-2-base-ft) #31793

Open
2 of 4 tasks
Laz4rz opened this issue Jul 4, 2024 · 8 comments
Open
2 of 4 tasks

flash_attn ImportError breaking model loading (Florence-2-base-ft) #31793

Laz4rz opened this issue Jul 4, 2024 · 8 comments

Comments

@Laz4rz
Copy link

Laz4rz commented Jul 4, 2024

System Info

Transformers .from_pretrained() may break while loading models, in my case models from Florence-2 family. It yields:

ImportError: /home/mikolaj/miniconda3/envs/gemma/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEE

The above is caused by a failed flash_attn import. Which to be honest is strange, because up to this point using flash_attn was not necessary -- dunno if its needed now, but unless flash_attn can be imported, the model will not load.

This happens for the following combination of python, transformers, torch and flash_attn:

python: 3.10.14 (main, May  6 2024, 19:42:50) [GCC 11.2.0]
transformers: 4.42.3
torch: 2.3.1+cu121
flash_attn-2.5.9.post1

It can be fixed by upgrading python to 3.11 or 3.12.

Who can help?

@amyeroberts

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Create environment with the above specified environment.
  2. Run:
from transformers import AutoModelForCausalLM, AutoProcessor
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Florence-2-base-ft",
    trust_remote_code=True,
    revision='refs/pr/6',
).to(device) 

Expected behavior

Model loading properly, without the need to import flash_attn

@LysandreJik
Copy link
Member

Hello @Laz4rz, thanks for your report! The Florence-2 model is hosted on the Hub, with code contributed by the authors. It doesn't live within the transformers library.

Could you put the entirety of your stack trace? I'm not sure I see where the code is failing, and I see in their code that they seem to protect the import of flash attention as well, so I'm curious to see what's happening.

@Laz4rz
Copy link
Author

Laz4rz commented Jul 4, 2024

If flash_attn is installed:

ImportError                               Traceback (most recent call last)
Cell In[1], line 6
      2 import torch
      4 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
----> 6 model = AutoModelForCausalLM.from_pretrained(
      7     "microsoft/Florence-2-base-ft",
      8     trust_remote_code=True,
      9     revision='refs/pr/6',
     10 ).to(device) 
     11 processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft", 
     12     trust_remote_code=True,
     13     revision='refs/pr/6'
     14     )
     16 for param in model.vision_tower.parameters():

File ~/miniconda3/envs/gemma/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:551, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    549 if has_remote_code and trust_remote_code:
    550     class_ref = config.auto_map[cls.__name__]
--> 551     model_class = get_class_from_dynamic_module(
    552         class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
    553     )
    554     _ = hub_kwargs.pop("code_revision", None)
    555     if os.path.isdir(pretrained_model_name_or_path):

File ~/miniconda3/envs/gemma/lib/python3.10/site-packages/transformers/dynamic_module_utils.py:514, in get_class_from_dynamic_module(class_reference, pretrained_model_name_or_path, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, repo_type, code_revision, **kwargs)
    501 # And lastly we get the class inside our newly created module
    502 final_module = get_cached_module_file(
    503     repo_id,
    504     module_file + ".py",
   (...)
    512     repo_type=repo_type,
    513 )
--> 514 return get_class_in_module(class_name, final_module)

File ~/miniconda3/envs/gemma/lib/python3.10/site-packages/transformers/dynamic_module_utils.py:212, in get_class_in_module(class_name, module_path)
    210     sys.modules[name] = module
    211 # reload in both cases
--> 212 module_spec.loader.exec_module(module)
    213 return getattr(module, class_name)

File <frozen importlib._bootstrap_external>:883, in exec_module(self, module)

File <frozen importlib._bootstrap>:241, in _call_with_frames_removed(f, *args, **kwds)

File ~/.cache/huggingface/modules/transformers_modules/microsoft/Florence-2-base-ft/e0b8f375661041228a6431c950adac1a5c539b98/modeling_florence2.py:63
     54 from transformers.modeling_outputs import (
     55     BaseModelOutput,
     56     BaseModelOutputWithPastAndCrossAttentions,
     57     Seq2SeqLMOutput,
     58     Seq2SeqModelOutput,
     59 )
     62 if is_flash_attn_2_available():
---> 63     from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa
     65 logger = logging.get_logger(__name__)
     67 _CONFIG_FOR_DOC = "Florence2Config"

File ~/miniconda3/envs/gemma/lib/python3.10/site-packages/flash_attn/__init__.py:3
      1 __version__ = "2.5.9.post1"
----> 3 from flash_attn.flash_attn_interface import (
      4     flash_attn_func,
      5     flash_attn_kvpacked_func,
      6     flash_attn_qkvpacked_func,
      7     flash_attn_varlen_func,
      8     flash_attn_varlen_kvpacked_func,
      9     flash_attn_varlen_qkvpacked_func,
     10     flash_attn_with_kvcache,
     11 )

File ~/miniconda3/envs/gemma/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py:10
      6 import torch.nn as nn
      8 # isort: off
      9 # We need to import the CUDA kernels after importing torch
---> 10 import flash_attn_2_cuda as flash_attn_cuda
     12 # isort: on
     15 def _get_block_size_n(device, head_dim, is_dropout, is_causal):
     16     # This should match the block sizes in the CUDA kernel

ImportError: /home/mikolaj/miniconda3/envs/gemma/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZN2at4_ops5zeros4callEN3c108ArrayRefINS2_6SymIntEEENS2_8optionalINS2_10ScalarTypeEEENS6_INS2_6LayoutEEENS6_INS2_6DeviceEEENS6_IbEE

If it isn't:

---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
Cell In[1], line 6
      2 import torch
      4 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
----> 6 model = AutoModelForCausalLM.from_pretrained(
      7     "microsoft/Florence-2-base-ft",
      8     trust_remote_code=True,
      9     revision='refs/pr/6',
     10 ).to(device) 
     11 processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base-ft", 
     12     trust_remote_code=True,
     13     revision='refs/pr/6'
     14     )
     16 for param in model.vision_tower.parameters():

File ~/miniconda3/envs/florence/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py:551, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    549 if has_remote_code and trust_remote_code:
    550     class_ref = config.auto_map[cls.__name__]
--> 551     model_class = get_class_from_dynamic_module(
    552         class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
    553     )
    554     _ = hub_kwargs.pop("code_revision", None)
    555     if os.path.isdir(pretrained_model_name_or_path):

File ~/miniconda3/envs/florence/lib/python3.10/site-packages/transformers/dynamic_module_utils.py:502, in get_class_from_dynamic_module(class_reference, pretrained_model_name_or_path, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, repo_type, code_revision, **kwargs)
    500     code_revision = revision
    501 # And lastly we get the class inside our newly created module
--> 502 final_module = get_cached_module_file(
    503     repo_id,
    504     module_file + ".py",
    505     cache_dir=cache_dir,
    506     force_download=force_download,
    507     resume_download=resume_download,
    508     proxies=proxies,
    509     token=token,
    510     revision=code_revision,
    511     local_files_only=local_files_only,
    512     repo_type=repo_type,
    513 )
    514 return get_class_in_module(class_name, final_module)

File ~/miniconda3/envs/florence/lib/python3.10/site-packages/transformers/dynamic_module_utils.py:327, in get_cached_module_file(pretrained_model_name_or_path, module_file, cache_dir, force_download, resume_download, proxies, token, revision, local_files_only, repo_type, _commit_hash, **deprecated_kwargs)
    324     raise
    326 # Check we have all the requirements in our environment
--> 327 modules_needed = check_imports(resolved_module_file)
    329 # Now we move the module inside our cached dynamic modules.
    330 full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule

File ~/miniconda3/envs/florence/lib/python3.10/site-packages/transformers/dynamic_module_utils.py:182, in check_imports(filename)
    179         missing_packages.append(imp)
    181 if len(missing_packages) > 0:
--> 182     raise ImportError(
    183         "This modeling file requires the following packages that were not found in your environment: "
    184         f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
    185     )
    187 return get_relative_imports(filename)

ImportError: This modeling file requires the following packages that were not found in your environment: flash_attn. Run `pip install flash_attn`

@LysandreJik
Copy link
Member

huh it seems like it entered the conditional statement if is_flash_attn_2_available():, do you have FA2 installed in your env?

@Laz4rz
Copy link
Author

Laz4rz commented Jul 4, 2024

Ok that was a good direction, seems that _is_package_available("flash_attn") living inside is_flash_attn_2_available() is bugged.

Commenting out both checks for flash_attn allows to load the model and run inference correctly. Give me a few minutes and I'll try to pinpoint what exactly is wrong.

Probably only check if the lib is installed, but not if it can be imported. If we add this check it should work as expected -- but still should yield some information so that users know flash_attn is not being used due to import error.

@Laz4rz
Copy link
Author

Laz4rz commented Jul 4, 2024

Welp, looks like I was a little too quick -- got so many conda envs that I didnt see that I commented the import protection in dynamic_module_utils.py line 181 yesterday. So that'
s my bad.

This however doesn't change the final outcome. By proper handling of the flash_attn import Florence can be used and yields same result as with flash_attn installed. Changing the behavior of _is_package_available() makes it coherent with the name -- now it only checks whether the package is installed, no if it can be imported and used, which the name would suggest.

So there are two different things:

  1. Florence demanding flash_attn by check in transformers/dynamic_module_utils.py, line 162, but working properly without it
def check_imports(filename: Union[str, os.PathLike]) -> List[str]:
    """
    Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
    library is missing.

    Args:
        filename (`str` or `os.PathLike`): The module file to check.

    Returns:
        `List[str]`: The list of relative imports in the file.
    """
    imports = get_imports(filename)
    missing_packages = []
    for imp in imports:
        try:
            importlib.import_module(imp)
        except Exception as e:
            pass

    if len(missing_packages) > 0:
        raise ImportError(
            "This modeling file requires the following packages that were not found in your environment: "
            f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
        )

    return get_relative_imports(filename)
  1. Leaky _is_package_available() in transformers/utils/import_utils.py, line 42, that doesnt handle installed, but unable to be imported libraries, even though if this is the expected behavior due to it's name and how it handles Torch. When torch is installed, but can't be imported it returns False, with comment "If the package can't be imported, it's not available".
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
    # Check if the package spec exists and grab its version to avoid importing a local directory
    package_exists = importlib.util.find_spec(pkg_name) is not None
    package_version = "N/A"
    if package_exists:
        try:
            # Primary method to get the package version
            package_version = importlib.metadata.version(pkg_name)
        except importlib.metadata.PackageNotFoundError:
            # Fallback method: Only for "torch" and versions containing "dev"
            if pkg_name == "torch":
                try:
                    package = importlib.import_module(pkg_name)
                    temp_version = getattr(package, "__version__", "N/A")
                    # Check if the version contains "dev"
                    if "dev" in temp_version:
                        package_version = temp_version
                        package_exists = True
                    else:
                        package_exists = False
                except ImportError:
                    # If the package can't be imported, it's not available
                    package_exists = False
            else:
                # For packages other than "torch", don't attempt the fallback and set as not available
                package_exists = False
        logger.debug(f"Detected {pkg_name} version: {package_version}")
    if return_version:
        return package_exists, package_version
    else:
        return package_exists

I'm not sure how to handle 1. since that's probably on Microsoft's side if they want to allow running it with or without flash_attn. But I think 2. could benefit from catching and handling all libraries in the same way, so either:

a). We add check for being able to import for all libraries, not only Torch as it is currently
b). We remove the check for being able to import torch

I guess a). would be preferred. The PR for this is: #31798

@amyeroberts
Copy link
Collaborator

Hi @Laz4rz, thanks for raising this issue and opening a PR to address.

We definitely don't want to do a) as this will massively increase the time it takes to load the library. It's not obvious to be we'd want to do b) either.

Regarding @LysandreJik's question above, what do you get when running is_flash_attn_available and if you try to import flash attention modules? It'd be useful to try and pin down if this is a general issue in the flash attention module imports or perhaps something in your env setup

@Laz4rz
Copy link
Author

Laz4rz commented Jul 11, 2024

Hey, thanks for taking a look @amyeroberts.

is_flash_attn_available returns True, due to _is_package_available() (a bit) counterintuitive behavior.

Importing flash_attn yields: Dao-AILab/flash-attention#1027, and from what I've read it is a common flash_attn problem. Upgrading to Python 3.11 fixes the issue, also Colab which is running on 3.10 (with same libraries versions) has no problem with flash_attn. Tried it on few different VMs.

@amyeroberts
Copy link
Collaborator

is_flash_attn_available returns True, due to _is_package_available() (a bit) counterintuitive behavior.

I think the name is_flash_attn_available is available is fairly sensible - it checks whether the library is available in the environment, but it's not checking if it's compatible with other libraries, the hardware or can be imported.

Dao-AILab/flash-attention#1027

The linked issue looks like a problem with the cuda install, and it's compatibility with the environment. In other issues which report the same think - it appears to do with the pytorch installed e.g. Dao-AILab/flash-attention#836 and Dao-AILab/flash-attention#919. This would explain why python 3.10 works in some cases and not others.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants