Search code examples
tensorflowmachine-learningtensorflow-lite

Strange output of Conv2D in tflite graph


I have a tflite graph fragment of which depicted on attached pictureenter image description here

I needed to debug it's behavior and already on the first step I got quite puzzling results. When I feed zeros tensor as input after first Conv2D I expect to get a tensor which consists only of values from bias of Conv2D (since all kernel elements get multiplied by zeros), but instead I've got a tensor which consists of some random data, here is the code snippet:

def test_graph(path=PATH_DEFAULT):
    interp = tf.lite.Interpreter(path)
    interp.allocate_tensors()

    input_details = interp.get_input_details()
    in_idx = input_details[0]['index']

    zeros = np.zeros(shape=(1, 256, 256, 3), dtype=np.float32)
    interp.set_tensor(in_idx, zeros)
    interp.invoke()

    # index of output of first conv2d operator is 3 (see netron pic)
    after_conv_2d = interp.get_tensor(3)

    # shape of bias is just [count of output channels]
    n, h, w, c = after_conv_2d.shape

    # if we feed zeros as input, we can expect that the only values we get are the values of bias
    # since all kernel elems in that case are multiplied by zeros

    uniq_vals_cnt = len(np.unique(after_conv_2d))
    assert uniq_vals_cnt <= c, f"There are {uniq_vals_cnt} in output, should be <= than {c}"

output:

AssertionError: There are 287928 in output, should be <= than 24

Can someone help me with my misunderstanding?


Solution

  • Seems my assumption that I can get any intermediate tensor from interpreter is wrong, we can do it only for outputs, even though interpreter do not raise error and even gives tensors of the right shape for indices related to non-output tesnors.

    One way to debug such graph would be to make all tensors outputs, but it seems easiest way to do it would be converting tflite file to pb with toco and then convert pb back to tflite with new outputs specified. This way is not ideal though because toco support for tflite -> pb conversion was removed after 1.9 and using versions before that can break (in my case it breaks) on some graphs.

    More of it is here: tflite: get_tensor on non-output tensors gives random values