-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable pickling for tensor and tensor based geometry #5509
Changes from 6 commits
993ea37
55a7edc
360edc7
76cd178
c1a43c7
47c33bc
7dce8f5
ff8dc7e
7712caf
3e8e28b
a3e0f8e
6f1f084
8c3ee32
593d32a
350884f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -90,3 +90,6 @@ docs/Doxyfile | |
docs/getting_started.rst | ||
docs/docker.rst | ||
docs/tensorboard.md | ||
|
||
# test | ||
*.pkl | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,7 +46,20 @@ void pybind_core_device(py::module &m) { | |
.def("__repr__", &Device::ToString) | ||
.def("__str__", &Device::ToString) | ||
.def("get_type", &Device::GetType) | ||
.def("get_id", &Device::GetID); | ||
.def("get_id", &Device::GetID) | ||
.def(py::pickle( | ||
[](const Device &d) { | ||
return py::make_tuple(d.GetType(), d.GetID()); | ||
}, | ||
[](py::tuple t) { | ||
if (t.size() != 2) { | ||
utility::LogError( | ||
"Invalid state! Expecting a tuple of size " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a comment: "Cannot unpickle Device." It is useful for the user to know which class went wrong. Same for others. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
"2."); | ||
} | ||
return Device(t[0].cast<Device::DeviceType>(), | ||
t[1].cast<int>()); | ||
})); | ||
|
||
py::enum_<Device::DeviceType>(device, "DeviceType") | ||
.value("CPU", Device::DeviceType::CPU) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -356,6 +356,34 @@ void pybind_core_tensor(py::module& m) { | |
BindTensorFullCreation<bool>(m, tensor); | ||
docstring::ClassMethodDocInject(m, "Tensor", "full", argument_docs); | ||
|
||
// Pickling support. | ||
// The tensor will be on the same device after deserialization. | ||
// Non contiguous tensors will be converted to contiguous tensors after | ||
// deserialization. | ||
tensor.def(py::pickle( | ||
[](const Tensor& t) { | ||
// __getstate__ | ||
return py::make_tuple(t.GetDevice(), | ||
TensorToPyArray(t.To(Device("CPU:0")))); | ||
}, | ||
[](py::tuple t) { | ||
// __setstate__ | ||
if (t.size() != 2) { | ||
utility::LogError( | ||
"Invalid state! Expecting a tuple of size 2."); | ||
} | ||
const Device& device = t[0].cast<Device>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
if (device.IsCUDA() && !core::cuda::IsAvailable()) { | ||
utility::LogWarning( | ||
"CUDA is not available, tensor will be " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Change this comment as well. Same for others. |
||
"created on CPU."); | ||
return PyArrayToTensor(t[1].cast<py::array>(), true); | ||
} else { | ||
return PyArrayToTensor(t[1].cast<py::array>(), true) | ||
.To(device); | ||
} | ||
})); | ||
|
||
tensor.def_static( | ||
"eye", | ||
[](int64_t n, utility::optional<Dtype> dtype, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
import open3d.core as o3c | ||
import numpy as np | ||
import pytest | ||
import pickle | ||
|
||
import sys | ||
import os | ||
|
@@ -76,3 +77,12 @@ def test_buffer_protocol_cpu(device): | |
im = im.to(device=device) | ||
dst_t = np.asarray(im.cpu()) | ||
np.testing.assert_array_equal(src_t, dst_t) | ||
|
||
|
||
@pytest.mark.parametrize("device", list_devices()) | ||
def test_pickle(device): | ||
img = o3d.t.geometry.Image(o3c.Tensor.ones((10, 10, 3), o3c.uint8, device)) | ||
pickle.dump(img, open("img.pkl", "wb")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use import tempfile
with tempfile.TemporaryDirectory() as temp_dir:
pickle_path = xxx Same for others. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
img_load = pickle.load(open("img.pkl", "rb")) | ||
assert img_load.as_tensor().allclose(img.as_tensor()) | ||
assert img_load.device == img.device and img_load.dtype == o3c.uint8 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.