Search code examples
pythontensorflowtensorflow-datasetskaggle

TypeError: Tensor is unhashable. Instead, use tensor.ref()


I'm getting "TypeError: Tensor is unhashable. Instead, use tensor.ref() as the key"

I did a slight change to a public kaggle kernel

I defined a function which checks whether certain value is in a set:

li = pd.read_csv('../input/ch-stock-market-companies/censored_list')
def map_id (id):
  li_set = set(li['investment_id'])
  if id in li_set: return id
  return -5

This function is called during the preprocessing of a tensorflow dataset:

def preprocess(item):
  return (map_id(item["investment_id"]), item["features"]), item["target"] #this is the offending line

def make_dataset(file_paths, batch_size=4096, mode="train"):
  ds = tf.data.TFRecordDataset(file_paths)
  ds = ds.map(decode_function)
  ds = ds.map(preprocess)
  if mode == "train":
      ds = ds.shuffle(batch_size * 4)
  ds = ds.batch(batch_size).cache().prefetch(tf.data.AUTOTUNE)
  return ds

If the above offending line is not changed it would look like this:

    def preprocess(item):
      return (item["investment_id"], item["features"]), item["target"] #this was the line before I changed it

The error message tells me that I cannot use the function map_id as defined.

But how to properly do what I am trying to achieve? Namely, I want to "censor" some of the values in a pandas dataframe by replacing them with a default value of -5. And I want to do this, ideally, as part of creating a tensforflow dataset


Solution

  • As the error message says, you cannot use a tensor inside a Set directly, since it is not hashable. Try using a tf.lookup.StaticHashTable:

    keys_tensor = tf.constant([1, 2, 3])
    vals_tensor = tf.constant([1, 2, 3])
    table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
        default_value=-5)
    
    print(table.lookup(tf.constant(1)))
    print(table.lookup(tf.constant(5)))
    
    tf.Tensor(1, shape=(), dtype=int32)
    tf.Tensor(-5, shape=(), dtype=int32)
    

    Alternatively, you could also use tf.where:

    def check_value(value):
      frozen_set = tf.constant([1, 2, 3])
      return tf.where(tf.reduce_any(tf.equal(value, frozen_set), axis=0, keepdims=True), value, tf.constant(-5))
    
    print(check_value(tf.constant(1)))
    print(check_value(tf.constant(2)))
    print(check_value(tf.constant(4)))
    
    tf.Tensor([1], shape=(1,), dtype=int32)
    tf.Tensor([2], shape=(1,), dtype=int32)
    tf.Tensor([-5], shape=(1,), dtype=int32)