0

I am trying to run this model:

https://tfhub.dev/google/lite-model/seefood/segmenter/mobile_food_segmenter_V1/1

and display a segmentation mask merged with the original image.

However the result is a completely black image.

I expected a greyscale segmentation mask to sit on top of the image to highlight findings.

Any ideas how to fix?

<!DOCTYPE html>
<html>
  <head>
    <title>Image Transform with TensorFlow.js</title>
    <!-- Load TensorFlow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.1.0/dist/tf.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-tflite@0.0.1-alpha.6/dist/tf-tflite.min.js"></script>
  </head>
  <body>
    <h1>Image Transform with TensorFlow.js</h1>
    <!-- Load the image to be transformed -->
    <img id="input-image" src="https://upload.wikimedia.org/wikipedia/en/7/7d/Lenna_%28test_image%29.png" crossorigin="anonymous" alt="Original image">
    <br>
    <!-- Button to trigger image transform -->
    <button id="transform-button">Transform</button>
    <!-- Display the transformed image -->
    <h1>Transformed Image</h1>
    <canvas id="output-image"></canvas>
    <!-- Script to perform the transformation -->
    <script>
      const transformButton = document.getElementById('transform-button');
      const inputImage = document.getElementById('input-image');
      const canvas = document.getElementById('output-image');
      async function transform(inputImageData) {
        const model = await tflite.loadTFLiteModel('https://storage.googleapis.com/tfhub-lite-models/google/lite-model/seefood/segmenter/mobile_food_segmenter_V1/1.tflite');
        // Loop through the model's input tensors
        for (const input of model.inputs) {
          console.log(`Input Name: ${input.name}`);
          console.log(`Input Shape: ${input.shape}`);
          console.log(`Input Type: ${input.dtype}`);
          console.log();
        }
        // Loop through the model's output tensors
        for (const output of model.outputs) {
          console.log(`Output Name: ${output.name}`);
          console.log(`Output Shape: ${output.shape}`);
          console.log(`Output Type: ${output.dtype}`);
          console.log();
        }
        // Transform
        let inputTensor = tf.browser.fromPixels(inputImageData);
        inputTensor = tf.image.resizeBilinear(inputTensor, [
          513,
          513
        ]).expandDims(0);
        inputTensor = tf.cast (inputTensor, 'int32')
        
        // transform
        let outputTensor = model.predict(inputTensor);
        inputTensor = inputTensor.squeeze(0);
        outputTensor = outputTensor.squeeze(0);
        
        //merge output channels
        // TODO
        let scalarWeight = 0.5;
        let averageOutput = outputTensor.mul(scalarWeight).max(2);
        averageOutput = averageOutput.expandDims(-1);
        
        // merge input and output
        const inputNorm = inputTensor.div(255);
        const outputNorm = averageOutput.div(255);
        console.log(inputNorm.shape);
        console.log(outputNorm.shape);
        let combined = await inputNorm.mul(outputNorm);
        // Create the imageData object
        let dataArray = await tf.browser.toPixels(combined);
        const outputImageData = new ImageData(dataArray, combined.shape[1], combined.shape[0]);
        return outputImageData;
      }
      // When the button is clicked, update the image
      transformButton.addEventListener('click', async () => {
        // Get the canvas element
        const canvas = document.querySelector("canvas");
        const ctx = canvas.getContext('2d');
        // Set the canvas size to the size of the original image
        canvas.width = inputImage.width;
        canvas.height = inputImage.height;
        // Draw the original image on the canvas
        ctx.drawImage(inputImage, 0, 0);
        let imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
        const result = await transform(imageData);
        // Update the image data on the canvas
        canvas.width = result.width;
        canvas.height = result.height;
        ctx.putImageData(result, 0, 0);
      });
    </script>
  </body>
</html>

It should behave like this:

https://codepen.io/jinjingforever/pen/yLMYVJw

Where they use model:

https://tfhub.dev/sayakpaul/lite-model/mobilenetv2-coco/fp16/1

And call:

new ImageData(result.segmentationMap, result.width, result.height)

But I don't know where the .segmentationMap property comes from. I believe it might be documented here:

https://github.com/tensorflow/tfjs-models/tree/master/deeplab

Hugh Pearse
  • 699
  • 1
  • 7
  • 18

1 Answers1

0

Found the solution is to use TFJS Task API: https://www.npmjs.com/package/@tensorflow-models/tasks

<!DOCTYPE html>
<html>
  <head>
    <title>Image Transform with TensorFlow.js</title>
    <!-- Load TensorFlow.js library -->
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.1.0/dist/tf.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs-tflite@0.0.1-alpha.6/dist/tf-tflite.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/tasks@0.0.1-alpha.8/dist/tfjs-tasks.min.js"></script>
  </head>
  <body>
    <h1>Image Transform with TensorFlow.js</h1>
    <!-- Load the image to be transformed -->
    <img id="input-image" src="https://upload.wikimedia.org/wikipedia/en/7/7d/Lenna_%28test_image%29.png" crossorigin="anonymous" alt="Original image">
    <br>
    <!-- Button to trigger image transform -->
    <button id="transform-button">Transform</button>
    <!-- Display the transformed image -->
    <h1>Transformed Image</h1>
    <canvas id="output-image"></canvas>
    <!-- Script to perform the transformation -->
    <script>
      async function transform(inputImageData) {
        const model = await tfTask.ImageSegmentation.CustomModel.TFLite.load({
          // mobilenetv2-coco
          model: "https://tfhub.dev/sayakpaul/lite-model/mobilenetv2-coco/fp16/1?lite-format=tflite"
        });
        // transform
        let result = await model.predict(inputImageData);
        //return result
        return new ImageData(result.segmentationMap, result.width, result.height);
      }
      // When the button is clicked, update the image
      const transformButton = document.getElementById('transform-button');
      transformButton.addEventListener('click', async () => {
        const inputImage = document.getElementById('input-image');
        const result = await transform(inputImage);
        // Update the image data on the canvas
        const canvas = document.querySelector("canvas");
        canvas.width = result.width;
        canvas.height = result.height;
        const ctx = canvas.getContext("2d");
        ctx.clearRect(0, 0, result.width, result.height);
        ctx.putImageData(result, 0, 0);
        canvas.style.width = `${inputImage.width}px`;
        canvas.style.height = `${inputImage.height}px`;
      });
    </script>
  </body>
</html>
Hugh Pearse
  • 699
  • 1
  • 7
  • 18