[21.01] RepVGG
Making VGG Great Again
RepVGG: Making VGG-style ConvNets Great Again
Is it true that fewer parameters in a model result in faster speeds?
You might know the answer is: "Not necessarily!"
This is because we need to consider the computational load, specifically the FLOPs (floating-point operations per second) of the model.
So let's ask another question: "Does fewer FLOPs in a model mean faster speed?"
The answer is still "Not necessarily!"
Defining the Problem
There are many factors influencing inference speed!
Many research papers focus on reducing the FLOPs of a model, but when used in engineering, the inference speed doesn't always improve, and sometimes it even slows down.
This is a typical gap between academia and industry, and as engineers, we can't be misled by the computational data.
High Memory Usage
With the advent of Inception, ResNet, and DenseNet, many researchers shifted their focus to carefully designed architectures, making models increasingly complex. Moreover, some architectures are based on NAS (Neural Architecture Search) or use compound scaling strategies.
Although these network architectures achieve higher accuracy, their complex branching leads to increased memory usage, as shown in the figure above.
In residual structures, branches require additional memory to store branch results during inference. If you have multiple branches (yes, Inception, looking at you!), the memory usage becomes even higher.
Slow Inference Speed
Some operations lack hardware support, such as the depthwise convolution in MobileNets. This method reduces parameters by about 80% compared to standard convolution but is slower:
Similarly, the channel shuffle operation in ShuffleNet, although not computationally intensive, also reduces inference speed:
However...
Multi-branch structures are very effective!
Multi-branch architectures allow the model to become an implicit ensemble of many shallow models.
Residual connections provide a shortcut. For ResNet, theoretically, for N blocks, the model can be interpreted as an ensemble of 2N models because each block can be split into two possible paths.
Additionally, shortcuts offer better gradient flow, alleviating the training difficulty caused by vanishing gradients.
ResNet's success has already proven the effectiveness of residual connections, so why give them up?
Solving the Problem
Model Architecture
To address the above issues, this paper first modifies the residual structure of ResNet.
The authors redefine the RepVGG unit, as we need multi-branch architectures to improve model performance during training. Hence, the structure is as shown in Figure (b):
- A 3x3 convolution
- A 1x1 convolution
- An identity connection
But during inference, we need to remove all branches, as shown in Figure (c).
The training structure is defined this way to satisfy the multi-branch training architecture while also meeting the model's inference requirements:
- Reparameterization of the model.
Reparameterization
Removing branches isn't about directly deleting them; past research on model pruning has shown that direct deletion leads to performance degradation.
The authors propose a new method: Reparameterization.
In the RepVGG unit, each convolution operation is followed by Batch Normalization. This structure can be reparameterized into an equivalent convolution operation.
The entire process is as follows:
- Merge the 3x3 convolution and BatchNorm into an equivalent convolution operation.
- Expand the 1x1 convolution into an equivalent 3x3 convolution using zero-padding, then merge it with BatchNorm into an equivalent convolution operation.
- Expand the identity connection into an equivalent 3x3 convolution, then merge it with BatchNorm into an equivalent convolution operation.
- Finally, add the three branches together into an equivalent convolution operation.
Although you might not like it, we still describe the specific process provided in the paper:
- Let represent the kernel of a convolution layer with input channels and output channels.
- Let represent the kernel of a convolution layer with input channels and output channels.
- denote the mean, standard deviation, learnable scaling factor, and bias of the BN layer after the convolution layer.
- denote the mean, standard deviation, learnable scaling factor, and bias of the BN layer after the convolution layer.
- denote the identity branch.
Let and represent the input and output, respectively, and denote convolution operation.
If , then: