0

I have two models, first one is for classifying images and crop to classes.
After cropping classes from an image I send it to the second model, which classifies digits.
Both of them Yolo v5 models.
But the problem is that I can't send the second one directly from GPU.
First I need to crop, I will get a NumPy array. After receiving the NumPy array I send it to the second one.
I want to stop losing time on converting to NumPy to tensor and vice-versa

model = torch.hub.load('.', 'custom', path=img_cls_path, source='local', force_reload=True)
model_ocr = torch.hub.load('.', 'custom', path=ocr_path, source='local', force_reload=True)
cap = cv2.VideoCapture(some_video_path)

while(cap.isOpened()):
    ret, frame = cap.read()
    results = model(frame)
    crops = results.crop(save=False)    
    for crop in crops:
        if 'number' in crop['label']:
            ocr_result = model_ocr(crop['im'])
            ocr_crop = ocr_result.crop(save=False)

How to combine two models?

1 Answers1

0

You will need to modify the source code for the model architecture to prevent the first model's outputs from being written as numpy type and instead output a pytorch tensor. Barring this, there is no way to prevent the GPU->CPU->GPU transfer.

DerekG
  • 3,555
  • 1
  • 11
  • 21
  • I tried it, but the second don't want to get pytorch tensor, only numpy array – Nurislom Rakhmatullaev Oct 12 '21 at 13:46
  • Then rewrite the start of the `forward` function for the network so it can handle either type of input. Should be very simple. – DerekG Oct 12 '21 at 17:45
  • Can you add a sample code or give a link with an example, please? – Nurislom Rakhmatullaev Oct 12 '21 at 18:01
  • Not really, without knowing the exact model of Yolov5 from torch.hub that you're using. These models tend to be designed for off-the-shelf usage and not that easy to modify, as there are several layers of abstraction on top of the base CNN. I suggest to you to find the line of code that expects a type numpy array and throws an error when passed a pytorch array. You need to modify this line. – DerekG Oct 12 '21 at 18:13