Skip to content

Commit

Permalink
feat(python)!: Read 2D numpy arrays as Array[dt, shape] instead of Li…
Browse files Browse the repository at this point in the history
…stst[dt] (#16710)
  • Loading branch information
ritchie46 authored Jun 4, 2024
1 parent 49da837 commit 7e5db3d
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
4 changes: 1 addition & 3 deletions py-polars/polars/_utils/construction/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,9 +499,7 @@ def numpy_to_pyseries(
return constructor(
name, values, nan_to_null if dtype in (np.float32, np.float64) else strict
)
# TODO: remove this branch on 1.0.
# This returns a List whereas we should return an Array type
elif values.ndim == 2:
elif sum(values.shape) == 0:
# Optimize by ingesting 1D and reshaping in Rust
original_shape = values.shape
values = values.reshape(-1)
Expand Down
11 changes: 8 additions & 3 deletions py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,15 +541,20 @@ def test_init_ndarray() -> None:
assert np.array_equal(df.to_numpy(), np.arange(4).reshape(-1, 1).astype(np.int64))

df = pl.DataFrame(np.arange(4).reshape(-1, 2).astype(np.int64), schema=["a"])
assert_frame_equal(df, pl.DataFrame({"a": [[0, 1], [2, 3]]}))
assert_frame_equal(
df,
pl.DataFrame(
{"a": [[0, 1], [2, 3]]}, schema={"a": pl.Array(pl.Int64, shape=2)}
),
)

# 2D numpy arrays
df = pl.DataFrame({"a": np.arange(5, dtype=np.int64).reshape(1, -1)})
assert df.dtypes == [pl.List(pl.Int64)]
assert df.dtypes == [pl.Array(pl.Int64, shape=5)]
assert df.shape == (1, 1)

df = pl.DataFrame({"a": np.arange(10, dtype=np.int64).reshape(2, -1)})
assert df.dtypes == [pl.List(pl.Int64)]
assert df.dtypes == [pl.Array(pl.Int64, shape=5)]
assert df.shape == (2, 1)
assert df.rows() == [([0, 1, 2, 3, 4],), ([5, 6, 7, 8, 9],)]

Expand Down
4 changes: 3 additions & 1 deletion py-polars/tests/unit/operations/namespaces/list/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,9 @@ def test_list_sliced_get_5186() -> None:
df = pl.from_dict(
{
"ind": pl.arange(0, n, eager=True),
"inds": np.stack([np.arange(n), -np.arange(n)], axis=-1),
"inds": pl.Series(
np.stack([np.arange(n), -np.arange(n)], axis=-1), dtype=pl.List
),
}
)

Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_init_inputs(monkeypatch: Any) -> None:
nan_to_null=True,
),
):
assert res.dtype == pl.List(pl.Float32)
assert res.dtype == pl.Array(pl.Float32, shape=2)
assert res[0].to_list() == [1.0, 2.0]
assert res[1].to_list() == [3.0, None]

Expand All @@ -134,7 +134,7 @@ def test_init_inputs(monkeypatch: Any) -> None:

assert pl.Series(
values=np.array([["foo", "bar"], ["foo2", "bar2"]])
).dtype == pl.List(pl.String)
).dtype == pl.Array(pl.String, shape=2)

# lists
assert pl.Series("a", [[1, 2], [3, 4]]).dtype == pl.List(pl.Int64)
Expand Down

0 comments on commit 7e5db3d

Please sign in to comment.