Search code examples
pythonnumpynumpy-slicingjaxpde

How to extract chunks of a 2D numpy array that has been flattened


I would like to know the best way of extacting chunks of elements from a 2D numpy array that has been flattened. See example python code below which hopefully explains what I want to do a little better.

import numpy as np

nx = 5
nz = 7
numGPs = nx*nz

GPs_matrix = np.arange(numGPs).reshape((nx,nz), order='F')
av = np.zeros_like(GPs_matrix)
av[1:-1,1:-1] = (GPs_matrix[1:-1,2:] + GPs_matrix[1:-1,:-2] + GPs_matrix[2:,1:-1] + GPs_matrix[:-2,1:-1])/4
# How to do the above if GPs is flattened as per below?
GPs_flat = GPs_matrix.reshape(-1, order='F')
# One (very clunky) way is to do the following
cor = GPs_matrix[1:-1,1:-1].reshape(-1, order='F')
btm = GPs_matrix[1:-1,:-2].reshape(-1, order='F')
top = GPs_matrix[1:-1,2:].reshape(-1, order='F')
lft = GPs_matrix[:-2,1:-1].reshape(-1, order='F')
rgt = GPs_matrix[2:,1:-1].reshape(-1, order='F')
av_flat = np.zeros_like(GPs_flat)
av_flat[cor] = (GPs_flat[top] + GPs_flat[btm] + GPs_flat[rgt] + GPs_flat[lft])/4
# Check
print(av.reshape(-1, order='F'))
print(av_flat)
# Is there a better way?
# I see np.r_() may be useful but I don't know how to best use it. I assume the below attempt can be improved upon
sl_cor = np.r_[nx+1:2*nx-1,2*nx+1:3*nx-1,3*nx+1:4*nx-1,4*nx+1:5*nx-1,5*nx+1:6*nx-1] # Should match cor above
# Check
print(sl_cor)
print(cor)
# The below also works but I would like to avoid using loops if possible
print(np.r_[*[np.arange(1,nx-1)+nx*j for j in range(1,nz-1)]])

Essentially, I am trying to solve the possion equation in two spatial dimensions. It is convenient to set up the problem in the form of a 2D array as the location of the elements in the array corresponds to the position of the grid points in the cartesian mesh. Ultimately I will be using an iterate solver in jax to solve the linear system (e.g. bicgstab) which requires a linear operator function as input. Therefore, the function needs to return a vector and loops are not efficient.


Solution

  • First off, it looks like what you're attempting to compute is a convolution. In general I'd avoid doing a convolution by hand, and instead use something like scipy.signal.convolve2d. Here's the equivalent for your case:

    from scipy.signal import convolve2d
    kernel = np.array(([[0, 1, 0],
                        [1, 0, 1],
                        [0, 1, 0]])) / 4
    av = np.zeros_like(GPs_matrix)
    av[1:-1, 1:-1] = convolve2d(GPs_matrix, kernel, mode='valid').astype(av.dtype)
    

    In the flattened case, your best approach is probably going to be to reshape the 1D input, perform the 2D convolution, and then flatten the output. For example:

    def eval_1D(vec):
      mat = vec.reshape(nx, nz, order='F')
      kernel = np.array(([[0, 1, 0],
                          [1, 0, 1],
                          [0, 1, 0]])) / 4
      av = np.zeros_like(mat)
      av[1:-1, 1:-1] = convolve2d(mat, kernel, mode='valid').astype(av.dtype)
      return av.ravel(order='F')
    
    print(eval_1D(GPs_matrix.ravel(order='F')))
    # [ 0  0  0  0  0  0  6  7  8  0  0 11 12 13  0  0 16 17
    #  18  0  0 21 22 23  0  0 26 27 28  0  0  0  0  0  0]