Search code examples
tensorflowkerastensorflow2.0tensorflow2.x

How to point from the inputs of shape (100,24,24,6) the last channel dimension i.e (6,) to be worked on?


I am trying to use the tf.map_fn() , where my elems should be pointing to the channel dimension of my inputs(shape = 100,24,24,6), so my elems should be a list/tuple of tensors, pointing or accessing the values of the channel dimension(6) of the inputs .I am trying to do it by making a for loop in such a way :

@tf.function def call(self, inputs, training=True):

    elems = []

    for b in inputs:
        for h in b:
            for w in h:
                for c in w:
                    elems.append(c)

    changed_inputs = tf.map_fn(self.do_mapping, elems)
    return changed_inputs

What i am trying to achieve in the self.do_mapping is that it is doing a dictionary look up for the values of a dictionary (vmap) using the keys and the return the values. the dictionary vmap is made by accessing the output of a layer and appending only the similar values of the channel dimension of the output of layer so the keys in dictionary are tuple of 6 (as the size of channel dimension) tf.tensorobjects and values of dictionary is the count which i keep. This is how the dictionary is made :

value = list(self.get_values())
vmap = {}
cnt = 0
for v0 in value:
    for v1 in v0:
        for v2 in v1:
            for v3 in v2:
                v = tuple(v3)
                if v not in vmap:
                    vmap[v]=cnt
                    cnt+=1

the do_mapping function is :

@tf.function
def do_mapping(self,pixel):
    if self._compression :
        pixel = tuple(pixel)
        enumerated_value=self._vmap.get(pixel)
        print(enumerated_value)
        print(tf.shape(pixel))
        exit()
        return enumerated_value

If i try to use the tf.map_fn now where i try to point the elems to the channel dimension then i get the following error :(ValueError: elements in elems must be 1+ dimensional Tensors, not scalars ). Please help me to understand how can i use the tf.map_fn for my case ? Thank you in advance


Solution

  • First, instead of doing a for loop (try to avoid for efficiency), you can just reshape that way:

    elems = tf.reshape(inputs,-1)

    Second, what do you want to do exactly? What do you mean by "it doesn't work"? What is the error message? What is self.do_mapping?

    Best,

    Keivan