Skip to content

Commit

Permalink
Merge pull request #868 from helmholtz-analytics/bug/866-op-nested
Browse files Browse the repository at this point in the history
fix binary_op on operands with single element
  • Loading branch information
coquelin77 authored Oct 8, 2021
2 parents c570047 + 18a8297 commit 1e9619d
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 6 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `_reduce_op` when axis and keepdim were set.
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `min`, `max` where DNDarrays with empty processes can't be computed.
- [#868](https://github.com/helmholtz-analytics/heat/pull/868) Fixed an issue in `__binary_op` where data was falsely distributed if a DNDarray has single element.

## Feature Additions
### Linear Algebra
- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot`

### Communication
- [#868](https://github.com/helmholtz-analytics/heat/pull/868) New `MPICommunication` method `Split`

### DNDarray
- [#856](https://github.com/helmholtz-analytics/heat/pull/856) New `DNDarray` method `__torch_proxy__`

Expand Down
16 changes: 10 additions & 6 deletions heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from .communication import MPI, MPI_WORLD
from . import factories
from . import devices
from . import stride_tricks
from . import sanitation
from . import statistics
Expand Down Expand Up @@ -115,28 +114,33 @@ def __binary_op(
output_device = t1.device
output_comm = t1.comm

# ToDo: Fine tuning in case of comm.size>t1.shape[t1.split]. Send torch tensors only to ranks, that will hold data.
if t1.split is not None:
if t1.shape[t1.split] == 1 and t1.comm.is_distributed():
# warnings.warn(
# "Broadcasting requires transferring data of first operator between MPI ranks!"
# )
if t1.comm.rank > 0:
color = 0 if t1.comm.rank < t2.shape[t1.split] else 1
newcomm = t1.comm.Split(color, t1.comm.rank)
if t1.comm.rank > 0 and color == 0:
t1.larray = torch.zeros(
t1.shape, dtype=t1.dtype.torch_type(), device=t1.device.torch_device
)
t1.comm.Bcast(t1)
newcomm.Bcast(t1)
newcomm.Free()

if t2.split is not None:
if t2.shape[t2.split] == 1 and t2.comm.is_distributed():
# warnings.warn(
# "Broadcasting requires transferring data of second operator between MPI ranks!"
# )
if t2.comm.rank > 0:
color = 0 if t2.comm.rank < t1.shape[t2.split] else 1
newcomm = t2.comm.Split(color, t2.comm.rank)
if t2.comm.rank > 0 and color == 0:
t2.larray = torch.zeros(
t2.shape, dtype=t2.dtype.torch_type(), device=t2.device.torch_device
)
t2.comm.Bcast(t2)
newcomm.Bcast(t2)
newcomm.Free()

else:
raise TypeError(
Expand Down
19 changes: 19 additions & 0 deletions heat/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,25 @@ def alltoall_recvbuffer(

return self.as_mpi_memory(obj), (recvcount, recvdispls), recvtypes

def Free(self) -> None:
"""
Free a communicator.
"""
self.handle.Free()

def Split(self, color: int = 0, key: int = 0) -> MPICommunication:
"""
Split communicator by color and key.
Parameters
----------
color : int, optional
Determines the new communicator for a process.
key: int, optional
Ordering within the new communicator.
"""
return MPICommunication(self.handle.Split(color, key))

def Irecv(
self,
buf: Union[DNDarray, torch.Tensor, Any],
Expand Down
11 changes: 11 additions & 0 deletions heat/core/tests/test_arithmetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ def test_add(self):
self.assertTrue(ht.equal(ht.add(self.a_tensor, self.an_int_scalar), result))
self.assertTrue(ht.equal(ht.add(self.a_split_tensor, self.a_tensor), result))

# Single element split
a = ht.array([1], split=0)
b = ht.array([1, 2], split=0)
c = ht.add(a, b)
self.assertTrue(ht.equal(c, ht.array([2, 3])))
if c.comm.size > 1:
if c.comm.rank < 2:
self.assertEqual(c.larray.size()[0], 1)
else:
self.assertEqual(c.larray.size()[0], 0)

with self.assertRaises(ValueError):
ht.add(self.a_tensor, self.another_vector)
with self.assertRaises(TypeError):
Expand Down
15 changes: 15 additions & 0 deletions heat/core/tests/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,21 @@ def test_default_comm(self):
with self.assertRaises(TypeError):
ht.use_comm("1")

def test_split(self):
a = ht.zeros((4, 5), split=0)

color = a.comm.rank % 2
newcomm = a.comm.Split(color, key=a.comm.rank)

self.assertIsInstance(newcomm, ht.MPICommunication)
if ht.MPI_WORLD.size == 1:
self.assertTrue(newcomm.size == a.comm.size)
else:
self.assertTrue(newcomm.size < a.comm.size)
self.assertIsNot(newcomm, a.comm)

newcomm.Free()

def test_allgather(self):
# contiguous data
data = ht.ones((1, 7))
Expand Down

0 comments on commit 1e9619d

Please sign in to comment.