Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some Bugs in JetMoE #31791

Open
2 of 4 tasks
Phoenix-Shen opened this issue Jul 4, 2024 · 4 comments
Open
2 of 4 tasks

Some Bugs in JetMoE #31791

Phoenix-Shen opened this issue Jul 4, 2024 · 4 comments

Comments

@Phoenix-Shen
Copy link

System Info

transformers version: 4.43.0.dev0 (installed from source)

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Outline:
There are a couple of bugs that cause JetMoE to not be able to output logits for gating and calculate aux_loss.

  1. Code
    I want to output the logits of the gating.
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    AutoModelForSequenceClassification,
)
import os
import torch

BASE_DIR = "model_ckpt"
# from jetmoe import JetMoEForCausalLM, JetMoEConfig, JetMoEForSequenceClassification
model_name = os.path.join(BASE_DIR, "jetmoe-8b")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype="auto", device_map="auto"
)

output = model.forward(
    torch.zeros(32, 12, device="cuda", dtype=torch.long),
    output_router_logits=True,
    return_dict=True,
)
  1. It will report an error:
    Traceback (most recent call last):
    File "/home/ubuntu/ssk/test_jetmoe.py", line 18, in
    output = model.forward(
    File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/accelerate/hooks.py", line 166, in new_forward
    output = module._old_forward(*args, **kwargs)
    File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/transformers/models/jetmoe/modeling_jetmoe.py", line 1365, in forward
    self.num_experts,
    File "/home/ubuntu/miniconda3/envs/pytorch/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1709, in getattr
    raise AttributeError(f"'{type(self).name}' object has no attribute '{name}'")
    AttributeError: 'JetMoeForCausalLM' object has no attribute 'num_experts'

  2. Analysis
    After examination of the code (https://github.com/huggingface/transformers/blob/main/src/transformers/models/jetmoe/modeling_jetmoe.py), I found serval mistakes:

  • self.num_experts and self.num_experts_per_tok are not defined in the JetMoeForCausalLM class.
  • the code does not pass output_router_logits argument to the forward function of self.model in JetMoeForCausalLM class. (see line 1310 and 1341, modeling_jetmoe.py)
  • for the JetMoeForSequenceClassification class, it misses the process of calculating aux_loss and forgets to pass output_router_logits argument to self.model.forward.
  1. Quick fix of the JetMoeForCausalLM class
  • Add self.num_experts = config.num_local_experts, and self.num_experts_per_tok = config.num_experts_per_tok in the __init__ function of the JetMoeForCausalLM.
  • Pass output_router_logits to self.model.forward (line 1331)
          # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
          outputs = self.model(
              input_ids=input_ids,
              attention_mask=attention_mask,
              position_ids=position_ids,
              past_key_values=past_key_values,
              inputs_embeds=inputs_embeds,
              use_cache=use_cache,
              output_attentions=output_attentions,
              output_hidden_states=output_hidden_states,
              return_dict=return_dict,
              cache_position=cache_position,
              output_router_logits=output_router_logits # Add this line.
          )

Expected behavior

The solution has been described in the previous section.

@LysandreJik
Copy link
Member

Thanks @Phoenix-Shen! Let me cc @yikangshen, who has contributed the model.

@yikangshen
Copy link
Contributor

yikangshen commented Jul 6, 2024

Hi @Phoenix-Shen, thanks for bringing up the issue! Your fix looks good to me. Would you like to submit a PR?

@Phoenix-Shen
Copy link
Author

Hi @Phoenix-Shen, thanks for bringing up the issue! Your fix looks good to me. Would you like to submit a PR?

Ok, I've fixed all the bugs and am ready to submit a PR.

@ArthurZucker
Copy link
Collaborator

Thanks, reviewed!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants