Search code examples
pythontensorflowtf.data.dataset

How to access Tensor shape within .map function?


I have a dataset of audios in multiple lengths, and I want to crop all of them in 5 second windows (which means 240000 elements with 48000 sample rate). So, after loading the .tfrecord, I'm doing:

audio, sr = tf.audio.decode_wav(image_data)

which returns me a Tensor that has the audio length. If this length is less than the 240000 I would like to repeat the audio content til it's 240000. So I'm doing on ALL audios, with a tf.data.Dataset.map() function:

audio = tf.tile(audio, [5])

Since that's what it takes to pad my shortest audio to the desired length.

But for efficiency I wanted to do the operation only on elements that need it:

if audio.shape[0] < 240000:
  pad_num = tf.math.ceil(240000 / audio.shape[0]) #i.e. if the audio is 120000 long, the audio will repeat 2 times
  audio = tf.tile(audio, [pad_num])

But I can't access the shape property since it's dynamic and varies across the audios. I've tried using tf.shape(audio), audio.shape, audio.get_shape(), but I get values like None for the shape, that doesn't allow me to do the comparison.

Is it possible to do this?


Solution

  • You can use a function like this:

    import tensorflow as tf
    
    def enforce_length(audio):
        # Target shape
        AUDIO_LEN = 240_000
        # Current shape
        current_len = tf.shape(audio)[0]
        # Compute number of necessary repetitions
        num_reps = AUDIO_LEN // current_len
        num_reps += tf.dtypes.cast((AUDIO_LEN % current_len) > 0, num_reps.dtype)
        # Do repetitions
        audio_rep = tf.tile(audio, [num_reps])
        # Trim to required size
        return audio_rep[:AUDIO_LEN]
    
    # Test
    examples = tf.data.Dataset.from_generator(lambda: iter([
        tf.zeros([100_000], tf.float32),
        tf.zeros([300_000], tf.float32),
        tf.zeros([123_456], tf.float32),
    ]), output_types=tf.float32, output_shapes=[None])
    result = examples.map(enforce_length)
    for item in result:
        print(item.shape)
    

    Output:

    (240000,)
    (240000,)
    (240000,)