I am trying to create a model with 'sequence.gather' operator, but getting an error "Where operation can only operate on scalar input" when calling 'train_minibatch'.
input_seq_axis = Axis('inputAxis')
input_sequence = sequence.input_variable(shape=vocab_dim, sequence_axis=input_seq_axis)
vowel_mask_sequence = sequence.input_variable(shape=2, sequence_axis=input_seq_axis)
a = Sequential([
C.layers.Recurrence(C.layers.LSTM(hidden_dim)),
])
b=C.sequence.gather(a(input_sequence),vowel_mask_sequence)
z=Dense(3)(b)
label_sequence = sequence.input_variable(3, sequence_axis=z.dynamic_axes[1])
How can I fix the error ? I even dont use 'where' operator.
For sequence.gather(x, y), y
has to be a scalar, that is to say:
assert y.shape == (1,)
The values of y must be either 0 or 1, and also with the same exact dynamic axis as x.
An example on how to use sequence.gather
from a library i maintain.