0

I need to extract data from any middle layer. I need it for making comparing images, the idea is that if the model recognizes a cat it would check if that image of a cat matches the previous images of cats.

I have done the following:

model = YOLO("yolov8s")
del model.model.model[-1]

This way I am deleting the final layer, but I can't extract data from it. When I call the predict method I get the following:

File "\Lib\site-packages\ultralytics\utils\ops.py", line 219, in non_max_suppression
    x = x[xc[xi]]  # confidence
        ~^^^^^^^^
IndexError: The shape of the mask [12, 20] at index 0 does not match the shape of the indexed tensor [512, 20, 12] at index 0

What can I do differently to get the data? It would be better not create a model and delete the final layer( or x amount) because I still need the final result.

1 Answers1

0

The only way I found to do this is to modify the Detect class. I added a callback function.

def modifyDetectClass(): 
Detect.callback = None

def forward(self, x):
    shape = x[0].shape
    for i in range(self.nl):
        x[i] = torch.cat((self.cv2[i](x[i]), self.cv3[i](x[i])), 1)
    if self.training:
        return x
    elif self.dynamic or self.shape != shape:
        self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
        self.shape = shape

    x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
    if self.export and self.format in ('saved_model', 'pb', 'tflite', 'edgetpu', 'tfjs'): 
        box = x_cat[:, :self.reg_max * 4]
        cls = x_cat[:, self.reg_max * 4:]
    else:
        box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
    dbox =  dist2bbox(self.dfl(box), self.anchors.unsqueeze(0), xywh=True, dim=1) * self.strides
    if self.export and self.format in ('tflite', 'edgetpu'):
        img_h = shape[2] * self.stride[0]
        img_w = shape[3] * self.stride[0]
        img_size = torch.tensor([img_w, img_h, img_w, img_h], device=dbox.device).reshape(1, 4, 1)
        dbox /= img_size

    y = torch.cat((dbox, cls.sigmoid()), 1)
    # Added code for intermediate data extraction
    if Detect.callback is not None:
        Detect.callback(x[1])
        Detect.callback = None
    #
    return y if self.export else (y, x)

Detect.forward = forward
return

And every time before you call the model add the callback function. Example:

self.value = None
def setValue(x : any):
    self.value = x
Detect.callback = setValue
self.model.predict(img, verbose=False)

But to get the vectors of same size you need to resize the image to a constant size(512,512).