[GroundingDino] Fix grounding dino loss #31828
Open
+320
−144
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes #31434
As the original repo doesn't provide the loss implementation I'm using the one implemented here as a baseline since it was mentioned by the original repo, on this issue IDEA-Research/GroundingDINO#241, as a reliable source if one wants to train a
GroundingDino
modelTODO:
GroundingDinoMatcher
andGroundingDinoLoss
are working properlyExplanation of the Issue and Solution
So the issue was that
GroundingDinoLoss
andGroundingDinoHungarianMatcher
were just a copy fromDeformableDetr
which is used for closed-set object detection (i.e. a fixed set of categories). Whereas inGroundingDino
there's no limited amount of categories and the output logits ared_model
dimensional where the firstseq_len
elements have a specified value and the subsequent arenan
. The main differences are:class_labels
are associated with the text prompt usedFor instance if an image with bounding boxes with fishes and jellyfishes using a prompt
"fish. jellyfish."
fish should haveclass_label
0 assigned to it and jellyfish should have 1 assigned. If the position of jellyfish and fish in the prompt swapped then theclass_label
s would swap as well. Moreover, jellyfish is represented by two tokens ([20919, 7529]
) and fish by one token ([3869]
) therefore we need to select the appropriate logits for each class.As the original implementation doesn't provide the training loop or the loss implementation, but does recommend other implementations for training
GroundingDino
on this issue IDEA-Research/GroundingDINO#241, I took as baseline the implementation from Open-GroundingDino as it supports both visual grounding and object detection and they've trained their ownGroundingDino
using their code base achieving good performance.Things added in this PR are:
build_label_maps
which generates a list oftorch.Tensor
with lenghtbatch_size
mapping each category to its corresponding tokens based on theinput_ids
build_text_mask
just expand theattention_mask
to select the appropriate tokens when computingGroundingDino.loss_labels
enc_topk_proposals
,encoder_logits
andencoder_pred_boxes
toGroundingDinoModelOutput
andGroundingDinoObjectDetectionOutput
to compute first stage lossclass_loss_coefficient
(with correct default value) andclass_loss_reduction
toGroundingDinoConfig
.class_loss_reduction
was added because insigmoid_focal_loss
from the baseline implementation they reducedloss_ce
with a simple sum, but that makes the losses imbalanced most of the time and in the original implementation they do have asigmoid_focal_loss
implemented, but usingmean
reduction, therefore I made I decided to make it configurable and use thesum
one for testing reasonsGroundingDinoLoss
andGroundingDinoHungarianMatcher
Also added a new integration test called
test_grounding_dino_loss
where I compare the loss obtained from 2 sample images with the baseline implementation fromOpen-GroundingDino
.c.c. @amyeroberts