Search code examples
tensorflowtensorflow-litehuggingface-transformerstext-classificationhuggingface

Converting h5 to tflite


I have been trying to get this zero-shot text classification joeddav / xlm-roberta-large-xnli to convert from h5 to tflite file (https://huggingface.co/joeddav/xlm-roberta-large-xnli), but this error pops up and I cant find it described online, how is it fixed? If it can't, is there another zero-shot text classifier I can use that would produce similar accuracy even after becoming tflite?

AttributeError: 'T5ForConditionalGeneration' object has no attribute 'call'

I have been trying a few different tutorials and the current google colab file I have is an amalgam of a couple of them. https://colab.research.google.com/drive/1sYQJqvhM_KEvMt2IP15d8Ud9L-ApiYv6?usp=sharing


Solution

  • [ Convert TFLite from saved .h5 model to TFLite model ]

    Conversion using tflite convert there are multiple ways by

    1. TF-Lite Convertor TF-Lite convertor
    2. TF.Lite.TFLiteConverter OR else

    From the provided links currently they try to convert from saved model .h5 to TFLite, to confirm their question.

    [ Sample ]:

    """""""""""""""""""""""""""""""""""""""""""""""""""""""""
    : Model Initialize
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""
    model = tf.keras.models.Sequential([
        tf.keras.layers.InputLayer(input_shape=( 32, 32, 3 )),
        tf.keras.layers.Dense(128, activation='relu'),
    ])
    model.compile(optimizer='sgd', loss='mean_squared_error') # compile the model
    model.summary()
    
    model.save_weights(checkpoint_path)
    
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""
    : FileWriter
    """""""""""""""""""""""""""""""""""""""""""""""""""""""""
    if exists(checkpoint_path) :
        model.load_weights(checkpoint_path)
        print("model load: " + checkpoint_path)
    
    
    tf_lite_model_converter = tf.lite.TFLiteConverter.from_keras_model(
        model
    ) # <tensorflow.lite.python.lite.TFLiteKerasModelConverterV2 object at 0x0000021095194E80>
    tflite_model = tf_lite_model_converter.convert()
    
    # Save the model.
    with open(checkpoint_dir + '\\model.tflite', 'wb') as f:
        f.write(tflite_model)