I am trying to run detectron2 google colab code given on the front page of the repository. I have a system with apple silicon GPU and I am encountering an error. It seems like the "mps" only supports up to 16 proposals from the error message. Any workaround or monkey-patching?
Reproduction steps:
1. Download sample image from coco dataset using the following command:
!wget http://images.cocodataset.org/val2017/000000439715.jpg -q -O input.jpg
2. Run the following code:
# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
# import some common libraries
import numpy as np
import os, json, cv2, random
import matplotlib.pyplot as plt
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
cfg = get_cfg()
# This line is added to set device to "mps"
cfg.MODEL.DEVICE = "mps"
# add project-specific config (e.g., TensorMask) here if you're not running a model in detectron2's core library
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model
# Find a model from detectron2's model zoo. You can use the https://dl.fbaipublicfiles... url as well
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)
# !wget http://images.cocodataset.org/val2017/000000439715.jpg -q -O input.jpg
im = cv2.imread("input.jpg")
plt.imshow(im)
outputs = predictor(I'm)
print(outputs)
3. Actual output
WARNING:root:Pytorch pre-release version 1.13.0.dev20220618 - assuming intent to test it
/Users/bikram/miniforge3/envs/mcs2/lib/python3.7/site-packages/torch/functional.py:478: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp:2890.)
return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
Traceback (most recent call last):
File "/Users/bikram/PycharmProjects/detectron-example/main.py", line 33, in <module>
outputs = predictor(im)
File "/Users/bikram/miniforge3/envs/mcs2/lib/python3.7/site-packages/detectron2/engine/defaults.py", line 317, in __call__
predictions = self.model([inputs])[0]
File "/Users/bikram/miniforge3/envs/mcs2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1131, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/bikram/miniforge3/envs/mcs2/lib/python3.7/site-packages/detectron2/modeling/meta_arch/rcnn.py", line 150, in forward
return self.inference(batched_inputs)
File "/Users/bikram/miniforge3/envs/mcs2/lib/python3.7/site-packages/detectron2/modeling/meta_arch/rcnn.py", line 208, in inference
proposals, _ = self.proposal_generator(images, features, None)
File "/Users/bikram/miniforge3/envs/mcs2/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1131, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/bikram/miniforge3/envs/mcs2/lib/python3.7/site-packages/detectron2/modeling/proposal_generator/rpn.py", line 478, in forward
anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes
File "/Users/bikram/miniforge3/envs/mcs2/lib/python3.7/site-packages/detectron2/modeling/proposal_generator/rpn.py", line 511, in predict_proposals
self.training,
File "/Users/bikram/miniforge3/envs/mcs2/lib/python3.7/site-packages/detectron2/modeling/proposal_generator/proposal_utils.py", line 79, in find_top_rpn_proposals
topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
RuntimeError: Currently topk on mps works only for k<=16
Expected output:
When the device is set to "CPU" on the same system, I get the following outputs without any error:
{'instances': Instances(num_instances=15, image_height=480, image_width=640, fields=[pred_boxes: Boxes(tensor([[126.6035, 244.8977, 459.8291, 480.0000],
[251.1083, 157.8127, 338.9731, 413.6379],
[114.8496, 268.6864, 148.2352, 398.8111],
[ 0.8217, 281.0327, 78.6072, 478.4210],
[ 49.3954, 274.1229, 80.1545, 342.9808],
[561.2248, 271.5816, 596.2755, 385.2552],
[385.9072, 270.3125, 413.7130, 304.0397],
[515.9295, 278.3744, 562.2792, 389.3803],
[335.2409, 251.9167, 414.7491, 275.9375],
[350.9300, 269.2060, 386.0984, 297.9081],
[331.6292, 230.9996, 393.2759, 257.2009],
[510.7349, 263.2656, 570.9865, 295.9194],
[409.0841, 271.8646, 460.5582, 356.8722],
[506.8767, 283.3257, 529.9403, 324.0392],
[594.5663, 283.4820, 609.0577, 311.4124]])), scores: tensor([0.9997, 0.9957, 0.9915, 0.9882, 0.9861, 0.9840, 0.9769, 0.9716, 0.9062,
0.9037, 0.8870, 0.8575, 0.6592, 0.5899, 0.5767]), pred_classes: tensor([17, 0, 0, 0, 0, 0, 0, 0, 25, 0, 25, 25, 0, 0, 24]), pred_masks: tensor([[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
...,
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]],
[[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
...,
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False],
[False, False, False, ..., False, False, False]]])])}