Skip to content

Commit

Permalink
feat: Support reading Enum dtype from csv (pola-rs#20188)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Dec 6, 2024
1 parent f70c52b commit 39550c0
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
42 changes: 40 additions & 2 deletions crates/polars-io/src/csv/read/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ pub struct CategoricalField {
escape_scratch: Vec<u8>,
quote_char: u8,
builder: CategoricalChunkedBuilder,
is_enum: bool,
}

#[cfg(feature = "dtype-categorical")]
Expand All @@ -275,6 +276,16 @@ impl CategoricalField {
escape_scratch: vec![],
quote_char: quote_char.unwrap_or(b'"'),
builder,
is_enum: false,
}
}

fn new_enum(quote_char: Option<u8>, builder: CategoricalChunkedBuilder) -> Self {
Self {
escape_scratch: vec![],
quote_char: quote_char.unwrap_or(b'"'),
builder,
is_enum: true,
}
}

Expand Down Expand Up @@ -552,7 +563,19 @@ pub fn init_buffers(
DataType::Categorical(_, ordering) => Buffer::Categorical(CategoricalField::new(
name, capacity, quote_char, *ordering,
)),
// TODO (ENUM) support writing to Enum
#[cfg(feature = "dtype-categorical")]
DataType::Enum(rev_map, _) => {
let Some(rev_map) = rev_map else {
polars_bail!(ComputeError: "enum categories must be set")
};
let cats = rev_map.get_categories();
let mut builder =
CategoricalChunkedBuilder::new(name, capacity, Default::default());
for cat in cats.values_iter() {
builder.register_value(cat);
}
Buffer::Categorical(CategoricalField::new_enum(quote_char, builder))
},
dt => polars_bail!(
ComputeError: "unsupported data type when reading CSV: {} when reading CSV", dt,
),
Expand Down Expand Up @@ -643,7 +666,22 @@ impl Buffer {
Buffer::Categorical(buf) => {
#[cfg(feature = "dtype-categorical")]
{
buf.builder.finish().into_series()
let ca = buf.builder.finish();

if buf.is_enum {
let DataType::Categorical(Some(rev_map), _) = ca.dtype() else {
unreachable!()
};
let idx = ca.physical().clone();
let dtype = DataType::Enum(Some(rev_map.clone()), Default::default());

unsafe {
CategoricalChunked::from_cats_and_dtype_unchecked(idx, dtype)
.into_series()
}
} else {
ca.into_series()
}
}
#[cfg(not(feature = "dtype-categorical"))]
{
Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/unit/datatypes/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,3 +694,20 @@ class Number(*EnumBase): # type: ignore[misc]

expected = pl.Series(values=[8, 2, 4], dtype=pl.Int64)
assert_series_equal(expected, s)


def test_read_enum_from_csv() -> None:
df = pl.DataFrame(
{
"foo": ["ham", "spam", None, "and", "such"],
"bar": ["ham", "spam", None, "and", "such"],
}
)
f = io.BytesIO()
df.write_csv(f)
f.seek(0)

schema = {"foo": pl.Enum(["ham", "and", "such", "spam"]), "bar": pl.String()}
read = pl.read_csv(f, schema=schema)
assert read.schema == schema
assert_frame_equal(df.cast(schema), read) # type: ignore[arg-type]

0 comments on commit 39550c0

Please sign in to comment.