Here I was pointed to use tf.TensorArray
instead of tf.Variable
or tf.queue.FIFOQueue
for making FIFO contained in custom layer. Is it an effective way? Exist any alternative here?
If it's the most effective method how can I replace self.queue.assign(tf.concat([self.queue[timesteps:, :], inputs], axis=0))
with methods of tf.TensorArray
?
class FIFOLayer(Layer):
def __init__(self, window_size, **kwargs):
super(FIFOLayer, self).__init__(**kwargs)
self.window_size = window_size
self.count = 0
def build(self, input_shape):
super(FIFOLayer, self).build(input_shape)
self.queue = self.add_weight(
name="queue",
shape=(self.window_size, input_shape[-1]),
initializer=tf.initializers.Constant(value=np.nan),
trainable=False,
)
def call(self, inputs, training):
timesteps = tf.shape(inputs)[0]
# check if batch_size is more than queue capacity
if timesteps > self.window_size:
raise ValueError()
# 1. append new state to queue
self.queue.assign(tf.concat([self.queue[timesteps:, :], inputs], axis=0))
self.count += timesteps
# 2. feed-forward
if self.count < self.window_size:
# generate mask
attention_mask = tf.cast(
tf.math.reduce_all(
tf.math.logical_not(tf.math.is_nan(self.queue)), axis=-1
),
dtype=tf.float32,
)
attention_mask = tf.matmul(
attention_mask[..., tf.newaxis],
attention_mask[..., tf.newaxis],
transpose_b=True,
)
return self.queue[tf.newaxis, ...], attention_mask
# !!! check overflow
elif self.count > self.window_size:
self.count = self.window_size
return self.queue[tf.newaxis, ...], None
@property
def is_full(self):
return self.count == self.window_size
def clear(self):
self.count = 0
self.queue.assign(tf.fill(self.queue.shape, np.nan))
l = FIFOLayer(window_size=10)
for i in range(6):
x = tf.random.normal((2, 12))
y = l(x)
print(y)
print(l.is_full, "\n\n")
l.clear()
print(l(x))
print(l.is_full, "\n\n")
Using tf.TensorArray
, you can try something like this:
import tensorflow as tf
import numpy as np
tf.random.set_seed(111)
class FIFOLayer(tf.keras.layers.Layer):
def __init__(self, window_size, **kwargs):
super(FIFOLayer, self).__init__(**kwargs)
self.window_size = window_size
self.count = 0
def build(self, input_shape):
super(FIFOLayer, self).build(input_shape)
self.queue_array = tf.TensorArray(dtype=tf.float32, size=self.window_size)
self.queue_array = self.queue_array.scatter(tf.range(self.window_size), tf.constant(np.nan)*tf.ones((self.window_size, input_shape[-1])))
def call(self, inputs, training):
timesteps = tf.shape(inputs)[0]
# check if batch_size is more than queue capacity
if timesteps > self.window_size:
raise ValueError()
self.queue_array = self.queue_array.scatter(tf.range(self.window_size), tf.concat([self.queue_array.gather(tf.range(timesteps, self.window_size)), inputs], axis=0))
queue_tensor = self.queue_array.stack()
self.count += timesteps
# 2. feed-forward
if self.count < self.window_size:
# generate mask
attention_mask = tf.cast(
tf.math.reduce_all(
tf.math.logical_not(tf.math.is_nan(queue_tensor)), axis=-1
),
dtype=tf.float32,
)
attention_mask = tf.matmul(
attention_mask[..., tf.newaxis],
attention_mask[..., tf.newaxis],
transpose_b=True,
)
return queue_tensor[tf.newaxis, ...], attention_mask
# !!! check overflow
elif self.count > self.window_size:
self.count = self.window_size
return queue_tensor[tf.newaxis, ...], None
@property
def is_full(self):
return self.count == self.window_size
def clear(self):
self.count = 0
shape = tf.shape(self.queue_array.stack())[-1]
self.queue_array = tf.TensorArray(dtype=tf.float32, size=self.window_size)
self.queue_array = self.queue_array.scatter(tf.range(self.window_size), tf.constant(np.nan)*tf.ones((self.window_size, shape)))
l = FIFOLayer(window_size=10)
for i in range(6):
x = tf.random.normal((2, 12))
y = l(x)
print(y)
print(l.is_full, "\n\n")
l.clear()
print(l(x))
print(l.is_full, "\n\n")
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ 0.7558127 , 1.5447265 , 1.6315602 , -0.19868968,
0.08828261, 0.01711658, -1.8133892 , 0.12930395,
0.47128937, 0.08567389, -1.7158676 , -0.5843805 ],
[-0.7664911 , -0.7145203 , -1.089696 , 0.14649415,
0.03585422, 0.9916008 , 0.9384322 , 0.34755042,
-0.09592161, 0.76490027, -1.2517685 , -1.5740465 ]]],
dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 1.]], dtype=float32)>)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ 0.7558127 , 1.5447265 , 1.6315602 , -0.19868968,
0.08828261, 0.01711658, -1.8133892 , 0.12930395,
0.47128937, 0.08567389, -1.7158676 , -0.5843805 ],
[-0.7664911 , -0.7145203 , -1.089696 , 0.14649415,
0.03585422, 0.9916008 , 0.9384322 , 0.34755042,
-0.09592161, 0.76490027, -1.2517685 , -1.5740465 ],
[-0.31995258, -0.43669155, -0.28932425, -0.06870204,
-0.01291991, 1.171546 , 0.75079876, -0.7693662 ,
0.05902815, 0.60606545, -1.1038904 , -0.99837613],
[-0.6687948 , 0.22192897, -0.02249479, -0.08962449,
1.2408841 , 0.119805 , -0.53699344, 1.020805 ,
0.9610218 , 0.6133564 , -0.4358486 , 2.733222 ]]],
dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 0., 0., 1., 1., 1., 1.]], dtype=float32)>)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ 0.7558127 , 1.5447265 , 1.6315602 , -0.19868968,
0.08828261, 0.01711658, -1.8133892 , 0.12930395,
0.47128937, 0.08567389, -1.7158676 , -0.5843805 ],
[-0.7664911 , -0.7145203 , -1.089696 , 0.14649415,
0.03585422, 0.9916008 , 0.9384322 , 0.34755042,
-0.09592161, 0.76490027, -1.2517685 , -1.5740465 ],
[-0.31995258, -0.43669155, -0.28932425, -0.06870204,
-0.01291991, 1.171546 , 0.75079876, -0.7693662 ,
0.05902815, 0.60606545, -1.1038904 , -0.99837613],
[-0.6687948 , 0.22192897, -0.02249479, -0.08962449,
1.2408841 , 0.119805 , -0.53699344, 1.020805 ,
0.9610218 , 0.6133564 , -0.4358486 , 2.733222 ],
[-0.33772066, 0.80799913, -0.00896128, 1.606288 ,
1.1561627 , 0.17252289, 0.2451608 , 1.4633939 ,
-0.9294784 , 0.42795137, -0.3016553 , -1.1823792 ],
[ 0.30927372, 0.3482721 , 1.0262096 , -0.97228396,
-0.55333287, -0.7914886 , 1.0115404 , -0.5656188 ,
0.30958036, -0.8476673 , 2.4919312 , 0.9093976 ]]],
dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1., 1., 1.]], dtype=float32)>)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ 0.7558127 , 1.5447265 , 1.6315602 , -0.19868968,
0.08828261, 0.01711658, -1.8133892 , 0.12930395,
0.47128937, 0.08567389, -1.7158676 , -0.5843805 ],
[-0.7664911 , -0.7145203 , -1.089696 , 0.14649415,
0.03585422, 0.9916008 , 0.9384322 , 0.34755042,
-0.09592161, 0.76490027, -1.2517685 , -1.5740465 ],
[-0.31995258, -0.43669155, -0.28932425, -0.06870204,
-0.01291991, 1.171546 , 0.75079876, -0.7693662 ,
0.05902815, 0.60606545, -1.1038904 , -0.99837613],
[-0.6687948 , 0.22192897, -0.02249479, -0.08962449,
1.2408841 , 0.119805 , -0.53699344, 1.020805 ,
0.9610218 , 0.6133564 , -0.4358486 , 2.733222 ],
[-0.33772066, 0.80799913, -0.00896128, 1.606288 ,
1.1561627 , 0.17252289, 0.2451608 , 1.4633939 ,
-0.9294784 , 0.42795137, -0.3016553 , -1.1823792 ],
[ 0.30927372, 0.3482721 , 1.0262096 , -0.97228396,
-0.55333287, -0.7914886 , 1.0115404 , -0.5656188 ,
0.30958036, -0.8476673 , 2.4919312 , 0.9093976 ],
[-0.44241378, -0.6971805 , -0.37439492, 1.0154608 ,
-0.34494257, 0.1988212 , -0.9541314 , -0.44339198,
0.162457 , -0.31033182, -0.34568167, 1.0341203 ],
[-0.89020306, -0.8646532 , 0.13348487, -0.6604107 ,
0.07642484, 1.3407826 , 0.79119945, -0.7598532 ,
0.85146165, -0.2791065 , -0.4600736 , 0.809218 ]]],
dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.],
[0., 0., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[ 7.5581270e-01, 1.5447265e+00, 1.6315602e+00, -1.9868968e-01,
8.8282607e-02, 1.7116580e-02, -1.8133892e+00, 1.2930395e-01,
4.7128937e-01, 8.5673891e-02, -1.7158676e+00, -5.8438051e-01],
[-7.6649112e-01, -7.1452028e-01, -1.0896960e+00, 1.4649415e-01,
3.5854220e-02, 9.9160081e-01, 9.3843222e-01, 3.4755042e-01,
-9.5921606e-02, 7.6490027e-01, -1.2517685e+00, -1.5740465e+00],
[-3.1995258e-01, -4.3669155e-01, -2.8932425e-01, -6.8702042e-02,
-1.2919909e-02, 1.1715460e+00, 7.5079876e-01, -7.6936620e-01,
5.9028149e-02, 6.0606545e-01, -1.1038904e+00, -9.9837613e-01],
[-6.6879481e-01, 2.2192897e-01, -2.2494787e-02, -8.9624494e-02,
1.2408841e+00, 1.1980500e-01, -5.3699344e-01, 1.0208050e+00,
9.6102178e-01, 6.1335641e-01, -4.3584859e-01, 2.7332220e+00],
[-3.3772066e-01, 8.0799913e-01, -8.9612845e-03, 1.6062880e+00,
1.1561627e+00, 1.7252289e-01, 2.4516080e-01, 1.4633939e+00,
-9.2947841e-01, 4.2795137e-01, -3.0165529e-01, -1.1823792e+00],
[ 3.0927372e-01, 3.4827209e-01, 1.0262096e+00, -9.7228396e-01,
-5.5333287e-01, -7.9148859e-01, 1.0115404e+00, -5.6561881e-01,
3.0958036e-01, -8.4766728e-01, 2.4919312e+00, 9.0939760e-01],
[-4.4241378e-01, -6.9718051e-01, -3.7439492e-01, 1.0154608e+00,
-3.4494257e-01, 1.9882120e-01, -9.5413142e-01, -4.4339198e-01,
1.6245700e-01, -3.1033182e-01, -3.4568167e-01, 1.0341203e+00],
[-8.9020306e-01, -8.6465323e-01, 1.3348487e-01, -6.6041070e-01,
7.6424837e-02, 1.3407826e+00, 7.9119945e-01, -7.5985318e-01,
8.5146165e-01, -2.7910650e-01, -4.6007359e-01, 8.0921799e-01],
[-6.7833281e-01, 4.7877081e-02, -2.0416839e+00, -1.5634586e+00,
-5.1782840e-01, 5.2898288e-01, -1.4573561e+00, 4.6455118e-01,
-3.2871577e-01, -1.5697428e+00, 1.4454672e-01, 8.2387424e-01],
[ 2.5552011e-03, 1.2834518e+00, 4.1382611e-01, 1.6535892e+00,
7.8654990e-02, -1.2952465e-01, 3.6811054e-01, 1.1675907e+00,
9.6434945e-01, -4.2399967e-01, -1.3700709e-01, -5.2056974e-01]]],
dtype=float32)>, None)
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[-3.1995258e-01, -4.3669155e-01, -2.8932425e-01, -6.8702042e-02,
-1.2919909e-02, 1.1715460e+00, 7.5079876e-01, -7.6936620e-01,
5.9028149e-02, 6.0606545e-01, -1.1038904e+00, -9.9837613e-01],
[-6.6879481e-01, 2.2192897e-01, -2.2494787e-02, -8.9624494e-02,
1.2408841e+00, 1.1980500e-01, -5.3699344e-01, 1.0208050e+00,
9.6102178e-01, 6.1335641e-01, -4.3584859e-01, 2.7332220e+00],
[-3.3772066e-01, 8.0799913e-01, -8.9612845e-03, 1.6062880e+00,
1.1561627e+00, 1.7252289e-01, 2.4516080e-01, 1.4633939e+00,
-9.2947841e-01, 4.2795137e-01, -3.0165529e-01, -1.1823792e+00],
[ 3.0927372e-01, 3.4827209e-01, 1.0262096e+00, -9.7228396e-01,
-5.5333287e-01, -7.9148859e-01, 1.0115404e+00, -5.6561881e-01,
3.0958036e-01, -8.4766728e-01, 2.4919312e+00, 9.0939760e-01],
[-4.4241378e-01, -6.9718051e-01, -3.7439492e-01, 1.0154608e+00,
-3.4494257e-01, 1.9882120e-01, -9.5413142e-01, -4.4339198e-01,
1.6245700e-01, -3.1033182e-01, -3.4568167e-01, 1.0341203e+00],
[-8.9020306e-01, -8.6465323e-01, 1.3348487e-01, -6.6041070e-01,
7.6424837e-02, 1.3407826e+00, 7.9119945e-01, -7.5985318e-01,
8.5146165e-01, -2.7910650e-01, -4.6007359e-01, 8.0921799e-01],
[-6.7833281e-01, 4.7877081e-02, -2.0416839e+00, -1.5634586e+00,
-5.1782840e-01, 5.2898288e-01, -1.4573561e+00, 4.6455118e-01,
-3.2871577e-01, -1.5697428e+00, 1.4454672e-01, 8.2387424e-01],
[ 2.5552011e-03, 1.2834518e+00, 4.1382611e-01, 1.6535892e+00,
7.8654990e-02, -1.2952465e-01, 3.6811054e-01, 1.1675907e+00,
9.6434945e-01, -4.2399967e-01, -1.3700709e-01, -5.2056974e-01],
[ 1.3070145e+00, -6.7240512e-01, 1.9308577e+00, 1.7688200e-03,
3.0533668e-01, 6.5813893e-01, 5.2471739e-01, 2.1659613e+00,
-8.7725663e-01, 3.5695407e-01, -1.2751107e+00, -7.7276069e-01],
[-4.3180370e-01, -1.1814500e+00, 2.4167557e-01, 5.7490116e-01,
5.6998456e-01, -7.4528801e-01, -9.1826969e-01, -7.3694932e-01,
-1.2400552e+00, 1.6947891e+00, -2.6127639e-01, 7.8419834e-01]]],
dtype=float32)>, None)
True
(<tf.Tensor: shape=(1, 10, 12), dtype=float32, numpy=
array([[[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ nan, nan, nan, nan,
nan, nan, nan, nan,
nan, nan, nan, nan],
[ 1.3070145e+00, -6.7240512e-01, 1.9308577e+00, 1.7688200e-03,
3.0533668e-01, 6.5813893e-01, 5.2471739e-01, 2.1659613e+00,
-8.7725663e-01, 3.5695407e-01, -1.2751107e+00, -7.7276069e-01],
[-4.3180370e-01, -1.1814500e+00, 2.4167557e-01, 5.7490116e-01,
5.6998456e-01, -7.4528801e-01, -9.1826969e-01, -7.3694932e-01,
-1.2400552e+00, 1.6947891e+00, -2.6127639e-01, 7.8419834e-01]]],
dtype=float32)>, <tf.Tensor: shape=(10, 10), dtype=float32, numpy=
array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 1.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 1.]], dtype=float32)>)
tf.Tensor(False, shape=(), dtype=bool)
On a side note, using tf.queue.FIFOQueue
is really slow.