Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

can't create torch_jit_trace script for a Mistral model #31807

Open
1 of 4 tasks
helloWorld199 opened this issue Jul 5, 2024 · 3 comments
Open
1 of 4 tasks

can't create torch_jit_trace script for a Mistral model #31807

helloWorld199 opened this issue Jul 5, 2024 · 3 comments

Comments

@helloWorld199
Copy link

helloWorld199 commented Jul 5, 2024

System Info

transformers version: 4.34.0
torch version: 2.3.0
python version: 3.10.12

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import torch
from transformers import MistralConfig, MistralModel, MistralForCausalLM

configuration = MistralConfig(hidden_size = 512, intermediate_size = 512*4, num_attention_heads = 16, num_hidden_layers = 6, num_key_value_heads = 4, vocab_size = 700)
mistral_model = MistralForCausalLM(configuration)

mistral_output= model(input_ids=example_input)
traced_script_module = torch.jit.trace(mistral_model, [example_input, mistraloutput.past_key_values])

Output error


AttributeError Traceback (most recent call last)
in <cell line: 1>()
----> 1 traced_script_module = torch.jit.trace(mistral_model, [example_input, mistraloutput[1]])

11 frames
/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
818 else:
819 raise RuntimeError("example_kwarg_inputs should be a dict")
--> 820 return trace_module(
821 func,
822 {"forward": example_inputs},

/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
1086 else:
1087 example_inputs = make_tuple(example_inputs)
-> 1088 module._c._create_method_from_trace(
1089 method_name,
1090 func,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
1520 recording_scopes = False
1521 try:
-> 1522 result = self.forward(*input, **kwargs)
1523 finally:
1524 if recording_scopes:

/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1043
1044 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1045 outputs = self.model(
1046 input_ids=input_ids,
1047 attention_mask=attention_mask,

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
1520 recording_scopes = False
1521 try:
-> 1522 result = self.forward(*input, **kwargs)
1523 finally:
1524 if recording_scopes:

/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
886 )
887
--> 888 attention_mask = self._prepare_decoder_attention_mask(
889 attention_mask,
890 (batch_size, seq_length),

/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length, sliding_window)
804 if attention_mask is not None:
805 # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
--> 806 expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
807 inputs_embeds.device
808 )

/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in _expand_mask(mask, dtype, tgt_len)
99 Expands attention_mask from [bsz, seq_len] to [bsz, 1, tgt_seq_len, src_seq_len].
100 """
--> 101 bsz, src_len = mask.size()
102 tgt_len = tgt_len if tgt_len is not None else src_len
103

AttributeError: 'tuple' object has no attribute 'size'

Expected behavior

I expect the traced_script_module to be correctly initialized

If I do the same thing with a GPT2 model, as in this code snippet, it works, even if the structure of the past_key_values is the same:

from transformers import GPT2LMHeadModel, GPT2Config
gpt2_config = GPT2Config()
gpt2_model = GPT2LMHeadModel(gpt2_config)
gpt2_output = gpt2_model(example_input)
traced_script_module = torch.jit.trace(gpt2_model, [example_input, gpt2_output.past_key_values])

What should I do to make the script work for mistral aswell?

Thanks!

@LysandreJik
Copy link
Member

cc @ArthurZucker, I don't think the Mistral model is jittable just yet

@ArthurZucker
Copy link
Collaborator

Hey! The _prepare_decoder_attention_mask is the old function, which was removed in favor of

causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, use_cache, output_attentions
)

you are probably not using the latest version of transformers!

@ArthurZucker
Copy link
Collaborator

pip install -U transformers should fix it

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants