Search code examples
pythonnumpymatlabindexingmatrix-indexing

Flat indexing of all but first dimension with Numpy


Is there some way to use flat indexing for the remaining dimensions with NumPy? I'm trying to translate the following MATLAB function to Python

function [indices, weights] = locate(values, gridpoints)
    indices = ones(size(values));
    weights = zeros([2, size(values)]);

    for ix = 1:numel(values)
        if values(ix) <= gridpoints(1)
            indices(ix) = 1;
            weights(:, ix) = [1; 0];
        elseif values(ix) >= gridpoints(end)
            indices(ix) = length(gridpoints) - 1;
            weights(:, ix) = [0; 1];
        else
            indices(ix) = find(gridpoints <= values(ix), 1, 'last');    
            weights(:, ix) = ...
                [gridpoints(indices(ix) + 1) - values(ix); ...
                 values(ix) - gridpoints(indices(ix))] ...
                / (gridpoints(indices(ix) + 1) - gridpoints(indices(ix)));
        end
    end
end

but I can't wrap my head around what the NumPy equivalent of MATLAB's weights(:, ix) would be---that is, linear indexing only in the remaining dimensions.

I was hoping that the syntax could be directly translated, but suppose that values is a 3-by-4 array, then weights becomes a 2-by-3-by-4 array. In MATLAB, weights(:, ix) is then a 2-by-1 array, whereas in Python weights[:, ix] is a 2-by-3 array.

I think that I have handled everything else in the function below.

import numpy as np


def locate(values, gridpoints):
    indices = np.zeros(np.shape(values), dtype=int)
    weights = np.zeros((2,) + np.shape(values))

    for ix in range(values.size):
        if values.flat[ix] <= gridpoints[0]:
            indices.flat[ix] = 0
            # weights[:, ix] = [1, 0]
        elif values.flat[ix] >= gridpoints[-1]:
            indices.flat[ix] = gridpoints.size - 2
            # weights[:, ix] = [0, 1]
        else:
            indices.flat[ix] = (
                np.argwhere(gridpoints <= values.flat[ix]).flatten()[-1]
            )
            # weights[:, ix] = (
            #         np.array([gridpoints[indices.flat[ix] + 1] - values.flat[ix],
            #                   values.flat[ix] - gridpoints[indices.flat[ix]]])
            #         / (gridpoints[indices.flat[ix] + 1] - gridpoints[indices.flat[ix]])
            # )

    return indices, weights

Do you have any suggestions? Perhaps I'm just thinking about the problem all wrong. I have also tried to write the code as simply as possible as I intend to use Numba to speed it up later.


Solution

  • As per hpaulj's comment, there doesn't seem to be a direct NumPy equivalent. In lack thereof, the best I can think of is to reshape the weights array as in the code below and the suggestion from NumPy for Matlab Users.

    import numpy as np
    
    
    def locate(values, gridpoints):
        indices = np.zeros(values.shape, dtype=int)
        weights = np.zeros((2, values.size))  # Temporarily make weights 2-by-N
    
        for ix in range(values.size):
            if values.flat[ix] <= gridpoints[0]:
                indices.flat[ix] = 0
                weights[:, ix] = [1, 0]
            elif values.flat[ix] >= gridpoints[-1]:
                indices.flat[ix] = gridpoints.size - 2
                weights[:, ix] = [0, 1]
            else:
                indices.flat[ix] = (
                    np.argwhere(gridpoints <= values.flat[ix]).flatten()[-1]
                )
                weights[:, ix] = (
                        np.array([gridpoints[indices.flat[ix] + 1] - values.flat[ix],
                                  values.flat[ix] - gridpoints[indices.flat[ix]]])
                        / (gridpoints[indices.flat[ix] + 1] - gridpoints[indices.flat[ix]])
                )
        
        # Give weights correct dimensions
        weights.shape = (2,) + values.shape
        
        return indices, weights