Skip to content

Commit

Permalink
Support append TFDS into Space dataset (#11)
Browse files Browse the repository at this point in the history
* Support append TFDS into Space dataset

* Add append TFDS unit tests
  • Loading branch information
Zhou Fang authored Dec 25, 2023
1 parent 5306b81 commit 01bf674
Show file tree
Hide file tree
Showing 17 changed files with 357 additions and 22 deletions.
2 changes: 1 addition & 1 deletion python/src/space/core/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


class Dataset:
"""Dataset is the interface to intract with Space storage."""
"""Dataset is the interface to interact with Space storage."""

def __init__(self, storage: Storage):
self._storage = storage
Expand Down
11 changes: 8 additions & 3 deletions python/src/space/core/fs/array_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
#
"""ArrayRecord file utilities."""

from typing import List
from typing import List, Optional

from space.core.utils.lazy_imports_utils import array_record_module as ar


def read_record_file(file_path: str, positions: List[int]) -> List[bytes]:
def read_record_file(file_path: str,
positions: Optional[List[int]] = None) -> List[bytes]:
"""Read records of an ArrayRecord file.
Args:
Expand All @@ -28,6 +29,10 @@ def read_record_file(file_path: str, positions: List[int]) -> List[bytes]:
"""
record_reader = ar.ArrayRecordReader(file_path)
records = record_reader.read(positions)
if positions is not None:
records = record_reader.read(positions)
else:
records = record_reader.read_all()

record_reader.close()
return records
2 changes: 1 addition & 1 deletion python/src/space/core/proto/runtime.proto
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ message Patch {
}

// Result of a job.
// NEXT_ID: 2
// NEXT_ID: 3
message JobResult {
enum State {
STATE_UNSPECIFIED = 0;
Expand Down
2 changes: 1 addition & 1 deletion python/src/space/core/proto/runtime_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ global___Patch = Patch
@typing_extensions.final
class JobResult(google.protobuf.message.Message):
"""Result of a job.
NEXT_ID: 2
NEXT_ID: 3
"""

DESCRIPTOR: google.protobuf.descriptor.Descriptor
Expand Down
20 changes: 19 additions & 1 deletion python/src/space/core/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from space.core.ops.base import InputData
import space.core.proto.runtime_pb2 as runtime
from space.core.storage import Storage
from space.tf.conversion import LocalConvertTfdsOp, TfdsIndexFn


class BaseRunner(ABC):
Expand All @@ -44,6 +45,17 @@ def read(self,
def append(self, data: InputData) -> runtime.JobResult:
"""Append data into the dataset."""

@abstractmethod
def append_tfds(self, tfds_path: str,
index_fn: TfdsIndexFn) -> runtime.JobResult:
"""Append data from a Tensorflow Dataset without copying data.
Args:
tfds_path: the folder of TFDS dataset files, should contain ArrowRecord
files.
index_fn: a function that build index fields from each TFDS record.
"""

@abstractmethod
def delete(self, filter_: pc.Expression) -> runtime.JobResult:
"""Delete data matching the filter from the dataset."""
Expand Down Expand Up @@ -72,6 +84,12 @@ def append(self, data: InputData) -> runtime.JobResult:
op.write(data)
return self._try_commit(op.finish())

def append_tfds(self, tfds_path: str,
index_fn: TfdsIndexFn) -> runtime.JobResult:
op = LocalConvertTfdsOp(self._storage.location, self._storage.metadata,
tfds_path, index_fn)
return self._try_commit(op.write())

def delete(self, filter_: pc.Expression) -> runtime.JobResult:
ds = self._storage
op = FileSetDeleteOp(self._storage.location, self._storage.metadata,
Expand All @@ -88,5 +106,5 @@ def _job_result(patch: Optional[runtime.Patch]) -> runtime.JobResult:
state=runtime.JobResult.State.SUCCEEDED,
storage_statistics_update=patch.storage_statistics_update)

logging.info(f'Job result:\n{result}')
logging.info(f"Job result:\n{result}")
return result
17 changes: 17 additions & 0 deletions python/src/space/core/schema/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,20 @@ def binary_field(field: utils.Field) -> pa.Field:

def _set_field_type(field: utils.Field, type_: pa.DataType) -> pa.Field:
return pa.field(field.name, type_, metadata=field_metadata(field.field_id))


def logical_to_physical_schema(logical_schema: pa.Schema,
record_fields: Set[str]) -> pa.Schema:
"""Convert a logical schema to a physical schema."""
fields: List[pa.Field] = []
for f in logical_schema:
if f.name in record_fields:
fields.append(
pa.field(
f.name,
pa.struct(record_address_types()), # type: ignore[arg-type]
metadata=f.metadata))
else:
fields.append(f)

return pa.schema(fields)
4 changes: 3 additions & 1 deletion python/src/space/core/serializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@
#
"""Serializers (and deserializers) for unstructured record fields."""

from space.core.serializers.base import DeserializedData, FieldSerializer
from space.core.serializers.base import DeserializedData
from space.core.serializers.base import DictSerializer
from space.core.serializers.base import FieldSerializer
43 changes: 42 additions & 1 deletion python/src/space/core/serializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
"""Serializers (and deserializers) for unstructured record fields."""

from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Dict, List
from typing_extensions import TypeAlias

import pyarrow as pa
# pylint: disable=line-too-long
from tensorflow_datasets.core.dataset_utils import NumpyElem, Tree # type: ignore[import-untyped]

DeserializedData: TypeAlias = Tree[NumpyElem]
DictData: TypeAlias = Dict[str, List[DeserializedData]]


class FieldSerializer(ABC):
Expand All @@ -45,3 +47,42 @@ def deserialize(self, value_bytes: bytes) -> DeserializedData:
Returns:
Numpy-like nested dict.
"""


class DictSerializer:
"""A serializer (deserializer) of a dict of fields.
The fields are serialized by FieldSerializer. The dict is usually a Py dict
converted from an Arrow table, e.g., {"field": [values, ...], ...}
"""

def __init__(self, logical_schema: pa.Schema):
self._serializers: Dict[str, FieldSerializer] = {}

for field in logical_schema:
if isinstance(field.type, FieldSerializer):
self._serializers[field.name] = field.type

def serialize(self, value: DictData) -> DictData:
"""Serialize a value.
Args:
value: a dict of numpy-like nested dicts.
"""
for name, ser in self._serializers.items():
if name in value:
value[name] = [ser.serialize(d) for d in value[name]]

return value

def deserialize(self, value_bytes: DictData) -> DictData:
"""Deserialize a dict of bytes to a dict of values.
Returns:
A dict of numpy-like nested dicts.
"""
for name, ser in self._serializers.items():
if name in value_bytes:
value_bytes[name] = [ser.deserialize(d) for d in value_bytes[name]]

return value_bytes
2 changes: 1 addition & 1 deletion python/src/space/core/utils/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def new_index_file_path(data_dir_: str):

def new_record_file_path(data_dir_: str, field_name: str):
"""Return a random record file path in a given data directory.."""
return path.join(data_dir_, f"{field_name}_{uuid_()}.arrowrecord")
return path.join(data_dir_, f"{field_name}_{uuid_()}.array_record")


def new_index_manifest_path(metadata_dir_: str):
Expand Down
Empty file added python/src/space/tf/__init__.py
Empty file.
111 changes: 111 additions & 0 deletions python/src/space/tf/conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""TFDS to Space dataset conversion."""

import os
from typing import Any, Callable, Dict, List, Optional, Tuple

import pyarrow as pa
from typing_extensions import TypeAlias

from space.core.fs.array_record import read_record_file
from space.core.proto import metadata_pb2 as meta
from space.core.proto import runtime_pb2 as runtime
from space.core.ops import utils
from space.core.ops.append import LocalAppendOp
from space.core.schema import arrow
from space.core.serializers import DictSerializer
from space.core.utils.paths import StoragePaths

TfdsIndexFn: TypeAlias = Callable[[Dict[str, Any]], Dict[str, Any]]


class LocalConvertTfdsOp(StoragePaths):
"""Convert a TFDS dataset to a Space dataset without copying data."""

def __init__(self, location: str, metadata: meta.StorageMetadata,
tfds_path: str, index_fn: TfdsIndexFn):
StoragePaths.__init__(self, location)

self._metadata = metadata
self._tfds_path = tfds_path
self._index_fn = index_fn

record_fields = set(self._metadata.schema.record_fields)
logical_schema = arrow.arrow_schema(self._metadata.schema.fields,
record_fields,
physical=False)
self._physical_schema = arrow.logical_to_physical_schema(
logical_schema, record_fields)

_, self._record_fields = arrow.classify_fields(self._physical_schema,
record_fields,
selected_fields=None)

assert len(self._record_fields) == 1, "Support only one record field"
self._record_field = self._record_fields[0]

self._serializer = DictSerializer(logical_schema)
self._tfds_files = _list_tfds_files(tfds_path)

def write(self) -> Optional[runtime.Patch]:
"""Write files to append a TFDS dataset to Space."""
# TODO: to convert files in parallel.
append_op = LocalAppendOp(self._location,
self._metadata,
record_address_input=True)

total_record_bytes = 0
for f in self._tfds_files:
index_data, record_bytes = self._build_index_for_array_record(f)
total_record_bytes += record_bytes
append_op.write(index_data)

patch = append_op.finish()
if patch is not None:
patch.storage_statistics_update.record_uncompressed_bytes += total_record_bytes # pylint: disable=line-too-long

return patch

def _build_index_for_array_record(self,
file_path: str) -> Tuple[pa.Table, int]:
record_field = self._record_field.name
# TODO: to avoid loading all data into memory at once.
serialized_records = read_record_file(file_path)

indxes: List[Dict[str, Any]] = []
record_uncompressed_bytes = 0
for sr in serialized_records:
record_uncompressed_bytes += len(sr)
record = self._serializer.deserialize({record_field: [sr]})
indxes.append(self._index_fn(record))

index_data = pa.Table.from_pylist(indxes, schema=self._physical_schema)
index_data = index_data.drop(record_field) # type: ignore[attr-defined]
index_data = index_data.append_column(
record_field,
utils.address_column(file_path, start_row=0, num_rows=len(indxes)))

return index_data, record_uncompressed_bytes


def _list_tfds_files(tfds_path: str) -> List[str]:
files: List[str] = []
for f in os.listdir(tfds_path):
full_path = os.path.join(tfds_path, f)
if os.path.isfile(full_path) and '.array_record' in f:
files.append(full_path)

return files
6 changes: 3 additions & 3 deletions python/tests/core/manifests/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ def test_write(self, tmp_path):
manifest_writer = RecordManifestWriter(metadata_dir=str(metadata_dir))

manifest_writer.write(
"data/file0.arrayrecord", 0,
"data/file0.array_record", 0,
meta.StorageStatistics(num_rows=123,
index_compressed_bytes=10,
index_uncompressed_bytes=20,
record_uncompressed_bytes=30))
manifest_writer.write(
"data/file1.arrayrecord", 1,
"data/file1.array_record", 1,
meta.StorageStatistics(num_rows=456,
index_compressed_bytes=10,
index_uncompressed_bytes=20,
Expand All @@ -43,7 +43,7 @@ def test_write(self, tmp_path):

assert manifest_path is not None
assert pq.read_table(manifest_path).to_pydict() == {
"_FILE": ["data/file0.arrayrecord", "data/file1.arrayrecord"],
"_FILE": ["data/file0.array_record", "data/file1.array_record"],
"_FIELD_ID": [0, 1],
"_NUM_ROWS": [123, 456],
"_UNCOMPRESSED_BYTES": [30, 100]
Expand Down
8 changes: 4 additions & 4 deletions python/tests/core/ops/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,14 @@ def test_update_record_stats_bytes():

def test_address_column():
result = [{
"_FILE": "data/file.arrayrecord",
"_FILE": "data/file.array_record",
"_ROW_ID": 2
}, {
"_FILE": "data/file.arrayrecord",
"_FILE": "data/file.array_record",
"_ROW_ID": 3
}, {
"_FILE": "data/file.arrayrecord",
"_FILE": "data/file.array_record",
"_ROW_ID": 4
}]
assert utils.address_column("data/file.arrayrecord", 2,
assert utils.address_column("data/file.array_record", 2,
3).to_pylist() == result
11 changes: 11 additions & 0 deletions python/tests/core/schema/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,14 @@ def test_field_names():
utils.Field("list_struct", 220),
utils.Field("struct_list", 260)
]) == ["struct", "list_struct", "struct_list"]


def test_logical_to_physical_schema(tf_features_arrow_schema):
physical_schema = pa.schema([
pa.field("int64", pa.int64(), metadata=field_metadata(0)),
pa.field("features",
pa.struct([("_FILE", pa.string()), ("_ROW_ID", pa.int32())]),
metadata=field_metadata(1))
])
assert arrow.logical_to_physical_schema(tf_features_arrow_schema,
set(["features"])) == physical_schema
Loading

0 comments on commit 01bf674

Please sign in to comment.