Search code examples
pythontensorflowkerasnlp

How to save Keras TextVectorization layer configuration with custom standardization function into a pickle file and reload it?


I have a Keras TextVectorization layer which uses a custom standardization function.

def custom_standardization(input_string, preserve=['[', ']'], add=['¿']):

    strip_chars = string.punctuation
    for item in add:
        strip_chars += item
    
    for item in preserve:
        strip_chars = strip_chars.replace(item, '')

    lowercase = tf.strings.lower(input_string)
    output = tf.strings.regex_replace(lowercase, f'[{re.escape(strip_chars)}]', '')

    return output
target_vectorization = keras.layers.TextVectorization(max_tokens=vocab_size,
                                                output_mode='int',
                                                output_sequence_length=sequence_length + 1,
                                                standardize=custom_standardization)
target_vectorization.adapt(train_spanish_texts)

I want to save the adapted configuration for an inference model to make use of.

One way, as described here, is to save the weights and config separately as a pickle file and reload them.

However, target_vectorization.get_config() returns

{'name': 'text_vectorization_5',
 'trainable': True,
 ...
 'standardize': <function __main__.custom_standardization(input_string, preserve=['[', ']'], add=['¿'])>,
 ...
 'vocabulary_size': 15000}

which is being saved into the pickle file.

Trying to load this config using keras.layers.TextVectorization.from_config(pickle.load(open('ckpts/spanish_vectorization.pkl', 'rb'))['config']) results in TypeError: Could not parse config: <function custom_standardization at 0x2a1973a60>, because the file does not have any information about this custom standardization function.

What is a good way to save the TextVectorization weights and configuration for an inference model to make use of, in this scenario?


Solution

  • The solution here was to define a wrapper around the TextVectorization object and use the custom standardizer as a method. Moreover, we needed to exclude callable objects while saving configuration to the pickle file. Here's the fixed code:

    @keras.utils.register_keras_serializable(package='custom_layers', name='TextVectorizer')
    class TextVectorizer(layers.Layer):
        '''English - Spanish Text Vectorizer'''
    
        def __init__(self, max_tokens=None, output_mode='int', output_sequence_length=None, standardize='lower_and_strip_punctuation', vocabulary=None, config=None):
            super().__init__()
            if config:
                self.vectorization = layers.TextVectorization.from_config(config)
    
            else:
                self.max_tokens = max_tokens
                self.output_mode = output_mode
                self.output_sequence_length = output_sequence_length
                self.vocabulary = vocabulary
                if standardize != 'lower_and_strip_punctuation':
                    self.vectorization = layers.TextVectorization(max_tokens=self.max_tokens,
                                                                  output_mode=self.output_mode,
                                                                  output_sequence_length=self.output_sequence_length,
                                                                  vocabulary=self.vocabulary,
                                                                  standardize=self.standardize)
                else:
                    self.vectorization = layers.TextVectorization(max_tokens=self.max_tokens,
                                                                  output_mode=self.output_mode,
                                                                  output_sequence_length=self.output_sequence_length,
                                                                  vocabulary=self.vocabulary)
    
    
        def standardize(self, input_string, preserve=['[', ']'], add=['¿']) -> str:
            strip_chars = string.punctuation
            for item in add:
                strip_chars += item
        
            for item in preserve:
                strip_chars = strip_chars.replace(item, '')
    
            lowercase = tf.strings.lower(input_string)
            output = tf.strings.regex_replace(lowercase, f'[{re.escape(strip_chars)}]', '')
    
            return output
    
        def __call__(self, *args, **kwargs):
            return self.vectorization.__call__(*args, **kwargs)
    
        def get_config(self):
            return {key: value if not callable(value) else None for key, value in self.vectorization.get_config().items()}
    
        def from_config(config):
            return TextVectorizer(config=config)
    
        def set_weights(self, weights):
            self.vectorization.set_weights(weights)
    
        def adapt(self, dataset):
            self.vectorization.adapt(dataset)
    
        def get_vocabulary(self):
            return self.vectorization.get_vocabulary()
    

    To adapt and save weights [Training Phase]:

    vocab_size = 15000
    sequence_length = 20
    
    source_vectorization = TextVectorizer(max_tokens=vocab_size,
                                          output_mode='int',
                                          output_sequence_length=sequence_length)
    
    target_vectorization = TextVectorizer(max_tokens=vocab_size,
                                          output_mode='int',
                                          output_sequence_length=sequence_length + 1,
                                          standardize='spanish')
    
    train_english_texts = [pair[0] for pair in train_pairs]
    train_spanish_texts = [pair[1] for pair in train_pairs]
    source_vectorization.adapt(train_english_texts)
    target_vectorization.adapt(train_spanish_texts)
    
    pickle.dump({'config': source_vectorization.get_config(),
                 'weights': source_vectorization.get_weights()}, open('ckpts/english_vectorization.pkl', 'wb'))
    
    pickle.dump({'config': target_vectorization.get_config(),
                 'weights': target_vectorization.get_weights()}, open('ckpts/spanish_vectorization.pkl', 'wb'))
    

    To load and use them [Inference Phase]:

    vectorization_data = pickle.load(open('ckpts/english_vectorization.pkl', 'rb'))
    source_vectorization = TextVectorizer.from_config(vectorization_data['config'])
    source_vectorization.set_weights(vectorization_data['weights'])
    
    vectorization_data = pickle.load(open('ckpts/spanish_vectorization.pkl', 'rb'))
    target_vectorization = TextVectorizer.from_config(vectorization_data['config'])
    target_vectorization.set_weights(vectorization_data['weights'])