Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLaVa HF format + accelerate load_checkpoint_and_dispatch warns that tie_weights is needed even after calling tie_weights #31801

Open
2 of 4 tasks
btruhand opened this issue Jul 5, 2024 · 1 comment

Comments

@btruhand
Copy link

btruhand commented Jul 5, 2024

System Info

transformers version: 4.42.3
python version: 3.10.14
accelerate version: 0.31.0

Who can help?

@muellerzr

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Here is a snippet of my code:

class Wrapper:
    def __init__(self, model: LlavaForConditionalGeneration, processor: LlavaProcessor):
        self.model: LlavaForConditionalGeneration = model
        self.processor: LlavaProcessor = processor

    @classmethod
    def from_pretrained(cls, model_path_or_id: str, **loading_config) -> Self:
        """Load the model from the given path or Hugging Face model id

        Args:
            model_path_or_id (str): Path to the model or model id
        """
        model = LlavaForConditionalGeneration.from_pretrained(
            model_path_or_id, **loading_config
        )
        processor = AutoProcessor.from_pretrained(model_path_or_id)
        return cls(model, processor)

    @classmethod
    def from_config(cls, model_path_or_id: str, **loading_config) -> Self:
        """Loads the model's configuration and architecture from the given path or Hugging Face model id

        The model's weights are not loaded yet, and should be loaded via a call to the `load` method.

        Args:
            model_path_or_id (str): Path to the model or model id
        """
        if os.path.isdir(model_path_or_id):
            weights_location = model_path_or_id
        else:
            from huggingface_hub import snapshot_download

            weights_location = snapshot_download(model_path_or_id)

        config = AutoConfig.from_pretrained(weights_location)
        with init_empty_weights():
            model = LlavaForConditionalGeneration._from_config(config, **loading_config)

        processor = AutoProcessor.from_pretrained(weights_location)

        return cls(model, processor)

    def load(
        self,
        device_map: Union[str, dict[str, Union[int, torch.dtype]]] = "auto",
        **kwargs,
    ):
        """Load the model weights according to the specified device map. See Accelerate's documentation for more details.
        https://huggingface.co/docs/accelerate/main/en/concept_guides/big_model_inference#designing-a-device-map
        :param device_map: device map to load the model weights, defaults to "auto"
        :type device_map: str | dict[str, int]
        """
        kwargs["device_map"] = device_map
        kwargs["dtype"] = self.model.config.torch_dtype if kwargs.get('dtype') is None else kwargs.get('dtype')

        self.model.tie_weights()

        self.model = load_checkpoint_and_dispatch(
            self.model,
            checkpoint=self.model.config._name_or_path,
            no_split_module_classes=self.model._get_no_split_modules(device_map),
            **kwargs,
        )

Then I run:

import torch

wrapper = Wrapper.from_config('llmingbt/product-taxonomy-llava-hf-v1.6-13b', attn_implementation='flash_attention_2', torch_dtype=torch.float16)
wrapper.load(device_map={'': 'cuda:0'}, dtype=torch.float16)

I will then get the warning message The model weights are not tied. Please use the tie_weights method before using the infer_auto_device function. even though in load I already call tie_weights

The warning message isn't shown when I use LlavaForConditionalGeneration.from_pretrained with the same setup

Expected behavior

No warning message regarding tie_weights is shown when using accelerate's load_checkpoint_and_dispatch

@LysandreJik
Copy link
Member

cc @SunMarc as well :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants