diff --git a/native/explorer/Cargo.lock b/native/explorer/Cargo.lock index 6ea3f58a6..db6a6961b 100644 --- a/native/explorer/Cargo.lock +++ b/native/explorer/Cargo.lock @@ -1216,8 +1216,7 @@ checksum = "a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244" [[package]] name = "rustler" version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d61e8ddf75de20513455d7b6f17241a595abbb01b53a6340cecc798a1b13422d" +source = "git+https://github.com/rusterlium/rustler#b673d5cb4b8f653e46e0562032d16e93b44c2337" dependencies = [ "lazy_static", "rustler_codegen", @@ -1227,8 +1226,7 @@ dependencies = [ [[package]] name = "rustler_codegen" version = "0.26.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baa2e45c0165272070f80ce93bcd7dd5407a3c84a1ef73ab9900e00f00ef3d36" +source = "git+https://github.com/rusterlium/rustler#b673d5cb4b8f653e46e0562032d16e93b44c2337" dependencies = [ "heck", "proc-macro2", @@ -1239,8 +1237,7 @@ dependencies = [ [[package]] name = "rustler_sys" version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ff26a42e62d538f82913dd34f60105ecfdffbdb25abdc3c3580b0c622285332" +source = "git+https://github.com/rusterlium/rustler#b673d5cb4b8f653e46e0562032d16e93b44c2337" dependencies = [ "regex", "unreachable", diff --git a/native/explorer/Cargo.toml b/native/explorer/Cargo.toml index 0c6292ba1..f7121428f 100644 --- a/native/explorer/Cargo.toml +++ b/native/explorer/Cargo.toml @@ -16,7 +16,7 @@ anyhow = "1" chrono = "0.4" rand = { version = "0.8.4", features = ["alloc"] } rand_pcg = "0.3.1" -rustler = "0.26.0" +rustler = { git = "https://github.com/rusterlium/rustler" } thiserror = "1" # MiMalloc won´t compile on Windows with the GCC compiler. diff --git a/native/explorer/src/datatypes.rs b/native/explorer/src/datatypes.rs index 8bc3f260b..8b7d82c3c 100644 --- a/native/explorer/src/datatypes.rs +++ b/native/explorer/src/datatypes.rs @@ -1,8 +1,7 @@ use crate::atoms; use chrono::prelude::*; use polars::prelude::*; -use rustler::resource::ResourceArc; -use rustler::{Atom, NifStruct}; +use rustler::{Atom, NifStruct, ResourceArc}; use std::convert::TryInto; pub struct ExDataFrameRef(pub DataFrame); diff --git a/native/explorer/src/encoding.rs b/native/explorer/src/encoding.rs index 9fac710e5..eb5d3e4cb 100644 --- a/native/explorer/src/encoding.rs +++ b/native/explorer/src/encoding.rs @@ -1,39 +1,17 @@ use chrono::prelude::*; use polars::prelude::*; -use rustler::{Binary, Encoder, Env, NewBinary, Term}; +use rustler::{Encoder, Env, ResourceArc, Term}; use crate::atoms::{ self, calendar, day, hour, infinity, microsecond, minute, month, nan, neg_infinity, second, year, }; -use crate::datatypes::{days_to_date, timestamp_to_datetime, ExSeriesRef}; +use crate::datatypes::{days_to_date, timestamp_to_datetime, ExSeries, ExSeriesRef}; use rustler::types::atom; -use rustler::wrapper::list; -use rustler::wrapper::{map, NIF_TERM}; +use rustler::wrapper::{binary, list, map, NIF_TERM}; -pub fn term_from_value<'b>(v: AnyValue, env: Env<'b>) -> Term<'b> { - match v { - AnyValue::Null => None::.encode(env), - AnyValue::Boolean(v) => Some(v).encode(env), - AnyValue::Utf8(v) => Some(v).encode(env), - AnyValue::Int8(v) => Some(v).encode(env), - AnyValue::Int16(v) => Some(v).encode(env), - AnyValue::Int32(v) => Some(v).encode(env), - AnyValue::Int64(v) => Some(v).encode(env), - AnyValue::UInt8(v) => Some(v).encode(env), - AnyValue::UInt16(v) => Some(v).encode(env), - AnyValue::UInt32(v) => Some(v).encode(env), - AnyValue::UInt64(v) => Some(v).encode(env), - AnyValue::Float64(v) => Some(v).encode(env), - AnyValue::Float32(v) => Some(v).encode(env), - AnyValue::Date(v) => encode_date(v, env), - AnyValue::Datetime(v, time_unit, None) => encode_datetime(v, time_unit, env), - dt => panic!("get/2 not implemented for {:?}", dt), - } -} - -// ExSeriesRef encoding +// Encoding helpers // TODO: Implement this as a regular function or encapsulate it inside Rustler. macro_rules! unsafe_iterator_to_list { @@ -208,18 +186,23 @@ fn encode_datetime_series<'b>(s: &Series, time_unit: TimeUnit, env: Env<'b>) -> } #[inline] -fn encode_utf8_series<'b>(s: &Series, env: Env<'b>) -> Term<'b> { +fn encode_utf8_series<'b>( + resource: &ResourceArc, + s: &Series, + env: Env<'b>, +) -> Term<'b> { let utf8 = s.utf8().unwrap(); - let nil_atom = atom::nil().to_term(env); let env_as_c_arg = env.as_c_arg(); + let nil_as_c_arg = atom::nil().to_term(env).as_c_arg(); let acc = unsafe { list::make_list(env_as_c_arg, &[]) }; let list = utf8.downcast_iter().rfold(acc, |acc, array| { // Create a binary per array buffer let values = array.values(); - let mut new_binary = NewBinary::new(env, values.len()); - new_binary.copy_from_slice(values.as_slice()); - let binary: Binary = new_binary.into(); + + let binary = unsafe { resource.make_binary_unsafe(env, |_| values) } + .to_term(env) + .as_c_arg(); // Offsets have one more element than values and validity, // so we read the last one as the initial accumulator and skip it. @@ -235,17 +218,16 @@ fn encode_utf8_series<'b>(s: &Series, env: Env<'b>) -> Term<'b> { iter.rfold(acc, |acc, uncast_offset| { let offset = *uncast_offset as NIF_TERM; - let term = if validity_iter.next_back().unwrap_or(true) { - binary - .make_subbinary(offset, last_offset - offset) - .unwrap() - .to_term(env) + let term_as_c_arg = if validity_iter.next_back().unwrap_or(true) { + unsafe { + binary::make_subbinary(env_as_c_arg, binary, offset, last_offset - offset) + } } else { - nil_atom + nil_as_c_arg }; last_offset = offset; - unsafe { list::make_list_cell(env_as_c_arg, term.as_c_arg(), acc) } + unsafe { list::make_list_cell(env_as_c_arg, term_as_c_arg, acc) } }) }); @@ -315,23 +297,45 @@ macro_rules! encode_list { }; } -impl Encoder for ExSeriesRef { - fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { - let s = &self.0; - match s.dtype() { - DataType::Boolean => encode!(s, env, bool), - DataType::Int32 => encode!(s, env, i32), - DataType::Int64 => encode!(s, env, i64), - DataType::UInt8 => encode!(s, env, u8), - DataType::UInt32 => encode!(s, env, u32), - DataType::Utf8 => encode_utf8_series(s, env), - DataType::Float64 => encode_float64_series(s, env), - DataType::Date => encode_date_series(s, env), - DataType::Datetime(time_unit, None) => encode_datetime_series(s, *time_unit, env), - DataType::List(t) if t as &DataType == &DataType::UInt32 => { - encode_list!(s, env, u32, u32) - } - dt => panic!("to_list/1 not implemented for {:?}", dt), +// API + +pub fn term_from_value<'b>(v: AnyValue, env: Env<'b>) -> Term<'b> { + match v { + AnyValue::Null => None::.encode(env), + AnyValue::Boolean(v) => Some(v).encode(env), + AnyValue::Utf8(v) => Some(v).encode(env), + AnyValue::Int8(v) => Some(v).encode(env), + AnyValue::Int16(v) => Some(v).encode(env), + AnyValue::Int32(v) => Some(v).encode(env), + AnyValue::Int64(v) => Some(v).encode(env), + AnyValue::UInt8(v) => Some(v).encode(env), + AnyValue::UInt16(v) => Some(v).encode(env), + AnyValue::UInt32(v) => Some(v).encode(env), + AnyValue::UInt64(v) => Some(v).encode(env), + AnyValue::Float64(v) => Some(v).encode(env), + AnyValue::Float32(v) => Some(v).encode(env), + AnyValue::Date(v) => encode_date(v, env), + AnyValue::Datetime(v, time_unit, None) => encode_datetime(v, time_unit, env), + dt => panic!("get/2 not implemented for {:?}", dt), + } +} + +pub fn list_from_series(data: ExSeries, env: Env) -> Term { + let s = &data.resource.0; + + match s.dtype() { + DataType::Boolean => encode!(s, env, bool), + DataType::Int32 => encode!(s, env, i32), + DataType::Int64 => encode!(s, env, i64), + DataType::UInt8 => encode!(s, env, u8), + DataType::UInt32 => encode!(s, env, u32), + DataType::Utf8 => encode_utf8_series(&data.resource, s, env), + DataType::Float64 => encode_float64_series(s, env), + DataType::Date => encode_date_series(s, env), + DataType::Datetime(time_unit, None) => encode_datetime_series(s, *time_unit, env), + DataType::List(t) if t as &DataType == &DataType::UInt32 => { + encode_list!(s, env, u32, u32) } + dt => panic!("to_list/1 not implemented for {:?}", dt), } } diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index 4174e95fa..86b5792a4 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -1,7 +1,6 @@ use crate::{ datatypes::{ExDate, ExDateTime}, - encoding::term_from_value, - ExDataFrame, ExSeries, ExSeriesRef, ExplorerError, + encoding, ExDataFrame, ExSeries, ExplorerError, }; use polars::prelude::*; @@ -443,8 +442,7 @@ pub fn rolling_opts( #[rustler::nif(schedule = "DirtyCpu")] pub fn s_to_list(env: Env, data: ExSeries) -> Result { - let s = ExSeriesRef(data.resource.0.clone()); - Ok(s.encode(env)) + Ok(encoding::list_from_series(data, env)) } #[rustler::nif(schedule = "DirtyCpu")] @@ -574,7 +572,7 @@ pub fn s_std(env: Env, data: ExSeries) -> Result { #[rustler::nif] pub fn s_get(env: Env, data: ExSeries, idx: usize) -> Result { let s = &data.resource.0; - Ok(term_from_value(s.get(idx), env)) + Ok(encoding::term_from_value(s.get(idx), env)) } #[rustler::nif(schedule = "DirtyCpu")] @@ -616,7 +614,7 @@ pub fn s_quantile<'a>( Some(microseconds) => Ok(ExDateTime::from(microseconds as i64).encode(env)), } } - _ => Ok(term_from_value( + _ => Ok(encoding::term_from_value( s.quantile_as_series(quantile, strategy)? .cast(dtype)? .get(0),