Search code examples
tensorflowkerasdeep-learningnlpsequence

Attention Layer changing Batch Size at inference


I have trained seq-to-seq model using Encoder-Decoder architecture. I'm trying to produce an output sequence given an input context, and I am trying to do that on a batch of input context vectors. I have a Self Attention layer before the final output in the Decoder and it seems to be changing batch shape or not getting the shape correctly, and throwing an error. It works fine if I just infer on individual sample one or a batch size of 1 but practically it will take a long in production and infer on thousands of input context vectors. So, I need your help in debugging the error and to implement a better way of producing the output sequence that is computationally optimized.

Below is my implementation:

### Define Inference Encoder
def define_inference_encoder(input_shape):
  encoder_input = Input(shape=input_shape, name='en_input_layer')

  ### First Bidirectional GRU Layer
  bidi_gru1 = Bidirectional(GRU(160, return_sequences=True), name='en_bidirect_gru1')
  gru1_out = bidi_gru1(encoder_input)

  gru1_out = Dropout(0.46598303573163413, name='bidirect_gru1_dropout')(gru1_out)

  ### Second GRU Layer
  # hp_units_2 = hp.Int('enc_lstm2', min_value=32, max_value=800, step=32)
  gru2 = GRU(hsize, return_sequences=True, return_state=True, name='en_gru2_layer')
  gru2_out, gru2_states = gru2(gru1_out)

  encoder_model = Model(inputs=encoder_input, outputs=[gru2_out, gru2_states])
  return encoder_model

### Define Inference Decoder
def define_inf_decoder(context_vec, input_shape):
  decoder_input = Input(shape=input_shape)
  decoder_state_input = Input(shape=(hsize,))

  de_gru1 = GRU(hsize, return_sequences=True, return_state=True, name='de_gru1_layer')
  de_gru1_out, de_state_out = de_gru1(decoder_input, initial_state=decoder_state_input)

  attn_layer = Attention(use_scale=True, name='attn_layer')
  attn_out = attn_layer([de_gru1_out, context_vec])

  attn_added = Concatenate(name='attn_source_concat_layer')([de_gru1_out, attn_out])

  attn_dense_layer = Dense(736, name='tanh_dense_layer', activation='tanh')

  h_hat = attn_dense_layer(attn_added)

  ### Output Layer
  preds = Dense(1, name='output_layer')(h_hat)

  decoder_model = Model(inputs=[decoder_input, decoder_state_input], outputs=[preds, de_state_out])
  return decoder_model

def set_weights(untrained_model, trained_model):
  trained_layers = [l.name for l in trained_model.layers]
  print(f"No. of trained layers: {len(trained_layers)}")

  for l in untrained_model.layers:
    if l.name in trained_layers:
      trained_wts = trained_model.get_layer(l.name).get_weights()
      if len(trained_wts)>0:
        untrained_model.get_layer(l.name).set_weights(trained_wts)
        print(f"Layer {l.name} weight set")
        
  return untrained_model

Generate output sequence:

inference_encoder = define_inference_encoder((12, 1))
inference_encoder = set_weights(inference_encoder, tuned_model)

for (ex_context, ex_target_in), ex_target_out in test_ds.take(1):
  print(ex_context.shape, ex_target_in.shape) ### (64, 12, 1) (64, 3, 1)

test_context, test_states = inference_encoder.predict(tf.reshape(ex_context, shape=(-1,seq_len, 1)))
print(test_context.shape, test_states.shape) ### (64, 12, 256) (64, 256)

inf_decoder = define_inf_decoder(test_context, (1,1))
inf_decoder = set_weights(inf_decoder, tuned_model)

dec_inp = tf.reshape(ex_context[:,-1], shape=(-1,1,1))
dec_inp.shape ### (64,1,1)

test_inf_decoder_out = inf_decoder.predict([dec_inp, test_states])

Error:

ValueError: Exception encountered when calling layer 'attn_layer' (type Attention).

Dimensions must be equal, but are 32 and 64 for '{{node model_7/attn_layer/MatMul}} = BatchMatMulV2[T=DT_FLOAT, adj_x=false,

adj_y=true](model_7/de_gru1_layer/PartitionedCall:1, model_7/15181)' with input shapes: [32,1,256], [64,12,256].

Call arguments received by layer 'attn_layer' (type Attention):
  • inputs=['tf.Tensor(shape=(32, 1, 256), dtype=float32)', 'tf.Tensor(shape=(64, 12, 256), >dtype=float32)']
  • mask=None
  • training=False
  • return_attention_scores=False
  • use_causal_mask=False

What I don't understand is how attn_layer is getting a batch size of 32 when I'm passing the inputs with the batch size of 64. It works fine when I work with the batch size of 1. What am I doing wrong?


Solution

  • I solved it by replacing .predict with predict_on_batch. I also found out that predict method treat each sample individually so when passing more than one sample it creates problem for the attention layer as it gets more than one sample's context vector from the encoder to calculate the attention weights for a single sample and would be defaulting to the batch_size of 32 which might be causing the error when the context vector's batch_size is 64. So, basically changing it to predict_on_batch for more than one sample works like a charm.