I am trying to run this model:
https://tfhub.dev/google/lite-model/seefood/segmenter/mobile_food_segmenter_V1/1
and display a segmentation mask merged with the original image.
However the result is a completely black image.
I expected a greyscale segmentation mask to sit on top of the image to highlight findings.
Any ideas how to fix?
<!DOCTYPE html>
<html>
<head>
<title>Image Transform with TensorFlow.js</title>
<!-- Load TensorFlow.js library -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.1.0/dist/tf.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-tflite@0.0.1-alpha.6/dist/tf-tflite.min.js"></script>
</head>
<body>
<h1>Image Transform with TensorFlow.js</h1>
<!-- Load the image to be transformed -->
<img id="input-image" src="https://upload.wikimedia.org/wikipedia/en/7/7d/Lenna_%28test_image%29.png" crossorigin="anonymous" alt="Original image">
<br>
<!-- Button to trigger image transform -->
<button id="transform-button">Transform</button>
<!-- Display the transformed image -->
<h1>Transformed Image</h1>
<canvas id="output-image"></canvas>
<!-- Script to perform the transformation -->
<script>
const transformButton = document.getElementById('transform-button');
const inputImage = document.getElementById('input-image');
const canvas = document.getElementById('output-image');
async function transform(inputImageData) {
const model = await tflite.loadTFLiteModel('https://storage.googleapis.com/tfhub-lite-models/google/lite-model/seefood/segmenter/mobile_food_segmenter_V1/1.tflite');
// Loop through the model's input tensors
for (const input of model.inputs) {
console.log(`Input Name: ${input.name}`);
console.log(`Input Shape: ${input.shape}`);
console.log(`Input Type: ${input.dtype}`);
console.log();
}
// Loop through the model's output tensors
for (const output of model.outputs) {
console.log(`Output Name: ${output.name}`);
console.log(`Output Shape: ${output.shape}`);
console.log(`Output Type: ${output.dtype}`);
console.log();
}
// Transform
let inputTensor = tf.browser.fromPixels(inputImageData);
inputTensor = tf.image.resizeBilinear(inputTensor, [
513,
513
]).expandDims(0);
inputTensor = tf.cast (inputTensor, 'int32')
// transform
let outputTensor = model.predict(inputTensor);
inputTensor = inputTensor.squeeze(0);
outputTensor = outputTensor.squeeze(0);
//merge output channels
// TODO
let scalarWeight = 0.5;
let averageOutput = outputTensor.mul(scalarWeight).max(2);
averageOutput = averageOutput.expandDims(-1);
// merge input and output
const inputNorm = inputTensor.div(255);
const outputNorm = averageOutput.div(255);
console.log(inputNorm.shape);
console.log(outputNorm.shape);
let combined = await inputNorm.mul(outputNorm);
// Create the imageData object
let dataArray = await tf.browser.toPixels(combined);
const outputImageData = new ImageData(dataArray, combined.shape[1], combined.shape[0]);
return outputImageData;
}
// When the button is clicked, update the image
transformButton.addEventListener('click', async () => {
// Get the canvas element
const canvas = document.querySelector("canvas");
const ctx = canvas.getContext('2d');
// Set the canvas size to the size of the original image
canvas.width = inputImage.width;
canvas.height = inputImage.height;
// Draw the original image on the canvas
ctx.drawImage(inputImage, 0, 0);
let imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
const result = await transform(imageData);
// Update the image data on the canvas
canvas.width = result.width;
canvas.height = result.height;
ctx.putImageData(result, 0, 0);
});
</script>
</body>
</html>
It should behave like this:
https://codepen.io/jinjingforever/pen/yLMYVJw
Where they use model:
https://tfhub.dev/sayakpaul/lite-model/mobilenetv2-coco/fp16/1
And call:
new ImageData(result.segmentationMap, result.width, result.height)
But I don't know where the .segmentationMap property comes from. I believe it might be documented here:
https://github.com/tensorflow/tfjs-models/tree/master/deeplab