Some preliminary setup:
import xarray as xr
import numpy as np
xr.set_options(display_style="text")
<xarray.core.options.set_options at 0x7f3777111e50>
Suppose that I have label
s which are composed of two parts: first
and second
:
raw_labels = np.array(
[["a", "c"], ["b", "a"], ["a", "b"], ["c", "a"]],
dtype="<U1",
)
raw_labels
array([['a', 'c'],
['b', 'a'],
['a', 'b'],
['c', 'a']], dtype='<U1')
I can make an xarray.DataArray
easily enough to represent this raw information with informative tags:
label_metas = xr.DataArray(
raw_labels,
dims=("label", "parts"),
coords={
"label": ["-".join(x) for x in raw_labels],
"parts": ["first", "second"],
},
name="meta",
)
label_metas
<xarray.DataArray 'meta' (label: 4, parts: 2)> Size: 32B array([['a', 'c'], ['b', 'a'], ['a', 'b'], ['c', 'a']], dtype='<U1') Coordinates: * label (label) <U3 48B 'a-c' 'b-a' 'a-b' 'c-a' * parts (parts) <U6 48B 'first' 'second'
Now suppose that I have additional information for a label: let's say it is some count information for simplicity.
raw_counts = np.random.randint(0, 100, size=len(label_metas))
raw_counts
array([95, 23, 6, 77])
label_counts = xr.DataArray(
raw_counts,
dims="label",
coords={"label": label_metas.coords["label"]},
name="count",
)
label_counts
<xarray.DataArray 'count' (label: 4)> Size: 32B array([95, 23, 6, 77]) Coordinates: * label (label) <U3 48B 'a-c' 'b-a' 'a-b' 'c-a'
How do I combine these clearly related xr.DataArray
s? From what I understand: by using xr.Dataset
s.
label_info = xr.merge([label_metas, label_counts])
label_info
<xarray.Dataset> Size: 160B Dimensions: (label: 4, parts: 2) Coordinates: * label (label) <U3 48B 'a-c' 'b-a' 'a-b' 'c-a' * parts (parts) <U6 48B 'first' 'second' Data variables: meta (label, parts) <U1 32B 'a' 'c' 'b' 'a' 'a' 'b' 'c' 'a' count (label) int64 32B 95 23 6 77
Now suppose I want to filter this dataset, so that I only have left those labels with first part 'a'
. How would I go about it? According to the docs, where
can apply to xr.Dataset
too, but no examples are given showing this in action. Here are the results of my experiments:
label_info["meta"].sel(parts="first")
<xarray.DataArray 'meta' (label: 4)> Size: 16B array(['a', 'b', 'a', 'c'], dtype='<U1') Coordinates: * label (label) <U3 48B 'a-c' 'b-a' 'a-b' 'c-a' parts <U6 24B 'first'
label_info.where(label_info["meta"].sel(parts="first") == "a")
<xarray.Dataset> Size: 192B Dimensions: (label: 4, parts: 2) Coordinates: * label (label) <U3 48B 'a-c' 'b-a' 'a-b' 'c-a' * parts (parts) <U6 48B 'first' 'second' Data variables: meta (label, parts) object 64B 'a' 'c' nan nan 'a' 'b' nan nan count (label) float64 32B 95.0 nan 6.0 nan
We see that those points that do not match the where
are replaced with a np.nan
, as expected from the docs. Does that mean there is some re-allocation of backing arrays involved? Suppose then that we just asked for those regions that do not match to be dropped, does that also cause a re-allocation? I am not sure, because I am unable to drop those values due to IndexError: dimension coordinate 'parts' conflicts between indexed and indexing objects
:
label_info.where(label_info["meta"].sel(parts="first") == "a", drop=True)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Cell In[20], line 1
----> 1 label_info.where(label_info["meta"].sel(parts="first") == "a", drop=True)
File ~/miniforge3/envs/xarray-tutorial/lib/python3.11/site-packages/xarray/core/common.py:1225, in DataWithCoords.where(self, cond, other, drop)
1222 for dim in cond.sizes.keys():
1223 indexers[dim] = _get_indexer(dim)
-> 1225 self = self.isel(**indexers)
1226 cond = cond.isel(**indexers)
1228 return ops.where_method(self, cond, other)
File ~/miniforge3/envs/xarray-tutorial/lib/python3.11/site-packages/xarray/core/dataset.py:2972, in Dataset.isel(self, indexers, drop, missing_dims, **indexers_kwargs)
2970 indexers = either_dict_or_kwargs(indexers, indexers_kwargs, "isel")
2971 if any(is_fancy_indexer(idx) for idx in indexers.values()):
-> 2972 return self._isel_fancy(indexers, drop=drop, missing_dims=missing_dims)
2974 # Much faster algorithm for when all indexers are ints, slices, one-dimensional
2975 # lists, or zero or one-dimensional np.ndarray's
2976 indexers = drop_dims_from_indexers(indexers, self.dims, missing_dims)
File ~/miniforge3/envs/xarray-tutorial/lib/python3.11/site-packages/xarray/core/dataset.py:3043, in Dataset._isel_fancy(self, indexers, drop, missing_dims)
3040 selected = self._replace_with_new_dims(variables, coord_names, indexes)
3042 # Extract coordinates from indexers
-> 3043 coord_vars, new_indexes = selected._get_indexers_coords_and_indexes(indexers)
3044 variables.update(coord_vars)
3045 indexes.update(new_indexes)
File ~/miniforge3/envs/xarray-tutorial/lib/python3.11/site-packages/xarray/core/dataset.py:2844, in Dataset._get_indexers_coords_and_indexes(self, indexers)
2840 # we don't need to call align() explicitly or check indexes for
2841 # alignment, because merge_variables already checks for exact alignment
2842 # between dimension coordinates
2843 coords, indexes = merge_coordinates_without_align(coords_list)
-> 2844 assert_coordinate_consistent(self, coords)
2846 # silently drop the conflicted variables.
2847 attached_coords = {k: v for k, v in coords.items() if k not in self._variables}
File ~/miniforge3/envs/xarray-tutorial/lib/python3.11/site-packages/xarray/core/coordinates.py:941, in assert_coordinate_consistent(obj, coords)
938 for k in obj.dims:
939 # make sure there are no conflict in dimension coordinates
940 if k in coords and k in obj.coords and not coords[k].equals(obj[k].variable):
--> 941 raise IndexError(
942 f"dimension coordinate {k!r} conflicts between "
943 f"indexed and indexing objects:\n{obj[k]}\nvs.\n{coords[k]}"
944 )
IndexError: dimension coordinate 'parts' conflicts between indexed and indexing objects:
<xarray.DataArray 'parts' (parts: 2)> Size: 48B
array(['first', 'second'], dtype='<U6')
Coordinates:
* parts (parts) <U6 48B 'first' 'second'
vs.
<xarray.Variable ()> Size: 24B
array('first', dtype='<U6')
xarray
was the wrong choice for my task, and probably every other task imaginable. polars
turned out to be much better.
I think the reason why is:
n
-dimensional array is the same as an n
-argument functionn
-argument function is the same as a "table" with n + 1
"columns", where the first n
columns correspond to the function's input, and the last column corresponds to the function's output.Therefore, everything that one might want to try and do with a multi-dimensional array can be done with a "flat" table (a DataFrame
in polars
speak), and the quality of polars
is such that:
xarray
;xarray
.When it comes to this question in particular, what I was trying to do is joining two n
-dimensional maps, and then filter for particular rows. That's a kind of join
operation in SQL/polars
/etc. speak, followed by a filter
.
It's an utter pain to do this in xarray
. See this question, and related: `xarray`: merging two `DataArray`s which have only one shared dimension results in a `Dataset` that lists other dimensions?
It is easy to do in polars
:
import numpy as np
import polars as pl
raw_labels = np.array(
[["a", "c"], ["b", "a"], ["a", "b"], ["c", "a"]],
dtype="<U1",
)
df = pl.DataFrame(
[pl.Series(col, dtype=pl.String()) for col in raw_labels.transpose()]
).rename({"column_0": "first", "column_1": "second"})
print(df)
# attach new data to the table
raw_counts = np.random.randint(0, 100, size=raw_labels.shape[0])
df.insert_column(len(df.columns), pl.Series("counts", raw_counts))
print(df)
# filter the dataframe for those rows where column `first` == "a"
df.filter(pl.col("first") == "a")
print(df)
shape: (4, 2) # create a table with first, second as columns
┌───────┬────────┐
│ first ┆ second │
│ --- ┆ --- │
│ str ┆ str │
╞═══════╪════════╡
│ a ┆ c │
│ b ┆ a │
│ a ┆ b │
│ c ┆ a │
└───────┴────────┘
shape: (4, 3) # attach extra count information
┌───────┬────────┬────────┐
│ first ┆ second ┆ counts │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞═══════╪════════╪════════╡
│ a ┆ c ┆ 71 │
│ b ┆ a ┆ 19 │
│ a ┆ b ┆ 46 │
│ c ┆ a ┆ 73 │
└───────┴────────┴────────┘
shape: (2, 3) # filter for those rows where column `first` == "a"
┌───────┬────────┬────────┐
│ first ┆ second ┆ counts │
│ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 │
╞═══════╪════════╪════════╡
│ a ┆ c ┆ 71 │
│ a ┆ b ┆ 46 │
└───────┴────────┴────────┘