Search code examples
pythonnumpyscipymaximization

Python/Scipy: Find "bounded" min/max of a matrix


I think it is easiest to specify my problem, the generalized case is difficult to explain.

Say I have a matrix

a with dimensions NxMxT,

where one can think about T as a time-dimension (to make the question easier). Let (n,m) be the indices through NxM. I might call (n,m) the state-space identifier. Then I need to find the python/scipy equivalent of

for each (n,m):
     find a*(n,m) = min(a(n,m,:) s.t. a*(n,m) > a(n,m,T)

That is, find the smallest state-space value that is still higher than the last (among time dimension) observation - for the whole state-space.

My first attempt was to first solve the inner problem (find a that is higher than a[...,-1]):

aHigherThanLast = a[ a > a[...,-1][...,newaxis] ]

And then I wanted to find the smallest among all of these for each (n,m). Unfortunately, aHigherThanLast now contains a 1-D array of all these values, so I don't have the (n,m) correspondence anymore. What would be a better approach to this?

As an additional problem: The state-space is variable, it could also be 3 or more dimensions (NxMxKx...), and I cannot hard-code this. So any kind of

for (n,m,t) in nditer(a):

is not feasible.

Many thanks!

/edit:

a = array([[[[[[[[ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.]]]],



          [[[[ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.]]]]],




         [[[[[ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.]]]],



          [[[[ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.],
             [ 0.,  2.,  1.]]]]]]]])
# a.shape = (1L, 1L, 2L, 2L, 1L, 1L, 10L, 3L). so in this case, T = 3.
# expected output would be the sort of
# b.shape = (1L, 1L, 2L, 2L, 1L, 1L, 10L), which solves
  • b[a,b,c,d,e,f,g] > a[a,b,c,d,e,f,g,-1] (b is higher than the newest observation)

    • There is no element in i in a that satisfies both

      -- a[a,b,c,d,e,f,g,t] > a[a,b,c,d,e,f,g,-1]

      -- a[a,b,c,d,e,f,g,t] < b[a,b,c,d,e,f,g] (b is the smallest element that is higher than the newest observation)

So, given that the previous array is a simple stack if [0,2,1] along the last observation, I would expect

b = ones((1,1,2,2,1,1,10))*2

however, - if, among some (a,b,c,d,e,f,g), there was not only the value of either {0,1,2}, but also {3}, then I would still want the 2 (as it is the smaller of i = {2,3} that satisfies i > 1. - if among some (a,b,c,d,e,f,g) there was only the value {0,1,3}, I would want the 3, as i = 3 would be the smallest number that satisfies i > 1.

Hope that cleared it up a bit?

/edit2:

Appreciate the answer a lot, it works. How would I adjust it if I wanted the opposite, i.e. the largest among those that are smaller? I didn't try to get through that complicated indexing logic, so my (weak) attempt of only changing the first three lines did not succeed:

        b = sort(a[...,:-1], axis=-1)
        b = b[...,::-1]
        mask = b < a[..., -1:]
        index = argmax(mask, axis=-1)
        indices = tuple([arange(j) for j in a.shape[:-1]])
        indices = meshgrid(*indices, indexing='ij', sparse=True)
        indices.append(index)
        indices = tuple(indices)
        a[indices]

Also, a[...,::-1][indices], my second attempt, was not fruitful either.


Solution

  • I think Mr. E is on the right track. You definitely start by sorting the array without that last time value:

    b = np.sort(a[..., :-1], axis=-1)
    

    You would now ideally use `np.searchsorted to find where the first item larger than the end value is, but unfortunately np.searchsorted only works on flattened arrays, so we have to do some more work, like creating a boolean mask, then finding the first True using np.argmax:

    mask = b > a[..., -1:]
    index = np.argmax(mask, axis=-1)
    

    You now have the indices, to extract the actual values, you need to do some indexing magic:

    indices = tuple([np.arange(j) for j in b.shape[:-1]])
    indices = np.meshgrid(*indices, indexing='ij', sparse=True)
    indices.append(index)
    indices = tuple(indices)
    

    And you can now finally do:

    >>> b[indices]
    array([[[[[[[ 2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.]]],
    
    
              [[[ 2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.]]]],
    
    
    
             [[[[ 2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.]]],
    
    
              [[[ 2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.]]]]]]])
    >>> b[indices].shape
    (1L, 1L, 2L, 2L, 1L, 1L, 10L)
    

    To get the largest among those that are smaller, you could do something like:

    mask = b >= a[..., -1:]
    index = np.argmax(mask, axis=-1) - 1
    

    i.e. the largest among those that are smaller is, the item right before the smallest among those that are equal or larger. This second case makes it more clear that this approach gives garbage result if there is no item that fulfills the condition. In this second case, when that happens, you will get a -1 for an index, so you could check that the results are valid doing np.any(index == -1).

    You can set the index to -1 if the condition cannot be satisfied for the first case by doing

    mask = b > a[..., -1:]
    wrong = np.all(~mask, axis=-1)
    index = np.argmax(mask, axis=-1)
    index[wrong] = -1