Search code examples
pythontensorflowimage-preprocessing

Convert image sequences to 4D tensor/ .npy in Python


I have a sequence of 2D images (how it has propagated through each time step) depicting a single simulation. Let's say for example I have 1000 sets of simulations, each containing 10-time frame images. This is not a supervised learning problem as there are no class labels. The model has to learn how to simulation progresses with time. (I have a separate folder for each simulation, each containing 10-time frame images).

Can anyone help me with creating a suitable 4D tensor/ .npy for the same in the form [no_frames_in_each_sample, total_samples, image_height, image_width) (in our example, that would be [10, 1000, 64, 64].

Later I can use this to split it into training and validation.

Any help would be much appreciated! Thank you.


Solution

  • Sample code to convert Images to 4Dimension array

    import tarfile
    my_tar = tarfile.open('images.tar.gz')
    my_tar.extractall() # specify which folder to extract to
    my_tar.close()
    
    import pathlib
    data_dir = pathlib.Path('/content/images/')
    
    import tensorflow as tf
    batch_size = 32
    img_height = 224
    img_width = 224
    
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
      data_dir,
      validation_split=0.2,
      subset="training",
      seed=123,
      image_size=(img_height, img_width),
      batch_size=batch_size)
    
    class_names = train_ds.class_names
    
    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
      data_dir,
      validation_split=0.2,
      subset="validation",
      seed=123,
      image_size=(img_height, img_width),
      batch_size=batch_size)
    
    val_batches = tf.data.experimental.cardinality(val_ds)
    test_dataset = val_ds.take(val_batches // 5)
    

    Output

    Found 8 files belonging to 2 classes.
    Using 7 files for training.
    Found 8 files belonging to 2 classes.
    Using 1 files for validation.
    

    for image_batch, labels_batch in train_ds:
      print(image_batch.shape)
      print(labels_batch.shape)
    

    Output

    (7, 224, 224, 3)
    (7,)