Skip to content

Commit

Permalink
feat(python): Allow insert_column to take expressions (#19024)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Oct 1, 2024
1 parent 529e3a3 commit 2e2823a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
26 changes: 22 additions & 4 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4525,7 +4525,7 @@ def rename(
"""
return self.lazy().rename(mapping, strict=strict).collect(_eager=True)

def insert_column(self, index: int, column: Series) -> DataFrame:
def insert_column(self, index: int, column: IntoExprColumn) -> DataFrame:
"""
Insert a Series at a certain column index.
Expand All @@ -4536,7 +4536,7 @@ def insert_column(self, index: int, column: Series) -> DataFrame:
index
Index at which to insert the new `Series` column.
column
`Series` to insert.
`Series` or expression to insert.
Examples
--------
Expand Down Expand Up @@ -4575,9 +4575,27 @@ def insert_column(self, index: int, column: Series) -> DataFrame:
│ 4 ┆ 13.0 ┆ true ┆ 0.0 │
└─────┴──────┴───────┴──────┘
"""
if index < 0:
if (original_index := index) < 0:
index = len(self.columns) + index
self._df.insert_column(index, column._s)
if index < 0:
msg = f"column index {original_index} is out of range (frame has {len(self.columns)} columns)"
raise IndexError(msg)
elif index > len(self.columns):
msg = f"column index {original_index} is out of range (frame has {len(self.columns)} columns)"
raise IndexError(msg)

if isinstance(column, pl.Series):
self._df.insert_column(index, column._s)
else:
if isinstance(column, str):
column = F.col(column)
if isinstance(column, pl.Expr):
cols = self.columns
cols.insert(index, column) # type: ignore[arg-type]
self._df = self.select(cols)._df
else:
msg = f"column must be a Series or Expr, got {column!r} (type={type(column)})"
raise TypeError(msg)
return self

def filter(
Expand Down
34 changes: 34 additions & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ def test_assignment() -> None:


def test_insert_column() -> None:
# insert series
df = (
pl.DataFrame({"z": [3, 4, 5]})
.insert_column(0, pl.Series("x", [1, 2, 3]))
Expand All @@ -466,6 +467,39 @@ def test_insert_column() -> None:
expected_df = pl.DataFrame({"x": [1, 2, 3], "y": [2, 3, 4], "z": [3, 4, 5]})
assert_frame_equal(expected_df, df)

# insert expressions
df = pl.DataFrame(
{
"id": ["xx", "yy", "zz"],
"v1": [5, 4, 6],
"v2": [7, 3, 3],
}
)
df.insert_column(3, (pl.col("v1") * pl.col("v2")).alias("v3"))
df.insert_column(1, (pl.col("v2") - pl.col("v1")).alias("v0"))

expected = pl.DataFrame(
{
"id": ["xx", "yy", "zz"],
"v0": [2, -1, -3],
"v1": [5, 4, 6],
"v2": [7, 3, 3],
"v3": [35, 12, 18],
}
)
assert_frame_equal(df, expected)

# check that we raise suitable index errors
for idx, column in (
(10, pl.col("v1").sqrt().alias("v1_sqrt")),
(-10, pl.Series("foo", [1, 2, 3])),
):
with pytest.raises(
IndexError,
match=rf"column index {idx} is out of range \(frame has 5 columns\)",
):
df.insert_column(idx, column)


def test_replace_column() -> None:
df = (
Expand Down

0 comments on commit 2e2823a

Please sign in to comment.