I'm trying to train model with Detectron2 and COCO dataset for vehicle and person detection and I'm having problems with model loading.
I've used posts here on SO and https://github.com/immersive-limit/coco-manager (filter.py file) code to filter COCO dataset to only include annotations and images from classes "person", "car", "bike", "truck" and "bicycle". Now my directory structure is:
main
- annotations:
- instances_train2017_filtered.json
- instances_val2017_filtered.json
- images:
- train2017_filtered (lots of images inside)
- val2017_filtered (lots of images inside)
Basically, the only thing that I've done here was to remove documents and images not corresponding to those classes, and changed their IDs (so they are from 1 to 5).
Then I've used code from Detectron2 tutorial:
import random
import cv2
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2.data.datasets import register_coco_instances
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.config import get_cfg
import os
from detectron2.model_zoo import model_zoo
from detectron2.utils.visualizer import Visualizer
register_coco_instances("train",
{},
"/home/jakub/Projects/coco/annotations/instances_train2017_filtered.json",
"/home/jakub/Projects/coco/images/train2017_filtered/")
register_coco_instances("val",
{},
"/home/jakub/Projects/coco/annotations/instances_val2017_filtered.json",
"/home/jakub/Projects/coco/images/val2017_filtered/")
metadata = MetadataCatalog.get("train")
dataset_dicts = DatasetCatalog.get("train")
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.DATASETS.TRAIN = ("train",)
cfg.DATASETS.TEST = ()
cfg.DATALOADER.NUM_WORKERS = 2
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.00025
cfg.SOLVER.MAX_ITER = 300
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 5
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)
trainer.train()
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
cfg.DATASETS.TEST = ("val", )
predictor = DefaultPredictor(cfg)
img = cv2.imread("demo/input.jpg")
outputs = predictor(img)
for d in random.sample(dataset_dicts, 1):
im = cv2.imread(d["file_name"])
outputs = predictor(im)
v = Visualizer(im[:, :, ::-1],
metadata=metadata,
scale=0.8)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2.imwrite('demo/output_retrained.jpg', out.get_image()[:, :, ::-1])
During training, I get the following errors:
Unable to load 'roi_heads.box_predictor.cls_score.weight' to the model due to incompatible shapes: (81, 1024) in the checkpoint but (6, 1024) in the model!
Unable to load 'roi_heads.box_predictor.cls_score.bias' to the model due to incompatible shapes: (81,) in the checkpoint but (6,) in the model!
Unable to load 'roi_heads.box_predictor.bbox_pred.weight' to the model due to incompatible shapes: (320, 1024) in the checkpoint but (20, 1024) in the model!
Unable to load 'roi_heads.box_predictor.bbox_pred.bias' to the model due to incompatible shapes: (320,) in the checkpoint but (20,) in the model!
Unable to load 'roi_heads.mask_head.predictor.weight' to the model due to incompatible shapes: (80, 256, 1, 1) in the checkpoint but (5, 256, 1, 1) in the model!
Unable to load 'roi_heads.mask_head.predictor.bias' to the model due to incompatible shapes: (80,) in the checkpoint but (5,) in the model!
The model cannot predict anything useful after training, despite reducing total_loss during training. I get that I should get warnings because of size mismatch (I've reduced number of classes), which is normal from what I've seen on the internet, but I don't get "Skipped" after each error line. I think that model is actually not loading anything here, and I wonder why and how can I fix this.
EDIT
For comparison, a similar behaviour in almost identical situation was reported as an Issue, but it had "Skipped" at the end of each error line, making them effectively warnings, not errors: https://github.com/facebookresearch/detectron2/issues/196