You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
System Info
transformers version: 4.34.0
torch version: 2.3.0
python version: 3.10.12
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Output error
AttributeError Traceback (most recent call last)
in <cell line: 1>()
----> 1 traced_script_module = torch.jit.trace(mistral_model, [example_input, mistraloutput[1]])
11 frames
/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
818 else:
819 raise RuntimeError("example_kwarg_inputs should be a dict")
--> 820 return trace_module(
821 func,
822 {"forward": example_inputs},
/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
1086 else:
1087 example_inputs = make_tuple(example_inputs)
-> 1088 module._c._create_method_from_trace(
1089 method_name,
1090 func,
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
1520 recording_scopes = False
1521 try:
-> 1522 result = self.forward(*input, **kwargs)
1523 finally:
1524 if recording_scopes:
/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
1043
1044 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1045 outputs = self.model(
1046 input_ids=input_ids,
1047 attention_mask=attention_mask,
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _slow_forward(self, *input, **kwargs)
1520 recording_scopes = False
1521 try:
-> 1522 result = self.forward(*input, **kwargs)
1523 finally:
1524 if recording_scopes:
/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
886 )
887
--> 888 attention_mask = self._prepare_decoder_attention_mask(
889 attention_mask,
890 (batch_size, seq_length),
/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length, sliding_window)
804 if attention_mask is not None:
805 # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
--> 806 expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
807 inputs_embeds.device
808 )
/usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py in _expand_mask(mask, dtype, tgt_len)
99 Expands attention_mask from
[bsz, seq_len]
to[bsz, 1, tgt_seq_len, src_seq_len]
.100 """
--> 101 bsz, src_len = mask.size()
102 tgt_len = tgt_len if tgt_len is not None else src_len
103
AttributeError: 'tuple' object has no attribute 'size'
Expected behavior
I expect the traced_script_module to be correctly initialized
If I do the same thing with a GPT2 model, as in this code snippet, it works, even if the structure of the past_key_values is the same:
What should I do to make the script work for mistral aswell?
Thanks!
The text was updated successfully, but these errors were encountered: