Skip to content

Commit

Permalink
pr feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
wjones127 committed Aug 29, 2023
1 parent e6113c9 commit c1735fa
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 39 deletions.
17 changes: 14 additions & 3 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use arrow_array::RecordBatchIterator;
use arrow_schema::{ArrowError, Schema};
use env_logger::Env;
use futures::StreamExt;
use pyo3::exceptions::PyIOError;
use pyo3::exceptions::{PyIOError, PyValueError};
use pyo3::prelude::*;

#[macro_use]
Expand Down Expand Up @@ -112,6 +112,9 @@ fn json_to_schema(py: Python<'_>, json: &str) -> PyResult<PyObject> {
/// string_features: Optional[List[str]]
/// Names of features that should be treated as strings. Otherwise they
/// will be treated as binary.
/// batch_size: Optional[int], default None
/// Number of records to read to infer the schema. If None, will read the
/// entire file.
///
/// Returns
/// -------
Expand All @@ -120,11 +123,12 @@ fn json_to_schema(py: Python<'_>, json: &str) -> PyResult<PyObject> {
/// alphabetically sorted by field names, since TFRecord doesn't have
/// a concept of field order.
#[pyfunction]
#[pyo3(signature = (uri, *, tensor_features = None, string_features = None))]
#[pyo3(signature = (uri, *, tensor_features = None, string_features = None, num_rows = None))]
fn infer_tfrecord_schema(
uri: &str,
tensor_features: Option<Vec<String>>,
string_features: Option<Vec<String>>,
num_rows: Option<usize>,
) -> PyResult<PyArrowType<ArrowSchema>> {
let tensor_features = tensor_features.unwrap_or_default();
let tensor_features = tensor_features
Expand All @@ -141,6 +145,7 @@ fn infer_tfrecord_schema(
uri,
&tensor_features,
&string_features,
num_rows,
))
.map_err(|err| PyIOError::new_err(err.to_string()))?;
Ok(PyArrowType(schema))
Expand All @@ -156,6 +161,8 @@ fn infer_tfrecord_schema(
/// Arrow schema of the tfrecord file. Use :py:func:`infer_tfrecord_schema`
/// to infer the schema. The schema is allowed to be a subset of fields; the
/// reader will only parse the fields that are present in the schema.
/// batch_size: int, default 10k
/// Number of records to read per batch.
///
/// Returns
/// -------
Expand All @@ -164,15 +171,18 @@ fn infer_tfrecord_schema(
/// :py:func:`lance.write_dataset`. The output schema will match the schema
/// provided, including field order.
#[pyfunction]
#[pyo3(signature = (uri, schema, *, batch_size = 10_000))]
fn read_tfrecord(
uri: &str,
schema: PyArrowType<ArrowSchema>,
batch_size: usize,
) -> PyResult<PyArrowType<ArrowArrayStreamReader>> {
let schema = schema.0;
let mut batch_stream = RT
.block_on(::lance::utils::tfrecord::read_tfrecord(
uri,
Arc::new(schema.clone()),
Some(batch_size),
))
.map_err(|err| PyIOError::new_err(err.to_string()))?;

Expand All @@ -185,7 +195,8 @@ fn read_tfrecord(

// TODO: this should be handled by upstream
let stream = FFI_ArrowArrayStream::new(Box::new(batch_reader));
let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap();
let stream_reader = ArrowArrayStreamReader::try_new(stream)
.map_err(|err| PyValueError::new_err(err.to_string()))?;

Ok(PyArrowType(stream_reader))
}
78 changes: 42 additions & 36 deletions rust/src/utils/tfrecord.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,14 @@ use tfrecord::{Example, Feature};
///
/// The features named by `string_features` will be assumed to be UTF-8 encoded
/// strings.
///
/// `num_rows` determines the number of rows to read from the file to infer the
/// schema. If `None`, the entire file will be read.
pub async fn infer_tfrecord_schema(
uri: &str,
tensor_features: &[&str],
string_features: &[&str],
num_rows: Option<usize>,
) -> Result<ArrowSchema> {
let mut columns: HashMap<String, FeatureMeta> = HashMap::new();

Expand All @@ -66,6 +70,7 @@ pub async fn infer_tfrecord_schema(
.map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err))
.into_async_read();
let mut records = RecordStream::<Example, _>::from_reader(data, Default::default());
let mut i = 0;
while let Some(record) = records.next().await {
let record = record.map_err(|err| Error::IO {
message: err.to_string(),
Expand All @@ -78,15 +83,22 @@ pub async fn infer_tfrecord_schema(
} else {
columns.insert(
name.clone(),
FeatureMeta::new(
FeatureMeta::try_new(
&feature,
tensor_features.contains(&name.as_str()),
string_features.contains(&name.as_str()),
),
)?,
);
}
}
}

i += 1;
if let Some(num_rows) = num_rows {
if i >= num_rows {
break;
}
}
}

let mut fields = columns
Expand All @@ -101,12 +113,17 @@ pub async fn infer_tfrecord_schema(

/// Read a TFRecord file into an Arrow record batch stream.
///
/// Reads 10k rows at a time.
/// Reads `batch_size` rows at a time. If `batch_size` is `None`, a default
/// batch size of 10,000 is used.
///
/// The schema may be a partial schema, in which case only the fields present in
/// the schema will be read.
pub async fn read_tfrecord(uri: &str, schema: ArrowSchemaRef) -> Result<SendableRecordBatchStream> {
let batch_size = 10_000;
pub async fn read_tfrecord(
uri: &str,
schema: ArrowSchemaRef,
batch_size: Option<usize>,
) -> Result<SendableRecordBatchStream> {
let batch_size = batch_size.unwrap_or(10_000);

let (store, path) = ObjectStore::from_uri(uri).await?;
let data = store
Expand Down Expand Up @@ -163,23 +180,11 @@ struct FeatureMeta {

impl FeatureMeta {
/// Create a new FeatureMeta from a single example.
pub fn new(feature: &Feature, is_tensor: bool, is_string: bool) -> Self {
pub fn try_new(feature: &Feature, is_tensor: bool, is_string: bool) -> Result<Self> {
let feature_type = match feature.kind.as_ref().unwrap() {
Kind::BytesList(data) => {
if is_tensor {
let val = &data.value[0];
let tensor_proto = TensorProto::decode(val.as_slice()).unwrap();
FeatureType::Tensor {
shape: tensor_proto
.tensor_shape
.as_ref()
.unwrap()
.dim
.iter()
.map(|d| d.size)
.collect(),
dtype: tensor_proto.dtype(),
}
Self::extract_tensor(data.value[0].as_slice())?
} else if is_string {
FeatureType::String
} else {
Expand All @@ -189,10 +194,10 @@ impl FeatureMeta {
Kind::FloatList(_) => FeatureType::Float,
Kind::Int64List(_) => FeatureType::Integer,
};
Self {
Ok(Self {
repeated: feature_is_repeated(feature),
feature_type,
}
})
}

/// Update the FeatureMeta with a new example, or return an error if the
Expand All @@ -202,21 +207,7 @@ impl FeatureMeta {
Kind::BytesList(data) => match self.feature_type {
FeatureType::String => FeatureType::String,
FeatureType::Binary => FeatureType::Binary,
FeatureType::Tensor { .. } => {
let val = &data.value[0];
let tensor_proto = TensorProto::decode(val.as_slice()).unwrap();
FeatureType::Tensor {
shape: tensor_proto
.tensor_shape
.as_ref()
.unwrap()
.dim
.iter()
.map(|d| d.size)
.collect(),
dtype: tensor_proto.dtype(),
}
}
FeatureType::Tensor { .. } => Self::extract_tensor(data.value[0].as_slice())?,
_ => {
return Err(Error::IO {
message: format!(
Expand All @@ -240,6 +231,21 @@ impl FeatureMeta {
}
Ok(())
}

fn extract_tensor(data: &[u8]) -> Result<FeatureType> {
let tensor_proto = TensorProto::decode(data)?;
Ok(FeatureType::Tensor {
shape: tensor_proto
.tensor_shape
.as_ref()
.unwrap()
.dim
.iter()
.map(|d| d.size)
.collect(),
dtype: tensor_proto.dtype(),
})
}
}

/// Metadata for a fixed-shape tensor.
Expand Down

0 comments on commit c1735fa

Please sign in to comment.