Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support unified checkpoint for expert_parallel #8591

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from

Conversation

DesmonDay
Copy link
Contributor

PR types

New features

PR changes

Others

Description

Support unified checkpoint for expert_parallel.

Copy link

paddle-bot bot commented Jun 12, 2024

Thanks for your contribution!

Copy link

codecov bot commented Jun 12, 2024

Codecov Report

Attention: Patch coverage is 4.00000% with 120 lines in your changes missing coverage. Please review.

Project coverage is 55.74%. Comparing base (4e3f60d) to head (c555c30).
Report is 48 commits behind head on develop.

Files Patch % Lines
paddlenlp/trainer/plugins/unified_checkpoint.py 4.00% 120 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8591      +/-   ##
===========================================
+ Coverage    53.86%   55.74%   +1.88%     
===========================================
  Files          620      620              
  Lines        97110    96741     -369     
===========================================
+ Hits         52306    53930    +1624     
+ Misses       44804    42811    -1993     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@DesmonDay DesmonDay force-pushed the moe_add_uc branch 2 times, most recently from 6415e4c to 92d5432 Compare June 25, 2024 08:20
@DesmonDay DesmonDay changed the title [WIP] Support unified checkpoint for expert_parallel Support unified checkpoint for expert_parallel Jul 1, 2024
@@ -22,6 +22,7 @@
import paddle
import paddle.distributed as dist
from paddle.distributed import fleet
from paddle.framework import core
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个最好加个测试,框架可能需要有个模型支持一下expert_parallel

if model_state_dict[key_name[0]].dtype != core.VarDesc.VarType.FP32:
key_name = "_".join([static_name, FP32_MASTER, key_name[1]])
else:
key_name = "_".join([static_name, key_name[1]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP32 情况,加单测


need_files = set()
state_dict = get_expected_state_dict(model)
for key in state_dict.keys():
filename = index["weight_map"][key]
# When using expert parallel, there's no need to check tensors with `no_sync=False` when dp_rank > 0.
if args.use_expert_parallel and dp_rank > 0 and not getattr(state_dict[key], "no_sync", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是跳过 no_sync 参数吗?

@@ -962,6 +1015,7 @@ def save_single_card_optimizer(args, model, optimizer, output_dir):
if master_weights is not None:
for key in list(master_weights.keys()):
master_weights[static2struct_name_mappings[key]] = master_weights.pop(key)
master_weights.update(fp32_weight)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个加载的时候,会pop吗?

else:
shard_file = file_name.replace(
".pdparams",
f"-{args.logical_process_index + 1:05d}-of-{args.world_size//args.dataset_world_size:05d}.pdparams",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前说的 序号不对问题,还有吗?

)
shard_file = shard_file.replace(
".safetensors",
f"-{args.logical_process_index + 1:05d}-of-{args.world_size//sd_degree:05d}.safetensors",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里太长了,合并一下吧,简化一下代码吧

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants