Yolov5学习笔记4——源码剖析——Head部分 | 字数总计: 2.1k | 阅读时长: 10分钟 | 阅读量:
Yolov5学习笔记4——源码剖析——Head部分 Detect类对应yolov5的检查头(head)部分
Detect类在yolo.py程序中的33行。
class Detect()代码分析 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 class Detect (nn.Module): stride = None onnx_dynamic = False def __init__ (self, nc=80 , anchors=( ), ch=( ), inplace=True ): super ().__init__() self.nc = nc self.no = nc + 5 self.nl = len (anchors) self.na = len (anchors[0 ]) // 2 self.grid = [torch.zeros(1 )] * self.nl self.anchor_grid = [torch.zeros(1 )] * self.nl self.register_buffer('anchors' , torch.tensor(anchors).float ().view(self.nl, -1 , 2 )) self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1 ) for x in ch) self.inplace = inplace def forward (self, x ): z = [] for i in range (self.nl): x[i] = self.m[i](x[i]) bs, _, ny, nx = x[i].shape x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0 , 1 , 3 , 4 , 2 ).contiguous() if not self.training: if self.onnx_dynamic or self.grid[i].shape[2 :4 ] != x[i].shape[2 :4 ]: self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) y = x[i].sigmoid() if self.inplace: y[..., 0 :2 ] = (y[..., 0 :2 ] * 2 - 0.5 + self.grid[i]) * self.stride[i] y[..., 2 :4 ] = (y[..., 2 :4 ] * 2 ) ** 2 * self.anchor_grid[i] else : xy = (y[..., 0 :2 ] * 2 - 0.5 + self.grid[i]) * self.stride[i] wh = (y[..., 2 :4 ] * 2 ) ** 2 * self.anchor_grid[i] y = torch.cat((xy, wh, y[..., 4 :]), -1 ) z.append(y.view(bs, -1 , self.no)) return x if self.training else (torch.cat(z, 1 ), x)
初始化函数init() 首先分析这个类的初始化函数:
1 2 3 4 5 6 7 8 9 10 11 def __init__ (self, nc=80 , anchors=( ), ch=( ), inplace=True ): super ().__init__() self.nc = nc self.no = nc + 5 self.nl = len (anchors) self.na = len (anchors[0 ]) // 2 self.grid = [torch.zeros(1 )] * self.nl self.anchor_grid = [torch.zeros(1 )] * self.nl self.register_buffer('anchors' , torch.tensor(anchors).float ().view(self.nl, -1 , 2 )) self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1 ) for x in ch) self.inplace = inplace
yolov5的检测头仍为FPN结构 ,所以self.m为3个输出卷积。这三个输出卷积模块的channel变化分别为128$\longrightarrow$255|256$\longrightarrow$255|512$\longrightarrow$255。 self.no为每个anchor位置的输出channel维度,每个位置都预测80个类(coco)+ 4个位置坐标xywh + 1个confidence score。所以输出channel为85。每个尺度下有3个anchor位置,所以输出85*3=255个channel。检测层数为3,锚点数量为85
forward()函数 接下来看head部分的forward()函数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 def forward (self, x ): z = [] for i in range (self.nl): x[i] = self.m[i](x[i]) bs, _, ny, nx = x[i].shape x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0 , 1 , 3 , 4 , 2 ).contiguous() if not self.training: if self.onnx_dynamic or self.grid[i].shape[2 :4 ] != x[i].shape[2 :4 ]: self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) y = x[i].sigmoid() if self.inplace: y[..., 0 :2 ] = (y[..., 0 :2 ] * 2 - 0.5 + self.grid[i]) * self.stride[i] y[..., 2 :4 ] = (y[..., 2 :4 ] * 2 ) ** 2 * self.anchor_grid[i] else : xy = (y[..., 0 :2 ] * 2 - 0.5 + self.grid[i]) * self.stride[i] wh = (y[..., 2 :4 ] * 2 ) ** 2 * self.anchor_grid[i] y = torch.cat((xy, wh, y[..., 4 :]), -1 ) z.append(y.view(bs, -1 , self.no)) return x if self.training else (torch.cat(z, 1 ), x)
x是一个列表的形式,分别对应着3个head的输入。它们的shape分别为:
[bs, 128, 32, 32]
[1, 256, 16, 16]
[1, 512, 8, 8]
三个输入先后被送入了3个卷积,得到输出结果。
1 x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0 , 1 , 3 , 4 , 2 ).contiguous()
这里将x进行变换从:
x[0]:(bs,255,32,32) => x(bs,3,32,32,85) x[1]:(bs,255,32,32) => x(bs,3,16,16,85) x[2]:(bs,255,32,32) => x(bs,3,8,8,85)
make_grid()函数 1 2 3 4 5 6 7 8 9 10 def _make_grid (self, nx=20 , ny=20 , i=0 ): d = self.anchors[i].device if check_version(torch.__version__, '1.10.0' ): yv, xv = torch.meshgrid([torch.arange(ny, device=d), torch.arange(nx, device=d)], indexing='ij' ) else : yv, xv = torch.meshgrid([torch.arange(ny, device=d), torch.arange(nx, device=d)]) grid = torch.stack((xv, yv), 2 ).expand((1 , self.na, ny, nx, 2 )).float () anchor_grid = (self.anchors[i].clone() * self.stride[i]) \ .view((1 , self.na, 1 , 1 , 2 )).expand((1 , self.na, ny, nx, 2 )).float () return grid, anchor_grid
这里的_make_grid()函数是准备好格点。所有的预测的单位长度都是基于grid层面的而不是原图。注意每一层的grid的尺寸都是不一样的,和每一层输出的尺寸w,h是一样的。
1 2 3 4 5 6 7 8 9 y = x[i].sigmoid() if self.inplace: y[..., 0 :2 ] = (y[..., 0 :2 ] * 2 - 0.5 + self.grid[i]) * self.stride[i] y[..., 2 :4 ] = (y[..., 2 :4 ] * 2 ) ** 2 * self.anchor_grid[i] else : xy = (y[..., 0 :2 ] * 2 - 0.5 + self.grid[i]) * self.stride[i] wh = (y[..., 2 :4 ] * 2 ) ** 2 * self.anchor_grid[i] y = torch.cat((xy, wh, y[..., 4 :]), -1 ) z.append(y.view(bs, -1 , self.no))
这里是inference的核心代码,对应的是yolov5的bbox回归机制。yolov5的回归机制如下图所示:
相较于yolov3的回归机制,可以明显的发现box center的x,y的预测被乘以2并减去了0.5,所以这里的值域从yolov3里的(0,1)注意是开区间,变成了(-0.5, 1.5)。从表面理解是yolov5可以跨半个格点预测了,这样可以提高对格点周围的bbox的召回。当然还有一个好处就是也解决了yolov3中因为sigmoid开区间而导致中心无法到达边界处的问题。
同样,在w,h的回归上,yolov5也有了新的变化,同样对比yolov3的源代码:
1 2 3 4 5 6 x = torch.sigmoid(prediction[..., 0 ]) y = torch.sigmoid(prediction[..., 1 ]) w = prediction[..., 2 ] h = prediction[..., 3 ] pred_conf = torch.sigmoid(prediction[..., 4 ]) pred_cls = torch.sigmoid(prediction[..., 5 :])
很明显yolov3对于w,h没有做sigmoid,而在yolov5中对于x,y,w,h都做了sigmoid。其次yolov5的预测缩放比例变成了:(2*w_pred/h_pred) ^2。 值域从基于anchor宽高的(0,+∞)变成了(0,4)。这可能目的在于使预测的框范围更精准,通过sigmoid约束,让回归的框比例尺寸更为合理。
class Model()代码分析 接下来分析Model类里面的函数。主要分析它的前向传播过程,这里有两个函数:forward()和forward_once()。
forward()函数 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 def forward (self, x, augment=False , profile=False , visualize=False ): if augment: return self._forward_augment(x) return self._forward_once(x, profile, visualize) def _forward_augment (self, x ): img_size = x.shape[-2 :] s = [1 , 0.83 , 0.67 ] f = [None , 3 , None ] y = [] for si, fi in zip (s, f): xi = scale_img(x.flip(fi) if fi else x, si, gs=int (self.stride.max ())) yi = self._forward_once(xi)[0 ] yi = self._descale_pred(yi, fi, si, img_size) y.append(yi) y = self._clip_augmented(y) return torch.cat(y, 1 ), None
self.forward()函数里面augment可以理解为控制TTA,如果打开会对图片进行scale 和flip 。默认是关闭的。
scale_img的源码如下:
scale_img()函数 1 2 3 4 5 6 7 8 9 10 11 def scale_img (img, ratio=1.0 , same_shape=False , gs=32 ): if ratio == 1.0 : return img else : h, w = img.shape[2 :] s = (int (h * ratio), int (w * ratio)) img = F.interpolate(img, size=s, mode='bilinear' , align_corners=False ) if not same_shape: h, w = (math.ceil(x * ratio / gs) * gs for x in (h, w)) return F.pad(img, [0 , w - s[1 ], 0 , h - s[0 ]], value=0.447 )
通过普通的双线性插值实现,根据ratio来控制图片的缩放比例,最后通过pad 0补齐到原图的尺寸。
forward_once()函数 1 2 3 4 5 6 7 8 9 10 11 12 def _forward_once (self, x, profile=False , visualize=False ): y, dt = [], [] for m in self.model: if m.f != -1 : x = y[m.f] if isinstance (m.f, int ) else [x if j == -1 else y[j] for j in m.f] if profile: self._profile_one_layer(m, x, dt) x = m(x) y.append(x if m.i in self.save else None ) if visualize: feature_visualization(x, m.type , m.i, save_dir=visualize) return x
self.foward_once()就是前向执行一次model里的所有module,得到结果。profile参数打开会记录每个模块的平均执行时长和flops用于分析模型的瓶颈,提高模型的执行速度和降低显存占用。
本文分析了yolov5head部分的前向传播和inference的源码。
参考资料:
yolov5深度剖析+源码debug级讲解系列(三)yolov5 head源码解析
YOLO全系列更新,YOLO的进化历程