0

I'm trying to implement handwriting text recogontition in my Android App. I found TensorFlow to be a doable solution, so I've tried to create a .tflite Model from the Handwriting Recognition Model from Keras The tutorial states that it is fully compatible with TF Lite I managed to create the .tflite model and then in Android intialize the Interpreter with the model. I then ran the Interpreter with a ByteBuffer of a bitmap and the output is a shape of [1,32,81], which is a array of floats. As far as i know the output should just be a String; the prediction text of the given input. How can I get/decode the output to the String I need?

I had a few problems

  1. Converting the model to a .tflite but i managed to do it using certain flags as follows:
converter = tf.lite.TFLiteConverter.from_keras_model(prediction_model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
converter._experimental_lower_tensor_list_ops = False
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tf_lite_model = converter.convert()
open('textRecognitionModel.tflite', 'wb').write(tf_lite_model)
  1. According to the docs of TF Lite you have to use the following dependencies
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly-SNAPSHOT'
// This dependency adds the necessary TF op support.
implementation 'org.tensorflow:tensorflow-lite-select-tf-ops:0.0.0-nightly-SNAPSHOT'

After finally creating a .tflite model file, I then added it to the assets directory of my android app and tried importing it. However, it would crash with no error message, apparently a memory failure. I updated the libraries to the latest version:

"org.tensorflow:tensorflow-lite:2.11.0"
"org.tensorflow:tensorflow-lite-select-tf-ops:2.11.0"

And converted my model to ByteBuffer as follows (I'm not sure if i'm doing it right regarding the native order logic):

// fileName is the name of the model file in the assets dir
val inputStream = assetManager.open(filename)
val output = ByteArrayOutputStream()
inputStream.copyTo(output, 1024)
val file = output.toByteArray()
val bb = ByteBuffer.allocateDirect(file.size)
bb.order(ByteOrder.nativeOrder())
bb.put(file)
return bb

And finally the initialization of the Interpreter API is finally working. I then run the interpreter on a ByteBuffer of a Bitmap. So I'm expecting that the model will read the input and give prediction text (a String) as output. However, the output is a [1,32,81] shape, so i created an array to read the output and ran the Interpreter on it:

val output = Array(1) {
    Array(32) {
        FloatArray(81)
    }
}
// byteBuffer: ByteBuffer of bitmap
interpreter.run(byteBuffer, output)

And the output is an array of floats which I don't understand what this means. Shouldn't it just be a String? I've attached a screenshot of the output arrayoutput screenshot

Can someone please help me?? I would highly appreciate any tips or solutions :)

  • The code of model doesnt return the string. There is another function `decode_batch_predictions(pred)` in your reference link which takes your predicted output and returns the string. Unfortunately the function is not part of model and hence you wont be able to directly use it. – MSS Dec 13 '22 at 01:01

1 Answers1

0

Before converting the prediction_model to tflite format, you need to add a custom layer at the end and then convert it into tflite format.

prediction_model = keras.models.Model(
    model.get_layer(name="image").input, model.get_layer(name="dense2").output
) # This line is present in the handwriting_recognition notebook.

def CTCDecoder():
  def decode_batch_predictions(pred):
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    # Use greedy search. For complex tasks, you can use beam search
    results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][:, :max_length]
    # Iterate over the results and get back the text
    output_text = []
    for res in results:
        #print(res)
        res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
        output_text.append(res)
    return output_text

  return tf.keras.layers.Lambda(decode_batch_predictions, name='decode')

decoded_pred_model = keras.models.Model(prediction_model.input, outputs=CTCDecoder()(prediction_model.output))

Now you can convert decoded_pred_model to your tflite format and use it. CTCDecoder is the custom layer added on top of prediction_model.output to decode the predictions with shape [1,32,81] into texts.

MSS
  • 3,306
  • 1
  • 19
  • 50
  • Thanks @MSS, In this line `results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][:, :max_length]` I get following error: ` ValueError: Exception encountered when calling layer "decode" (type Lambda). Shape must be rank 1 but is rank 0 for '{{node decode/CTCGreedyDecoder}} = CTCGreedyDecoder[T=DT_FLOAT, blank_index=-1, merge_repeated=true](decode/Log, decode/Cast)' with input shapes: [32,?,81], []. Call arguments received by layer "decode" (type Lambda): • inputs=tf.Tensor(shape=(None, 32, 81), dtype=float32) • mask=None • training=False ` – Mehdi Karbalai Dec 13 '22 at 13:31
  • Is this error coming in android while feeding the image or inside notebook ? – MSS Dec 13 '22 at 15:09
  • No, in Python when i run the .py in console. If it's possible, could you perhaps start a chat and we can continue there? – Mehdi Karbalai Dec 13 '22 at 15:20
  • Ok let me check – MSS Dec 13 '22 at 15:21
  • https://chat.stackoverflow.com/rooms/250375/ctcdecoder-issue – MSS Dec 13 '22 at 15:24
  • Sorry, apparantly I need 20 reputation to chat and it's not possible – Mehdi Karbalai Dec 13 '22 at 15:27
  • i used the solution from https://stackoverflow.com/a/67220052/9061528 And i'm able to generate a model with decoded output The output now is an array of type INT64 (long). Can you help me as to how I can read it in java/android? – Mehdi Karbalai Dec 17 '22 at 20:59
  • The complete solution is given in https://github.com/tulasiram58827/ocr_tflite/blob/main/colabs/KERAS_OCR_TFLITE.ipynb with facility to load pretrained weights. Use it to save a model in tflite format for your use. – MSS Dec 18 '22 at 17:30