电脑基础 · 2023年4月19日

使用MobileViT替换YOLOv5主干网络

使用MobileViT替换YOLOv5主干网络,并训练

  • 前述
    • 使用MobileViT替换YOLOv5主干网络
    • 训练

前述

读了MobileViT这篇论文之后觉得文章里面提到的技巧很新奇,所以就尝试把它替换到YOLOv5中主干网络上去,看看会不会有那么大的提升,但是结果并不是让我很满意,在实施过程中也确实遇到一些问题,于是就写下来和大家分享一下。
相比较于其他的transformer变体,MobileViT这篇文章给出的改动技巧很简单高效,它解决的ViT中因为像素摊平操作导致的位置信息损失问题,将卷积的局部信息提取优势和自注意力机制的全局信息提取能力结合起来,并且根据论文描述具有高度轻量化+极快的推理速度,具体的大佬们自己去读读,本菜鸡好久之前读的了,印象不是很深了…
使用MobileViT替换YOLOv5主干网络

使用MobileViT替换YOLOv5主干网络

话不多说,直入主题吧!
MobileViT代码(PyTorch版)
和在YOLOv5中添加其他模块一样,将mobilevit代码下载下来放在YOLOv5项目下的文件夹里,然后我们只需要在yolo.py文件中导入mobilevit相关模块(操作如下两带箭头图所示)、在yaml文件中添加上相应的结构名称就可以了,如下代码所示。
使用MobileViT替换YOLOv5主干网络
使用MobileViT替换YOLOv5主干网络

# parameters
nc: 10  # number of classes
depth_multiple: 1  # model depth multiple
width_multiple: 1  # layer channel multiple
# anchors
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32
# YOLOv5 backbone
backbone:
  # [from, number, module, args] 640 x 640
  [[-1, 1, Focus, [32, 3]],  # 0-P1/2  320 x 320
   [-1, 1, MV2Block, [32, 1, 2]],  # 1-P2/4
   [-1, 1, MV2Block, [48, 2, 2]],  # 160 x 160
   [-1, 2, MV2Block, [48, 1, 2]],
   [-1, 1, MV2Block, [64, 2, 2]],  # 80 x 80
   [-1, 1, MobileViTAttention, [64, 96, 2, 3, 2, 192]], # 5 in_channel, dim, depth, kernel_size, patch_size, mlp_dim
   [-1, 1, MV2Block, [80, 2, 2]],  # 40 x 40
   [-1, 1, MobileViTAttention, [80, 120, 2, 3, 4, 480]], # 7
   [-1, 1, MV2Block, [96, 2, 2]],   # 20 x 20
   [-1, 1, MobileViTAttention, [96, 144, 2, 3, 3, 576]], # 11-P2/4 # 9
  ]
# YOLOv5 head
head:
  [[-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 7], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [256, False]],  # 13
   [-1, 1, Conv, [128, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 5], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [128, False]],  # 17 (P3/8-small)
   [-1, 1, Conv, [128, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [256, False]],  # 20 (P4/16-medium)
   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [512, False]],  # 23 (P5/32-large)
   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

训练

训练过程中需要注意的有两点:
1、由于transformer自注意力机制结构中存在大量的矩阵乘操作,对显存的消耗特别大,如果显卡显存不是特别大,建议将in_channel、dim、 depth, 、mlp_dim这其中的参数选几个适当调小一点(*[-1, 1, MobileViTAttention, [64, 96, 2, 3, 2, 192]](对应参数:in_channel, dim, depth, kernel_size, patch_size, mlp_dim)),我训练时使用的是一块titan RTX 24GB的卡,直接套用官方MobileViT代码结构上去,显存居然吃不消!
2、由于mobilevit中的transformer部分把输入特征分成若干个小方块(论文中说最好的结果是patch_size=2),所以训练时不要采用矩形训练(验证和测试时也不要采用rect模式),否则会出错,解决办法将rect=False即可。