Skip to content

Commit

Permalink
Allow target to be str or FSSpecTarget (#460)
Browse files Browse the repository at this point in the history
* Allow target to be str or FSSpecTarget

The latter allows them to be complex FSSpec objects that can
encode information about S3 / GCS, and how they can be authenticated
with.

* Fix param renames

* Allow cache to be an FSSpec thing too
  • Loading branch information
yuvipanda authored Jan 2, 2023
1 parent c25aa99 commit eea8e44
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
24 changes: 15 additions & 9 deletions pangeo_forge_recipes/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,20 @@ def wrapper(arg, **kwargs):
class OpenURLWithFSSpec(beam.PTransform):
"""Open indexed string-based URLs with fsspec.
:param cache_url: If provided, data will be cached at this url path before opening.
:param cache: If provided, data will be cached at this url path before opening.
:param secrets: If provided these secrets will be injected into the URL as a query string.
:param open_kwargs: Extra arguments passed to fsspec.open.
"""

cache_url: Optional[str] = None
cache: Optional[str | CacheFSSpecTarget] = None
secrets: Optional[dict] = None
open_kwargs: Optional[dict] = None

def expand(self, pcoll):
cache = CacheFSSpecTarget.from_url(self.cache_url) if self.cache_url else None
if isinstance(self.cache, str):
cache = CacheFSSpecTarget.from_url(self.cache)
else:
cache = self.cache
return pcoll | "Open with fsspec" >> beam.Map(
_add_keys(open_url),
cache=cache,
Expand Down Expand Up @@ -208,18 +211,21 @@ class PrepareZarrTarget(beam.PTransform):
Zarr store with the correct variables, dimensions, attributes and chunking.
Note that the dimension coordinates will be initialized with dummy values.
:param target_url: Where to store the target Zarr dataset.
:param target: Where to store the target Zarr dataset.
:param target_chunks: Dictionary mapping dimension names to chunks sizes.
If a dimension is a not named, the chunks will be inferred from the schema.
If chunking is present in the schema for a given dimension, the length of
the first chunk will be used. Otherwise, the dimension will not be chunked.
"""

target_url: str
target: str | FSSpecTarget
target_chunks: Dict[str, int] = field(default_factory=dict)

def expand(self, pcoll: beam.PCollection) -> beam.PCollection:
target = FSSpecTarget.from_url(self.target_url)
if isinstance(self.target, str):
target = FSSpecTarget.from_url(self.target)
else:
target = self.target
store = target.get_mapper()
initialized_target = pcoll | beam.Map(
schema_to_zarr, target_store=store, target_chunks=self.target_chunks
Expand Down Expand Up @@ -262,21 +268,21 @@ class StoreToZarr(beam.PTransform):
"""Store a PCollection of Xarray datasets to Zarr.
:param combine_dims: The dimensions to combine
:param target_url: Where to store the target Zarr dataset.
:param target: Where to store the target Zarr dataset.
:param target_chunks: Dictionary mapping dimension names to chunks sizes.
If a dimension is a not named, the chunks will be inferred from the data.
"""

# TODO: make it so we don't have to explictly specify combine_dims
# Could be inferred from the pattern instead
combine_dims: List[Dimension]
target_url: str
target: str | FSSpecTarget
target_chunks: Dict[str, int] = field(default_factory=dict)

def expand(self, datasets: beam.PCollection) -> beam.PCollection:
schema = datasets | DetermineSchema(combine_dims=self.combine_dims)
indexed_datasets = datasets | IndexItems(schema=schema)
target_store = schema | PrepareZarrTarget(
target_url=self.target_url, target_chunks=self.target_chunks
target=self.target, target_chunks=self.target_chunks
)
return indexed_datasets | StoreDatasetFragments(target_store=target_store)
2 changes: 1 addition & 1 deletion tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_xarray_zarr(
| beam.Create(pattern.items())
| OpenWithXarray(file_type=pattern.file_type)
| StoreToZarr(
target_url=tmp_target_url,
target=tmp_target_url,
target_chunks=target_chunks,
combine_dims=pattern.combine_dim_keys,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def cache_url(tmp_cache_url, request):
def pcoll_opened_files(pattern, cache_url):
input = beam.Create(pattern.items())
output = input | OpenURLWithFSSpec(
cache_url=cache_url,
cache=cache_url,
secrets=pattern.query_string_secrets,
open_kwargs=pattern.fsspec_open_kwargs,
)
Expand Down Expand Up @@ -192,7 +192,7 @@ def _check_target(actual):

with pipeline as p:
input = p | beam.Create([schema])
target = input | PrepareZarrTarget(target_url=tmp_target_url, target_chunks=target_chunks)
target = input | PrepareZarrTarget(target=tmp_target_url, target_chunks=target_chunks)
assert_that(target, correct_target())


Expand Down

0 comments on commit eea8e44

Please sign in to comment.