non_max_suppression(): Remove unsupported options (#74)

Fixes #33  Delete all the code that is related to __soft_nms__ as recommended at https://github.com/Zhongdao/Towards-Realtime-MOT/issues/33#issuecomment-572886297
This commit is contained in:
Christian Clauss 2020-01-10 07:36:47 +01:00 committed by ZhongdaoWang
parent 43aad41ffe
commit e6c39ef673

View file

@ -414,13 +414,6 @@ def pooling_nms(heatmap, kernel=1):
keep = (hmax == heatmap).float() keep = (hmax == heatmap).float()
return keep * heatmap return keep * heatmap
def soft_nms(dets, sigma=0.5, Nt=0.3, threshold=0.05, method=1):
keep = cpu_soft_nms(np.ascontiguousarray(dets, dtype=np.float32),
np.float32(sigma), np.float32(Nt),
np.float32(threshold),
np.uint8(method))
return keep
def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4, method='standard'): def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4, method='standard'):
""" """
Removes detections with lower object confidence score than 'conf_thres' Removes detections with lower object confidence score than 'conf_thres'
@ -431,7 +424,7 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4, method='stand
prediction, prediction,
conf_thres, conf_thres,
nms_thres, nms_thres,
method = 'standard', 'fast', 'soft_linear' or 'soft_gaussian' method = 'standard' or 'fast'
""" """
output = [None for _ in range(len(prediction))] output = [None for _ in range(len(prediction))]
@ -457,12 +450,6 @@ def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4, method='stand
# Non-maximum suppression # Non-maximum suppression
if method == 'standard': if method == 'standard':
nms_indices = nms(pred[:, :4], pred[:, 4], nms_thres) nms_indices = nms(pred[:, :4], pred[:, 4], nms_thres)
elif method == 'soft_linear':
dets = pred[:, :5].clone().contiguous().data.cpu().numpy()
nms_indices = soft_nms(dets, Nt=nms_thres, method=0)
elif method == 'soft_gaussian':
dets = pred[:, :5].clone().contiguous().data.cpu().numpy()
nms_indices = soft_nms(dets, Nt=nms_thres, method=1)
elif method == 'fast': elif method == 'fast':
nms_indices = fast_nms(pred[:, :4], pred[:, 4], iou_thres=nms_thres, conf_thres=conf_thres) nms_indices = fast_nms(pred[:, :4], pred[:, 4], iou_thres=nms_thres, conf_thres=conf_thres)
else: else: