0

I am trying to run inference using yolo model for pose estimation in a batched manner (i.e I concat the images in the batch dimension before passing it through the model). I am running into an issue where the non-max-suppression part is taking very long for the first image.

Here is the code (not mine) : -

def non_max_suppression_kpt(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=True, multi_label=False,
                        labels=(), kpt_label=False, nc=None, nkpt=None):
    """Runs Non-Maximum Suppression (NMS) on inference results

    Returns:
         list of detections, on (n,6) tensor per image [xyxy, conf, cls]
    """
    if nc is None:
        nc = prediction.shape[2] - 5  if not kpt_label else prediction.shape[2] - 56 # number of classes
    xc = prediction[..., 4] > conf_thres  # candidates

    # Settings
    min_wh, max_wh = 2, 4096  # (pixels) minimum and maximum box width and height
    max_det = 300  # maximum number of detections per image
    max_nms = 30000  # maximum number of boxes into torchvision.ops.nms()
    time_limit = 10.0  # seconds to quit after
    redundant = True  # require redundant detections
    multi_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)
    merge = False  # use merge-NMS

    t = time.time()
    output = [torch.zeros((0,6), device=prediction.device)] * prediction.shape[0]
    for xi, x in enumerate(prediction):  # image index, image inference
        # Apply constraints
        # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height
        
        t1 = time.time()
        x = x[xc[xi]]  # confidence
        print("Time taken for NMS cam : " , xi , " x = x[xc[xi]] : " , time.time() - t1, " xc shape : " , xc.shape)
        # Cat apriori labels if autolabelling
        if labels and len(labels[xi]):
            l = labels[xi]
            v = torch.zeros((len(l), nc + 5), device=x.device)
            v[:, :4] = l[:, 1:5]  # box
            v[:, 4] = 1.0  # conf
            v[range(len(l)), l[:, 0].long() + 5] = 1.0  # cls
            x = torch.cat((x, v), 0)

        # If none remain process next image
        if not x.shape[0]:
            continue

        # Compute conf
        t1 = time.time()
        x[:, 5:5+nc] *= x[:, 4:5]  # conf = obj_conf * cls_conf
        print("Time taken for NMS cam : " , xi , " x[:, 5:5+nc] *= x[:, 4:5] : " , time.time() - t1)

        # Box (center x, center y, width, height) to (x1, y1, x2, y2)
        t1 = time.time()
        box = xywh2xyxy(x[:, :4])
        print("Time taken for NMS cam : " , xi , " box = xywh2xyxy(x[:, :4]) : " , time.time() - t1)

        # Detections matrix nx6 (xyxy, conf, cls)
        if multi_label:
            i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
            x = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)
        else:  # best class only
            if not kpt_label:
                conf, j = x[:, 5:].max(1, keepdim=True)
                x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]
            else:
                t1 = time.time()
                kpts = x[:, 6:]
                conf, j = x[:, 5:6].max(1, keepdim=True)
                x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres]
                print("Time taken for NMS cam : " , xi , " x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] : " , time.time() - t1)


        # Filter by class
        if classes is not None:
            x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]

        # Apply finite constraint
        # if not torch.isfinite(x).all():
        #     x = x[torch.isfinite(x).all(1)]

        # Check shape
        n = x.shape[0]  # number of boxes
        if not n:  # no boxes
            continue
        elif n > max_nms:  # excess boxes
            x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence

        # Batched NMS
        t1 = time.time()
        c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes
        print("Time taken for NMS cam : " , xi , " c = x[:, 5:6] * (0 if agnostic else max_wh) : " , time.time() - t1)
        t1 = time.time()
        boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores
        print("Time taken for NMS cam : " , xi , " boxes, scores = x[:, :4] + c, x[:, 4] : " , time.time() - t1)
        t1 = time.time()
        i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS
        print("Time taken for NMS cam : " , xi , " i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS : " , time.time() - t1)
        if i.shape[0] > max_det:  # limit detections
            i = i[:max_det]
        if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)
            # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
            iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrix
            weights = iou * scores[None]  # box weights
            x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxes
            if redundant:
                i = i[iou.sum(1) > 1]  # require redundancy

        output[xi] = x[i]
        if (time.time() - t) > time_limit:
            print(f'WARNING: NMS time limit {time_limit}s exceeded')
            break  # time limit exceeded

    return output

and here are the is the time for a single iteration with a batch size of 5 :

Time taken for NMS cam :  0  x = x[xc[xi]] :  0.024521827697753906  xc shape :  torch.Size([6, 16320])
Time taken for NMS cam :  0  x[:, 5:5+nc] *= x[:, 4:5] :  6.413459777832031e-05
Time taken for NMS cam :  0  box = xywh2xyxy(x[:, :4]) :  0.0001697540283203125
Time taken for NMS cam :  0  x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] :  0.00015687942504882812
Time taken for NMS cam :  0  c = x[:, 5:6] * (0 if agnostic else max_wh) :  2.09808349609375e-05
Time taken for NMS cam :  0  boxes, scores = x[:, :4] + c, x[:, 4] :  1.6927719116210938e-05
Time taken for NMS cam :  0  i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS :  0.00021004676818847656
Time taken for NMS cam :  1  x = x[xc[xi]] :  4.9114227294921875e-05  xc shape :  torch.Size([6, 16320])
Time taken for NMS cam :  1  x[:, 5:5+nc] *= x[:, 4:5] :  2.9802322387695312e-05
Time taken for NMS cam :  1  box = xywh2xyxy(x[:, :4]) :  0.00013899803161621094
Time taken for NMS cam :  1  x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] :  0.00010418891906738281
Time taken for NMS cam :  1  c = x[:, 5:6] * (0 if agnostic else max_wh) :  1.7881393432617188e-05
Time taken for NMS cam :  1  boxes, scores = x[:, :4] + c, x[:, 4] :  1.9073486328125e-05
Time taken for NMS cam :  1  i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS :  0.00011444091796875
Time taken for NMS cam :  2  x = x[xc[xi]] :  4.5299530029296875e-05  xc shape :  torch.Size([6, 16320])
Time taken for NMS cam :  2  x[:, 5:5+nc] *= x[:, 4:5] :  2.86102294921875e-05
Time taken for NMS cam :  2  box = xywh2xyxy(x[:, :4]) :  0.0001354217529296875
Time taken for NMS cam :  2  x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] :  9.965896606445312e-05
Time taken for NMS cam :  2  c = x[:, 5:6] * (0 if agnostic else max_wh) :  1.811981201171875e-05
Time taken for NMS cam :  2  boxes, scores = x[:, :4] + c, x[:, 4] :  1.811981201171875e-05
Time taken for NMS cam :  2  i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS :  0.00011396408081054688
Time taken for NMS cam :  3  x = x[xc[xi]] :  4.649162292480469e-05  xc shape :  torch.Size([6, 16320])
Time taken for NMS cam :  3  x[:, 5:5+nc] *= x[:, 4:5] :  2.86102294921875e-05
Time taken for NMS cam :  3  box = xywh2xyxy(x[:, :4]) :  0.0001380443572998047
Time taken for NMS cam :  3  x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] :  0.00011682510375976562
Time taken for NMS cam :  3  c = x[:, 5:6] * (0 if agnostic else max_wh) :  1.8358230590820312e-05
Time taken for NMS cam :  3  boxes, scores = x[:, :4] + c, x[:, 4] :  1.8596649169921875e-05
Time taken for NMS cam :  3  i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS :  0.00011110305786132812
Time taken for NMS cam :  4  x = x[xc[xi]] :  4.482269287109375e-05  xc shape :  torch.Size([6, 16320])
Time taken for NMS cam :  4  x[:, 5:5+nc] *= x[:, 4:5] :  2.86102294921875e-05
Time taken for NMS cam :  4  box = xywh2xyxy(x[:, :4]) :  0.00013637542724609375
Time taken for NMS cam :  4  x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] :  9.822845458984375e-05
Time taken for NMS cam :  4  c = x[:, 5:6] * (0 if agnostic else max_wh) :  1.8358230590820312e-05
Time taken for NMS cam :  4  boxes, scores = x[:, :4] + c, x[:, 4] :  1.9311904907226562e-05
Time taken for NMS cam :  4  i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS :  0.00011277198791503906
Time taken for NMS cam :  5  x = x[xc[xi]] :  4.458427429199219e-05  xc shape :  torch.Size([6, 16320])
Time taken for NMS cam :  5  x[:, 5:5+nc] *= x[:, 4:5] :  2.8133392333984375e-05
Time taken for NMS cam :  5  box = xywh2xyxy(x[:, :4]) :  0.00013685226440429688
Time taken for NMS cam :  5  x = torch.cat((box, conf, j.float(), kpts), 1)[conf.view(-1) > conf_thres] :  9.775161743164062e-05
Time taken for NMS cam :  5  c = x[:, 5:6] * (0 if agnostic else max_wh) :  1.8358230590820312e-05
Time taken for NMS cam :  5  boxes, scores = x[:, :4] + c, x[:, 4] :  1.811981201171875e-05
Time taken for NMS cam :  5  i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMS :  0.00010848045349121094

As we can see the time taken by the part "x = x[xc[xi]]" in the first iteration, of the predictions for loop (essentially the first image in the batch), is very high compared to the next iterations.

I read somewhere that it might be due to pytorch shifting data to gpu. Adding "torch.cuda.empty_cache()" moves this bottleneck to that point.

Am I missing something? Is it due to pytorch shifting data around? or is it due to the fact that this code is not meant for batched NMS? if so How can I shift this code to batched nms?

In both cases I would really admire any help you can give me or point me to anything useful.

I tried to clear cuda cache on every iteration but the same thing happens but now in the empty caching part and only in the first iteration.

  • Please provide a [minimal reproducible example](https://stackoverflow.com/help/minimal-reproducible-example). – Gugu72 Aug 28 '23 at 21:22
  • @Gugu72 thanks for the reply. Ill try to write this up but it will require the repository to be setted up and I am using the following repository :- https://github.com/WongKinYiu/yolov7 – gamerchief gaming Aug 28 '23 at 23:56

0 Answers0