diff --git a/doc/whats-new.rst b/doc/whats-new.rst index effb199f18e..b30268388e9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -29,6 +29,9 @@ Bug Fixes - Fix a major performance regression in :py:meth:`Coordinates.to_index` (and consequently :py:meth:`Dataset.to_dataframe`) caused by converting the cached code ndarrays into Python lists (:issue:`11305`). +- Allow non-mapping arguments such as ``"auto"`` or an integer to + :py:meth:`DataTree.chunk`, matching :py:meth:`Dataset.chunk` + (:issue:`11315`). Documentation diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 98934f29b92..1baefd0b892 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -4,6 +4,7 @@ import io import itertools import textwrap +import warnings from collections import ChainMap, defaultdict from collections.abc import ( Callable, @@ -60,6 +61,7 @@ _default, drop_dims_from_indexers, either_dict_or_kwargs, + emit_user_level_warning, maybe_wrap_array, parse_dims_as_set, ) @@ -2650,15 +2652,29 @@ def chunk( xarray.unify_chunks dask.array.from_array """ - # don't support deprecated ways of passing chunks - if not isinstance(chunks, Mapping): - raise TypeError( - f"invalid type for chunks: {type(chunks)}. Only mappings are supported." + if chunks is None and not chunks_kwargs: + warnings.warn( + "None value for 'chunks' is deprecated. " + "It will raise an error in the future. Use instead '{}'", + category=FutureWarning, + stacklevel=2, ) - combined_chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") + chunks = {} all_dims = self._get_all_dims() + combined_chunks: Mapping[Any, Any] + if not isinstance(chunks, Mapping) and chunks is not None: + if isinstance(chunks, tuple | list): + emit_user_level_warning( + "Supplying chunks as dimension-order tuples is deprecated. " + "It will raise an error in the future. Instead use a dict with dimensions as keys.", + category=FutureWarning, + ) + combined_chunks = dict.fromkeys(all_dims, chunks) + else: + combined_chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") + bad_dims = combined_chunks.keys() - all_dims if bad_dims: raise ValueError( diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 20c383d02cb..d86464d64a8 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -2766,11 +2766,28 @@ def test_chunk(self): assert_identical(actual, expected) assert actual.chunksizes == expected.chunksizes - with pytest.raises(TypeError, match="invalid type"): + with pytest.warns(FutureWarning, match="None value for 'chunks'"): tree.chunk(None) - with pytest.raises(TypeError, match="invalid type"): - tree.chunk((1, 2)) - with pytest.raises(ValueError, match="not found in data dimensions"): tree.chunk({"u": 2}) + + @requires_dask + def test_chunk_non_mapping(self): + # GH11315: DataTree.chunk should accept non-mapping inputs like + # Dataset.chunk does (e.g. "auto" or an int that broadcasts to all dims). + ds1 = xr.Dataset({"a": (("x", "y"), np.zeros((10, 5)))}) + ds2 = xr.Dataset({"b": ("z", np.arange(4))}) + tree = xr.DataTree.from_dict({"/": ds1, "/group": ds2}) + + actual = tree.chunk("auto") + expected = xr.DataTree.from_dict( + {"/": ds1.chunk("auto"), "/group": ds2.chunk("auto")} + ) + assert_identical(actual, expected) + + actual_int = tree.chunk(3) + for node in actual_int.subtree: + for v in node.dataset.variables.values(): + if v.ndim: + assert v.chunks is not None