I am trying to run u2net model in browser, I have converted the pytorch u2netp model into ONNX model and wrote the following code to run it but the results very poor. I followed the same preprocessing steps as that of python script but did not get the results. I was not able to find onnx functions to perform the preprocessing so I have used for loops to change the values in each channel.
<!DOCTYPE html>
<html>
<header>
<title>ONNX Runtime JavaScript examples: Quick Start - Web (using script tag)</title>
<input id="image-selector" type="file" style="top:10px;left:10px" >
<button id="predict-button" class="btn btn-dark float-right" style="top:10px;left:70px" >Predict</button>
<img id="selected-image" src="" />
<canvas id="canvas" width =320px height=320px ></canvas>
</header>
<body>
<!-- import ONNXRuntime Web from CDN -->
<script src="https://code.jquery.com/jquery-3.3.1.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/ort.min.js"></script>
<script>
$("#image-selector").change(function () {
let reader = new FileReader();
reader.onload = function () {
let dataURL = reader.result;
$("#selected-image").attr("src", dataURL);
}
let file = $("#image-selector").prop("files")[0];
reader.readAsDataURL(file);
});
$("#predict-button").click(async function (){
const canvas = document.getElementById("canvas");
const ctx = canvas.getContext("2d");
const session = await ort.InferenceSession.create('./u2netp.onnx').then(console.log("model loaded"));
const inputNames = session.inputNames;
const outputNames = session.outputNames;
console.log('inputNames', inputNames)
console.log('outputNames', outputNames)
let image = $("#selected-image").get(0);
console.log("image.naturalHeight", image.naturalHeight)
console.log("image.naturalWidth", image.naturalWidth)
var oc = document.createElement('canvas'),
octx = oc.getContext('2d');
oc.width = 320;
oc.height = 320;
octx.drawImage(image, 0, 0, oc.width, oc.height);
var input_imageData = octx.getImageData(0, 0, 320, 320);
var floatArr = new Float32Array(320 * 320 * 3)
var j = 0
for (let i = 1; i < input_imageData.data.length+1; i ++) {
if(i % 4 != 0){
floatArr[j] = (input_imageData.data[i-1].toFixed(2))/255; // red color
j = j + 1;
}
}
console.log("floatArr1", floatArr)
for (let i = 1; i < floatArr.length+1; i += 3) {
floatArr[i-1] = (floatArr[i-1] - 0.485)/0.229 // red color
floatArr[i] = (floatArr[i] - 0.456)/0.224 // green color
floatArr[i+1] = (floatArr[i+1] - 0.406)/0.225 // blue color
}
console.log("floatArr2", floatArr)
const input = new ort.Tensor('float32', floatArr, [1, 3, 320, 320])
a = inputNames[0]
console.log("a", a)
const feeds = {"input.1": input};
console.log("feeds", feeds)
const results = await session.run(feeds).then();
const pred = Object.values(results)[0]
console.log('pred', pred)
console.log('pred.data.length', pred.data.length)
console.log('pred.data[0]', Math.round(pred.data[0]*255))
var myImageData = ctx.createImageData(320, 320);
for (let i = 0; i < pred.data.length*4; i += 4) {
var pixelIndex = i;
if(i != 0){
t = i/4;
}
else{
t = 0;
}
myImageData.data[pixelIndex ] = Math.round(pred.data[t]*255); // red color
myImageData.data[pixelIndex + 1] = Math.round(pred.data[t]*255); // green color
myImageData.data[pixelIndex + 2] = Math.round(pred.data[t]*255); // blue color
myImageData.data[pixelIndex + 3] = 255;
}
ctx.putImageData(myImageData, 10, 10);
console.log("myImageData", myImageData)
});
</script>
</body>
</html>