You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
Here is a snippet of my code:
classWrapper:
def__init__(self, model: LlavaForConditionalGeneration, processor: LlavaProcessor):
self.model: LlavaForConditionalGeneration=modelself.processor: LlavaProcessor=processor@classmethoddeffrom_pretrained(cls, model_path_or_id: str, **loading_config) ->Self:
"""Load the model from the given path or Hugging Face model id Args: model_path_or_id (str): Path to the model or model id """model=LlavaForConditionalGeneration.from_pretrained(
model_path_or_id, **loading_config
)
processor=AutoProcessor.from_pretrained(model_path_or_id)
returncls(model, processor)
@classmethoddeffrom_config(cls, model_path_or_id: str, **loading_config) ->Self:
"""Loads the model's configuration and architecture from the given path or Hugging Face model id The model's weights are not loaded yet, and should be loaded via a call to the `load` method. Args: model_path_or_id (str): Path to the model or model id """ifos.path.isdir(model_path_or_id):
weights_location=model_path_or_idelse:
fromhuggingface_hubimportsnapshot_downloadweights_location=snapshot_download(model_path_or_id)
config=AutoConfig.from_pretrained(weights_location)
withinit_empty_weights():
model=LlavaForConditionalGeneration._from_config(config, **loading_config)
processor=AutoProcessor.from_pretrained(weights_location)
returncls(model, processor)
defload(
self,
device_map: Union[str, dict[str, Union[int, torch.dtype]]] ="auto",
**kwargs,
):
"""Load the model weights according to the specified device map. See Accelerate's documentation for more details. https://huggingface.co/docs/accelerate/main/en/concept_guides/big_model_inference#designing-a-device-map :param device_map: device map to load the model weights, defaults to "auto" :type device_map: str | dict[str, int] """kwargs["device_map"] =device_mapkwargs["dtype"] =self.model.config.torch_dtypeifkwargs.get('dtype') isNoneelsekwargs.get('dtype')
self.model.tie_weights()
self.model=load_checkpoint_and_dispatch(
self.model,
checkpoint=self.model.config._name_or_path,
no_split_module_classes=self.model._get_no_split_modules(device_map),
**kwargs,
)
I will then get the warning message The model weights are not tied. Please use the tie_weights method before using the infer_auto_device function. even though in load I already call tie_weights
The warning message isn't shown when I use LlavaForConditionalGeneration.from_pretrained with the same setup
Expected behavior
No warning message regarding tie_weights is shown when using accelerate's load_checkpoint_and_dispatch
The text was updated successfully, but these errors were encountered:
System Info
transformers version:
4.42.3
python version:
3.10.14
accelerate version:
0.31.0
Who can help?
@muellerzr
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Here is a snippet of my code:
Then I run:
I will then get the warning message
The model weights are not tied. Please use the tie_weights method before using the infer_auto_device function.
even though inload
I already calltie_weights
The warning message isn't shown when I use
LlavaForConditionalGeneration.from_pretrained
with the same setupExpected behavior
No warning message regarding
tie_weights
is shown when using accelerate'sload_checkpoint_and_dispatch
The text was updated successfully, but these errors were encountered: