Search code examples
tensorflowtf.kerastpu

Tensor Slicing on TPU


I would like to run a model on the TPU (Google Cloud TPU). I've tried to reduce to a minimum. I left out the model code since it's not relevant, my issue happens earlier.

Here is the main python file:

import tensorflow as tf
import os
from Model import Model
from DataGeneratorTPU import load_dataset

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=os.environ['TPU_NAME'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)

strategy = tf.distribute.TPUStrategy(resolver)

with strategy.scope():
    model = Model(32768,7)
    model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])
    dg = load_dataset('gs://bucket/data.tf','gs://bucket/annotations.tf',32768).batch(32,drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
    model.fit(dg,epochs=10,verbose=1)
    model.save('test')

And this is DataGeneratorTPU.py:

import tensorflow as tf

def slice(i,data,annotations,lead_length):
    X = data[i:i+lead_length,:]
    y = annotations[i+lead_length,0,:]
    print(X.shape,y.shape) #OUTPUT2
    return X,y

def load_dataset(filename_data,filename_annotations,lead_length,step_size=1):
    data = tf.io.parse_tensor(tf.io.read_file(filename_data), tf.float32)
    annotations = tf.io.parse_tensor(tf.io.read_file(filename_annotations), tf.int32)
    print(data.shape,annotations.shape) #OUTPUT1
    rangeds = tf.data.Dataset.range(0,data.shape[0]-lead_length,step_size)
    def slice_(i):
        return slice(i,data,annotations,tf.constant(lead_length,dtype=tf.int64))

    return rangeds.map(slice_, tf.data.experimental.AUTOTUNE)

As you may notice I marked the two print statements with OUTPUT1 and OUTPUT2, so I can tell you what the outputs are:

OUTPUT1 is (432001, 7) (432001, 7, 3)

OUTPUT2 is (None, 7) (3, )

However, I believe OUTPUT2 should be (32768, 7) (3, ).

And indeed, the model then complains that (just example from one layer, there are more than those, this is from a conv1d layer):

  (0) Invalid argument: {{function_node __inference_train_function_33579}} Compilation failure: Dynamic Spatial Convolution is not supported: lhs shape is f32[4,1,<=32774,7]     
     [[{{node Model/conv1d/conv1d}}]]
        TPU compilation failed
         [[tpu_compile_succeeded_assert/_12623170171032432447/_5]] 
         [[tpu_compile_succeeded_assert/_12623170171032432447/_5/_303]]

Which complains that the dimension we're talking about (which I printed in the mapped function) is dynamic, and not fixed at 32768. It should be static, however, since I'm using a constant width for the slice of 32768 and I even made sure the range doesn't look at the last 32768 elements where things could go wrong. It seems just to be able to estimate this to be smaller than 32774, and I have no idea where the 6 extra elements come from...

What am I doing wrong? How can I get this static?


Solution

  • There seems to be a case where using tf.strided_slice (which is the function called by the __getitem__ method) loses the shape information of the tensor passed to it. I guess that this is because slices are quite flexible, and passing an "impossible" slice size (for example, an end index bigger than the size of the array) is permitted. Doing so would result in variable shaped elements in your final dataset. The function can not ensure anything about the final shape of the array, so it defaults to None.

    Your case is simple enough to be replaced by a call to tf.slice, which preserves the shape information, by asking for the size of the slice.

    Replacing your slice function with the following:

    def slice(i, data, annotations, lead_length):
        X = tf.slice(data, [i,0], [lead_length, tf.shape(data)[1]])
        # I also used slice for y for the sake of it, but its probably more readable to use
        # y = annotations[i+lead_length,0,:]
        y = tf.squeeze(tf.slice(annotations, [i+lead_length,0,0], [1, 1, tf.shape(annotations)[2]]))
        return X, y
    

    And looking at the dataset shapes gives:

    >>> ds = rangeds.map(slice_, tf.data.experimental.AUTOTUNE)
    >>> ds
    <ParallelMapDataset shapes: ((32768, 7), (3,)), types: (tf.float32, tf.float32)>
    

    One other possibility would be to call set_shape on your tensors if you know you can guarantee that the shape is correct (i.e, that i+lead_length never gets bigger than the size of your first dimension). If you can't, it will lead to hard to debug runtime errors.

    def slice(i,data,annotations,lead_length):
        X = data[i:i+lead_length,:]
        y = annotations[i+lead_length,0,:]
        X.set_shape((lead_length,7))
        return X,y
    

    I think that in your case, letting tf.slice do the job is cleaner.