Search code examples
tensorflownlptensorflow-datasetsbert-language-modeltensorflow-data-validation

Dimension does not match when using `keras.Model.fit` in `BERT` of tensorflow


I follow the instruction of Fine-tuning BERT to build a model with my own dataset(It is kind of large, and greater than 20G), then take steps to re-cdoe my data and load them from tf_record files. The training_dataset I create has the same signature as that in the instruction

training_dataset.element_spec

({'input_word_ids': TensorSpec(shape=(32, 1024), dtype=tf.int32, name=None), 
'input_mask': TensorSpec(shape=(32, 1024), dtype=tf.int32, name=None), 
'input_type_ids': TensorSpec(shape=(32, 1024), dtype=tf.int32, name=None)}, 
TensorSpec(shape=(32,), dtype=tf.int32, name=None))

where batch_size is 32, max_seq_length is 1024. As the instruction suggestes,

The resulting tf.data.Datasets return (features, labels) pairs, as expected by keras.Model.fit

It semms that everything works as expected,(the instruction does not show how to use training_dataset though ) However, the following code

bert_classifier.fit(
    x = training_dataset, 
    validation_data=test_dataset, # has the same signature just as training_dataset
    batch_size=32,
    epochs=epochs,
    verbose=1,
)

encouters an error that seems weird to me,

Traceback (most recent call last):
  File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/captain/project/dataload/train.py", line 81, in <module>
    verbose=1,
  File "/home/captain/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1100, in fit
    tmp_logs = self.train_function(iterator)
  File "/home/captain/.local/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/home/captain/.local/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 871, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/home/captain/.local/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 726, in _initialize
    *args, **kwds))
  File "/home/captain/.local/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2969, in _get_concrete_function_internal_garbage_collected
    graph_function, _ = self._maybe_define_function(args, kwargs)
  File "/home/captain/.local/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3361, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/home/captain/.local/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3206, in _create_graph_function
    capture_by_value=self._capture_by_value),
  File "/home/captain/.local/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 990, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/home/captain/.local/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 634, in wrapped_fn
    out = weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/home/captain/.local/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py", line 977, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    /home/captain/.local/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:805 train_function  *
        return step_function(self, iterator)
    /home/captain/.local/lib/python3.7/site-packages/official/nlp/keras_nlp/layers/position_embedding.py:88 call  *
        return tf.broadcast_to(position_embeddings, input_shape)
    /home/captain/.local/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py:845 broadcast_to  **
        "BroadcastTo", input=input, shape=shape, name=name)
    /home/captain/.local/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750 _apply_op_helper
        attrs=attr_protos, op_def=op_def)
    /home/captain/.local/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:592 _create_op_internal
        compute_device)
    /home/captain/.local/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:3536 _create_op_internal
        op_def=op_def)
    /home/captain/.local/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:2016 __init__
        control_input_ops, op_def)
    /home/captain/.local/lib/python3.7/site-packages/tensorflow/python/framework/ops.py:1856 _create_c_op
        raise ValueError(str(e))

    ValueError: Dimensions must be equal, but are 512 and 1024 for '{{node bert_classifier/bert_encoder_1/position_embedding/BroadcastTo}} = 
BroadcastTo[T=DT_FLOAT, Tidx=DT_INT32](bert_classifier/bert_encoder_1/position_embedding/strided_slice_1, bert_classifier/bert_encoder_1/position_embedding/Shape)' 
with input shapes: [512,768], [3] and with input tensors computed as partial shapes: input[1] = [32,1024,768].

There is nothing to do with 512, and I didn't use 512 thorough my code. So where is wrong with my code and how to fix that?


Solution

  • They created the bert_classifier based on bert_config_file loaded from bert_config.json

    bert_classifier, bert_encoder = bert.bert_models.classifier_model(bert_config, num_labels=2)
    

    bert_config.json

    {
    'attention_probs_dropout_prob': 0.1,
     'hidden_act': 'gelu',
     'hidden_dropout_prob': 0.1,
     'hidden_size': 768,
     'initializer_range': 0.02,
     'intermediate_size': 3072,
     'max_position_embeddings': 512,
     'num_attention_heads': 12,
     'num_hidden_layers': 12,
     'type_vocab_size': 2,
     'vocab_size': 30522
    }
    

    According to this config, hidden_size is 768 and max_position_embeddings is 512 so your input data used to feed to BERT model must be the same shape as described. It explains the reason why you are getting the shape-mismatched issue.

    Therefore, to make it work, you have to change all lines for creating tensor inputs from 1024 to 512.