Search code examples
pythontensorflowkerasgoogle-colaboratorysubclassing

TF2.6: ValueError: Model cannot be saved because the input shapes have not been set


I want to create a custom model using transfer learning in Google Colab.

import tensorflow as tf
from tensorflow.keras.layers import Conv2D
from tensorflow.python.keras.applications.xception import Xception

class MyModel(tf.keras.Model):

  def __init__(self, input_shape, num_classes=5, dropout_rate=0.5):
    super(MyModel, self).__init__()
    self.weight_dict = {}
    self.weight_dict['backbone'] = Xception(input_shape=input_shape, weights='imagenet', include_top=False)
    self.weight_dict['outputs'] = Conv2D(num_classes, (1, 1), padding="same", activation="softmax")
    self.build((None,) + input_shape)

  def call(self, inputs, training=False):
    self.weight_dict['backbone'].trainable = False
    x = self.weight_dict['backbone'](inputs)
    x = self.weight_dict['outputs'](x)
    return x

model = MyModel(input_shape=(256, 256, 3))
model.save('./saved')

However, I encounter this error:

ValueError: Model `<__main__.MyModel object at 0x7fc66134bdd0>` cannot be saved because the input shapes have not been set. Usually, input shapes are automatically determined from calling `.fit()` or `.predict()`. To manually set the shapes, call `model.build(input_shape)`.

Yes, there is no call to .fit() or .predict(). But there is a call to .build in the __init__() method of the class. What am I to do?


Solution

  • If the layer has not been built, compute_output_shape will call build on the layer. This assumes that the layer will later be used with inputs that match the input shape provided.

    Working code as shown below

    import tensorflow as tf
    print(tf.__version__)
    from tensorflow.keras.layers import Conv2D
    from tensorflow.keras.applications.xception import Xception
    
    class MyModel(tf.keras.Model):
    
      def __init__(self, input_shape, num_classes=5, dropout_rate=0.5):
        super(MyModel, self).__init__()
        self.weight_dict = {}
        self.weight_dict['backbone'] = Xception(input_shape=input_shape, weights='imagenet', include_top=False)
        self.weight_dict['outputs'] = Conv2D(num_classes, (1, 1), padding="same", activation="softmax")
        self.build((None,) + input_shape)
    
      def call(self, inputs, training=False):
        self.weight_dict['backbone'].trainable = False
        x = self.weight_dict['backbone'](inputs)
        x = self.weight_dict['outputs'](x)
        return x
    
    input_shape=(256, 256, 3)
    model=MyModel(input_shape)
    
    model.compute_output_shape(input_shape=(None, 256, 256, 3))
    model.save('./saved')
    

    Output:

    2.6.0
    
    Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
    83689472/83683744 [==============================] - 1s 0us/step
    INFO:tensorflow:Assets written to: ./saved/assets
    

    For more information you can refer here.