I have a tensorflow dataset:
def fake_sequence():
seq = [np.random.choice(["A", "B", "C", "D"]) for _ in range(100)]
mutate = [np.random.choice(["E", "F", "G", "H"]) for _ in range(100)]
mask = np.random.choice(a=[True, False], size=100, p=[0.999, 0.001])
return "".join(np.where(mask, seq, mutate))
seqs = [fake_sequence() for _ in range(100)]
ds = tf.data.Dataset.from_tensor_slices(seqs)
I'd like to filter it with the following pythonic function:
def python_filter(x):
x = set(x)
x = x.issubset({"A", "B", "C", "D"})
return x
Unfortunately, decorating with @tf.function
doesn't work. Can any of you wizards help me? Here's what I have so far.
def filter(x):
x = tf.strings.bytes_split(x)
x = tf.unique(x)[0]
# tensorflow function for x.issubset({"A", "B", "C", "D"})
return x
ds = ds.filter(filter)
You could use a tf.lookup.StaticHashTable
and tf.cond
to solve what you want:
import tensorflow as tf
import numpy as np
def fake_sequence():
seq = [np.random.choice(["A", "B", "C", "D"]) for _ in range(100)]
mutate = [np.random.choice(["E", "F", "G", "H"]) for _ in range(100)]
mask = np.random.choice(a=[True, False], size=100, p=[0.999, 0.001])
return "".join(np.where(mask, seq, mutate))
seqs = [fake_sequence() for _ in range(100)]
ds = tf.data.Dataset.from_tensor_slices(seqs)
keys_tensor = tf.constant(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])
vals_tensor = tf.constant([1, 2, 3, 4, 5, 6, 7, 8])
table = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
default_value=-1)
def filter(x):
subset = tf.constant(["A", "B", "C", "D"])
x = tf.strings.bytes_split(x)
x = tf.unique(x)[0]
x, y = tf.sort(table.lookup(x)), tf.sort(table.lookup(subset))
return tf.cond(tf.shape(x)[0]>tf.shape(y)[0], lambda: False, lambda: tf.reduce_all(tf.equal(x, y)))
ds = ds.map(filter)
for x in ds.take(5):
print(x)
The tf.lookup.StaticHashTable
just maps all letters to integer values, which are easier to compare.