Skip to content

Commit

Permalink
address some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu committed Jul 26, 2023
1 parent 3ac8ad6 commit f4ea9ec
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions python/python/lance/tf/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import logging
from pathlib import Path
from typing import List, Optional, Union
from typing import Iterable, List, Optional, Union

import lance
import numpy as np
Expand Down Expand Up @@ -61,7 +61,12 @@ def arrow_data_type_to_tf(dt: pa.DataType) -> tf.DType:
return tf.float32
elif pa.types.is_float64(dt):
return tf.float64
elif pa.types.is_string(dt) or pa.types.is_binary(dt):
elif (
pa.types.is_string(dt)
or pa.types.is_large_string(dt)
or pa.types.is_binary(dt)
or pa.types.is_large_binary(dt)
):
return tf.string

raise TypeError(f"Arrow/Tf conversion: Unsupported arrow data type: {dt}")
Expand All @@ -76,7 +81,11 @@ def data_type_to_tensor_spec(dt: pa.DataType) -> tf.TensorSpec:
or pa.types.is_string(dt)
):
return tf.TensorSpec(shape=(None,), dtype=arrow_data_type_to_tf(dt))
elif pa.types.is_list(dt):
elif pa.types.is_fixed_size_list(dt):
return tf.TensorSpec(
shape=(dt.list_size, None), dtype=arrow_data_type_to_tf(dt.value_type)
)
elif pa.types.is_list(dt) or pa.types.is_large_list(dt):
return tf.TensorSpec(
shape=(
None,
Expand All @@ -103,7 +112,7 @@ def from_lance(
columns: Optional[List[str]] = None,
batch_size: int = 256,
filter: Optional[str] = None,
fragments: Union[List[LanceFragment], tf.data.Dataset] = None,
fragments: Union[Iterable[LanceFragment], tf.data.Dataset] = None,
) -> tf.data.Dataset:
"""Create a ``tf.data.Dataset`` from a Lance dataset.
Expand All @@ -112,7 +121,8 @@ def from_lance(
dataset : Union[str, Path, LanceDataset]
Lance dataset or dataset URI/path.
columns : Optional[List[str]], optional
List of columns to include in the output dataset, by default None
List of columns to include in the output dataset.
If not set, all columns will be read.
batch_size : int, optional
Batch size, by default 256
filter : Optional[str], optional
Expand Down

0 comments on commit f4ea9ec

Please sign in to comment.