Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GroundingDino] Fix grounding dino loss #31828

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

EduardoPach
Copy link
Contributor

@EduardoPach EduardoPach commented Jul 7, 2024

What does this PR do?

Fixes #31434

As the original repo doesn't provide the loss implementation I'm using the one implemented here as a baseline since it was mentioned by the original repo, on this issue IDEA-Research/GroundingDINO#241, as a reliable source if one wants to train a GroundingDino model

TODO:

  • Test GroundingDinoMatcher and GroundingDinoLoss are working properly

Explanation of the Issue and Solution

So the issue was that GroundingDinoLoss and GroundingDinoHungarianMatcher were just a copy from DeformableDetr which is used for closed-set object detection (i.e. a fixed set of categories). Whereas in GroundingDino there's no limited amount of categories and the output logits are d_model dimensional where the first seq_len elements have a specified value and the subsequent are nan. The main differences are:

  1. class_labels are associated with the text prompt used
  2. The logits are asscoaited with the tokens of the text so it's not necessarily 1-to-1

For instance if an image with bounding boxes with fishes and jellyfishes using a prompt "fish. jellyfish." fish should have class_label 0 assigned to it and jellyfish should have 1 assigned. If the position of jellyfish and fish in the prompt swapped then the class_labels would swap as well. Moreover, jellyfish is represented by two tokens ([20919, 7529]) and fish by one token ([3869]) therefore we need to select the appropriate logits for each class.

As the original implementation doesn't provide the training loop or the loss implementation, but does recommend other implementations for training GroundingDino on this issue IDEA-Research/GroundingDINO#241, I took as baseline the implementation from Open-GroundingDino as it supports both visual grounding and object detection and they've trained their own GroundingDino using their code base achieving good performance.

Things added in this PR are:

  • build_label_maps which generates a list of torch.Tensor with lenght batch_size mapping each category to its corresponding tokens based on the input_ids
  • build_text_mask just expand the attention_mask to select the appropriate tokens when computing GroundingDino.loss_labels
  • Added enc_topk_proposals, encoder_logits and encoder_pred_boxes to GroundingDinoModelOutput and GroundingDinoObjectDetectionOutput to compute first stage loss
  • Added class_loss_coefficient (with correct default value) and class_loss_reduction to GroundingDinoConfig. class_loss_reduction was added because in sigmoid_focal_loss from the baseline implementation they reduced loss_ce with a simple sum, but that makes the losses imbalanced most of the time and in the original implementation they do have a sigmoid_focal_loss implemented, but using mean reduction, therefore I made I decided to make it configurable and use the sum one for testing reasons
  • Modifications to GroundingDinoLoss and GroundingDinoHungarianMatcher

Also added a new integration test called test_grounding_dino_loss where I compare the loss obtained from 2 sample images with the baseline implementation from Open-GroundingDino.

c.c. @amyeroberts

@EduardoPach EduardoPach changed the title WIP - [GroundingDino] Fix grounding dino loss [GroundingDino] Fix grounding dino loss Jul 14, 2024
@EduardoPach
Copy link
Contributor Author

@amyeroberts FYI for some reason, when testing locally, test_cross_attention_mask is failing on this branch, but when I tested using the main branch it was also failing (locally)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

GroundingDino - Loss calculation exceptions
2 participants