Skip to content

Commit

Permalink
update pyo3 to new bounds api
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael-J-Ward committed Apr 16, 2024
1 parent b506cfc commit 7ebeaa8
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 48 deletions.
20 changes: 10 additions & 10 deletions apis/python/node/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,17 @@ impl Node {
&mut self,
output_id: String,
data: PyObject,
metadata: Option<&PyDict>,
metadata: Option<Bound<'_, PyDict>>,
py: Python,
) -> eyre::Result<()> {
let parameters = pydict_to_metadata(metadata)?;

if let Ok(py_bytes) = data.downcast::<PyBytes>(py) {
if let Ok(py_bytes) = data.downcast_bound::<PyBytes>(py) {
let data = py_bytes.as_bytes();
self.node
.send_output_bytes(output_id.into(), parameters, data.len(), data)
.wrap_err("failed to send output")?;
} else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow(data.as_ref(py)) {
} else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow_bound(data.bind(py)) {
self.node.send_output(
output_id.into(),
parameters,
Expand Down Expand Up @@ -203,15 +203,15 @@ pub fn start_runtime() -> eyre::Result<()> {
}

#[pymodule]
fn dora(py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(start_runtime, m)?)?;
fn dora(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(start_runtime, &m)?)?;
m.add_class::<Node>().unwrap();

let ros2_bridge = PyModule::new(py, "ros2_bridge")?;
dora_ros2_bridge_python::create_dora_ros2_bridge_module(ros2_bridge)?;
let experimental = PyModule::new(py, "experimental")?;
experimental.add_submodule(ros2_bridge)?;
m.add_submodule(experimental)?;
let ros2_bridge = PyModule::new_bound(py, "ros2_bridge")?;
dora_ros2_bridge_python::create_dora_ros2_bridge_module(&ros2_bridge)?;
let experimental = PyModule::new_bound(py, "experimental")?;
experimental.add_submodule(&ros2_bridge)?;
m.add_submodule(&experimental)?;

Ok(())
}
16 changes: 10 additions & 6 deletions apis/python/operator/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use arrow::{array::ArrayRef, pyarrow::ToPyArrow};
use dora_node_api::{merged::MergedEvent, Event, Metadata, MetadataParameters};
use eyre::{Context, Result};
use pyo3::{exceptions::PyLookupError, prelude::*, types::PyDict};
use pyo3::{exceptions::PyLookupError, prelude::*, pybacked::PyBackedStr, types::PyDict};

#[pyclass]
pub struct PyEvent {
Expand Down Expand Up @@ -110,11 +110,15 @@ impl From<MergedEvent<PyObject>> for PyEvent {
}
}

pub fn pydict_to_metadata(dict: Option<&PyDict>) -> Result<MetadataParameters> {
pub fn pydict_to_metadata(dict: Option<Bound<'_, PyDict>>) -> Result<MetadataParameters> {
let mut default_metadata = MetadataParameters::default();
if let Some(metadata) = dict {
for (key, value) in metadata.iter() {
match key.extract::<&str>().context("Parsing metadata keys")? {
match key
.extract::<PyBackedStr>()
.context("Parsing metadata keys")?
.as_ref()
{
"watermark" => {
default_metadata.watermark =
value.extract().context("parsing watermark failed")?;
Expand All @@ -124,7 +128,7 @@ pub fn pydict_to_metadata(dict: Option<&PyDict>) -> Result<MetadataParameters> {
value.extract().context("parsing deadline failed")?;
}
"open_telemetry_context" => {
let otel_context: &str = value
let otel_context: PyBackedStr = value
.extract()
.context("parsing open telemetry context failed")?;
default_metadata.open_telemetry_context = otel_context.to_string();
Expand All @@ -136,8 +140,8 @@ pub fn pydict_to_metadata(dict: Option<&PyDict>) -> Result<MetadataParameters> {
Ok(default_metadata)
}

pub fn metadata_to_pydict<'a>(metadata: &'a Metadata, py: Python<'a>) -> &'a PyDict {
let dict = PyDict::new(py);
pub fn metadata_to_pydict<'a>(metadata: &'a Metadata, py: Python<'a>) -> pyo3::Bound<'a, PyDict> {
let dict = PyDict::new_bound(py);
dict.set_item(
"open_telemetry_context",
&metadata.parameters.open_telemetry_context,
Expand Down
2 changes: 1 addition & 1 deletion apis/rust/operator/types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ pub fn dora_free_input_id(_input_id: char_p_boxed) {}
#[ffi_export]
pub fn dora_read_data(input: &mut Input) -> Option<safer_ffi::Vec<u8>> {
let data_array = input.data_array.take()?;
let data = unsafe {arrow::ffi::from_ffi(data_array, &input.schema).ok()? };
let data = unsafe { arrow::ffi::from_ffi(data_array, &input.schema).ok()? };
let array = ArrowData(arrow::array::make_array(data));
let bytes: &[u8] = TryFrom::try_from(&array).ok()?;
Some(bytes.to_owned().into())
Expand Down
32 changes: 17 additions & 15 deletions binaries/runtime/src/operator/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use dora_operator_api_types::DoraStatus;
use eyre::{bail, eyre, Context, Result};
use pyo3::{
pyclass,
types::{IntoPyDict, PyDict},
types::{IntoPyDict, PyAnyMethods, PyDict, PyTracebackMethods},
Py, PyAny, Python,
};
use std::{
Expand All @@ -23,7 +23,7 @@ use tokio::sync::{mpsc::Sender, oneshot};
use tracing::{error, field, span, warn};

fn traceback(err: pyo3::PyErr) -> eyre::Report {
let traceback = Python::with_gil(|py| err.traceback(py).and_then(|t| t.format().ok()));
let traceback = Python::with_gil(|py| err.traceback_bound(py).and_then(|t| t.format().ok()));
if let Some(traceback) = traceback {
eyre::eyre!("{traceback}\n{err}")
} else {
Expand Down Expand Up @@ -78,7 +78,9 @@ pub fn run(
let parent_path = parent_path
.to_str()
.ok_or_else(|| eyre!("module path is not valid utf8"))?;
let sys = py.import("sys").wrap_err("failed to import `sys` module")?;
let sys = py
.import_bound("sys")
.wrap_err("failed to import `sys` module")?;
let sys_path = sys
.getattr("path")
.wrap_err("failed to import `sys.path` module")?;
Expand All @@ -90,14 +92,14 @@ pub fn run(
.wrap_err("failed to append module path to python search path")?;
}

let module = py.import(module_name).map_err(traceback)?;
let module = py.import_bound(module_name).map_err(traceback)?;
let operator_class = module
.getattr("Operator")
.wrap_err("no `Operator` class found in module")?;

let locals = [("Operator", operator_class)].into_py_dict(py);
let locals = [("Operator", operator_class)].into_py_dict_bound(py);
let operator = py
.eval("Operator()", None, Some(locals))
.eval_bound("Operator()", None, Some(&locals))
.map_err(traceback)?;
operator.setattr(
"dataflow_descriptor",
Expand Down Expand Up @@ -140,11 +142,11 @@ pub fn run(
.wrap_err("could not extract operator state as a PyDict")?;
// Reload module
let module = py
.import(module_name)
.import_bound(module_name)
.map_err(traceback)
.wrap_err(format!("Could not retrieve {module_name} while reloading"))?;
let importlib = py
.import("importlib")
.import_bound("importlib")
.wrap_err("failed to import `importlib` module")?;
let module = importlib
.call_method("reload", (module,), None)
Expand All @@ -154,9 +156,9 @@ pub fn run(
.wrap_err("no `Operator` class found in module")?;

// Create a new reloaded operator
let locals = [("Operator", reloaded_operator_class)].into_py_dict(py);
let locals = [("Operator", reloaded_operator_class)].into_py_dict_bound(py);
let operator: Py<pyo3::PyAny> = py
.eval("Operator()", None, Some(locals))
.eval_bound("Operator()", None, Some(&locals))
.map_err(traceback)
.wrap_err("Could not initialize reloaded operator")?
.into();
Expand Down Expand Up @@ -299,8 +301,8 @@ mod callback_impl {
use eyre::{eyre, Context, Result};
use pyo3::{
pymethods,
types::{PyBytes, PyDict},
PyObject, Python,
types::{PyBytes, PyBytesMethods, PyDict},
Bound, PyObject, Python,
};
use tokio::sync::oneshot;
use tracing::{field, span};
Expand All @@ -317,7 +319,7 @@ mod callback_impl {
&mut self,
output: &str,
data: PyObject,
metadata: Option<&PyDict>,
metadata: Option<Bound<'_, PyDict>>,
py: Python,
) -> Result<()> {
let parameters = pydict_to_metadata(metadata)
Expand Down Expand Up @@ -353,12 +355,12 @@ mod callback_impl {
}
};

let (sample, type_info) = if let Ok(py_bytes) = data.downcast::<PyBytes>(py) {
let (sample, type_info) = if let Ok(py_bytes) = data.downcast_bound::<PyBytes>(py) {
let data = py_bytes.as_bytes();
let mut sample = allocate_sample(data.len())?;
sample.copy_from_slice(data);
(sample, ArrowTypeInfo::byte_array(data.len()))
} else if let Ok(arrow_array) = ArrayData::from_pyarrow(data.as_ref(py)) {
} else if let Ok(arrow_array) = ArrayData::from_pyarrow_bound(data.bind(py)) {
let total_len = required_data_size(&arrow_array);
let mut sample = allocate_sample(total_len)?;

Expand Down
16 changes: 8 additions & 8 deletions libraries/extensions/ros2-bridge/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use eyre::{eyre, Context, ContextCompat};
use futures::{Stream, StreamExt};
use pyo3::{
prelude::{pyclass, pymethods},
types::{PyDict, PyList, PyModule},
PyAny, PyObject, PyResult, Python,
types::{PyAnyMethods, PyDict, PyList, PyModule, PyModuleMethods},
Bound, PyAny, PyObject, PyResult, Python,
};
use typed::{deserialize::StructDeserializer, TypeInfo, TypedValue};

Expand Down Expand Up @@ -194,8 +194,8 @@ pub struct Ros2Publisher {

#[pymethods]
impl Ros2Publisher {
pub fn publish(&self, data: &PyAny) -> eyre::Result<()> {
let pyarrow = PyModule::import(data.py(), "pyarrow")?;
pub fn publish(&self, data: Bound<'_, PyAny>) -> eyre::Result<()> {
let pyarrow = PyModule::import_bound(data.py(), "pyarrow")?;

let data = if data.is_instance_of::<PyDict>() {
// convert to arrow struct scalar
Expand All @@ -204,15 +204,15 @@ impl Ros2Publisher {
data
};

let data = if data.is_instance(pyarrow.getattr("StructScalar")?)? {
let data = if data.is_instance(&pyarrow.getattr("StructScalar")?)? {
// convert to arrow array
let list = PyList::new(data.py(), [data]);
let list = PyList::new_bound(data.py(), [data]);
pyarrow.getattr("array")?.call1((list,))?
} else {
data
};

let value = arrow::array::ArrayData::from_pyarrow(data)?;
let value = arrow::array::ArrayData::from_pyarrow_bound(&data)?;
//// add type info to ensure correct serialization (e.g. struct types
//// and map types need to be serialized differently)
let typed_value = TypedValue {
Expand Down Expand Up @@ -297,7 +297,7 @@ impl Stream for Ros2SubscriptionStream {
}
}

pub fn create_dora_ros2_bridge_module(m: &PyModule) -> PyResult<()> {
pub fn create_dora_ros2_bridge_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<Ros2Context>()?;
m.add_class::<Ros2Node>()?;
m.add_class::<Ros2NodeOptions>()?;
Expand Down
18 changes: 10 additions & 8 deletions libraries/extensions/ros2-bridge/python/src/typed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ mod tests {
use arrow::pyarrow::ToPyArrow;

use pyo3::types::IntoPyDict;
use pyo3::types::PyAnyMethods;
use pyo3::types::PyDict;
use pyo3::types::PyList;
use pyo3::types::PyModule;
use pyo3::types::PyTuple;
use pyo3::PyNativeType;
use pyo3::Python;
use serde::de::DeserializeSeed;
use serde::Serialize;
Expand All @@ -61,13 +63,13 @@ mod tests {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); //.join("test_utils.py"); // Adjust this path as needed

// Add the Python module's directory to sys.path
py.run(
py.run_bound(
"import sys; sys.path.append(str(path))",
Some([("path", path)].into_py_dict(py)),
Some(&[("path", path)].into_py_dict_bound(py)),
None,
)?;

let my_module = PyModule::import(py, "test_utils")?;
let my_module = PyModule::import_bound(py, "test_utils")?;

let arrays: &PyList = my_module.getattr("TEST_ARRAYS")?.extract()?;
for array_wrapper in arrays.iter() {
Expand All @@ -77,7 +79,7 @@ mod tests {
println!("Checking {}::{}", package_name, message_name);
let in_pyarrow = arrays.get_item(2)?;

let array = arrow::array::ArrayData::from_pyarrow(in_pyarrow)?;
let array = arrow::array::ArrayData::from_pyarrow_bound(&in_pyarrow.as_borrowed())?;
let type_info = TypeInfo {
package_name: package_name.into(),
message_name: message_name.clone().into(),
Expand All @@ -99,17 +101,17 @@ mod tests {

let out_pyarrow = out_value.to_pyarrow(py)?;

let test_utils = PyModule::import(py, "test_utils")?;
let context = PyDict::new(py);
let test_utils = PyModule::import_bound(py, "test_utils")?;
let context = PyDict::new_bound(py);

context.set_item("test_utils", test_utils)?;
context.set_item("in_pyarrow", in_pyarrow)?;
context.set_item("out_pyarrow", out_pyarrow)?;

let _ = py
.eval(
.eval_bound(
"test_utils.is_subset(in_pyarrow, out_pyarrow)",
Some(context),
Some(&context),
None,
)
.context("could not check if it is a subset")?;
Expand Down

0 comments on commit 7ebeaa8

Please sign in to comment.