Search code examples
python-3.xnumpymultidimensional-arrayreshapenumpy-ndarray

How to remove subarrays containing nan elements on a 3d array preserving the shape?


I have a sparse array of shape (863, 923, 2) that contains lot of NANs:

[[[ 43.06010628 -11.01121568]
  [ 25.03068277  16.3949826 ]
  [-23.75853158 -10.95350074]
  ...
  [ 25.52110353   3.00428452]
  [ 32.66945663   9.76115107]
  [ 19.1341548    8.48547008]]

 [[ 19.08099208  11.27167832]
  [-29.4360534  -12.39131814]
  [ 11.24612069  14.38915742]
  ...
  [ 16.6897315   10.04601296]
  [ 30.09409518  17.09382562]
  [ -9.47312129  -9.57484782]]

 [[ 21.22006655  -5.01340343]
  [ 11.65512749   2.32398374]
  [-22.14668148 -11.05883399]
  ...
  [         nan          nan]
  [         nan          nan]
  [         nan          nan]]

 ...

 [[ 32.32522443  -3.73563526]
  [ 30.88408144  -2.92184744]
  [ 37.44548043 -21.8209554 ]
  ...
  [         nan          nan]
  [         nan          nan]
  [         nan          nan]]

 [[ 36.85471348  -7.86696711]
  [ 37.20204074  -6.32105844]
  [ 32.32522443  -3.73563526]
  ...
  [         nan          nan]
  [         nan          nan]
  [         nan          nan]]

 [[ 34.21397091  -5.88930588]
  [ 35.88819735  -7.64992589]
  [ 35.48958094 -10.34708285]
  ...
  [         nan          nan]
  [         nan          nan]
  [         nan          nan]]]

I would like to remove all nan-containing subarrays while preserving the dimensionality of the array. My understanding is that the shape of the array will change to something like (m, n, 2) but am unable to produce such an array after removing NANs. Here is my attempt:

nonnanarr = arr[~np.isnan(arr).any(axis=-1)].reshape((863, -1, 2))

And here is the error message:

Traceback (most recent call last):
  File "c:\Users\username\Desktop\observables\my_script.py", line 167, in <module>
    main()
  File "c:\Users\username\Desktop\observables\my_script.py", line 104, in main
    time_stamp_num, agents_num, spatial_dimensions_num = dataframe_splitter()
  File "c:\Users\username\Desktop\observables\utilities.py", line 1351, in dataframe_splitter
    nonnan_arr = arr[~np.isnan(arr).any(axis=-1)].reshape(
ValueError: cannot reshape array of size 226512 into shape (863,newaxis,2)

Solution

  • If you have an N-dimensional array, you need to reduce your mask along (N-1) dimensions.

    In you case, you have n = 3 dimensions, so you have three (comb(n, (n - 1))) possibilities.

    For example, with this input:

    import numpy as np
    
    
    arr = np.arange(3 * 4 * 5, dtype=np.float_).reshape((3, 4, 5))
    print(arr[1, 1, 1])
    # 26
    arr[1, 1, 1] = np.nan
    print(arr)
    # [[[ 0.  1.  2.  3.  4.]
    #   [ 5.  6.  7.  8.  9.]
    #   [10. 11. 12. 13. 14.]
    #   [15. 16. 17. 18. 19.]]
    
    #  [[20. 21. 22. 23. 24.]
    #   [25. nan 27. 28. 29.]
    #   [30. 31. 32. 33. 34.]
    #   [35. 36. 37. 38. 39.]]
    
    #  [[40. 41. 42. 43. 44.]
    #   [45. 46. 47. 48. 49.]
    #   [50. 51. 52. 53. 54.]
    #   [55. 56. 57. 58. 59.]]]
    

    You could reduce on (1, 2):

    mask1 = np.isnan(arr).any(axis=(1, 2))
    print(mask1)
    # [False  True False]
    
    print(arr[~mask1, :, :].shape)
    # (2, 4, 5)
    
    print(arr[~mask1, :, :])
    # [[[ 0.  1.  2.  3.  4.]
    #   [ 5.  6.  7.  8.  9.]
    #   [10. 11. 12. 13. 14.]
    #   [15. 16. 17. 18. 19.]]
    
    #  [[40. 41. 42. 43. 44.]
    #   [45. 46. 47. 48. 49.]
    #   [50. 51. 52. 53. 54.]
    #   [55. 56. 57. 58. 59.]]]
    

    or on (0, 2):

    mask2 = np.isnan(arr).any(axis=(0, 2))
    print(mask2)
    # [False  True False False]
    print(arr[:, ~mask2, :].shape)
    # (3, 3, 5)
    
    print(arr[:, ~mask2, :])
    # [[[ 0.  1.  2.  3.  4.]
    #   [10. 11. 12. 13. 14.]
    #   [15. 16. 17. 18. 19.]]
    
    #  [[20. 21. 22. 23. 24.]
    #   [30. 31. 32. 33. 34.]
    #   [35. 36. 37. 38. 39.]]
    
    #  [[40. 41. 42. 43. 44.]
    #   [50. 51. 52. 53. 54.]
    #   [55. 56. 57. 58. 59.]]]
    

    or on (0, 1):

    mask3 = np.isnan(arr).any(axis=(0, 1))
    print(mask3)
    # [False  True False False False]
    print(arr[:, :, ~mask3].shape)
    # (3, 4, 4)
    
    print(arr[:, :, ~mask3])
    # [[[ 0.  2.  3.  4.]
    #   [ 5.  7.  8.  9.]
    #   [10. 12. 13. 14.]
    #   [15. 17. 18. 19.]]
    
    #  [[20. 22. 23. 24.]
    #   [25. 27. 28. 29.]
    #   [30. 32. 33. 34.]
    #   [35. 37. 38. 39.]]
    
    #  [[40. 42. 43. 44.]
    #   [45. 47. 48. 49.]
    #   [50. 52. 53. 54.]
    #   [55. 57. 58. 59.]]]
    

    For your case, if you need the 3rd dimension to stay the same, you cannot reduce on (0, 1), but any of (1, 2) and (0, 2) would work. You need to pick the most appropriate for you.