Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup model init on CPU (by 10x+ for llama-3-8B as one example) #31771

Merged
merged 42 commits into from
Jul 16, 2024

Conversation

muellerzr
Copy link
Contributor

@muellerzr muellerzr commented Jul 3, 2024

What does this PR do?

This PR introduces utilizing _assign_to_params_buffers as a way to speed up weight loading if the dtypes of models are the same. By doing so, we can lazily load in model weights on the fly when an input is passed in, decreasing the TTL of both training and inference wrt the speed of your disk.

The benefit of this is now low_cpu_mem_usage and this have ~ the same memory usage if and only if the model weights precision == the loaded model precision.

For example, this will only work if you load in llama-3-8B in bfloat16, since the weights and architecture are both in bfloat16.

Unsupported models

Some models also do not support buffer param assignments. I've added a new _supports_param_buffer_assignment attr to the specific models that do not, while eventually it'd be good to investigate this if any models fail the test_from_pretrained_no_checkpoint tests, they need to set this attribute in their model config (similar to how VisionEncoderDecoderModel has supports_gradient_checkpointing=False).

Example model init time:

Model Before fix (s) After fix (s)
llama-3-8B 1.858 0.183
llama-3-70B 30.36 1.238

Example model throughput:

First pass

Note: this is special as the model weights get loaded in fully here, so it will be slower

Model Before fix (tok/s) After fix (tok/s)
llama-3-8B 2.416 2.382
llama-3-70B 0.286 0.256

Afterwards both lazy-loaded and non-lazy loaded inference times are the same (since we no longer need to read the weights in)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@LysandreJik @amyeroberts @SunMarc

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@muellerzr
Copy link
Contributor Author

Note: we're seeing some failures with encoder/decoder models that don't have tied weights. Not fully sure what's up there but @SunMarc is investigating

@muellerzr
Copy link
Contributor Author

This can allegedly also increase throughput from model.generate()... or at least that's what I'm seeing.

Setup:

import time
from accelerate.utils import set_seed
from transformers import LlamaForCausalLM, AutoTokenizer

set_seed(42)


file_size = 132 # Size in GB of the weights
factory_model = LlamaForCausalLM.from_pretrained("/mnt/superfast/llama-3-8B")

tokenizer = AutoTokenizer.from_pretrained("/mnt/superfast/llama-3-8B")
inputs = tokenizer("Blue is my favorite color. What is my favorite color?", return_tensors="pt")

start_time = time.time()
output = factory_model.generate(**inputs, max_new_tokens=20, num_return_sequences=1)

end_time = time.time()
time_taken = end_time - start_time
print(f"inference time={time_taken:.3f} seconds")
print(f"speed={file_size/time_taken:.3f} GB/second")

new_tokens = len(output[0]) - inputs.input_ids.shape[1]

print(f'tok/s={new_tokens/time_taken:.3f}')

Current setup in HF:

inference time=24.841 seconds
speed=5.314 GB/second
tok/s=0.805

New version:

inference time=9.205 seconds
speed=14.341 GB/second
tok/s=2.173

@muellerzr muellerzr changed the title Speedup model loading by 1,100%+! Speedup model loading (by 1,100%+) and .generate() on CPU (by 925%+)! Jul 3, 2024
@muellerzr
Copy link
Contributor Author

Did some tests with optimum-benchmark, new throughput results on CPU:

  • transformers main: 0.24 tokens/s
  • transformers my branch: 2.46 tokens/s

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow - what an impactful PR 🔥 !

Only question is about deepspeed compatibility and compatibility across pytorch versions.

cc @gante for reference for the generate speedups

Comment on lines 697 to 707
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we tested the new code with deepspeed too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet! To come next week :)

Copy link
Contributor

@stas00 stas00 Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this can work if the model has already been sharded under deepspeed ZeRO-3, because it hijacks the param tensors and the loading will ether fail (or worse silently remain random).

But I'd suggest to check in with the deepspeed team - perhaps they have some more recent tricks that will accomplish that faster.

Copy link
Contributor

@stas00 stas00 Jul 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But in general - under zero.Init context the model is already spread out across the gpus, so you can't just overwrite its shards - w/o the machinery you propose to delete.

I think the just published Universal Checkpoint might be usable here to broadcast the updated tensor shards to each gpu, w/o needing to gather their content first.

@tjruwase, if possible could you please assist if there is a way to update the already sharded tensors in a faster way?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course if this doesn't work, then you'd need to have 2 code branches. The users who use zero.Init won't mind waiting, because the huge models they want to load won't load onto a single gpu, so cost of slower loading speed is a trade off here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjruwase, if possible could you please assist if there is a way to update the already sharded tensors in a faster way?

@stas00, I am just catching up on this, but my initial thought is wouldn't the following feature help for this case?
https://deepspeed.readthedocs.io/en/latest/zero3.html#modifying-partitioned-states

Comment on lines 690 to 706
if is_deepspeed_zero3_enabled():
import deepspeed

# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this should be as simple as migrating this chunk and having then the load_state_dict occur under this context manager.

@muellerzr
Copy link
Contributor Author

@msaroufim if you have a moment, could you give this a look to check that everything makes sense here per my understanding of how we should be loading in model weights, etc? Would be very appreciative of your eyes/take on this

@muellerzr
Copy link
Contributor Author

muellerzr commented Jul 3, 2024

For transparency, here is the script I'm using: https://gist.github.com/muellerzr/7239668f61baff5726f556d30d2af5f5

@gante
Copy link
Member

gante commented Jul 4, 2024

💛 💛 💛 (local-gemma is very happy with this)

@muellerzr muellerzr changed the title Speedup model loading (by 1,100%+) and .generate() on CPU (by 925%+)! Speedup model loading (by ~10x) and .generate() on CPU (by ~10x)! Jul 4, 2024
@muellerzr
Copy link
Contributor Author

Got confirmation from Mark S (thanks Mark for looking this over) and this is indeed correct 🔥

@muellerzr
Copy link
Contributor Author

A very important caveat @SunMarc and I discovered today: set_module_tensor_to_device is very slow, so if users do device_map="auto" they will not see this speedup and it will still be slow.

This works for now as a small fix to users who load everything on CPU instead first/don't do device_map="auto", but there is more work we need to do

@SunMarc
Copy link
Member

SunMarc commented Jul 5, 2024

About the failing tests, we had the following ones:

FAILED tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py::SeamlessM4Tv2ModelWithSpeechInputTest::test_load_save_without_tied_weights - AssertionError: SeamlessM4Tv2Model: Tensor text_encoder.embed_tokens.weight: Tensor-likes are not close!

Mismatched elements: 114 / 120 (95.0%)
Greatest absolute difference: 0.10039517283439636 at index (12, 5) (up to 1e-05 allowed)
Greatest relative difference: 63.95556640625 at index (16, 0) (up to 1.3e-06 allowed)
FAILED tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersModelTest::test_load_save_without_tied_weights - AssertionError: SwitchTransformersModel: Tensor encoder.embed_tokens.weight: Tensor-likes are not close!

Mismatched elements: 3163 / 3168 (99.8%)
Greatest absolute difference: 0.010853035375475883 at index (47, 11) (up to 1e-05 allowed)
Greatest relative difference: 8950.5673828125 at index (4, 15) (up to 1.3e-06 allowed)
FAILED tests/models/m2m_100/test_modeling_m2m_100.py::M2M100ModelTest::test_load_save_without_tied_weights - AssertionError: M2M100Model: Tensor encoder.embed_tokens.weight: Tensor-likes are not close!

Mismatched elements: 1566 / 1584 (98.9%)
Greatest absolute difference: 0.0897781103849411 at index (57, 3) (up to 1e-05 allowed)
Greatest relative difference: 2123.130126953125 at index (23, 3) (up to 1.3e-06 allowed)
FAILED tests/models/switch_transformers/test_modeling_switch_transformers.py::SwitchTransformersEncoderOnlyModelTest::test_load_save_without_tied_weights - AssertionError: SwitchTransformersEncoderModel: Tensor encoder.embed_tokens.weight: Tensor-likes are not close!

Mismatched elements: 3159 / 3168 (99.7%)
Greatest absolute difference: 0.010908672586083412 at index (1, 3) (up to 1e-05 allowed)
Greatest relative difference: 4020.068115234375 at index (39, 5) (up to 1.3e-06 allowed)
FAILED tests/models/seamless_m4t_v2/test_modeling_seamless_m4t_v2.py::SeamlessM4Tv2ModelWithTextInputTest::test_load_save_without_tied_weights - AssertionError: SeamlessM4Tv2Model: Tensor text_encoder.embed_tokens.weight: Tensor-likes are not close!

Mismatched elements: 114 / 120 (95.0%)
Greatest absolute difference: 0.08684197068214417 at index (9, 0) (up to 1e-05 allowed)
Greatest relative difference: 1055.23974609375 at index (1, 0) (up to 1.3e-06 allowed)
FAILED tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py::Wav2Vec2BertModelTest::test_save_and_load_from_pretrained - AssertionError: 0.6309062 not less than or equal to 1e-05
=========== 6 failed, 2815 passed, 3759 skipped in 68.00s (0:01:08) ============

it is a bit complicated but basically, these tests were not supposed to pass initially. However, they passed in the end because the weights were tied by default (even when config.tie_word_embeddings =False) and modifying shared layers caused the other layers to be modified too without needing to retie the weights.

For example, this is the architecture of SeamlessM4Tv2Model. We see that we have a shared layer by default without any config attribute to make it optional.

    def __init__(self, config, current_modality="text"):
        super().__init__(config)

        self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)

        self.text_encoder = SeamlessM4Tv2Encoder(config, self.shared)
        self.speech_encoder = SeamlessM4Tv2SpeechEncoder(config)
        self.text_decoder = SeamlessM4Tv2Decoder(config, self.shared)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

With this PR, we've set assign =True and by doing this, we recreated the shared layer, breaking the tied weights. Since tie_weights does nothing (config.tie_word_embeddings =False), we get different values in the end.

assign (bool, optional): When ``False``, the properties of the tensors
                in the current module are preserved while when ``True``, the
                properties of the Tensors in the state dict are preserved. The only
                exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s
                for which the value from the module is preserved.
                Default: ``False``

So, the conclusion is that there should be no problem with the modification you did, I just need to skip/modify the tests. cc @muellerzr

In the future, we just need to make sure that when we have shared weight by default, we skip the tests or add the possibility to remove these shared weights.

@muellerzr
Copy link
Contributor Author

Okay ran some traces and I think this makes sense to me now.

Compare the following first calls to .generate():

Baseline:

-----------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                 Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-----------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          aten::clone         0.03%       6.590ms         0.20%      45.418ms      33.844us     359.36 Mb     -54.77 Mb          1342  
                                        aten::reshape         0.02%       4.874ms         0.22%      49.277ms       6.960us     330.34 Mb     344.00 Kb          7080  
                                     aten::empty_like         0.01%       2.366ms         0.02%       5.526ms       4.118us     325.24 Mb     125.18 Mb          1342  
                                          aten::empty         0.02%       3.810ms         0.02%       3.810ms       0.932us     266.56 Mb     266.56 Mb          4087  
                                         aten::matmul         0.11%      25.907ms        96.87%       22.083s       4.296ms     149.42 Mb    -344.00 Kb          5140  
                                         aten::linear         2.79%     635.090ms        96.92%       22.094s       4.910ms     149.22 Mb       3.75 Mb          4500  

Fix:

-----------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                 Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-----------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          aten::clone         0.07%       5.631ms         0.25%      20.631ms      15.373us     194.36 Mb     -31.73 Mb          1342  
                                        aten::reshape         0.14%      11.557ms         0.41%      34.781ms       4.913us     165.00 Mb           0 b          7080  
                                     aten::empty_like         0.04%       3.059ms         0.06%       4.894ms       3.647us     182.18 Mb      57.70 Mb          1342  
                                          aten::empty         0.03%       2.357ms         0.03%       2.357ms       0.577us     160.70 Mb     160.70 Mb          4083  
                                         aten::matmul         0.53%      44.745ms        88.49%        7.438s       1.447ms      74.81 Mb           0 b          5140  
                                         aten::linear         0.70%      58.470ms        88.82%        7.466s       1.659ms      74.61 Mb     550.00 Kb          4500  

What this hints at here I believe is because we are using mmap, the time it takes to read in the file from mmap'd RAM + perform the operation itself is faster than how things are currently. As a result, I saw that the first run was ~.2s longer than latter runs, which I think comes from reading from disk/allocating memory, which would make sense given it has ~8s to read all this in.

And because it's mmap'd, we also only need a layer at a time to directly replace them all, which is why it seems that we use less memory.

@tjruwase
Copy link
Contributor

tjruwase commented Jul 8, 2024

@muellerzr, I am curious about the I/O speeds in your OP. Can you please confirm that you are transferring weights from NVMe to HBM at 75-90GB/sec? Are you able to share PCIe and m.2 specs? Thanks

@muellerzr
Copy link
Contributor Author

muellerzr commented Jul 8, 2024

@tjruwase the answer I've come to (as in my last post) is mmapping is covering the transition from m.2 when we first pass an input through the model (I think). Weights are only allocated in space, but not fully loaded in. Hence why I'm seeing far above what my actual M.2 drive can bring in, but at 8s of time to bring in said weights is reasonable if done quickly!

(Because yes, I'd love to know what planet has a 75-90GB/s non-RAID M.2 as well!)

My setup:

  • NVME: Crucial T705 2TB
  • Memory: 192GB DDR5 running at 3600 MT/s
  • CPU: AMD Ryzen 9 7950X
  • MOBO is a Asus ProArt X670E-CREATOR, my 2x 4090's are running on x8/x8

Let me know how much more specific I can get with this for you!

@tjruwase
Copy link
Contributor

tjruwase commented Jul 8, 2024

(Because yes, I'd love to know what planet has a 75-90GB/s non-RAID M.2 as well!)

@muellerzr, thanks for the clarification. As you may have guessed fast I/O is a passion, and I am also awaiting the above :).

@muellerzr
Copy link
Contributor Author

@tjruwase do let me know if you see anything else odd about what I’ve done here etc too/if you have insights. I’ll look into the DeepSpeed stuff in a few days!

@tjruwase
Copy link
Contributor

tjruwase commented Jul 8, 2024

image

@muellerzr, your NVMe is blazingly fast, ~14GB/sec reads. May I request your contribution to the following?
microsoft/DeepSpeed#998

@tjruwase
Copy link
Contributor

tjruwase commented Jul 8, 2024

@tjruwase do let me know if you see anything else odd about what I’ve done here etc too/if you have insights. I’ll look into the DeepSpeed stuff in a few days!

@muellerzr, nothing looks good. This is truly amazing work that you have done here, kudos!

Do let me know if my suggestion for updating sharded DeepSpeed weights above is insufficient or problematic.

@muellerzr
Copy link
Contributor Author

Okay! After a ton of thorough testing I've proven that:

  1. When loading via device_map="auto", it's the same speed as though loading the model in half precision and doing .cuda()
  2. When loading via device_map="auto" and on CPU, it's the same speed again (the faster speed)
  3. Part of my issue was not specifying torch_dtype=torch.bfloat16 when doing llama tests, so a ton of time was wasted upcasting to float32, something others may do by accident too since it's not done by default I was finding. (not sure what we can do about that, just something I noticed)
  4. I did notice a slight speedup when doing model = LlamaForCausalLM.from_pretrained(llama_path, torch_dtype=torch.bfloat16).cuda() during model loading when compared to our current implementation, so it still can help.
  5. When doing factory_model = LlamaForCausalLM.from_pretrained("/mnt/superfast/llama-3-8B"), the model weights are loaded in bf16, Marc mentioned this might be a bad bug cc @ArthurZucker. Only if we do device_map="auto" are they loaded in fp32. (this will also cause slowdowns during model loading I found, which makes sense I think considering more parameters.
  6. Given that those are the only changes, and OOTB this should just work, this PR from a non-DeepSpeed standpoint is good to merge.

Users will not see much speedup if they do device_map="auto" for the aforementioned reasons, but this still helps other folks out too!

@muellerzr
Copy link
Contributor Author

muellerzr commented Jul 9, 2024

When I eventually ripped everything out to test, here's my full code:

from transformers import LlamaForCausalLM, AutoConfig, AutoTokenizer
from accelerate.utils import set_seed
from accelerate.big_modeling import init_empty_weights
from safetensors.torch import load_file
from pathlib import Path
import json
from safetensors import safe_open
from accelerate.utils import retie_parameters
from transformers import GenerationConfig
from transformers.utils.hub import get_checkpoint_shard_files
import time

set_seed(42)

llama_path = Path("/mnt/superfast/llama-3-8B")

tokenizer = AutoTokenizer.from_pretrained(llama_path)
inputs = tokenizer("Tell me about a girl that", return_tensors="pt")


config = AutoConfig.from_pretrained(llama_path)
use_keep_in_fp32_modules = False

resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
                llama_path,
                llama_path/"model.safetensors.index.json"
            )
loaded_state_dict_keys = sharded_metadata["all_checkpoint_keys"]

config = LlamaForCausalLM._autoset_attn_implementation(
            config, use_flash_attention_2=False, torch_dtype=None, device_map=None
        )
with init_empty_weights():
    factory_model = LlamaForCausalLM(config)

index_filename = llama_path / "model.safetensors.index.json"

with open(index_filename, "r") as f:
    index = json.load(f)

if "weight_map" in index:
    index = index["weight_map"]
checkpoint_files = sorted(list(set(index.values())))
checkpoint_files = [llama_path / f for f in checkpoint_files]

model_keys = set(factory_model.state_dict().keys())
new_state_dict = {}

for checkpoint_file in checkpoint_files:
    with safe_open(checkpoint_file, framework="pt") as f:
        metadata = f.metadata()
        weight_names = f.keys()
    file_state = load_file(checkpoint_file)
    new_state_dict.update(file_state)
factory_model.load_state_dict(new_state_dict, strict=True, assign=True)

retie_parameters(factory_model, [["lm_head.weight"]])

factory_model.eval()

factory_model.generation_config = GenerationConfig.from_pretrained(
                    llama_path
                )

start_time = time.time()
output = factory_model.generate(**inputs, max_new_tokens=20, num_return_sequences=1)

end_time = time.time()
time_taken = end_time - start_time
new_tokens = len(output[0]) - inputs.input_ids.shape[1]
print(f"{time_taken:.3f}s | {new_tokens/time_taken:.3f} tokens/second | {tokenizer.batch_decode(output, skip_special_tokens=True)} | ")

@muellerzr muellerzr changed the title Speedup model loading (by ~10x) and .generate() on CPU (by ~10x)! Speedup model init on CPU (by 30x+) Jul 9, 2024
@muellerzr
Copy link
Contributor Author

@SunMarc @LysandreJik @ArthurZucker I've adjusted this title to what is really happening here. See the new updated table, basically we "borrow" a little time later on during the first pass to load the weights in, rather than doing so immediately which can load models in much faster and after the 1st pass will still be quick.

On CUDA I saw nearly no time changes either, aside from loading the model in 0.185s rather than 2s for llama-3-8B, so that's safe too :)

@muellerzr muellerzr changed the title Speedup model init on CPU (by 30x+) Speedup model init on CPU (by 30x+ for llama-3-70B) Jul 9, 2024
@muellerzr muellerzr changed the title Speedup model init on CPU (by 30x+ for llama-3-70B) Speedup model init on CPU (by 30x+ for llama-3-70B as one example) Jul 9, 2024
@muellerzr
Copy link
Contributor Author

muellerzr commented Jul 9, 2024

So that we can merge this, for now I've kept the old deepspeed behavior in

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Simple and efficient! I think we need a good torch version (do we still support the 2 year old one?) and can you rebase to make sure failing tests are unrelated?

@muellerzr muellerzr force-pushed the muellerzr-speedup-inference branch from 695d512 to 4335956 Compare July 15, 2024 20:06
Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

178 lines of diff --> 10x improvement in loading speed - what a ratio! 🔥🔥 🔥 🔥

Main comment is about having a _supports_assign_param_buffer attribute always set by default. Otherwise mostly nits.

The documentation should be updated to reflect the new logic. From a quick search, there's at least two place I saw to update, but there may be more:

In the PR description - there's one bit I didn't understand in the First Pass section. What does "Afterwards they are ~ the same" mean?

tests/models/bart/test_modeling_bart.py Outdated Show resolved Hide resolved
tests/utils/test_modeling_utils.py Outdated Show resolved Hide resolved
tests/utils/test_modeling_utils.py Show resolved Hide resolved
tests/utils/test_modeling_utils.py Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
if start_prefix + first_key in state_dict:
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
else:
# For cases when the `state_dict` doesn't have any real weights (`albert`)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is weird - what do we mean by "real weights" here? If I look at the safetensors file for a checkpoint it looks like there are weights

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's in reference to test_model_weights_reload_no_missing_tied_weights, in which case we have nuked the saved tensors and as a result they don't exist in the state dict at all etc. In this case we should return False. I've changed it to reference the specific test.

(In most real world cases, we shouldn't get to this point)

@muellerzr
Copy link
Contributor Author

@amyeroberts reworked the PR description, let me know if everything makes sense now 🤗

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful ❤️

@@ -894,32 +895,42 @@ def test_from_pretrained_low_cpu_mem_usage_functional(self):
@require_usr_bin_time
@require_accelerate
@mark.accelerate_tests
def test_from_pretrained_low_cpu_mem_usage_measured(self):
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
def test_from_pretrained_low_cpu_mem_usage_slower(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be faster?

Suggested change
def test_from_pretrained_low_cpu_mem_usage_slower(self):
def test_from_pretrained_low_cpu_mem_usage_faster(self):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was running the tests, it's a bit slower due to the added hooks IMO (which is fine, as low_cpu_mem_usage=True is still needed 99% of the time, aka when weights != precision)

@muellerzr muellerzr merged commit e0dfd7b into main Jul 16, 2024
24 checks passed
@muellerzr muellerzr deleted the muellerzr-speedup-inference branch July 16, 2024 13:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants