Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(list): add fixed-size list support for value_counts #3521

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/daft-core/src/array/fixed_size_list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,18 @@
}

pub fn to_list(&self) -> ListArray {
let field = &self.field;

let DataType::FixedSizeList(inner_type, _) = &field.dtype else {
unreachable!("Expected FixedSizeListArray, got {:?}", field.dtype);

Check warning on line 204 in src/daft-core/src/array/fixed_size_list_array.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-core/src/array/fixed_size_list_array.rs#L204

Added line #L204 was not covered by tests
};

let datatype = DataType::List(inner_type.clone());
let mut field = (**field).clone();
field.dtype = datatype;

ListArray::new(
self.field.clone(),
field,
self.flat_child.clone(),
self.generate_offsets(),
self.validity.clone(),
Expand Down
14 changes: 9 additions & 5 deletions src/daft-functions/src/list/value_counts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ impl ScalarUDF for ListValueCountsFunction {

let data_field = data.to_field(schema)?;

let DataType::List(inner_type) = &data_field.dtype else {
return Err(DaftError::TypeError(format!(
"Expected list, got {}",
data_field.dtype
)));
let inner_type = match &data_field.dtype {
DataType::List(inner_type) => inner_type,
DataType::FixedSizeList(inner_type, _) => inner_type,
_ => {
return Err(DaftError::TypeError(format!(
"Expected list or fixed size list, got {}",
data_field.dtype
)));
}
};

let map_type = DataType::Map {
Expand Down
35 changes: 35 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,41 @@ def test_list_value_counts_nested():
)


def test_list_value_counts_fixed_size():
# Create data with lists of fixed size
data = {
"fixed_list": [
[1, 2, 3],
[4, 3, 4],
[4, 5, 6],
[1, 2, 3],
[7, 8, 9],
None,
]
}

# Create DataFrame and cast the column to fixed size list
df = daft.from_pydict(data).with_column(
"fixed_list", daft.col("fixed_list").cast(DataType.fixed_size_list(DataType.int64(), 3))
)

df = df.with_column("fixed_list", col("fixed_list").cast(DataType.fixed_size_list(DataType.int64(), 3)))

# Get value counts
result = df.with_column("value_counts", col("fixed_list").list.value_counts())

# Verify the value counts
result_dict = result.to_pydict()
assert result_dict["value_counts"] == [
[(1, 1), (2, 1), (3, 1)],
[(4, 2), (3, 1)],
[(4, 1), (5, 1), (6, 1)],
[(1, 1), (2, 1), (3, 1)],
[(7, 1), (8, 1), (9, 1)],
[],
]


def test_list_value_counts_degenerate():
import pyarrow as pa

Expand Down
Loading