Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove unnecessary wraps in sort #445

Merged
merged 1 commit into from
Jun 13, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 51 additions & 45 deletions arrow/src/compute/kernels/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ pub fn sort_to_indices(

let (v, n) = partition_validity(values);

match values.data_type() {
Ok(match values.data_type() {
DataType::Boolean => sort_boolean(values, v, n, &options, limit),
DataType::Int8 => {
sort_primitive::<Int8Type, _>(values, v, n, cmp, &options, limit)
Expand Down Expand Up @@ -278,10 +278,12 @@ pub fn sort_to_indices(
DataType::Float64 => {
sort_list::<i32, Float64Type>(values, v, n, &options, limit)
}
t => Err(ArrowError::ComputeError(format!(
"Sort not supported for list type {:?}",
t
))),
t => {
return Err(ArrowError::ComputeError(format!(
"Sort not supported for list type {:?}",
t
)))
}
},
DataType::LargeList(field) => match field.data_type() {
DataType::Int8 => sort_list::<i64, Int8Type>(values, v, n, &options, limit),
Expand All @@ -304,10 +306,12 @@ pub fn sort_to_indices(
DataType::Float64 => {
sort_list::<i64, Float64Type>(values, v, n, &options, limit)
}
t => Err(ArrowError::ComputeError(format!(
"Sort not supported for list type {:?}",
t
))),
t => {
return Err(ArrowError::ComputeError(format!(
"Sort not supported for list type {:?}",
t
)))
}
},
DataType::FixedSizeList(field, _) => match field.data_type() {
DataType::Int8 => sort_list::<i32, Int8Type>(values, v, n, &options, limit),
Expand All @@ -330,10 +334,12 @@ pub fn sort_to_indices(
DataType::Float64 => {
sort_list::<i32, Float64Type>(values, v, n, &options, limit)
}
t => Err(ArrowError::ComputeError(format!(
"Sort not supported for list type {:?}",
t
))),
t => {
return Err(ArrowError::ComputeError(format!(
"Sort not supported for list type {:?}",
t
)))
}
},
DataType::Dictionary(key_type, value_type)
if *value_type.as_ref() == DataType::Utf8 =>
Expand Down Expand Up @@ -363,17 +369,21 @@ pub fn sort_to_indices(
DataType::UInt64 => {
sort_string_dictionary::<UInt64Type>(values, v, n, &options, limit)
}
t => Err(ArrowError::ComputeError(format!(
"Sort not supported for dictionary key type {:?}",
t
))),
t => {
return Err(ArrowError::ComputeError(format!(
"Sort not supported for dictionary key type {:?}",
t
)))
}
}
}
t => Err(ArrowError::ComputeError(format!(
"Sort not supported for data type {:?}",
t
))),
}
t => {
return Err(ArrowError::ComputeError(format!(
"Sort not supported for data type {:?}",
t
)))
}
})
}

/// Options that define how sort kernels should behave
Expand All @@ -396,14 +406,13 @@ impl Default for SortOptions {
}

/// Sort primitive values
#[allow(clippy::unnecessary_wraps)]
fn sort_boolean(
values: &ArrayRef,
value_indices: Vec<u32>,
null_indices: Vec<u32>,
options: &SortOptions,
limit: Option<usize>,
) -> Result<UInt32Array> {
) -> UInt32Array {
let values = values
.as_any()
.downcast_ref::<BooleanArray>()
Expand Down Expand Up @@ -469,19 +478,18 @@ fn sort_boolean(
vec![],
);

Ok(UInt32Array::from(result_data))
UInt32Array::from(result_data)
}

/// Sort primitive values
#[allow(clippy::unnecessary_wraps)]
fn sort_primitive<T, F>(
values: &ArrayRef,
value_indices: Vec<u32>,
null_indices: Vec<u32>,
cmp: F,
options: &SortOptions,
limit: Option<usize>,
) -> Result<UInt32Array>
) -> UInt32Array
where
T: ArrowPrimitiveType,
T::Native: std::cmp::PartialOrd,
Expand Down Expand Up @@ -549,7 +557,7 @@ where
vec![],
);

Ok(UInt32Array::from(result_data))
UInt32Array::from(result_data)
}

// insert valid and nan values in the correct order depending on the descending flag
Expand All @@ -574,7 +582,7 @@ fn sort_string<Offset: StringOffsetSizeTrait>(
null_indices: Vec<u32>,
options: &SortOptions,
limit: Option<usize>,
) -> Result<UInt32Array> {
) -> UInt32Array {
let values = values
.as_any()
.downcast_ref::<GenericStringArray<Offset>>()
Expand All @@ -597,7 +605,7 @@ fn sort_string_dictionary<T: ArrowDictionaryKeyType>(
null_indices: Vec<u32>,
options: &SortOptions,
limit: Option<usize>,
) -> Result<UInt32Array> {
) -> UInt32Array {
let values: &DictionaryArray<T> = as_dictionary_array::<T>(values);

let keys: &PrimitiveArray<T> = &values.keys_array();
Expand All @@ -620,15 +628,14 @@ fn sort_string_dictionary<T: ArrowDictionaryKeyType>(

/// shared implementation between dictionary encoded and plain string arrays
#[inline]
#[allow(clippy::unnecessary_wraps)]
fn sort_string_helper<'a, A: Array, F>(
values: &'a A,
value_indices: Vec<u32>,
null_indices: Vec<u32>,
options: &SortOptions,
limit: Option<usize>,
value_fn: F,
) -> Result<UInt32Array>
) -> UInt32Array
where
F: Fn(&'a A, u32) -> &str,
{
Expand Down Expand Up @@ -661,23 +668,22 @@ where
if options.nulls_first {
nulls.append(&mut valid_indices);
nulls.truncate(len);
return Ok(UInt32Array::from(nulls));
UInt32Array::from(nulls)
} else {
// no need to sort nulls as they are in the correct order already
valid_indices.append(&mut nulls);
valid_indices.truncate(len);
UInt32Array::from(valid_indices)
}

// no need to sort nulls as they are in the correct order already
valid_indices.append(&mut nulls);
valid_indices.truncate(len);
Ok(UInt32Array::from(valid_indices))
}

#[allow(clippy::unnecessary_wraps)]
fn sort_list<S, T>(
values: &ArrayRef,
value_indices: Vec<u32>,
mut null_indices: Vec<u32>,
options: &SortOptions,
limit: Option<usize>,
) -> Result<UInt32Array>
) -> UInt32Array
where
S: OffsetSizeTrait,
T: ArrowPrimitiveType,
Expand Down Expand Up @@ -727,12 +733,12 @@ where
if options.nulls_first {
null_indices.append(&mut valid_indices);
null_indices.truncate(len);
return Ok(UInt32Array::from(null_indices));
UInt32Array::from(null_indices)
} else {
valid_indices.append(&mut null_indices);
valid_indices.truncate(len);
UInt32Array::from(valid_indices)
}

valid_indices.append(&mut null_indices);
valid_indices.truncate(len);
Ok(UInt32Array::from(valid_indices))
}

/// Compare two `Array`s based on the ordering defined in [ord](crate::array::ord).
Expand Down