Search code examples
pythontensorflowsettensorflow-datasets

Python Tensorflow Dataset Filter Set .issubset()


I have a tensorflow dataset:

def fake_sequence():
    seq = [np.random.choice(["A", "B", "C", "D"]) for _ in range(100)]
    mutate = [np.random.choice(["E", "F", "G", "H"]) for _ in range(100)]
    mask = np.random.choice(a=[True, False], size=100, p=[0.999, 0.001])
    return "".join(np.where(mask, seq, mutate))


seqs = [fake_sequence() for _ in range(100)]
ds = tf.data.Dataset.from_tensor_slices(seqs)

I'd like to filter it with the following pythonic function:

def python_filter(x):
   x = set(x)
   x = x.issubset({"A", "B", "C", "D"})
   return x

Unfortunately, decorating with @tf.function doesn't work. Can any of you wizards help me? Here's what I have so far.

def filter(x):
    x = tf.strings.bytes_split(x)
    x = tf.unique(x)[0]
    # tensorflow function for x.issubset({"A", "B", "C", "D"})
    return x

ds = ds.filter(filter)

Solution

  • You could use a tf.lookup.StaticHashTable and tf.cond to solve what you want:

    import tensorflow as tf
    import numpy as np
    
    def fake_sequence():
        seq = [np.random.choice(["A", "B", "C", "D"]) for _ in range(100)]
        mutate = [np.random.choice(["E", "F", "G", "H"]) for _ in range(100)]
        mask = np.random.choice(a=[True, False], size=100, p=[0.999, 0.001])
        return "".join(np.where(mask, seq, mutate))
    
    
    seqs = [fake_sequence() for _ in range(100)]
    ds = tf.data.Dataset.from_tensor_slices(seqs)
    
    keys_tensor = tf.constant(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'])
    vals_tensor = tf.constant([1, 2, 3, 4, 5, 6, 7, 8])
    table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
        default_value=-1)
    
    def filter(x):
        subset = tf.constant(["A", "B", "C", "D"])
        x = tf.strings.bytes_split(x)
        x = tf.unique(x)[0]
        x, y = tf.sort(table.lookup(x)), tf.sort(table.lookup(subset))
        return tf.cond(tf.shape(x)[0]>tf.shape(y)[0], lambda: False, lambda: tf.reduce_all(tf.equal(x, y)))
    
    ds = ds.map(filter)
    for x in ds.take(5):
      print(x)
    

    The tf.lookup.StaticHashTable just maps all letters to integer values, which are easier to compare.