2 Star 1 Fork 1

luotianhang / yolov4

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
model.py 3.25 KB
一键复制 编辑 原始数据 按行查看 历史
罗天杭 提交于 2021-12-23 17:05 . 20211223
# -*- coding: utf-8 -*-
"""
author:LTH
data:
"""
from common import *
def darknet53(pretrained, **kwargs):
model = CSPDarkNet([1, 2, 8, 8, 4])
if pretrained:
if isinstance(pretrained, str):
model.load_state_dict(torch.load(pretrained))
else:
raise Exception("darknet request a pretrained path. got [{}]".format(pretrained))
return model
class YoloBody(nn.Module):
def __init__(self, num_anchors, num_classes):
super(YoloBody, self).__init__()
self.backbone = darknet53(None)
self.conv1 = make_three_conv([512, 1024], 1024)
self.SPP = SpatialPyramidPooling()
self.conv2 = make_three_conv([512, 1024], 2048)
self.upsample1 = Upsample(512, 256)
self.conv_for_P4 = conv2d(512, 256, 1)
self.make_five_conv1 = make_five_conv([256, 512], 512)
self.upsample2 = Upsample(256, 128)
self.conv_for_P3 = conv2d(256, 128, 1)
self.make_five_conv2 = make_five_conv([128, 256], 256)
# 4+1+num_classes
final_out_filter2 = num_anchors * (5 + num_classes)
self.yolo_head3 = yolo_head([256, final_out_filter2], 128)
self.down_sample1 = conv2d(128, 256, 3, stride=2)
self.make_five_conv3 = make_five_conv([256, 512], 512)
# 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
final_out_filter1 = num_anchors * (5 + num_classes)
self.yolo_head2 = yolo_head([512, final_out_filter1], 256)
self.down_sample2 = conv2d(256, 512, 3, stride=2)
self.make_five_conv4 = make_five_conv([512, 1024], 1024)
# 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
final_out_filter0 = num_anchors * (5 + num_classes)
self.yolo_head1 = yolo_head([1024, final_out_filter0], 512)
def forward(self, x):
"""
:param x: shape is (N, C_in, H_in, W_in)
:return: out0 shape is (N, num_anchors*(5+num_classes), H_in/8, W_in/8)
out1 shape is (N, num_anchors*(5+num_classes), H_in/16, W_in/16)
out2 shape is (N, num_anchors*(5+num_classes), H_in/32, W_in/32)
"""
# backbone
x2, x1, x0 = self.backbone(x)
# neck
P5 = self.conv1(x0)
P5 = self.SPP(P5)
P5 = self.conv2(P5)
P5_upsample = self.upsample1(P5)
P4 = self.conv_for_P4(x1)
P4 = torch.cat([P4, P5_upsample], axis=1)
P4 = self.make_five_conv1(P4)
P4_upsample = self.upsample2(P4)
P3 = self.conv_for_P3(x2)
P3 = torch.cat([P3, P4_upsample], axis=1)
P3 = self.make_five_conv2(P3)
P3_downsample = self.down_sample1(P3)
P4 = torch.cat([P3_downsample, P4], axis=1)
P4 = self.make_five_conv3(P4)
P4_downsample = self.down_sample2(P4)
P5 = torch.cat([P4_downsample, P5], axis=1)
P5 = self.make_five_conv4(P5)
# head
out2 = self.yolo_head3(P3) # N*126*52*52
out1 = self.yolo_head2(P4) # N*126*26*26
out0 = self.yolo_head1(P5) # N*126*13*13
return out0, out1, out2
# ToDo
# 检查模型是否有梯度回传
# model=YoloBody()
# for name , parms in model.named_parameters():
# print('-->name:', name, '-->grad_requirs:', parms.requires_grad, '--weight', torch.mean(parms.data),
# ' -->grad_value:', torch.mean(parms.grad))
Python
1
https://gitee.com/luotianhang/yolo_rebuild.git
git@gitee.com:luotianhang/yolo_rebuild.git
luotianhang
yolo_rebuild
yolov4
master

搜索帮助