Search code examples
pythontensorflowtensorflow2.0tensorflow-serving

Converting an input to UTF-8 with @tf.function


I'm using TensorFlow 2.4.1. I tried to use tf.strings.unicode_decode for decoding a base64 encoded string with @tf.function, but the error occurred, which ValueError: Rank of input must be statically known. I checked that tf.strings.unicode_decode works fine without @tf.function. Is there a way to decode a base64 encoded string with @tf.function? I would appreciate your answer.

I loaded a SavedModel and wanted to change serving_default. But I got stuck in converting an input to UTF-8. This is the code I have tried.

class CustomTransformer(tf.keras.Model):
    def __init__(self):
        super(CustomTransformer, self).__init__()
        self.model = tf.saved_model.load('./models/transformer/1')
  
    @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.string)])
    def call(self, input):

        # Error occurred. ValueError: Rank of `input` must be statically known.
        _input_str = tf.strings.unicode_decode(input_data, 'UTF-8')

        return _input_str

Here is the error message.

ValueError: Rank of `input` must be statically known.

Is there an approach to convert an input to UTF-8 when trying to change serving_default from a loaded SavedModel?


Solution

  • When working with tf.strings.unicode_decode, you need to specify a shape. (see the documentation). In that case, because you're working with a Tensor without any dimension (a simple string), just provide an empty tuple as a shape:

    @tf.function(input_signature=[tf.TensorSpec(shape=(), dtype=tf.string)])