I want to create a neural network that takes a categorical tuple as input and passes its one-hot-encoded value to its layers.
For example, assuming that the tuple value limits were (2, 2, 3)
, I need a preprocessing layer that transforms the following three-dimensional list of values:
(1, 0, 0),
(0, 0, 1),
(1, 1, 2),
Into the following one-dimensional tensor:
0.0, 1.0, 0.0, 0.0, 0.0, 0.0,
1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
Does such a function exist?
I assume that this custom layer operates on a batch having varied number of tuples per sample. For example, an input batch may be
[[(1, 0, 0), (0, 0, 1), (1, 1, 2)],
[(1, 0, 0), (1, 1, 2)]]
and the desired output tensors would be
[[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]]
Since the samples can be of uneven sizes, the batch needs to be converted to tf.RaggedTensor
(instead of normal Tensor
) before being fed to the layer. However, the following solution works with both tf.Tensor
and tf.RaggedTensor
as input.
class FillOneLayer(tf.keras.layers.Layer):
def __init__(self, shape, *args, **kwargs):
super().__init__(*args, **kwargs)
self.shape = shape
def call(self, inputs, training=None):
num_samples = inputs.nrows() if isinstance(inputs, tf.RaggedTensor) else tf.shape(inputs)[0]
num_samples = tf.cast(num_samples, tf.int32)
ret = tf.TensorArray(tf.float32, size=num_samples, dynamic_size=False)
for i in range(num_samples):
sample = inputs[i]
sample = sample.to_tensor() if isinstance(sample, tf.RaggedTensor) else sample
updates_shape = tf.shape(sample)[:-1]
tmp = tf.zeros(self.shape)
tmp = tf.tensor_scatter_nd_update(tmp, sample, tf.ones(updates_shape))
ret = ret.write(i, tf.reshape(tmp, (-1,)))
return ret.stack()
Output for normal input tensor
>>> a = tf.constant([[(1, 0, 0), (0, 0, 1), (1, 1, 2)],
[(1, 0, 0), (0, 0, 1), (1, 0, 2)]])
>>> FillOneLayer((2,2,3))(a)
<tf.Tensor: shape=(2, 12), dtype=float32, numpy=
array([[0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.],
[0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0.]], dtype=float32)>
Output for ragged tensor
>>> a = tf.ragged.constant([[(1, 0, 0), (0, 0, 1), (1, 1, 2)],
[(1, 0, 0), (0, 0, 1)]])
>>> FillOneLayer((2,2,3))(a)
<tf.Tensor: shape=(2, 12), dtype=float32, numpy=
array([[0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.],
[0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]], dtype=float32)>
The solution also works when you decorate call()
with tf.function
, which is usually what happens when you call fit
on a model whom this layer is a member of. In that case, to avoid graph retracing, you should ensure that all batches are of the same type, i.e., either all RaggedTensor
or all Tensor