diff --git a/python/python/lance/ray/sink.py b/python/python/lance/ray/sink.py index 3034ec5824..bf472afd49 100644 --- a/python/python/lance/ray/sink.py +++ b/python/python/lance/ray/sink.py @@ -29,6 +29,8 @@ __all__ = ["LanceDatasink", "LanceFragmentWriter", "LanceCommitter", "write_lance"] +NONE_ARROW_STR = "None" + def _pd_to_arrow( df: Union[pa.Table, "pd.DataFrame", Dict], schema: Optional[pa.Schema] @@ -39,10 +41,27 @@ def _pd_to_arrow( if isinstance(df, dict): return pa.Table.from_pydict(df, schema=schema) - if _PANDAS_AVAILABLE and isinstance(df, pd.DataFrame): + elif _PANDAS_AVAILABLE and isinstance(df, pd.DataFrame): tbl = pa.Table.from_pandas(df, schema=schema) - new_schema = tbl.schema.remove_metadata() - new_table = tbl.replace_schema_metadata(new_schema.metadata) + tbl.schema = tbl.schema.remove_metadata() + return tbl + elif isinstance(df, pa.Table): + fields = df.schema.names + new_columns = [] + new_fields = [] + for field in fields: + col = df[field] + new_field = pa.field(field, col.type) + if ( + pa.types.is_null(col.type) + and schema.field_by_name(field).type == pa.string() + ): + new_field = pa.field(field, pa.string()) + col = pa.compute.if_else(pa.compute.is_null(col), NONE_ARROW_STR, col) + new_columns.append(col) + new_fields.append(new_field) + new_schema = pa.schema(fields=new_fields) + new_table = pa.Table.from_arrays(new_columns, schema=new_schema) return new_table return df diff --git a/python/python/tests/test_ray.py b/python/python/tests/test_ray.py index b85f185aff..54f1c42492 100644 --- a/python/python/tests/test_ray.py +++ b/python/python/tests/test_ray.py @@ -116,3 +116,25 @@ def test_ray_empty_write_lance(tmp_path: Path): # empty write would not generate dataset. with pytest.raises(ValueError): lance.dataset(tmp_path) + + +@pytest.mark.filterwarnings("ignore::DeprecationWarning") +def test_ray_write_lance_none_str(tmp_path: Path): + def f(row): + return { + "id": row["id"], + "str": None, + } + + schema = pa.schema([pa.field("id", pa.int64()), pa.field("str", pa.string())]) + (ray.data.range(10).map(f).write_lance(tmp_path, schema=schema)) + + ds = lance.dataset(tmp_path) + ds.count_rows() == 10 + assert ds.schema == schema + + tbl = ds.to_table() + pylist = tbl["str"].to_pylist() + assert len(pylist) == 10 + for item in pylist: + assert item is None