-1

I have been trying to get this zero-shot text classification joeddav / xlm-roberta-large-xnli to convert from h5 to tflite file (https://huggingface.co/joeddav/xlm-roberta-large-xnli), but this error pops up and I cant find it described online, how is it fixed? If it can't, is there another zero-shot text classifier I can use that would produce similar accuracy even after becoming tflite?

AttributeError: 'T5ForConditionalGeneration' object has no attribute 'call'

I have been trying a few different tutorials and the current google colab file I have is an amalgam of a couple of them. https://colab.research.google.com/drive/1sYQJqvhM_KEvMt2IP15d8Ud9L-ApiYv6?usp=sharing

desertnaut
  • 57,590
  • 26
  • 140
  • 166
zinger44
  • 13
  • 4

1 Answers1

0

[ Convert TFLite from saved .h5 model to TFLite model ]

Conversion using tflite convert there are multiple ways by

  1. TF-Lite Convertor TF-Lite convertor
  2. TF.Lite.TFLiteConverter OR else

From the provided links currently they try to convert from saved model .h5 to TFLite, to confirm their question.

[ Sample ]:

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: Model Initialize
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(input_shape=( 32, 32, 3 )),
    tf.keras.layers.Dense(128, activation='relu'),
])
model.compile(optimizer='sgd', loss='mean_squared_error') # compile the model
model.summary()

model.save_weights(checkpoint_path)

"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
: FileWriter
"""""""""""""""""""""""""""""""""""""""""""""""""""""""""
if exists(checkpoint_path) :
    model.load_weights(checkpoint_path)
    print("model load: " + checkpoint_path)


tf_lite_model_converter = tf.lite.TFLiteConverter.from_keras_model(
    model
) # <tensorflow.lite.python.lite.TFLiteKerasModelConverterV2 object at 0x0000021095194E80>
tflite_model = tf_lite_model_converter.convert()

# Save the model.
with open(checkpoint_dir + '\\model.tflite', 'wb') as f:
    f.write(tflite_model)
General Grievance
  • 4,555
  • 31
  • 31
  • 45