[20.10] Deformable DETR
Deformable Attention Mechanism
Deformable DETR: Deformable Transformers for End-to-End Object Detection
DETR left an excellent foundation for researchers to build upon.
In DETR, the authors used only the basic Transformer architecture for the simplest form of object detection, without employing any advanced tricks. They simply fed images into the model, and out came the object locations and categories.
What a brilliant paper!
This leaves ample room for researchers to improve the model, allowing for countless future publications based on this foundation.
Problem Definition
Compared to existing object detectors, DETR is remarkably slow. Although it runs at half the speed of the popular Faster-RCNN during inference, this doesn’t seem like a big issue.
But the training convergence time is 20 times slower!
What used to take just a day to train now takes 20 days. This is a significant problem—time is precious, and no one wants to wait for weeks to get a model to converge.
Clearly, this issue needs to be addressed.
Solution
The authors believe the issue lies in the attention mechanism of the Transformer. In the Transformer, every pixel attends to every other pixel, which wastes most of the computational power on irrelevant regions.
Thus, instead of using the original Transformer attention mechanism, they borrowed the idea of deformable convolutions and created a "deformable attention mechanism."
At the time of this research, Vision Transformers (ViT) hadn’t been published yet, so the operations were based on individual pixels rather than image patches.
Deformable Attention
In the feature map, for each query element, the authors select a reference point and compute attention only around a few important sampling points near that reference. This differs from the traditional Transformer, which calculates attention over all points in the spatial domain.
Suppose the input feature map is:
where is the number of channels, and and are the height and width of the feature map, respectively.
Each query element consists of a content feature and a 2D reference point . The deformable attention feature is computed as:
Where:
-
is the number of attention heads.
-
is the number of sampling points selected for each query, chosen from a small region around the reference point.
-
is the attention weight for the -th sampling point in the -th attention head, constrained to and normalized such that:
-
is the offset of the -th sampling point in the -th attention head, which can take arbitrary real values.
-
and are learnable weight matrices responsible for linear transformations of the input features.
-
represents the feature value at location . Since this position is a fractional value (i.e., not a discrete grid point), bilinear interpolation is used to calculate it.
The query feature undergoes a linear projection that outputs a tensor of size :
- The first channels encode the offsets for each sampling point.
- The remaining channels are passed through a softmax function to compute the attention weights .
This design ensures that both the offsets and attention weights are learned from the query feature rather than based on fixed rules.
Multi-Scale Computation
Modern object detection frameworks often use multi-scale feature maps to detect objects at different scales. The deformable attention module naturally extends to a multi-scale version, allowing sampling and operations on multiple feature map layers simultaneously.
Assume the input multi-scale feature maps are denoted as , where each feature map