Skip to content

Commit

Permalink
torch.from_numpy for complex dtypes (pytorch#35531)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#35531

Differential Revision: D20693581

Pulled By: anjali411

fbshipit-source-id: d53e26b4175452fa00b287efbfceea18104c1364
  • Loading branch information
anjali411 authored and facebook-github-bot committed Mar 27, 2020
1 parent f101949 commit 96eec95
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 14 deletions.
2 changes: 2 additions & 0 deletions caffe2/python/pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ int CaffeToNumpyType(const TypeMeta& meta) {
{TypeMeta::Id<bool>(), NPY_BOOL},
{TypeMeta::Id<double>(), NPY_DOUBLE},
{TypeMeta::Id<float>(), NPY_FLOAT},
{TypeMeta::Id<std::complex<double>>(), NPY_COMPLEX128},
{TypeMeta::Id<std::complex<float>>(), NPY_COMPLEX64},
{TypeMeta::Id<at::Half>(), NPY_FLOAT16},
{TypeMeta::Id<int>(), NPY_INT},
{TypeMeta::Id<int8_t>(), NPY_INT8},
Expand Down
23 changes: 16 additions & 7 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4527,6 +4527,8 @@ def test_from_numpy(self):
np.double,
np.float,
np.float16,
np.complex64,
np.complex128,
np.int64,
np.int32,
np.int16,
Expand All @@ -4535,22 +4537,29 @@ def test_from_numpy(self):
np.longlong,
np.bool,
]
complex_dtypes = [
np.complex64,
np.complex128,
]

for dtype in dtypes:
array = np.array([1, 2, 3, 4], dtype=dtype)
tensor_from_array = torch.from_numpy(array)
# TODO: change to tensor equality check once HalfTensor
# implements `==`
for i in range(len(array)):
self.assertEqual(tensor_from_array[i], array[i])
# This is a special test case for Windows
# https://github.com/pytorch/pytorch/issues/22615
array2 = array % 2
tensor_from_array2 = torch.from_numpy(array2)
for i in range(len(array2)):
self.assertEqual(tensor_from_array2[i], array2[i])
# ufunc 'remainder' not supported for complex dtypes
if dtype not in complex_dtypes:
# This is a special test case for Windows
# https://github.com/pytorch/pytorch/issues/22615
array2 = array % 2
tensor_from_array2 = torch.from_numpy(array2)
for i in range(len(array2)):
self.assertEqual(tensor_from_array2[i], array2[i])

# Test unsupported type
array = np.array([1, 2, 3, 4], dtype=np.complex)
array = np.array([1, 2, 3, 4], dtype=np.uint16)
with self.assertRaises(TypeError):
tensor_from_array = torch.from_numpy(array)

Expand Down
9 changes: 5 additions & 4 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def merge_dicts(*dicts):
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
returned Tensor. Default: ``torch.contiguous_format``.
"""))

Expand All @@ -86,7 +86,7 @@ def merge_dicts(*dicts):
returned tensor. Default: ``False``.
pin_memory (bool, optional): If set, returned tensor would be allocated in
the pinned memory. Works only for CPU tensors. Default: ``False``.
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
memory_format (:class:`torch.memory_format`, optional): the desired memory format of
returned Tensor. Default: ``torch.preserve_format``.
""")

Expand Down Expand Up @@ -2199,8 +2199,9 @@ def merge_dicts(*dicts):
tensor is not resizable.
It currently accepts :attr:`ndarray` with dtypes of ``numpy.float64``,
``numpy.float32``, ``numpy.float16``, ``numpy.int64``, ``numpy.int32``,
``numpy.int16``, ``numpy.int8``, ``numpy.uint8``, and ``numpy.bool``.
``numpy.float32``, ``numpy.float16``, ``numpy.complex64``, ``numpy.complex128``,
``numpy.int64``, ``numpy.int32``, ``numpy.int16``, ``numpy.int8``, ``numpy.uint8``,
and ``numpy.bool``.
Example::
Expand Down
8 changes: 5 additions & 3 deletions torch/csrc/utils/tensor_numpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ at::Tensor tensor_from_numpy(PyObject* obj) {

int aten_to_numpy_dtype(const ScalarType scalar_type) {
switch (scalar_type) {
case kComplexDouble: return NPY_COMPLEX128;
case kComplexFloat: return NPY_COMPLEX64;
case kDouble: return NPY_DOUBLE;
case kFloat: return NPY_FLOAT;
case kHalf: return NPY_HALF;
case kComplexDouble: return NPY_COMPLEX128;
case kComplexFloat: return NPY_COMPLEX64;
case kLong: return NPY_INT64;
case kInt: return NPY_INT32;
case kShort: return NPY_INT16;
Expand All @@ -211,6 +211,8 @@ ScalarType numpy_dtype_to_aten(int dtype) {
case NPY_DOUBLE: return kDouble;
case NPY_FLOAT: return kFloat;
case NPY_HALF: return kHalf;
case NPY_COMPLEX64: return kComplexFloat;
case NPY_COMPLEX128: return kComplexDouble;
case NPY_INT16: return kShort;
case NPY_INT8: return kChar;
case NPY_UINT8: return kByte;
Expand All @@ -236,7 +238,7 @@ ScalarType numpy_dtype_to_aten(int dtype) {
if (!pytype) throw python_error();
throw TypeError(
"can't convert np.ndarray of type %s. The only supported types are: "
"float64, float32, float16, int64, int32, int16, int8, uint8, and bool.",
"float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.",
((PyTypeObject*)pytype.get())->tp_name);
}

Expand Down

0 comments on commit 96eec95

Please sign in to comment.