Search code examples
tensorflowkerastensorflow2.0tensorflow-lite

Node number X (RESHAPE) failed to prepare. Tensor resize with tflite v2.2


Here is a simple code to reproduce the error:

import os
os.environ["CUDA_VISIBLE_DEVICES"]="-1"

import numpy as np
from keras.models import Sequential
from keras.layers import Conv1D, Flatten, Dense
import tensorflow as tf

model_path = 'test.h5'

model = Sequential()
model.add(Conv1D(8,(5,), input_shape=(100,1)))
model.add(Flatten())
model.add(Dense(1))
model.compile(loss='mse', optimizer='adam')
model.save(model_path)

model = tf.keras.models.load_model(model_path, compile=False)
converter = tf.lite.TFLiteConverter.from_keras_model(model)

tflite_model = converter.convert()

interpreter = tf.lite.Interpreter(model_content=tflite_model)

interpreter.resize_tensor_input(interpreter.get_input_details()[0]['index'], (2,100,1))
interpreter.resize_tensor_input(interpreter.get_output_details()[0]['index'], (2,1))

interpreter.allocate_tensors()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-3-ad8e2eea467f> in <module>
     27 interpreter.resize_tensor_input(interpreter.get_output_details()[0]['index'], (2,1))
     28 
---> 29 interpreter.allocate_tensors()

<>/tensorflow/lite/python/interpreter.py in allocate_tensors(self)
    240   def allocate_tensors(self):
    241     self._ensure_safe()
--> 242     return self._interpreter.AllocateTensors()
    243 
    244   def _safe_to_run(self):

<>/tensorflow/lite/python/interpreter_wrapper/tensorflow_wrap_interpreter_wrapper.py in AllocateTensors(self)
    108 
    109     def AllocateTensors(self):
--> 110         return _tensorflow_wrap_interpreter_wrapper.InterpreterWrapper_AllocateTensors(self)
    111 
    112     def Invoke(self):

RuntimeError: tensorflow/lite/kernels/reshape.cc:66 num_input_elements != num_output_elements (1536 != 768)Node number 3 (RESHAPE) failed to prepare.

It seems like the issue comes from the reshape function in the flatten layer. I've been able to execute this kind of resize with tensorflow 1.5 but not with the 2.2 version.

Here are the infos of the reshape layer:

 {'name': 'sequential_1/flatten_1/Reshape',
  'index': 8,
  'shape': array([  1, 768], dtype=int32),
  'shape_signature': array([  1, 768], dtype=int32),
  'dtype': numpy.float32,
  'quantization': (0.0, 0),
  'quantization_parameters': {'scales': array([], dtype=float32),
   'zero_points': array([], dtype=int32),
   'quantized_dimension': 0},
  'sparsity_parameters': {}},

I thought that maybe I should resize this layer too so I added:

interpreter.resize_tensor_input(8, (2,768))

but I got the exact same error.

RuntimeError: tensorflow/lite/kernels/reshape.cc:66 num_input_elements != num_output_elements (1536 != 768)Node number 3 (RESHAPE) failed to prepare.


Solution

  • I've come up with a workaround which reshape the model before converting to tflite by reshaping the keras model and then converting it to a concrete function and use from_concrete_function instead of from_keras_model.

    import os
    os.environ["CUDA_VISIBLE_DEVICES"]="-1"
    
    import numpy as np
    from keras.models import Sequential
    from keras.layers import Conv1D, Flatten, Dense
    import tensorflow as tf
    
    model_path = 'test.h5'
    
    model = Sequential()
    model.add(Conv1D(8,(5,), input_shape=(100,1)))
    model.add(Flatten())
    model.add(Dense(1))
    model.compile(loss='mse', optimizer='adam')
    model.save(model_path)
    
    model = tf.keras.models.load_model(model_path, compile=False)
    
    batch_size = 2
    input_shape = model.inputs[0].shape.as_list()
    input_shape[0] = batch_size
    func = tf.function(model).get_concrete_function(
        tf.TensorSpec(input_shape, model.inputs[0].dtype))
    converter = tf.lite.TFLiteConverter.from_concrete_functions([func])
    
    tflite_model = converter.convert()
    
    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    
    interpreter.allocate_tensors()