Search code examples
pythondeep-learningcntk

Calculate cumulative sum across static axis in CNTK


I'd like to calculate the cumulative sum of a tensor in a CNTK model. This is reasonably straightforward to do for sequences, but it's unclear how to do this on static axes. If one knows the dimensionality of the axis a priori you could conceivably do this with a convoluted set of gather/reduce_sum/splice operations, but this would be ridiculously inefficient.


Solution

  • Operations like that are definitely on our todo list. Since this can be expessed as a matrix product, it won't be too inefficient with the following implementation:

    def cumsum(x, axis=-1):
       d = x.shape[axis]
       U = C.constant(np.triu(np.ones((d,d))).astype(x.dtype))
       if axis != -1:
          x = C.swapaxes(x, -1, axis)
       z = C.times(x, U)
       if axis != -1:
          z = C.swapaxes(z, -1, axis)
       return z