Search code examples
python-xarray

Operating on the two inner-most dimensions in an xarray as a matrix


I have a multi-dimensional DataArray in xarray where the two inner-most dimensions have equal length. I want to apply np.linalg.det() on the the square matrices formed by the two dimensions. How do I do that?

I have looked into the documentation and spent some time getting apply_ufunc() to work, but I don't think thats the way to go.


Solution

  • I think apply_ufunc is indeed the way to go. What about this?

    import xarray as xr
    import numpy as np
    
    # sample dataarray with 3 dimensions: n, m, x
    
    n = 100
    m = 100
    x = 50
    
    data = np.random.rand(n, m, x)
    
    da = xr.DataArray(
        data,
        dims=["n", "m", "x"],
        coords={"n": np.arange(n), "m": np.arange(m), "x": np.arange(x)},
    )
    print(da)
    

    output:

    <xarray.DataArray (n: 100, m: 100, x: 50)>
    array([[[0.94590714, 0.09124397, 0.76948787, ..., 0.47377442,
             0.49205666, 0.2548719 ],
            [0.76343264, 0.41485065, 0.14439847, ..., 0.59190488,
             0.24071555, 0.00809688],
            [0.35608536, 0.44927753, 0.50033374, ..., 0.63904572,
             0.95172323, 0.24194656],
            ...,
            [0.71689346, 0.44678367, 0.92199706, ..., 0.99269701,
            [0.11836402, 0.23515399, 0.48816566, ..., 0.60960886,
    ...
    Coordinates:
      * n        (n) int64 0 1 2 3 4 5 6 7 8 9 10 ... 90 91 92 93 94 95 96 97 98 99
      * m        (m) int64 0 1 2 3 4 5 6 7 8 9 10 ... 90 91 92 93 94 95 96 97 98 99
      * x        (x) int64 0 1 2 3 4 5 6 7 8 9 10 ... 40 41 42 43 44 45 46 47 48 49
    

    then:

    result = xr.apply_ufunc(
        np.linalg.det,
        da,
        input_core_dims=[["n", "m"]],
        output_core_dims=[[]],
    )
    print(result)
    

    output:

    <xarray.DataArray (x: 50)>
    array([-3.05070222e+25,  4.43036648e+25,  1.46495819e+25,  4.70795386e+24,
           ...
           -2.37514032e+24, -2.44225924e+25, -1.29536184e+24,  6.68729607e+24,
           -5.06400130e+24,  9.12843726e+24, -2.96175281e+25, -1.79864819e+26,
           -3.26686808e+24,  2.83969992e+23])
    Coordinates:
      * x        (x) int64 0 1 2 3 4 5 6 7 8 9 10 ... 40 41 42 43 44 45 46 47 48 49