GhostNet v2(NeurIPS 2022 Spotlight)原理与代码解析

慈云数据 2024-03-15 技术支持 56 0

paper:GhostNetV2: Enhance Cheap Operation with Long-Range Attention

code:https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/ghostnetv2_pytorch

背景

智能手机和可穿戴设备上部署神经网络时,不仅要考虑模型的性能,还要考虑模型的效率,特别是实际推理速度。许多轻量模型比如MobileNet、ShuffleNet、GhostNet已经被应用到许多移动应用程序中。然而,基于卷积的轻量模型在长距离建模方面较弱,这限制了模型性能的进一步提升。Transformer引入的self-attention机制可以捕获全局信息,但是其复杂度相对于特征图的大小呈二次方的关系,对于计算非常不友好。此外,在计算attention map过程中还涉及大量的特征splitting和reshaping操作,虽然它们的理论复杂度可以忽略不计,但在实际应用中这些操作会产生更多的内存占用以及更长的延迟

本文的创新

本文提出了一种新的注意力机制(dubbed DFC attention)来捕获长距离的空间信息,同时保持了轻量型卷积神经网络的计算效率。为了简便只用了全连接层来生成atttention maps,具体来说,一个FC层被分解成了一个水平FC层和一个竖直FC层,这两个FC层沿各自的方向建模长距离的空间信息,结合这两个FC层就得到了全局的感受野。此外,作者重新研究了GhostNet中的bottleneck并加入了DFC attention来增强其中间层的特征表示,然后设计了一个新的轻量型骨干网络GhostNet v2,它可以在精度和推理速度之间获得更好的平衡。

方法介绍

A Brief Review of GhostNet

首先回顾下GhostNet,对于输入 \(X\in \mathbb{R}^{H\times W\times C}\),Ghost module将一个标准的卷积替换成两步。首先用一个1x1卷积生成intrinsic feature

其中 \(*\) 表示卷积操作,\(F_{1\times 1}\) 是point-wise卷积,\(Y'\in \mathbb{R}^{H\times W\times C'_{out}}\) 是输出的intrinsic feature,它的通道数小于原始输出的通道数,即 \(C'_{out} 1: x = self.conv_dw(x) x = self.bn_dw(x) if self.se is not None: x = self.se(x) x = self.ghost2(x) x += self.shortcut(residual) return x

GhostModuleV2的代码如下,其中self.short_conv就是DFC分支,首先avg pooling进行下采样,这里和文章也不一样,文中消融实验中提到max pooling的延迟低因此默认采用max pool。然后经过1x1卷积,接着是horizontal FC和vertical FC,这里用卷积替代两个方向的FC卷积核大小为(1, 5)、(5, 1),最终经过sigmoid得到DFC分支的输出。DFC分支的输出经过bilinear插值上采样得到原始输入大小,然后与原始ghost module的输出相乘得到最终输出。

class GhostModuleV2(nn.Module):
    def __init__(self, inp, oup, kernel_size=1, ratio=2, dw_size=3, stride=1, relu=True, mode=None, args=None):
        super(GhostModuleV2, self).__init__()
        self.mode = mode
        self.gate_fn = nn.Sigmoid()
        if self.mode in ['original']:
            self.oup = oup
            init_channels = math.ceil(oup / ratio)
            new_channels = init_channels * (ratio - 1)
            self.primary_conv = nn.Sequential(
                nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size // 2, bias=False),
                nn.BatchNorm2d(init_channels),
                nn.ReLU(inplace=True) if relu else nn.Sequential(),
            )
            self.cheap_operation = nn.Sequential(
                nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size // 2, groups=init_channels, bias=False),
                nn.BatchNorm2d(new_channels),
                nn.ReLU(inplace=True) if relu else nn.Sequential(),
            )
        elif self.mode in ['attn']:
            self.oup = oup
            init_channels = math.ceil(oup / ratio)
            new_channels = init_channels * (ratio - 1)
            self.primary_conv = nn.Sequential(
                nn.Conv2d(inp, init_channels, kernel_size, stride, kernel_size // 2, bias=False),
                nn.BatchNorm2d(init_channels),
                nn.ReLU(inplace=True) if relu else nn.Sequential(),
            )
            self.cheap_operation = nn.Sequential(
                nn.Conv2d(init_channels, new_channels, dw_size, 1, dw_size // 2, groups=init_channels, bias=False),
                nn.BatchNorm2d(new_channels),
                nn.ReLU(inplace=True) if relu else nn.Sequential(),
            )
            self.short_conv = nn.Sequential(
                nn.Conv2d(inp, oup, kernel_size, stride, kernel_size // 2, bias=False),
                nn.BatchNorm2d(oup),
                nn.Conv2d(oup, oup, kernel_size=(1, 5), stride=1, padding=(0, 2), groups=oup, bias=False),
                nn.BatchNorm2d(oup),
                nn.Conv2d(oup, oup, kernel_size=(5, 1), stride=1, padding=(2, 0), groups=oup, bias=False),
                nn.BatchNorm2d(oup),
            )
    def forward(self, x):
        if self.mode in ['original']:
            x1 = self.primary_conv(x)
            x2 = self.cheap_operation(x1)
            out = torch.cat([x1, x2], dim=1)
            return out[:, :self.oup, :, :]
        elif self.mode in ['attn']:
            res = self.short_conv(F.avg_pool2d(x, kernel_size=2, stride=2))
            x1 = self.primary_conv(x)
            x2 = self.cheap_operation(x1)
            out = torch.cat([x1, x2], dim=1)
            return out[:, :self.oup, :, :] * F.interpolate(self.gate_fn(res), size=(out.shape[-2], out.shape[-1]),
            mode='nearest')

微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon