2

AS a part of my Master's thesis, I have trained a UNET using Pytorch for detecting some objects in X-ray images. And to generate the predications, I have implemented the following function:

def make_predictions(model, imagePath):
# set model to evaluation mode
model.eval()

# turn off gradient tracking
with torch.no_grad():
    # load the image from disk, expand its dimensions, cast it
    # to float data type, and scale its pixel values
    image = cv2.imread(imagePath, 0)
    image = np.expand_dims(image, 0)
    image = np.expand_dims(image, 0)
    image = image.astype("float32") / 255.0         
    
    # find the filename and generate the path to ground truth mask
    filename = imagePath.split(os.path.sep)[-1]
    groundTruthPath = os.path.join(Config.Mask_dataset_dir, filename)
    
    # load the ground-truth segmentation mask in grayscale mode and resize it
    gtMask = cv2.imread(groundTruthPath, 0)
    gtMask = cv2.resize(gtMask, (Config.Input_Height, Config.Input_Height))
    
    # create a PyTorch tensor, and flash it to the current device
    image = torch.from_numpy(image).to(Config.DEVICE)
    
    # make the prediction, pass the results through the sigmoid
    # function, and convert the result to a NumPy array
    predMask = model(image)
    predMask = torch.sigmoid(predMask)
    predMask = predMask.cpu().numpy()
    
    # filter out the weak predictions and convert them to integers
    predMask = (predMask > Config.Thresh) * 255
    predMask = predMask.astype(np.uint8)
    filename = imagePath.split(os.path.sep)[-1]
    cv2.imwrite(Config.Base_Out+'\\'+filename, predMask)
    
    return(gtMask, predMask)

This function runs well for making the predictions and even plotting them. but the function cv2.imwrite() doesn't save the predictions as images in the passed directory, noting that filename already has the .PNG extension at the end. What could be the problem here?

Christoph Rackwitz
  • 11,317
  • 4
  • 27
  • 36

0 Answers0