Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ht.array() default to copy=None (e.g., only if necessary) #1119

Merged
merged 29 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
bb1e4db
Default to copy=None in `ht.array()` as per array-API
ClaudiaComito Mar 6, 2023
8f4e34b
remove unnecessary copy if split not None
ClaudiaComito Mar 6, 2023
fc59df1
return contiguous local tensor after torch.diagonal
ClaudiaComito Mar 7, 2023
0139c65
set default copy=None for asarray
ClaudiaComito Mar 7, 2023
eae8256
Update documentation
ClaudiaComito Mar 15, 2023
dafa614
allow copy for lists and tuples
ClaudiaComito Mar 15, 2023
aafa774
expand tests
ClaudiaComito Mar 15, 2023
c65eecd
Merge branch 'main' into features/#1117-array-copy-None
ClaudiaComito Mar 29, 2023
52e2c48
Merge branch 'main' into features/#1117-array-copy-None
ClaudiaComito Mar 30, 2023
f2c49fc
interpret numpy.ndarray.dtype
ClaudiaComito Mar 30, 2023
66069b5
Merge branch 'main' into features/#1117-array-copy-None
ClaudiaComito Apr 17, 2023
f5107b8
Merge branch 'main' into features/#1117-array-copy-None
ClaudiaComito Apr 17, 2023
729b9f3
expand tests
ClaudiaComito Apr 17, 2023
d56d8b8
fix dtype in numpy obj test
ClaudiaComito Apr 18, 2023
04b9259
Merge branch 'main' into features/#1117-array-copy-None
ClaudiaComito Apr 24, 2023
25a4cb9
Merge branch 'main' into features/#1117-array-copy-None
ClaudiaComito Apr 27, 2023
ac2bb29
Merge branch 'main' into features/#1117-array-copy-None
ClaudiaComito May 22, 2023
a64eaa2
improve docs after review
ClaudiaComito May 22, 2023
ca868ea
Merge branch 'main' into features/#1117-array-copy-None
ClaudiaComito May 31, 2023
e882c56
Merge branch 'main' into features/#1117-array-copy-None
ClaudiaComito Jun 6, 2023
c9ba402
Edit documentation
ClaudiaComito Jun 6, 2023
8514070
reduce memory footprint ht.array()
ClaudiaComito Jun 12, 2023
d3fb45b
skip tests on arm64 architecture
ClaudiaComito Jun 13, 2023
ef6b86a
assign correct split on 1 process as well
ClaudiaComito Jun 13, 2023
0312f86
add copy kwarg to asarray()
ClaudiaComito Jun 13, 2023
d6ebd8d
edit asarray docs
ClaudiaComito Jun 13, 2023
7b2a267
Merge branch 'main' into features/#1117-array-copy-None
ClaudiaComito Jun 19, 2023
49197b3
skip split mismatch test on 1 process
ClaudiaComito Jun 19, 2023
17ca24f
skip split mismatch test on 1 process
ClaudiaComito Jun 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 36 additions & 18 deletions heat/core/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def arange(
def array(
obj: Iterable,
dtype: Optional[Type[datatype]] = None,
copy: bool = True,
copy: bool = None,
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
ndmin: int = 0,
order: str = "C",
split: Optional[int] = None,
Expand All @@ -172,8 +172,9 @@ def array(
to hold the objects in the sequence. This argument can only be used to ‘upcast’ the array. For downcasting, use
the :func:`~heat.core.dndarray.astype` method.
copy : bool, optional
If ``True`` (default), then the object is copied. Otherwise, a copy will only be made if obj is a nested
sequence or if a copy is needed to satisfy any of the other requirements, e.g. ``dtype``.
If ``True``, the input object is copied.
If ``False``, input which supports the buffer protocol is never copied.
If ``None`` (default), the function reuses the existing memory buffer if possible, and copies otherwise.
ndmin : int, optional
Specifies the minimum number of dimensions that the resulting array should have. Ones will, if needed, be
attached to the shape if ``ndim > 0`` and prefaced in case of ``ndim < 0`` to meet the requirement.
Expand All @@ -198,6 +199,8 @@ def array(
------
NotImplementedError
If order is one of the NumPy options ``'K'`` or ``'A'``.
ValueError
If ``copy`` is False but a copy is necessary.

ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
Examples
--------
Expand Down Expand Up @@ -285,18 +288,16 @@ def array(
[torch.LongStorage of size 6]
"""
# array already exists; no copy
if (
isinstance(obj, DNDarray)
and not copy
and (dtype is None or dtype == obj.dtype)
and (split is None or split == obj.split)
and (is_split is None or is_split == obj.split)
and (device is None or device == obj.device)
):
return obj

# extract the internal tensor in case of a heat tensor
if isinstance(obj, DNDarray):
if not copy:
if (
(dtype is None or dtype == obj.dtype)
and (split is None or split == obj.split)
and (is_split is None or is_split == obj.split)
and (device is None or device == obj.device)
):
return obj
# extract the internal tensor
obj = obj.larray

# sanitize the data type
Expand Down Expand Up @@ -324,13 +325,28 @@ def array(
except RuntimeError:
raise TypeError("invalid data of type {}".format(type(obj)))
else:
if not isinstance(obj, DNDarray):
if copy is False and not np.isscalar(obj) and not isinstance(obj, (Tuple, List)):
# Python array-API compliance, cf. https://data-apis.org/array-api/2022.12/API_specification/generated/array_api.asarray.html
if not (
(dtype is None or dtype == types.canonical_heat_type(obj.dtype))
and (
device is None
or device.torch_device
== str(getattr(obj, "device", devices.get_device().torch_device))
)
):
raise ValueError(
"argument `copy` is set to False, but copy of input object is necessary. \n Set copy=None to reuse the memory buffer whenever possible and allow for copies otherwise."
)
try:
obj = torch.as_tensor(
obj,
device=device.torch_device
if device is not None
else devices.get_device().torch_device,
)
except RuntimeError:
raise TypeError("invalid data of type {}".format(type(obj)))

# infer dtype from obj if not explicitly given
if dtype is None:
Expand Down Expand Up @@ -377,9 +393,11 @@ def array(

# content shall be split, chunk the passed data object up
if split is not None:
# only keep local slice
_, _, slices = comm.chunk(gshape, split)
obj = obj[slices].clone()
obj = sanitize_memory_layout(obj, order=order)
_ = obj[slices].clone()
del obj
obj = sanitize_memory_layout(_, order=order)
# check with the neighboring rank whether the local shape would fit into a global shape
elif is_split is not None:
obj = sanitize_memory_layout(obj, order=order)
Expand Down Expand Up @@ -489,7 +507,7 @@ def asarray(
>>> ht.asarray(a, dtype=ht.float64) is a
False
"""
return array(obj, dtype=dtype, copy=False, order=order, is_split=is_split, device=device)
return array(obj, dtype=dtype, copy=None, order=order, is_split=is_split, device=device)


def empty(
Expand Down
6 changes: 4 additions & 2 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,11 +650,13 @@ def diagonal(a: DNDarray, offset: int = 0, dim1: int = 0, dim2: int = 1) -> DNDa
split = len(shape) - 1

if a.split is None or a.split not in (dim1, dim2):
result = torch.diagonal(a.larray, offset=offset, dim1=dim1, dim2=dim2)
result = torch.diagonal(a.larray, offset=offset, dim1=dim1, dim2=dim2).contiguous()
else:
vz = 1 if a.split == dim1 else -1
off, _, _ = a.comm.chunk(a.shape, a.split)
result = torch.diagonal(a.larray, offset=offset + vz * off, dim1=dim1, dim2=dim2)
result = torch.diagonal(
a.larray, offset=offset + vz * off, dim1=dim1, dim2=dim2
).contiguous()
return factories.array(result, dtype=a.dtype, is_split=split, device=a.device, comm=a.comm)


Expand Down
4 changes: 4 additions & 0 deletions heat/core/tests/test_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,10 @@ def test_array(self):
# invalid communicator
with self.assertRaises(TypeError):
ht.array((4,), comm={})
# copy=False but copy is necessary
data = np.arange(10)
with self.assertRaises(ValueError):
ht.array(data, dtype=ht.int32, copy=False)

# data already distributed but don't match in shape
if self.get_size() > 1:
Expand Down
3 changes: 3 additions & 0 deletions heat/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,9 @@ def canonical_heat_type(a_type: Union[str, Type[datatype], Any]) -> Type[datatyp
except TypeError:
pass

# extract type of numpy.dtype
a_type = getattr(a_type, "type", a_type)

# try to look the corresponding type up
try:
return __type_mappings[a_type]
Expand Down