1 Star 1 Fork 258

jkException / PaddleDetection

forked from PaddlePaddle / PaddleDetection 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
fcos_loss.py 10.25 KB
一键复制 编辑 原始数据 按行查看 历史
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from ppdet.modeling import ops
__all__ = ['FCOSLoss']
def flatten_tensor(inputs, channel_first=False):
"""
Flatten a Tensor
Args:
inputs (Tensor): 4-D Tensor with shape [N, C, H, W] or [N, H, W, C]
channel_first (bool): If true the dimension order of Tensor is
[N, C, H, W], otherwise is [N, H, W, C]
Return:
output_channel_last (Tensor): The flattened Tensor in channel_last style
"""
if channel_first:
input_channel_last = paddle.transpose(inputs, perm=[0, 2, 3, 1])
else:
input_channel_last = inputs
output_channel_last = paddle.flatten(
input_channel_last, start_axis=0, stop_axis=2)
return output_channel_last
@register
class FCOSLoss(nn.Layer):
"""
FCOSLoss
Args:
loss_alpha (float): alpha in focal loss
loss_gamma (float): gamma in focal loss
iou_loss_type (str): location loss type, IoU/GIoU/LINEAR_IoU
reg_weights (float): weight for location loss
quality (str): quality branch, centerness/iou
"""
def __init__(self,
loss_alpha=0.25,
loss_gamma=2.0,
iou_loss_type="giou",
reg_weights=1.0,
quality='centerness'):
super(FCOSLoss, self).__init__()
self.loss_alpha = loss_alpha
self.loss_gamma = loss_gamma
self.iou_loss_type = iou_loss_type
self.reg_weights = reg_weights
self.quality = quality
def __iou_loss(self,
pred,
targets,
positive_mask,
weights=None,
return_iou=False):
"""
Calculate the loss for location prediction
Args:
pred (Tensor): bounding boxes prediction
targets (Tensor): targets for positive samples
positive_mask (Tensor): mask of positive samples
weights (Tensor): weights for each positive samples
Return:
loss (Tensor): location loss
"""
plw = pred[:, 0] * positive_mask
pth = pred[:, 1] * positive_mask
prw = pred[:, 2] * positive_mask
pbh = pred[:, 3] * positive_mask
tlw = targets[:, 0] * positive_mask
tth = targets[:, 1] * positive_mask
trw = targets[:, 2] * positive_mask
tbh = targets[:, 3] * positive_mask
tlw.stop_gradient = True
trw.stop_gradient = True
tth.stop_gradient = True
tbh.stop_gradient = True
ilw = paddle.minimum(plw, tlw)
irw = paddle.minimum(prw, trw)
ith = paddle.minimum(pth, tth)
ibh = paddle.minimum(pbh, tbh)
clw = paddle.maximum(plw, tlw)
crw = paddle.maximum(prw, trw)
cth = paddle.maximum(pth, tth)
cbh = paddle.maximum(pbh, tbh)
area_predict = (plw + prw) * (pth + pbh)
area_target = (tlw + trw) * (tth + tbh)
area_inter = (ilw + irw) * (ith + ibh)
ious = (area_inter + 1.0) / (
area_predict + area_target - area_inter + 1.0)
ious = ious * positive_mask
if return_iou:
return ious
if self.iou_loss_type.lower() == "linear_iou":
loss = 1.0 - ious
elif self.iou_loss_type.lower() == "giou":
area_uniou = area_predict + area_target - area_inter
area_circum = (clw + crw) * (cth + cbh) + 1e-7
giou = ious - (area_circum - area_uniou) / area_circum
loss = 1.0 - giou
elif self.iou_loss_type.lower() == "iou":
loss = 0.0 - paddle.log(ious)
else:
raise KeyError
if weights is not None:
loss = loss * weights
return loss
def forward(self, cls_logits, bboxes_reg, centerness, tag_labels,
tag_bboxes, tag_center):
"""
Calculate the loss for classification, location and centerness
Args:
cls_logits (list): list of Tensor, which is predicted
score for all anchor points with shape [N, M, C]
bboxes_reg (list): list of Tensor, which is predicted
offsets for all anchor points with shape [N, M, 4]
centerness (list): list of Tensor, which is predicted
centerness for all anchor points with shape [N, M, 1]
tag_labels (list): list of Tensor, which is category
targets for each anchor point
tag_bboxes (list): list of Tensor, which is bounding
boxes targets for positive samples
tag_center (list): list of Tensor, which is centerness
targets for positive samples
Return:
loss (dict): loss composed by classification loss, bounding box
"""
cls_logits_flatten_list = []
bboxes_reg_flatten_list = []
centerness_flatten_list = []
tag_labels_flatten_list = []
tag_bboxes_flatten_list = []
tag_center_flatten_list = []
num_lvl = len(cls_logits)
for lvl in range(num_lvl):
cls_logits_flatten_list.append(
flatten_tensor(cls_logits[lvl], True))
bboxes_reg_flatten_list.append(
flatten_tensor(bboxes_reg[lvl], True))
centerness_flatten_list.append(
flatten_tensor(centerness[lvl], True))
tag_labels_flatten_list.append(
flatten_tensor(tag_labels[lvl], False))
tag_bboxes_flatten_list.append(
flatten_tensor(tag_bboxes[lvl], False))
tag_center_flatten_list.append(
flatten_tensor(tag_center[lvl], False))
cls_logits_flatten = paddle.concat(cls_logits_flatten_list, axis=0)
bboxes_reg_flatten = paddle.concat(bboxes_reg_flatten_list, axis=0)
centerness_flatten = paddle.concat(centerness_flatten_list, axis=0)
tag_labels_flatten = paddle.concat(tag_labels_flatten_list, axis=0)
tag_bboxes_flatten = paddle.concat(tag_bboxes_flatten_list, axis=0)
tag_center_flatten = paddle.concat(tag_center_flatten_list, axis=0)
tag_labels_flatten.stop_gradient = True
tag_bboxes_flatten.stop_gradient = True
tag_center_flatten.stop_gradient = True
mask_positive_bool = tag_labels_flatten > 0
mask_positive_bool.stop_gradient = True
mask_positive_float = paddle.cast(mask_positive_bool, dtype="float32")
mask_positive_float.stop_gradient = True
num_positive_fp32 = paddle.sum(mask_positive_float)
num_positive_fp32.stop_gradient = True
num_positive_int32 = paddle.cast(num_positive_fp32, dtype="int32")
num_positive_int32 = num_positive_int32 * 0 + 1
num_positive_int32.stop_gradient = True
normalize_sum = paddle.sum(tag_center_flatten * mask_positive_float)
normalize_sum.stop_gradient = True
# 1. cls_logits: sigmoid_focal_loss
# expand onehot labels
num_classes = cls_logits_flatten.shape[-1]
tag_labels_flatten = paddle.squeeze(tag_labels_flatten, axis=-1)
tag_labels_flatten_bin = F.one_hot(
tag_labels_flatten, num_classes=1 + num_classes)
tag_labels_flatten_bin = tag_labels_flatten_bin[:, 1:]
# sigmoid_focal_loss
cls_loss = F.sigmoid_focal_loss(
cls_logits_flatten, tag_labels_flatten_bin) / num_positive_fp32
if self.quality == 'centerness':
# 2. bboxes_reg: giou_loss
mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1)
tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1)
reg_loss = self.__iou_loss(
bboxes_reg_flatten,
tag_bboxes_flatten,
mask_positive_float,
weights=tag_center_flatten)
reg_loss = reg_loss * mask_positive_float / normalize_sum
# 3. centerness: sigmoid_cross_entropy_with_logits_loss
centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1)
quality_loss = ops.sigmoid_cross_entropy_with_logits(
centerness_flatten, tag_center_flatten)
quality_loss = quality_loss * mask_positive_float / num_positive_fp32
elif self.quality == 'iou':
# 2. bboxes_reg: giou_loss
mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1)
tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1)
reg_loss = self.__iou_loss(
bboxes_reg_flatten,
tag_bboxes_flatten,
mask_positive_float,
weights=None)
reg_loss = reg_loss * mask_positive_float / num_positive_fp32
# num_positive_fp32 is num_foreground
# 3. centerness: sigmoid_cross_entropy_with_logits_loss
centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1)
gt_ious = self.__iou_loss(
bboxes_reg_flatten,
tag_bboxes_flatten,
mask_positive_float,
weights=None,
return_iou=True)
quality_loss = ops.sigmoid_cross_entropy_with_logits(
centerness_flatten, gt_ious)
quality_loss = quality_loss * mask_positive_float / num_positive_fp32
else:
raise Exception(f'Unknown quality type: {self.quality}')
loss_all = {
"loss_cls": paddle.sum(cls_loss),
"loss_box": paddle.sum(reg_loss),
"loss_quality": paddle.sum(quality_loss),
}
return loss_all
Python
1
https://gitee.com/jkException/PaddleDetection.git
git@gitee.com:jkException/PaddleDetection.git
jkException
PaddleDetection
PaddleDetection
release/2.6

搜索帮助