diff --git a/lib/explorer/backend/lazy_series.ex b/lib/explorer/backend/lazy_series.ex index 6e14ea3e1..c5dbcb74a 100644 --- a/lib/explorer/backend/lazy_series.ex +++ b/lib/explorer/backend/lazy_series.ex @@ -135,7 +135,9 @@ defmodule Explorer.Backend.LazySeries do minute: 1, second: 1, # List functions - join: 2 + join: 2, + lengths: 1, + member: 3 ] @comparison_operations [:equal, :not_equal, :greater, :greater_equal, :less, :less_equal] @@ -990,6 +992,20 @@ defmodule Explorer.Backend.LazySeries do Backend.Series.new(data, :string) end + @impl true + def lengths(series) do + data = new(:lengths, [lazy_series!(series)], :integer) + + Backend.Series.new(data, :integer) + end + + @impl true + def member?(%Series{dtype: {:list, inner_dtype}} = series, value) do + data = new(:member, [lazy_series!(series), value, inner_dtype], :boolean) + + Backend.Series.new(data, :boolean) + end + @remaining_non_lazy_operations [ at: 2, at_every: 2, diff --git a/lib/explorer/backend/series.ex b/lib/explorer/backend/series.ex index 86e2274a0..d3f603e77 100644 --- a/lib/explorer/backend/series.ex +++ b/lib/explorer/backend/series.ex @@ -269,6 +269,8 @@ defmodule Explorer.Backend.Series do # List @callback join(s, String.t()) :: s + @callback lengths(s) :: s + @callback member?(s, valid_types()) :: s # Functions diff --git a/lib/explorer/polars_backend/expression.ex b/lib/explorer/polars_backend/expression.ex index 65886d2ce..6d94f5b9a 100644 --- a/lib/explorer/polars_backend/expression.ex +++ b/lib/explorer/polars_backend/expression.ex @@ -134,7 +134,9 @@ defmodule Explorer.PolarsBackend.Expression do split: 2, # Lists - join: 2 + join: 2, + lengths: 1, + member: 3 ] @custom_expressions [ diff --git a/lib/explorer/polars_backend/native.ex b/lib/explorer/polars_backend/native.ex index 2c08f5d7e..e51b6143d 100644 --- a/lib/explorer/polars_backend/native.ex +++ b/lib/explorer/polars_backend/native.ex @@ -401,6 +401,8 @@ defmodule Explorer.PolarsBackend.Native do def s_atan(_s), do: err() def s_join(_s, _separator), do: err() + def s_lengths(_s), do: err() + def s_member(_s, _value, _inner_dtype), do: err() defp err, do: :erlang.nif_error(:nif_not_loaded) end diff --git a/lib/explorer/polars_backend/series.ex b/lib/explorer/polars_backend/series.ex index 94a1f3cb3..8de5e9439 100644 --- a/lib/explorer/polars_backend/series.ex +++ b/lib/explorer/polars_backend/series.ex @@ -653,6 +653,14 @@ defmodule Explorer.PolarsBackend.Series do def join(series, separator), do: Shared.apply_series(series, :s_join, [separator]) + @impl true + def lengths(series), + do: Shared.apply_series(series, :s_lengths) + + @impl true + def member?(%Series{dtype: {:list, inner_dtype}} = series, value), + do: Shared.apply_series(series, :s_member, [value, inner_dtype]) + # Polars specific functions def name(series), do: Shared.apply_series(series, :s_name) diff --git a/lib/explorer/series.ex b/lib/explorer/series.ex index bb005f72e..6e81eec4c 100644 --- a/lib/explorer/series.ex +++ b/lib/explorer/series.ex @@ -5431,6 +5431,48 @@ defmodule Explorer.Series do def join(%Series{dtype: dtype}, _separator), do: dtype_error("join/2", dtype, [{:list, :string}]) + @doc """ + Calculates the length of each list in a list series. + + ## Examples + + iex> s = Series.from_list([[1], [1, 2]]) + iex> Series.lengths(s) + #Explorer.Series< + Polars[2] + integer [1, 2] + > + + """ + @doc type: :list_wise + @spec lengths(Series.t()) :: Series.t() + def lengths(%Series{dtype: {:list, _}} = series), + do: apply_series(series, :lengths) + + def lengths(%Series{dtype: dtype}), + do: dtype_error("lengths/1", dtype, [{:list, :_}]) + + @doc """ + Checks for the presence of a value in a list series. + + ## Examples + + iex> s = Series.from_list([[1], [1, 2]]) + iex> Series.member?(s, 2) + #Explorer.Series< + Polars[2] + boolean [false, true] + > + + """ + @doc type: :list_wise + @spec member?(Series.t(), Explorer.Backend.Series.valid_types()) :: Series.t() + def member?(%Series{dtype: {:list, _}} = series, value), + do: apply_series(series, :member?, [value]) + + def member?(%Series{dtype: dtype}, _value), + do: dtype_error("member?/2", dtype, [{:list, :_}]) + # Escape hatch @doc """ diff --git a/native/explorer/src/datatypes.rs b/native/explorer/src/datatypes.rs index 0c9f4e93b..699764e37 100644 --- a/native/explorer/src/datatypes.rs +++ b/native/explorer/src/datatypes.rs @@ -213,6 +213,12 @@ impl From for ExDate { } } +impl Literal for ExDate { + fn lit(self) -> Expr { + NaiveDate::from(self).lit().dt().date() + } +} + #[derive(NifStruct, Copy, Clone, Debug)] #[module = "Explorer.Duration"] pub struct ExDuration { @@ -226,6 +232,30 @@ impl From for i64 { } } +impl Literal for ExDuration { + fn lit(self) -> Expr { + // Note: it's tempting to use `.lit()` on a `chrono::Duration` struct in this function, but + // doing so will lose precision information as `chrono::Duration`s have no time units. + Expr::Literal(LiteralValue::Duration( + self.value, + time_unit_of_ex_duration(&self), + )) + } +} + +fn time_unit_of_ex_duration(duration: &ExDuration) -> TimeUnit { + let precision = duration.precision; + if precision == atoms::millisecond() { + TimeUnit::Milliseconds + } else if precision == atoms::microsecond() { + TimeUnit::Microseconds + } else if precision == atoms::nanosecond() { + TimeUnit::Nanoseconds + } else { + panic!("unrecognized precision: {precision:?}") + } +} + #[derive(NifStruct, Copy, Clone, Debug)] #[module = "NaiveDateTime"] pub struct ExDateTime { @@ -318,6 +348,12 @@ impl From for ExDateTime { } } +impl Literal for ExDateTime { + fn lit(self) -> Expr { + NaiveDateTime::from(self).lit() + } +} + #[derive(NifStruct, Copy, Clone, Debug)] #[module = "Time"] pub struct ExTime { @@ -379,6 +415,83 @@ impl From for ExTime { } } +impl Literal for ExTime { + fn lit(self) -> Expr { + Expr::Literal(LiteralValue::Time(self.into())) + } +} + +/// Represents valid Elixir types that can be used as literals in Polars. +pub enum ExValidValue<'a> { + I64(i64), + F64(f64), + Bool(bool), + Str(&'a str), + Date(ExDate), + Time(ExTime), + DateTime(ExDateTime), + Duration(ExDuration), +} + +impl<'a> ExValidValue<'a> { + pub fn lit_with_matching_precision(self, data_type: &DataType) -> Expr { + match data_type { + DataType::Datetime(time_unit, _) => self.lit().dt().cast_time_unit(*time_unit), + DataType::Duration(time_unit) => self.lit().dt().cast_time_unit(*time_unit), + _ => self.lit(), + } + } +} + +impl<'a> Literal for &ExValidValue<'a> { + fn lit(self) -> Expr { + match self { + ExValidValue::I64(v) => v.lit(), + ExValidValue::F64(v) => v.lit(), + ExValidValue::Bool(v) => v.lit(), + ExValidValue::Str(v) => v.lit(), + ExValidValue::Date(v) => v.lit(), + ExValidValue::Time(v) => v.lit(), + ExValidValue::DateTime(v) => v.lit(), + ExValidValue::Duration(v) => v.lit(), + } + } +} + +impl<'a> rustler::Decoder<'a> for ExValidValue<'a> { + fn decode(term: rustler::Term<'a>) -> rustler::NifResult { + use rustler::*; + + match term.get_type() { + TermType::Atom => term.decode::().map(ExValidValue::Bool), + TermType::Binary => term.decode::<&'a str>().map(ExValidValue::Str), + TermType::Number => { + if let Ok(i) = term.decode::() { + Ok(ExValidValue::I64(i)) + } else if let Ok(f) = term.decode::() { + Ok(ExValidValue::F64(f)) + } else { + Err(rustler::Error::BadArg) + } + } + TermType::Map => { + if let Ok(date) = term.decode::() { + Ok(ExValidValue::Date(date)) + } else if let Ok(time) = term.decode::() { + Ok(ExValidValue::Time(time)) + } else if let Ok(datetime) = term.decode::() { + Ok(ExValidValue::DateTime(datetime)) + } else if let Ok(duration) = term.decode::() { + Ok(ExValidValue::Duration(duration)) + } else { + Err(rustler::Error::BadArg) + } + } + _ => Err(rustler::Error::BadArg), + } + } +} + // In Elixir this would be represented like this: // * `:uncompressed` for `ExParquetCompression::Uncompressed` // * `{:brotli, 7}` for `ExParquetCompression::Brotli(Some(7))` diff --git a/native/explorer/src/expressions.rs b/native/explorer/src/expressions.rs index 63aceb179..4c0564b64 100644 --- a/native/explorer/src/expressions.rs +++ b/native/explorer/src/expressions.rs @@ -4,13 +4,12 @@ // or an expression and returns an expression that is // wrapped in an Elixir struct. -use chrono::{NaiveDate, NaiveDateTime}; -use polars::lazy::dsl::{col, concat_str, cov, pearson_corr, when, Expr, StrptimeOptions}; -use polars::prelude::{DataType, Literal, TimeUnit}; -use polars::prelude::{IntoLazy, LiteralValue, SortOptions}; +use polars::prelude::{ + col, concat_str, cov, pearson_corr, when, IntoLazy, LiteralValue, SortOptions, +}; +use polars::prelude::{DataType, Expr, Literal, StrptimeOptions, TimeUnit}; -use crate::atoms::{microsecond, millisecond, nanosecond}; -use crate::datatypes::{ExDate, ExDateTime, ExDuration, ExSeriesDtype}; +use crate::datatypes::{ExDate, ExDateTime, ExDuration, ExSeriesDtype, ExValidValue}; use crate::series::{cast_str_to_f64, ewm_opts, rolling_opts}; use crate::{ExDataFrame, ExExpr, ExSeries}; @@ -54,38 +53,17 @@ pub fn expr_atom(atom: &str) -> ExExpr { #[rustler::nif] pub fn expr_date(date: ExDate) -> ExExpr { - let naive_date = NaiveDate::from(date); - let expr = naive_date.lit().dt().date(); - ExExpr::new(expr) + ExExpr::new(date.lit()) } #[rustler::nif] pub fn expr_datetime(datetime: ExDateTime) -> ExExpr { - let naive_datetime = NaiveDateTime::from(datetime); - let expr = naive_datetime.lit(); - ExExpr::new(expr) + ExExpr::new(datetime.lit()) } #[rustler::nif] pub fn expr_duration(duration: ExDuration) -> ExExpr { - // Note: it's tempting to use `.lit()` on a `chrono::Duration` struct in this function, but - // doing so will lose precision information as `chrono::Duration`s have no time units. - let time_unit = time_unit_of_ex_duration(duration); - let expr = Expr::Literal(LiteralValue::Duration(duration.value, time_unit)); - ExExpr::new(expr) -} - -fn time_unit_of_ex_duration(duration: ExDuration) -> TimeUnit { - let precision = duration.precision; - if precision == millisecond() { - TimeUnit::Milliseconds - } else if precision == microsecond() { - TimeUnit::Microseconds - } else if precision == nanosecond() { - TimeUnit::Nanoseconds - } else { - panic!("unrecognized precision: {precision:?}") - } + ExExpr::new(duration.lit()) } #[rustler::nif] @@ -977,3 +955,28 @@ pub fn expr_second(expr: ExExpr) -> ExExpr { ExExpr::new(expr.dt().second().cast(DataType::Int64)) } + +#[rustler::nif] +pub fn expr_join(expr: ExExpr, sep: String) -> ExExpr { + let expr = expr.clone_inner(); + + ExExpr::new(expr.list().join(sep.lit())) +} + +#[rustler::nif] +pub fn expr_lengths(expr: ExExpr) -> ExExpr { + let expr = expr.clone_inner(); + + ExExpr::new(expr.list().len().cast(DataType::Int64)) +} + +#[rustler::nif] +pub fn expr_member(expr: ExExpr, value: ExValidValue, inner_dtype: ExSeriesDtype) -> ExExpr { + let expr = expr.clone_inner(); + let inner_dtype = DataType::try_from(&inner_dtype).unwrap(); + + ExExpr::new( + expr.list() + .contains(value.lit_with_matching_precision(&inner_dtype)), + ) +} diff --git a/native/explorer/src/lib.rs b/native/explorer/src/lib.rs index b7e48a5a0..9d788f25f 100644 --- a/native/explorer/src/lib.rs +++ b/native/explorer/src/lib.rs @@ -267,6 +267,10 @@ rustler::init!( expr_round, expr_floor, expr_ceil, + // list expressions + expr_join, + expr_lengths, + expr_member, // lazyframe lf_collect, lf_describe_plan, @@ -446,6 +450,8 @@ rustler::init!( s_floor, s_ceil, s_join, + s_lengths, + s_member, ], load = on_load ); diff --git a/native/explorer/src/series.rs b/native/explorer/src/series.rs index 338ebd7b2..59a04c709 100644 --- a/native/explorer/src/series.rs +++ b/native/explorer/src/series.rs @@ -1,6 +1,8 @@ use crate::{ atoms, - datatypes::{ExDate, ExDateTime, ExDuration, ExSeriesDtype, ExSeriesIoType, ExTime}, + datatypes::{ + ExDate, ExDateTime, ExDuration, ExSeriesDtype, ExSeriesIoType, ExTime, ExValidValue, + }, encoding, ExDataFrame, ExSeries, ExplorerError, }; @@ -1664,3 +1666,35 @@ pub fn s_join(s1: ExSeries, separator: &str) -> Result Ok(ExSeries::new(s2)) } + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn s_lengths(s: ExSeries) -> Result { + let s2 = s + .list()? + .lst_lengths() + .into_series() + .cast(&DataType::Int64)?; + + Ok(ExSeries::new(s2)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +fn s_member( + s: ExSeries, + value: ExValidValue, + inner_dtype: ExSeriesDtype, +) -> Result { + let inner_dtype = DataType::try_from(&inner_dtype)?; + let value_expr = value.lit_with_matching_precision(&inner_dtype); + + let s2 = s + .clone_inner() + .into_frame() + .lazy() + .select([col(s.name()).list().contains(value_expr)]) + .collect()? + .column(s.name())? + .clone(); + + Ok(ExSeries::new(s2)) +} diff --git a/test/explorer/data_frame_test.exs b/test/explorer/data_frame_test.exs index 52ac1a1a3..e04f50c4d 100644 --- a/test/explorer/data_frame_test.exs +++ b/test/explorer/data_frame_test.exs @@ -1849,6 +1849,37 @@ defmodule Explorer.DataFrameTest do assert Series.to_list(df[:simple_result]) == ["Exceptional", "Passed", "Passed"] assert Series.to_list(df[:result]) == [nil, "Failed", nil] end + + test "supports list operations" do + df = + DF.new( + a: [~w(a b c), ~w(d e f)], + b: [[1, 2, 3], [4, 5, 6]], + c: [ + [~N[2021-01-01 00:00:00], ~N[2021-01-02 00:00:00]], + [~N[2021-01-03 00:00:00], ~N[2021-01-04 00:00:00]] + ] + ) + + df = + DF.mutate(df, + join: join(a, ","), + lengths: lengths(b), + member?: member?(c, ~N[2021-01-02 00:00:00]) + ) + + assert DF.to_columns(df, atom_keys: true) == %{ + a: [~w(a b c), ~w(d e f)], + b: [[1, 2, 3], [4, 5, 6]], + c: [ + [~N[2021-01-01 00:00:00.000000], ~N[2021-01-02 00:00:00.000000]], + [~N[2021-01-03 00:00:00.000000], ~N[2021-01-04 00:00:00.000000]] + ], + join: ["a,b,c", "d,e,f"], + lengths: [3, 3], + member?: [true, false] + } + end end describe "arrange/3" do diff --git a/test/explorer/series_test.exs b/test/explorer/series_test.exs index 9b1ebd982..89e8e08b0 100644 --- a/test/explorer/series_test.exs +++ b/test/explorer/series_test.exs @@ -4454,7 +4454,7 @@ defmodule Explorer.SeriesTest do end end - describe "join" do + describe "join/2" do test "join/2" do series = Series.from_list([["1"], ["1", "2"]]) @@ -4462,6 +4462,74 @@ defmodule Explorer.SeriesTest do end end + describe "lengths/1" do + test "calculates the length of each list in a series" do + series = Series.from_list([[1], [1, 2, 3], [1, 2]]) + + assert series |> Series.lengths() |> Series.to_list() == [1, 3, 2] + end + end + + describe "member?/2" do + test "checks if any of the element lists contain the given value" do + series = Series.from_list([[1], [1, 2, 3], [1, 2]]) + + assert series |> Series.member?(1) |> Series.to_list() == [true, true, true] + assert series |> Series.member?(2) |> Series.to_list() == [false, true, true] + assert series |> Series.member?(3) |> Series.to_list() == [false, true, false] + end + + test "works with floats" do + series = Series.from_list([[1.0], [1.0, 2.0]]) + + assert series |> Series.member?(2.0) |> Series.to_list() == [false, true] + end + + test "works with booleans" do + series = Series.from_list([[true], [true, false]]) + + assert series |> Series.member?(false) |> Series.to_list() == [false, true] + end + + test "works with strings" do + series = Series.from_list([["a"], ["a", "b"]]) + + assert series |> Series.member?("b") |> Series.to_list() == [false, true] + end + + test "works with dates" do + series = Series.from_list([[~D[2021-01-01]], [~D[2021-01-01], ~D[2021-01-02]]]) + + assert series |> Series.member?(~D[2021-01-02]) |> Series.to_list() == [false, true] + end + + test "works with times" do + series = Series.from_list([[~T[00:00:00]], [~T[00:00:00], ~T[00:00:01]]]) + + assert series |> Series.member?(~T[00:00:01]) |> Series.to_list() == [false, true] + end + + test "works with datetimes" do + series = + Series.from_list([ + [~N[2021-01-01 00:00:00]], + [~N[2021-01-01 00:00:00], ~N[2021-01-01 00:00:01]] + ]) + + assert series |> Series.member?(~N[2021-01-01 00:00:01]) |> Series.to_list() == [ + false, + true + ] + end + + test "works with durations" do + series = Series.from_list([[1], [1, 2]], dtype: {:list, {:duration, :millisecond}}) + duration = %Explorer.Duration{value: 2000, precision: :microsecond} + + assert series |> Series.member?(duration) |> Series.to_list() == [false, true] + end + end + describe "to_iovec/1" do test "integer" do series = Series.from_list([-1, 0, 1])