I am trying to run inference using yolo model for pose estimation in a batched manner (i.e I concat the images in the batch dimension before passing it through the model). I am running into an issue where the non-max-suppression part is taking very long for the first image.
Here is the code (not mine) : -
def non_max_suppression_kpt(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=True, multi_label=False,
labels=(), kpt_label=False, nc=None, nkpt=None):
"""Runs Non-Maximum Suppression (NMS) on inference results
Returns:
list of detections, on (n,6) tensor per image [xyxy, conf, cls]
"""
if nc is None:
nc = prediction.shape[2] - 5 if not kpt_label else prediction.shape[2] - 56 # number of classes
xc = prediction[..., 4] > conf_thres # candidates
# Settings
min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
max_det = 300 # maximum number of detections per image
max_nms = 30000 # maximum number of boxes into torchvision.ops.nms()
time_limit = 10.0 # seconds to quit after
redundant = True # require redundant detections
multi_label &= nc > 1 # multiple labels per box (adds 0.5ms/img)
merge = False # use merge-NMS
t = time.time()
output = [torch.zeros((0,6), device=prediction.device)] * prediction.shape[0]
for xi, x in enumerate(prediction): # image index, image inference
# Apply constraints
# x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
t1 = time.time()
x = x[xc[xi]] # confidence
print("Time taken for NMS cam : " , xi , " x = x[xc[xi]] : " , time.time() - t1, " xc shape : " , xc.shape)
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
l = labels[xi]
v = torch.zeros((len(l), nc + 5), device=x.device)
v[:, :4] = l[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
x = torch.cat((x, v), 0)
# If none remain process next image
if not x.shape[0]:
continue
# Compute conf
t1 = time.time()
x[:, 5:5+nc] *= x[:, 4:5] # conf = obj_conf * cls_conf
print("Time taken for NMS cam : " , xi , " x[:, 5:5+nc] *= x[:, 4:5] : " , time.time() - t1)
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
t1 = time.time()
box = xywh2xyxy(x[:, :4])
print("Time taken for NMS cam : " , xi , " box = xywh2xyxy(x[:, :4]) : " , time.time() - t1)
# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
else: # best class only
if not kpt_label:
conf, j = x[:, 5:].max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
else:
t1 = time.time()
kpts = x[:, 6:]
conf, j = x[:, 5:6].max(1, keepdim=True)
x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres]
print("Time taken for NMS cam : " , xi , " x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : " , time.time() - t1)
# Filter by class
if classes is not None:
x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
# Apply finite constraint
# if not torch.isfinite(x).all():
# x = x[torch.isfinite(x).all(1)]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
elif n > max_nms: # excess boxes
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence
# Batched NMS
t1 = time.time()
c = x[:, 5:6] * (0 if agnostic else max_wh) # classes
print("Time taken for NMS cam : " , xi , " c = x[:, 5:6] * (0 if agnostic else max_wh) : " , time.time() - t1)
t1 = time.time()
boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
print("Time taken for NMS cam : " , xi , " boxes, scores = x[:, :4] + c, x[:, 4] : " , time.time() - t1)
t1 = time.time()
i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
print("Time taken for NMS cam : " , xi , " i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : " , time.time() - t1)
if i.shape[0] > max_det: # limit detections
i = i[:max_det]
if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
weights = iou * scores[None] # box weights
x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
if redundant:
i = i[iou.sum(1) > 1] # require redundancy
output[xi] = x[i]
if (time.time() - t) > time_limit:
print(f'WARNING: NMS time limit {time_limit}s exceeded')
break # time limit exceeded
return output
and here are the is the time for a single iteration with a batch size of 5 :
Time taken for NMS cam : 0 x = x[xc[xi]] : 0.024521827697753906 xc shape : torch.Size([6, 16320])
Time taken for NMS cam : 0 x[:, 5:5+nc] *= x[:, 4:5] : 6.413459777832031e-05
Time taken for NMS cam : 0 box = xywh2xyxy(x[:, :4]) : 0.0001697540283203125
Time taken for NMS cam : 0 x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : 0.00015687942504882812
Time taken for NMS cam : 0 c = x[:, 5:6] * (0 if agnostic else max_wh) : 2.09808349609375e-05
Time taken for NMS cam : 0 boxes, scores = x[:, :4] + c, x[:, 4] : 1.6927719116210938e-05
Time taken for NMS cam : 0 i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : 0.00021004676818847656
Time taken for NMS cam : 1 x = x[xc[xi]] : 4.9114227294921875e-05 xc shape : torch.Size([6, 16320])
Time taken for NMS cam : 1 x[:, 5:5+nc] *= x[:, 4:5] : 2.9802322387695312e-05
Time taken for NMS cam : 1 box = xywh2xyxy(x[:, :4]) : 0.00013899803161621094
Time taken for NMS cam : 1 x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : 0.00010418891906738281
Time taken for NMS cam : 1 c = x[:, 5:6] * (0 if agnostic else max_wh) : 1.7881393432617188e-05
Time taken for NMS cam : 1 boxes, scores = x[:, :4] + c, x[:, 4] : 1.9073486328125e-05
Time taken for NMS cam : 1 i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : 0.00011444091796875
Time taken for NMS cam : 2 x = x[xc[xi]] : 4.5299530029296875e-05 xc shape : torch.Size([6, 16320])
Time taken for NMS cam : 2 x[:, 5:5+nc] *= x[:, 4:5] : 2.86102294921875e-05
Time taken for NMS cam : 2 box = xywh2xyxy(x[:, :4]) : 0.0001354217529296875
Time taken for NMS cam : 2 x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : 9.965896606445312e-05
Time taken for NMS cam : 2 c = x[:, 5:6] * (0 if agnostic else max_wh) : 1.811981201171875e-05
Time taken for NMS cam : 2 boxes, scores = x[:, :4] + c, x[:, 4] : 1.811981201171875e-05
Time taken for NMS cam : 2 i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : 0.00011396408081054688
Time taken for NMS cam : 3 x = x[xc[xi]] : 4.649162292480469e-05 xc shape : torch.Size([6, 16320])
Time taken for NMS cam : 3 x[:, 5:5+nc] *= x[:, 4:5] : 2.86102294921875e-05
Time taken for NMS cam : 3 box = xywh2xyxy(x[:, :4]) : 0.0001380443572998047
Time taken for NMS cam : 3 x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : 0.00011682510375976562
Time taken for NMS cam : 3 c = x[:, 5:6] * (0 if agnostic else max_wh) : 1.8358230590820312e-05
Time taken for NMS cam : 3 boxes, scores = x[:, :4] + c, x[:, 4] : 1.8596649169921875e-05
Time taken for NMS cam : 3 i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : 0.00011110305786132812
Time taken for NMS cam : 4 x = x[xc[xi]] : 4.482269287109375e-05 xc shape : torch.Size([6, 16320])
Time taken for NMS cam : 4 x[:, 5:5+nc] *= x[:, 4:5] : 2.86102294921875e-05
Time taken for NMS cam : 4 box = xywh2xyxy(x[:, :4]) : 0.00013637542724609375
Time taken for NMS cam : 4 x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : 9.822845458984375e-05
Time taken for NMS cam : 4 c = x[:, 5:6] * (0 if agnostic else max_wh) : 1.8358230590820312e-05
Time taken for NMS cam : 4 boxes, scores = x[:, :4] + c, x[:, 4] : 1.9311904907226562e-05
Time taken for NMS cam : 4 i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : 0.00011277198791503906
Time taken for NMS cam : 5 x = x[xc[xi]] : 4.458427429199219e-05 xc shape : torch.Size([6, 16320])
Time taken for NMS cam : 5 x[:, 5:5+nc] *= x[:, 4:5] : 2.8133392333984375e-05
Time taken for NMS cam : 5 box = xywh2xyxy(x[:, :4]) : 0.00013685226440429688
Time taken for NMS cam : 5 x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : 9.775161743164062e-05
Time taken for NMS cam : 5 c = x[:, 5:6] * (0 if agnostic else max_wh) : 1.8358230590820312e-05
Time taken for NMS cam : 5 boxes, scores = x[:, :4] + c, x[:, 4] : 1.811981201171875e-05
Time taken for NMS cam : 5 i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS : 0.00010848045349121094
As we can see the time taken by the part "x = x[xc[xi]]" in the first iteration, of the predictions for loop (essentially the first image in the batch), is very high compared to the next iterations.
I read somewhere that it might be due to pytorch shifting data to gpu. Adding "torch.cuda.empty_cache()" moves this bottleneck to that point.
Am I missing something? Is it due to pytorch shifting data around? or is it due to the fact that this code is not meant for batched NMS? if so How can I shift this code to batched nms?
In both cases I would really admire any help you can give me or point me to anything useful.
I tried to clear cuda cache on every iteration but the same thing happens but now in the empty caching part and only in the first iteration.