From 899881ae5da003aa3f1be4993ab1bca8fbc12e6c Mon Sep 17 00:00:00 2001 From: Siddharth Vishwakarma <153494533+siddharth-vi@users.noreply.github.com> Date: Tue, 26 Nov 2024 20:44:00 +0530 Subject: [PATCH] fix: Bug fix in existing fast path for sorted series (#20004) --- .../src/chunked_array/ops/sort/mod.rs | 2 +- py-polars/tests/unit/operations/test_sort.py | 27 ++++++++++++++++--- 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs index add7e8b696a4..9d5bbe6503dd 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -138,7 +138,7 @@ macro_rules! sort_with_fast_path { // if the nulls are already last we can clone if $options.nulls_last && $ca.get($ca.len() - 1).is_none() || // if the nulls are already first we can clone - $ca.get(0).is_none() + (!$options.nulls_last && $ca.get(0).is_none()) { return $ca.clone(); } diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index b19d008556e1..e74d8302880f 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import date, datetime -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable import pytest from hypothesis import given @@ -163,11 +163,30 @@ def test_sort_by_exprs() -> None: assert out.to_list() == [1, -1, 2, -2] -def test_arg_sort_nulls() -> None: +@pytest.mark.parametrize( + ("sort_function", "expected"), + [ + (lambda x: x, ([0, 1, 2, 3, 4], [3, 4, 0, 1, 2])), + ( + lambda x: x.sort(descending=False, nulls_last=True), + ([0, 1, 2, 3, 4], [3, 4, 0, 1, 2]), + ), + ( + lambda x: x.sort(descending=False, nulls_last=False), + ([2, 3, 4, 0, 1], [0, 1, 2, 3, 4]), + ), + ], +) +def test_arg_sort_nulls( + sort_function: Callable[[pl.Series], pl.Series], + expected: tuple[list[int], list[int]], +) -> None: a = pl.Series("a", [1.0, 2.0, 3.0, None, None]) - assert a.arg_sort(nulls_last=True).to_list() == [0, 1, 2, 3, 4] - assert a.arg_sort(nulls_last=False).to_list() == [3, 4, 0, 1, 2] + a = sort_function(a) + + assert a.arg_sort(nulls_last=True).to_list() == expected[0] + assert a.arg_sort(nulls_last=False).to_list() == expected[1] res = a.to_frame().sort(by="a", nulls_last=False).to_series().to_list() assert res == [None, None, 1.0, 2.0, 3.0]