Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add broadcast_like. #3086

Merged
merged 6 commits into from
Jul 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ Reshaping and reorganizing
Dataset.shift
Dataset.roll
Dataset.sortby
Dataset.broadcast_like

DataArray
=========
Expand Down Expand Up @@ -386,6 +387,7 @@ Reshaping and reorganizing
DataArray.shift
DataArray.roll
DataArray.sortby
DataArray.broadcast_like

.. _api.ufuncs:

Expand Down
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ v0.12.4 (unreleased)
New functions/methods
~~~~~~~~~~~~~~~~~~~~~

- Added :py:meth:`DataArray.broadcast_like` and :py:meth:`Dataset.broadcast_like`.
By `Deepak Cherian <https://github.com/dcherian>`_.

Enhancements
~~~~~~~~~~~~

Expand Down Expand Up @@ -48,6 +51,7 @@ New functions/methods
(:issue:`3026`).
By `Julia Kent <https://github.com/jukent>`_.


Enhancements
~~~~~~~~~~~~

Expand Down
97 changes: 55 additions & 42 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,58 @@ def reindex_variables(
return reindexed, new_indexes


def _get_broadcast_dims_map_common_coords(args, exclude):

common_coords = OrderedDict()
dims_map = OrderedDict()
for arg in args:
for dim in arg.dims:
if dim not in common_coords and dim not in exclude:
dims_map[dim] = arg.sizes[dim]
if dim in arg.coords:
common_coords[dim] = arg.coords[dim].variable

return dims_map, common_coords


def _broadcast_helper(arg, exclude, dims_map, common_coords):

from .dataarray import DataArray
from .dataset import Dataset

def _set_dims(var):
# Add excluded dims to a copy of dims_map
var_dims_map = dims_map.copy()
for dim in exclude:
with suppress(ValueError):
# ignore dim not in var.dims
var_dims_map[dim] = var.shape[var.dims.index(dim)]

return var.set_dims(var_dims_map)

def _broadcast_array(array):
data = _set_dims(array.variable)
coords = OrderedDict(array.coords)
coords.update(common_coords)
return DataArray(data, coords, data.dims, name=array.name,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could potentially use ._replace, and then only supply the changed items (but not a big deal)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i.e. return array._replace(data, coords)
So array keeps its type if someone inherited from DataArray and we're not coupled to specific DataArray properties if those ever change

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but let's save this for a later refactor.

attrs=array.attrs)

def _broadcast_dataset(ds):
data_vars = OrderedDict(
(k, _set_dims(ds.variables[k]))
for k in ds.data_vars)
coords = OrderedDict(ds.coords)
coords.update(common_coords)
return Dataset(data_vars, coords, ds.attrs)

if isinstance(arg, DataArray):
return _broadcast_array(arg)
elif isinstance(arg, Dataset):
return _broadcast_dataset(arg)
else:
raise ValueError('all input must be Dataset or DataArray objects')


def broadcast(*args, exclude=None):
"""Explicitly broadcast any number of DataArray or Dataset objects against
one another.
Expand Down Expand Up @@ -463,55 +515,16 @@ def broadcast(*args, exclude=None):
a (x, y) int64 1 1 2 2 3 3
b (x, y) int64 5 6 5 6 5 6
"""
from .dataarray import DataArray
from .dataset import Dataset

if exclude is None:
exclude = set()
args = align(*args, join='outer', copy=False, exclude=exclude)

common_coords = OrderedDict()
dims_map = OrderedDict()
for arg in args:
for dim in arg.dims:
if dim not in common_coords and dim not in exclude:
dims_map[dim] = arg.sizes[dim]
if dim in arg.coords:
common_coords[dim] = arg.coords[dim].variable

def _set_dims(var):
# Add excluded dims to a copy of dims_map
var_dims_map = dims_map.copy()
for dim in exclude:
with suppress(ValueError):
# ignore dim not in var.dims
var_dims_map[dim] = var.shape[var.dims.index(dim)]

return var.set_dims(var_dims_map)

def _broadcast_array(array):
data = _set_dims(array.variable)
coords = OrderedDict(array.coords)
coords.update(common_coords)
return DataArray(data, coords, data.dims, name=array.name,
attrs=array.attrs)

def _broadcast_dataset(ds):
data_vars = OrderedDict(
(k, _set_dims(ds.variables[k]))
for k in ds.data_vars)
coords = OrderedDict(ds.coords)
coords.update(common_coords)
return Dataset(data_vars, coords, ds.attrs)

dims_map, common_coords = _get_broadcast_dims_map_common_coords(
args, exclude)
result = []
for arg in args:
if isinstance(arg, DataArray):
result.append(_broadcast_array(arg))
elif isinstance(arg, Dataset):
result.append(_broadcast_dataset(arg))
else:
raise ValueError('all input must be Dataset or DataArray objects')
result.append(_broadcast_helper(arg, exclude, dims_map, common_coords))

return tuple(result)

Expand Down
27 changes: 26 additions & 1 deletion xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
utils)
from .accessor_dt import DatetimeAccessor
from .accessor_str import StringAccessor
from .alignment import align, reindex_like_indexers
from .alignment import (align, _broadcast_helper,
_get_broadcast_dims_map_common_coords,
reindex_like_indexers)
from .common import AbstractArray, DataWithCoords
from .coordinates import (
DataArrayCoordinates, LevelCoordinatesSource, assert_coordinate_consistent,
Expand Down Expand Up @@ -994,6 +996,29 @@ def sel_points(self, dim='points', method=None, tolerance=None,
dim=dim, method=method, tolerance=tolerance, **indexers)
return self._from_temp_dataset(ds)

def broadcast_like(self,
other: Union['DataArray', Dataset],
exclude=None) -> 'DataArray':
"""Broadcast this DataArray against another Dataset or DataArray.
This is equivalent to xr.broadcast(other, self)[1]

Parameters
----------
other : Dataset or DataArray
Object against which to broadcast this array.
exclude : sequence of str, optional
Dimensions that must not be broadcasted
"""

if exclude is None:
exclude = set()
args = align(other, self, join='outer', copy=False, exclude=exclude)

dims_map, common_coords = _get_broadcast_dims_map_common_coords(
args, exclude)

return _broadcast_helper(self, exclude, dims_map, common_coords)

def reindex_like(self, other: Union['DataArray', Dataset],
method: Optional[str] = None, tolerance=None,
copy: bool = True, fill_value=dtypes.NA) -> 'DataArray':
Expand Down
27 changes: 26 additions & 1 deletion xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from ..coding.cftimeindex import _parse_array_of_cftime_strings
from . import (alignment, dtypes, duck_array_ops, formatting, groupby,
indexing, ops, pdcompat, resample, rolling, utils)
from .alignment import align
from .alignment import (align, _broadcast_helper,
_get_broadcast_dims_map_common_coords)
from .common import (ALL_DIMS, DataWithCoords, ImplementsDatasetReduce,
_contains_datetime_like_objects)
from .coordinates import (DatasetCoordinates, LevelCoordinatesSource,
Expand Down Expand Up @@ -2027,6 +2028,30 @@ def sel_points(self, dim='points', method=None, tolerance=None,
)
return self.isel_points(dim=dim, **pos_indexers)

def broadcast_like(self,
other: Union['Dataset', 'DataArray'],
exclude=None) -> 'Dataset':
"""Broadcast this DataArray against another Dataset or DataArray.
This is equivalent to xr.broadcast(other, self)[1]

Parameters
----------
other : Dataset or DataArray
Object against which to broadcast this array.
exclude : sequence of str, optional
Dimensions that must not be broadcasted

"""

if exclude is None:
exclude = set()
args = align(other, self, join='outer', copy=False, exclude=exclude)

dims_map, common_coords = _get_broadcast_dims_map_common_coords(
args, exclude)

return _broadcast_helper(self, exclude, dims_map, common_coords)

def reindex_like(self, other, method=None, tolerance=None, copy=True,
fill_value=dtypes.NA):
"""Conform this object onto the indexes of another object, filling in
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,21 @@ def test_coords_non_string(self):
expected = DataArray(2, coords={1: 2}, name=1)
assert_identical(actual, expected)

def test_broadcast_like(self):
original1 = DataArray(np.random.randn(5),
[('x', range(5))])

original2 = DataArray(np.random.randn(6),
[('y', range(6))])

expected1, expected2 = broadcast(original1, original2)

assert_identical(original1.broadcast_like(original2),
expected1.transpose('y', 'x'))

assert_identical(original2.broadcast_like(original1),
expected2)

def test_reindex_like(self):
foo = DataArray(np.random.randn(5, 6),
[('x', range(5)), ('y', range(6))])
Expand Down
15 changes: 15 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,6 +1560,21 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False,
assert_identical(mdata.sel(x={'one': 'a', 'two': 1}),
mdata.sel(one='a', two=1))

def test_broadcast_like(self):
original1 = DataArray(np.random.randn(5),
[('x', range(5))], name='a').to_dataset()

original2 = DataArray(np.random.randn(6),
[('y', range(6))], name='b')

expected1, expected2 = broadcast(original1, original2)

assert_identical(original1.broadcast_like(original2),
expected1.transpose('y', 'x'))

assert_identical(original2.broadcast_like(original1),
expected2)

def test_reindex_like(self):
data = create_test_data()
data['letters'] = ('dim3', 10 * ['a'])
Expand Down