I am attempting to convert a TF model to TFLite. The model was saved in .pb
format and I have converted it with the following code:
import os
import tensorflow as tf
from tensorflow.core.protobuf import meta_graph_pb2
export_dir = os.path.join('export_dir', '0')
if not os.path.exists('export_dir'):
os.mkdir('export_dir')
tf.compat.v1.enable_control_flow_v2()
tf.compat.v1.enable_v2_tensorshape()
# I took this function from a tutorial on the TF website
def wrap_frozen_graph(graph_def, inputs, outputs):
def _imports_graph_def():
tf.compat.v1.import_graph_def(graph_def, name="")
wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
import_graph = wrapped_import.graph
return wrapped_import.prune(
inputs, outputs)
graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(open(os.path.join(export_dir, 'saved_model.pb'),'rb').read())
concrete_func = wrap_frozen_graph(
graph_def, inputs=['extern_data/placeholders/data/data:0', 'extern_data/placeholders/data/data_dim0_size:0'],
outputs=['output/output_batch_major:0'])
concrete_func.inputs[0].set_shape([None, 50])
concrete_func.inputs[1].set_shape([None])
concrete_func.outputs[0].set_shape([None, 100])
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
converter.experimental_new_converter = True
converter.post_training_quantize=True
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS]
converter.allow_custom_ops=True
tflite_model = converter.convert()
# Save the model.
if not os.path.exists('tflite'):
os.mkdir('tflite')
output_model = os.path.join('tflite', 'model.tflite')
with open(output_model, 'wb') as f:
f.write(tflite_model)
However, when I try to use the intepretere with this model I get the following error:
INFO: TfLiteFlexDelegate delegate: 8 nodes delegated out of 970 nodes with 3 partitions.
INFO: TfLiteFlexDelegate delegate: 0 nodes delegated out of 4 nodes with 0 partitions.
INFO: TfLiteFlexDelegate delegate: 3 nodes delegated out of 946 nodes with 1 partitions.
INFO: TfLiteFlexDelegate delegate: 0 nodes delegated out of 1 nodes with 0 partitions.
INFO: TfLiteFlexDelegate delegate: 3 nodes delegated out of 16 nodes with 2 partitions.
Traceback (most recent call last):
File "/path/to/tflite_interpreter.py", line 9, in <module>
interpreter.allocate_tensors()
File "/path/to/lib/python3.6/site-packages/tensorflow/lite/python/interpreter.py", line 243, in allocate_tensors
return self._interpreter.AllocateTensors()
RuntimeError: Encountered unresolved custom op: VarHandleOp.Node number 0 (VarHandleOp) failed to prepare.
Now, I don't find any VarHandleOp
in the code and I found out that it is actually in tensorflow (https://www.tensorflow.org/api_docs/python/tf/raw_ops/VarHandleOp).
So, why isn't TFLite able to recognize it?
It's certainly hard to provide a minimal reproducible example in the case of model conversion, as the SO guidelines recommend, but the questions would benefit from better pointers. For example, instead of saying “I took this function from a tutorial on the TF website”, it is a much better idea to provide a link to the tutorial. The TF website is vastly huge.
The tutorial that you are referring to is probably from the section on migrating from TF1 to TF2, specifically the part of handling the raw graph files. The crucially important note is
if you have a "Frozen graph" (a
tf.Graph
where the variables have been turned into constants)
(the bold highlight is mine). Apparently, your graph contains VarHandleOp
(the same applies to the Variable
and VariableV2
nodes), and is not “frozen” by this definition. Your general approach makes sense, but you need a graph that contains actual trained values for the variables in the form of the Const
node. You need variables at the training time, but for inference time, and should be baked into the graph. TFLite, as an inference-time framework, does not support variables.
The rest of your idea seems fine. TFLiteConverter.from_concrete_functions
currently takes exactly one concrete_function
, but this is what you get from wrapping the graph. With enough luck it may work.
There is a utility tensorflow/python/tools/freeze_graph.py
that attempts its best to replace variables in a Graph.pb with constants taken from the latest checkpoint file. If you look at its code, either using the saved metagraph (checkpoint_name.meta) file or pointing the tool to the training directory eliminates a lot of guesswork; also, I think that providing the model directory is the only way to get a single frozen graph a sharded model.
I noticed that you use just input
in place of tf.nest.map_structure(import_graph.as_graph_element, inputs)
in the example. You may have other reasons for that, but if you do it because as_graph_element
complains about datatype/shape, this is likely to be resolved by freezing the graph properly. The concrete_function that you obtain from the frozen graph will have a good idea about its input shapes and datatypes. Generally, it's unexpected to need to manually set them, and the fact that you do seems odd to me (but I do not claim a broad experience with this dark corner of TF).
map_structure
has a keyword argument to skip the check.