Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BEIT3 RuntimeError: Size mismatch when loading checkpoint for BEiT3ForRetrieval model #1594

Open
thanhstar260 opened this issue Jul 5, 2024 · 0 comments

Comments

@thanhstar260
Copy link

Describe the bug
Model I am using (UniLM, MiniLM, LayoutLM ...): BEiT3

I encountered a runtime error when trying to load a checkpoint into my model. The error indicates a size mismatch for a specific parameter.

To Reproduce
Steps to reproduce the behavior:

  1. I download checkpoint from beit3_large_itc_patch16_224
  2. Initialize the model
  3. Load the checkpoint from the specified path.
  4. Attempt to load the state dictionary from the checkpoint into the model.
cd unilm/beit3
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from modeling_finetune import beit3_large_patch16_384_retrieval, beit3_base_patch16_224_retrieval

ckpt_path = r"/content/drive/MyDrive/unilm/beit3/beit3_large_itc_patch16_224.pth"
model = beit3_large_patch16_384_retrieval()

checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint['model'])

Error:

RuntimeError Traceback (most recent call last)
in <cell line: 12>()
10
11 checkpoint = torch.load(ckpt_path)
---> 12 model.load_state_dict(checkpoint['model'])

/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict, assign)
2187
2188 if len(error_msgs) > 0:
-> 2189 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
2190 self.class.name, "\n\t".join(error_msgs)))
2191 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for BEiT3ForRetrieval:
size mismatch for beit3.encoder.embed_positions.A.weight: copying a param with shape torch.Size([199, 1024]) from checkpoint, the shape in current model is torch.Size([579, 1024]).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant