0

I'd really appreciate it if anyone can advise with a task I've been working without success for the last week. I have semantic segmentation model (MobileNetV3 + Lightweight ASPP).Short info: input - 1024x1024, output - same size and 2 classes (bg and vehicle), so my output shape is (1, 1048576, 2). I'm not the mobile dev or java world guy, so I used a few complete andoid examples for image segmentation to test it: the one from google: https://github.com/tensorflow/examples/tree/master/lite/examples/image_segmentation and another one open-sourced: https://github.com/pillarpond/image-segmenter-android

I successfully converted it to tflite format and its inference time on OnePlus 7 with GPU enabled and 10 threads is between 105-140ms for such size. But here I run into a problem: general execution time in these two android examples or any you can find for semantic segmentation is about 1050-1300ms (which is less than 1FPS). The slower part of this pipeline is image post-processing (~900-1150ms). You can see that part in the Deeplab#segment method. Since I have only 1 class besides bg - I don't have this third loop, but everything else is untouched and still very slow. Output size is not small in comparison to other common mobile sizes like 128/226/512, but still. I think it shouldn't take so much time to process 1024x1024 matrix and draw rectangles in canvas on modern smartphones. I tried different solutions, like splitting matrix manipulations into threads or creating all these objects like RectF and Recognition once before and just filling their attributes with new data inside nested loops, but I didn't succeed on either of them. On the desktop side I easily handle it with numpy and opencv and I don't even close to understanding how can I do the same in Android and will it even be efficient or not. Here's code which I use in python:

CLASS_COLORS = [(0, 0, 0), (255, 255, 255)] # black for bg and white for mask


def get_image_array(image_input, width, height):
    img = cv2.imread(image_input, 1)
    img = cv2.resize(img, (width, height))
    img = img.astype(np.float32)
    img[:, :, 0] -= 128.0
    img[:, :, 1] -= 128.0
    img[:, :, 2] -= 128.0
    img = img[:, :, ::-1]
    return img

def get_segmentation_array(seg_arr, n_classes):
    output_height = seg_arr.shape[0]
    output_width = seg_arr.shape[1]
    seg_img = np.zeros((output_height, output_width, 3))
    for c in range(n_classes):
        seg_arr_c = seg_arr[:, :] == c
        seg_img[:, :, 0] += ((seg_arr_c)*(CLASS_COLORS[c][0])).astype('uint8')
        seg_img[:, :, 1] += ((seg_arr_c)*(CLASS_COLORS[c][1])).astype('uint8')
        seg_img[:, :, 2] += ((seg_arr_c)*(CLASS_COLORS[c][2])).astype('uint8')

    return seg_img


interpreter = tf.lite.Interpreter(model_path=f"my_model.tflite")
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


img_arr = get_image_array("input.png", 1024, 1024)
interpreter.set_tensor(input_details[0]['index'], np.array([x]))
interpreter.invoke()

output = interpreter.get_tensor(output_details[0]['index'])
output = output.reshape((1024,  1024, 2)).argmax(axis=2)
seg_img = get_segmentation_array(output, 2)
cv2.imwrite("output.png", seg_img)

Maybe there's anything powerful than the current solution for post-processing. I would really appreciate any help with this. I'm sure there's anything that can improve post-processing and reduce its time to ~100ms, so I will have ~5FPS in general.

Community
  • 1
  • 1
ezelen
  • 21
  • 3
  • I think you should try OpenCv inside android. Take a look at [this](https://github.com/farmaker47/Pneumothorax) project where OpenCv is used inside an android app that segments lung disease. OpenCv has a class for blob detection where it can draw rectangles at the edge of the blob (mask). Tag me if you need any help – Farmaker Jun 14 '20 at 05:49
  • Thanks for the help @Farmaker! – ezelen Jun 14 '20 at 23:48

1 Answers1

1

New Update. Thanks to Farmaker, I used a piece of code found in his repo from comment above and now pipeline looks like:

    int channels = 3;
    int n_classes = 2;
    int float_byte_size = 4;
    int width = model.inputWidth;
    int height = model.inputHeight;

    int[] intValues = new int[width * height];
    ByteBuffer inputBuffer = ByteBuffer.allocateDirect(width * height * channels * float_byte_size).order(ByteOrder.nativeOrder());
    ByteBuffer outputBuffer = ByteBuffer.allocateDirect(width * height * n_classes * float_byte_size).order(ByteOrder.nativeOrder());

    Bitmap input = textureView.getBitmap(width, height);
    input.getPixels(intValues, 0, width, 0, 0, height, height);

    inputBuffer.rewind();
    outputBuffer.rewind();

    for (final int value: intValues) {
        inputBuffer.putFloat(((value >> 16 & 0xff) - 128.0) / 1.0f);
        inputBuffer.putFloat(((value >> 8 & 0xff) - 128.0) / 1.0f);
        inputBuffer.putFloat(((value & 0xff) - 128.0) / 1.0f);
    }

    tfLite.run(inputBuffer, outputBuffer);

    final Bitmap output = Bitmap.createBitmap(width, height, Bitmap.Config.ARGB_8888);
    outputBuffer.flip();
    int[] pixels = new int[width * height];
    for (int i = 0; i < width * height; i++) {
        float max = outputBuffer.getFloat();
        float val = outputBuffer.getFloat();
        int id = val > max ? 1 : 0;
        pixels[i] = id == 0 ? 0x00000000 : 0x990000ff;
    }
    output.setPixels(pixels, 0, width, 0, 0, width, height);
    resultView.setImageBitmap(resizeBitmap(output, resultView.getWidth(), resultView.getHeight()));


    public static Bitmap resizeBitmap(Bitmap bm, int newWidth, int newHeight) {
        int width = bm.getWidth();
        int height = bm.getHeight();
        float scaleWidth = ((float) newWidth) / width;
        float scaleHeight = ((float) newHeight) / height;
        // CREATE A MATRIX FOR THE MANIPULATION
        Matrix matrix = new Matrix();
        // RESIZE THE BIT MAP
        matrix.postScale(scaleWidth, scaleHeight);

        // "RECREATE" THE NEW BITMAP
        Bitmap resizedBitmap = Bitmap.createBitmap(
                bm, 0, 0, width, height, matrix, false);
        bm.recycle();
        return resizedBitmap;
    }

Right now post-processing time is ~70-130ms, 95th is around 90ms, which alongside ~60ms of image pre-processing time, ~140ms inference time and around 30-40ms for other stuff with enabled GPU and 10 threads gives me general execution time around 330ms which is 3FPS! And this is for a large model for 1024x1024. At this point, I'm more than satisfied and want to try different configurations for my model, including MobilenetV3 small as a backbone.

sonique
  • 4,539
  • 2
  • 30
  • 39
ezelen
  • 21
  • 3