0

I need to create a reverse image search engine. The idea is the same as Google Image, you put an image and the engine will return the most similar images.

I did some research about content based image retrieval and found a process to achieve this result.

  1. Process all the images to extract features.
  2. Index the features to enable fast retrieval.

I have some doubts about the technique to extract features from images. I found some documentation about using features detectors (SIFT/SURF/ORB) and other talking about using CNNs to extract a feature vector.

What would be the best solution? CNNs seem to be faster and easier for a quick start, but I don't have the hardware to train a CNN myself (I can have very similar images in my database).

I think I will have 10k-200k images in total.

What would be a good way to index the extracted features? I had a look at Elasticsearch dense_vector field, it seems good, but I am worried about memory requirement. (with the Caltech101 dataset, the index is twice as big as the images themselves)

I also had a look at annoy, milvus and qdrant. Which index system would be the best?

Seltade
  • 1
  • 1
  • Hey @Seltade, please consider marking my answer accepted if you find it helpful (although it may have come a bit too late for you). Thanks! – mkisantal May 15 '23 at 16:17

1 Answers1

0

First, you don't need to train your own neural network to extract features. You can get pretty far with a standard CNN that was trained on ImageNet or a similar large, generic image dataset. These models learn internal representations, that can be useful many different downstream visual tasks.

You can essentially chop the top of an ImageNet trained classifier off (that does the actual classification to the categories in the dataset), and use hidden actitivations from a prior layer as features for your image retrival tasks.

Here is a pytorch example:

import torch
from torch import nn
from torchvision.models import resnet34, ResNet34_Weights
from torchvision import transforms
import numpy as np


class FeatureExtractor(nn.Module):
    def __init__(self, statedict_path):
        super(FeatureExtractor, self).__init__()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        # get backbone
        self.backbone = resnet34(weights=ResNet34_Weights.IMAGENET1K_V1)
        self.embedding = torch.empty(0) 
        self.backbone.avgpool.register_forward_hook(self.get_activation())
        self.backbone.eval()
        
        # load trained parameters
        loading_result = self.load_state_dict(torch.load(statedict_path))
        print(loading_result)
        
        self.to(self.device)
        self.transforms = ResNet34_Weights.IMAGENET1K_V1.transforms()
        
        return
    
    def get_activation(self):
        def fn(_model, _input, output):
            self.embedding = torch.squeeze(output)
        return fn
    
    def forward(self, x):
        with torch.no_grad():

            # prepare sample or batch
            if type(x) != torch.Tensor:
                x = self.transforms(x)
            if x.ndim == 3:
                x = x.unsqueeze(0)

            # inference time!
            x = x.to(self.device)
            _ = self.backbone(x)  # note: we are not using the output (ImageNet logits)
            return self.embedding # instead we return the embedding that we captured with a forward hook

This uses a standard, pretrained torchvision model. Now of course you may try to optimize this model to your particular task and image distribution, but that is significnalty more difficult, and you need a labeled dataset.

It's worth noting that simple centering and normalization may significantly boost the performance, without any training! See: https://arxiv.org/abs/1911.04623

Also I heard some time ago (and can also confirm from my experiences) that older model, in particular VGG have more 'rich' internal representations. In case you are not planning to finetune the model, they may produce better results.


As for the second part of your question, I have limited experience here but Qdrant seems quite good. If memory is a concern for you, but you are ok with longer query times you can store all vectors on your disk, see here. Vector quantization is also possible, in practiacal cases it seems you don't even need to sacrifice too much accuracy for this.

mkisantal
  • 644
  • 6
  • 13