Skip to content

Commit

Permalink
fix(python): fix struct dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 13, 2022
1 parent 0bc2768 commit 8240f84
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 6 deletions.
6 changes: 4 additions & 2 deletions polars/polars-core/src/chunked_array/logical/struct_/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::collections::BTreeMap;

use super::*;
use crate::datatypes::*;
use crate::utils::index_to_chunked_index2;

/// This is logical type [`StructChunked`] that
/// dispatches most logic to the `fields` implementations
Expand Down Expand Up @@ -191,13 +192,14 @@ impl LogicalType for StructChunked {

/// Gets AnyValue from LogicalType
fn get_any_value(&self, i: usize) -> AnyValue<'_> {
let (chunk_idx, idx) = index_to_chunked_index2(&self.chunks, i);
if let DataType::Struct(flds) = self.dtype() {
// safety: we already have a single chunk and we are
// guarded by the type system.
unsafe {
let arr = &**self.chunks.get_unchecked(0);
let arr = &**self.chunks.get_unchecked(chunk_idx);
let arr = &*(arr as *const dyn Array as *const StructArray);
AnyValue::Struct(i, arr, flds)
AnyValue::Struct(idx, arr, flds)
}
} else {
unreachable!()
Expand Down
16 changes: 16 additions & 0 deletions polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,22 @@ pub(crate) fn index_to_chunked_index<
(current_chunk_idx, index_remainder)
}

#[cfg(feature = "dtype-struct")]
pub(crate) fn index_to_chunked_index2(chunks: &[ArrayRef], index: usize) -> (usize, usize) {
let mut index_remainder = index;
let mut current_chunk_idx = 0;

for chunk in chunks {
if chunk.len() > index_remainder {
break;
} else {
index_remainder -= chunk.len();
current_chunk_idx += 1;
}
}
(current_chunk_idx, index_remainder)
}

/// # SAFETY
/// `dst` must be valid for `dst.len()` elements, and `src` and `dst` may not overlap.
#[inline]
Expand Down
27 changes: 23 additions & 4 deletions py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,15 @@ def arrow_to_pyseries(name: str, values: pa.Array, rechunk: bool = True) -> PySe
pys = PySeries.from_arrow(name, array)
else:
if array.num_chunks > 1:
it = array.iterchunks()
pys = PySeries.from_arrow(name, next(it))
for a in it:
pys.append(PySeries.from_arrow(name, a))
# somehow going through ffi with a structarray
# returns the first chunk everytime
if isinstance(array.type, pa.StructType):
pys = PySeries.from_arrow(name, array.combine_chunks())
else:
it = array.iterchunks()
pys = PySeries.from_arrow(name, next(it))
for a in it:
pys.append(PySeries.from_arrow(name, a))
elif array.num_chunks == 0:
pys = PySeries.from_arrow(name, pa.array([], array.type))
else:
Expand Down Expand Up @@ -816,6 +821,8 @@ def arrow_to_pydf(
# dictionaries cannot be built in different batches (categorical does not allow
# that) so we rechunk them and create them separately.
dictionary_cols = {}
# struct columns don't work properly if they contain multiple chunks.
struct_cols = {}
names = []
for i, column in enumerate(data):
# extract the name before casting
Expand All @@ -829,6 +836,9 @@ def arrow_to_pydf(
if pa.types.is_dictionary(column.type):
ps = arrow_to_pyseries(name, column, rechunk)
dictionary_cols[i] = pli.wrap_s(ps)
elif isinstance(column.type, pa.StructType) and column.num_chunks > 1:
ps = arrow_to_pyseries(name, column, rechunk)
struct_cols[i] = pli.wrap_s(ps)
else:
data_dict[name] = column

Expand All @@ -850,11 +860,20 @@ def arrow_to_pydf(
if rechunk:
pydf = pydf.rechunk()

reset_order = False
if len(dictionary_cols) > 0:
df = pli.wrap_df(pydf)
df = df.with_columns(
[pli.lit(s).alias(s.name) for s in dictionary_cols.values()]
)
reset_order = True

if len(struct_cols) > 0:
df = pli.wrap_df(pydf)
df = df.with_columns([pli.lit(s).alias(s.name) for s in struct_cols.values()])
reset_order = True

if reset_order:
df = df[names]
pydf = df._df

Expand Down
28 changes: 28 additions & 0 deletions py-polars/tests/slow/test_parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
import typing

import pyarrow.dataset as ds

import polars as pl


@typing.no_type_check
def test_struct_pyarrow_dataset_5796() -> None:
if os.name != "nt":
num_rows = 2**17 + 1

df = pl.from_records(
[
dict( # noqa: C408
id=i,
nested=dict( # noqa: C408
a=i,
),
)
for i in range(num_rows)
]
)

df.write_parquet("/tmp/out.parquet", use_pyarrow=True)
tbl = ds.dataset("/tmp/out.parquet").to_table()
assert pl.from_arrow(tbl).frame_equal(df)
8 changes: 8 additions & 0 deletions py-polars/tests/unit/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,3 +696,11 @@ def test_concat_list_reverse_struct_fields() -> None:
assert df.select(pl.concat_list(["combo", "reverse_combo"])).frame_equal(
df.select(pl.concat_list(["combo", "combo"]))
)


def test_struct_any_value_get_after_append() -> None:
a = pl.Series("a", [{"a": 1, "b": 2}])
b = pl.Series("a", [{"a": 2, "b": 3}])
a = a.append(b)
assert a[0] == {"a": 1, "b": 2}
assert a[1] == {"a": 2, "b": 3}

0 comments on commit 8240f84

Please sign in to comment.