-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathxarray_zarr.py
508 lines (432 loc) · 22.9 KB
/
xarray_zarr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
"""
A Pangeo Forge Recipe
"""
import logging
import warnings
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass, field, replace
from itertools import product
from typing import Callable, Dict, List, Optional, Sequence, Tuple
import dask
import numpy as np
import xarray as xr
import zarr
from ..patterns import FilePattern, prune_pattern
from ..storage import AbstractTarget, CacheFSSpecTarget, MetadataTarget, file_opener
from ..utils import (
chunk_bounds_and_conflicts,
chunked_iterable,
fix_scalar_attr_encoding,
lock_for_conflicts,
)
from .base import BaseRecipe, closure
# use this filename to store global recipe metadata in the metadata_cache
# it will be written once (by prepare_target) and read many times (by store_chunk)
_GLOBAL_METADATA_KEY = "pangeo-forge-recipe-metadata.json"
logger = logging.getLogger(__name__)
def _encode_key(key) -> str:
return "-".join([str(k) for k in key])
def _input_metadata_fname(input_key):
return "input-meta-" + _encode_key(input_key) + ".json"
def _chunk_metadata_fname(chunk_key) -> str:
return "chunk-meta-" + _encode_key(chunk_key) + ".json"
ChunkKey = Tuple[int]
InputKey = Tuple[int]
# Notes about dataclasses:
# - https://www.python.org/dev/peps/pep-0557/#inheritance
# - https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
@dataclass
class XarrayZarrRecipe(BaseRecipe):
"""This class represents a dataset composed of many individual NetCDF files.
This class uses Xarray to read and write data and writes its output to Zarr.
The organization of the source files is described by the ``file_pattern``.
Currently this recipe supports at most one ``MergeDim`` and one ``ConcatDim``
in the File Pattern.
:param file_pattern: An object which describes the organization of the input files.
:param inputs_per_chunk: The number of inputs to use in each chunk along the concat dim.
Must be an integer >= 1.
:param target_chunks: Desired chunk structure for the targret dataset.
:param target: A location in which to put the dataset. Can also be assigned at run time.
:param input_cache: A location in which to cache temporary data.
:param metadata_cache: A location in which to cache metadata for inputs and chunks.
Required if ``nitems_per_file=None`` on concat dim in file pattern.
:param cache_inputs: If ``True``, inputs are copied to ``input_cache`` before
opening. If ``False``, try to open inputs directly from their source location.
:param copy_input_to_local_file: Whether to copy the inputs to a temporary
local file. In this case, a path (rather than file object) is passed to
``xr.open_dataset``. This is required for engines that can't open
file-like objects (e.g. pynio).
:param consolidate_zarr: Whether to consolidate the resulting Zarr dataset.
:param xarray_open_kwargs: Extra options for opening the inputs with Xarray.
:param xarray_concat_kwargs: Extra options to pass to Xarray when concatenating
the inputs to form a chunk.
:param delete_input_encoding: Whether to remove Xarray encoding from variables
in the input dataset
:param fsspec_open_kwargs: Extra options for opening the inputs with fsspec.
:param process_input: Function to call on each opened input, with signature
`(ds: xr.Dataset, filename: str) -> ds: xr.Dataset`.
:param process_chunk: Function to call on each concatenated chunk, with signature
`(ds: xr.Dataset) -> ds: xr.Dataset`.
:param lock_timeout: The default timeout for acquiring a chunk lock.
"""
file_pattern: FilePattern
inputs_per_chunk: Optional[int] = 1
target_chunks: Dict[str, int] = field(default_factory=dict)
target: Optional[AbstractTarget] = None
input_cache: Optional[CacheFSSpecTarget] = None
metadata_cache: Optional[MetadataTarget] = None
cache_inputs: bool = True
copy_input_to_local_file: bool = False
consolidate_zarr: bool = True
xarray_open_kwargs: dict = field(default_factory=dict)
xarray_concat_kwargs: dict = field(default_factory=dict)
delete_input_encoding: bool = True
fsspec_open_kwargs: dict = field(default_factory=dict)
process_input: Optional[Callable[[xr.Dataset, str], xr.Dataset]] = None
process_chunk: Optional[Callable[[xr.Dataset], xr.Dataset]] = None
lock_timeout: Optional[int] = None
# internal attributes not meant to be seen or accessed by user
_concat_dim: Optional[str] = None
"""The concatenation dimension name."""
_concat_dim_chunks: Optional[int] = None
"""The desired chunking along the sequence dimension."""
# In general there may be a many-to-many relationship between input keys and chunk keys
_inputs_chunks: Dict[InputKey, Tuple[ChunkKey]] = field(
default_factory=dict, repr=False, init=False
)
"""Mapping of input keys to chunk keys."""
_chunks_inputs: Dict[ChunkKey, Tuple[InputKey]] = field(
default_factory=dict, repr=False, init=False
)
"""Mapping of chunk keys to input keys."""
_init_chunks: List[ChunkKey] = field(default_factory=list, repr=False, init=False)
"""List of chunks needed to initialize the recipe."""
def __post_init__(self):
self._validate_file_pattern()
# from here on we know there is at most one merge dim and one concat dim
self._concat_dim = self.file_pattern.concat_dims[0]
self._cache_metadata = any(
[v is None for v in self.file_pattern.concat_sequence_lens.values()]
)
self._nitems_per_input = self.file_pattern.nitems_per_input[self._concat_dim]
# now for the fancy bit: we have to define the mappings _inputs_chunks and _chunks_inputs
# this is where refactoring would need to happen to support more complex file patterns
# (e.g. multiple concat dims)
# for now we assume 1:many chunk_keys:input_keys
# theoretically this could handle more than one merge dimension
# list of iterators that iterates over merge dims normally
# but concat dims in chunks
dimension_iterators = [
range(v)
if k != self._concat_dim
else enumerate(chunked_iterable(range(v), self.inputs_per_chunk))
for k, v in self.file_pattern.dims.items()
]
for k in product(*dimension_iterators):
# typical k would look like (0, (0, (0, 1)))
chunk_key = tuple([v[0] if hasattr(v, "__len__") else v for v in k])
all_as_tuples = tuple([v[1] if hasattr(v, "__len__") else (v,) for v in k])
input_keys = tuple(v for v in product(*all_as_tuples))
self._chunks_inputs[chunk_key] = input_keys
for input_key in input_keys:
self._inputs_chunks[input_key] = (chunk_key,)
# init chunks are all elements from merge dim and first element from concat dim
merge_dimension_iterators = [
range(v) if k != self._concat_dim else (range(1))
for k, v in self.file_pattern.dims.items()
]
self._init_chunks = list(product(*merge_dimension_iterators))
self._validate_input_and_chunk_keys()
self._set_target_chunks()
def _validate_file_pattern(self):
if len(self.file_pattern.merge_dims) > 1:
raise NotImplementedError("This Recipe class can't handle more than one merge dim.")
if len(self.file_pattern.concat_dims) > 1:
raise NotImplementedError("This Recipe class can't handle more than one concat dim.")
def _validate_input_and_chunk_keys(self):
all_input_keys = set(self._inputs_chunks.keys())
all_chunk_keys = set(self._chunks_inputs.keys())
if not all_input_keys == set(self.file_pattern):
raise ValueError("_inputs_chunks and file_pattern don't have the same keys")
if not all_input_keys == set([c for val in self._chunks_inputs.values() for c in val]):
raise ValueError("_inputs_chunks and _chunks_inputs don't use the same input keys.")
if not all_chunk_keys == set([c for val in self._inputs_chunks.values() for c in val]):
raise ValueError("_inputs_chunks and _chunks_inputs don't use the same chunk keys.")
def copy_pruned(self, nkeep: int = 2) -> BaseRecipe:
"""Make a copy of this recipe with a pruned file pattern.
:param nkeep: The number of items to keep from each ConcatDim sequence.
"""
new_pattern = prune_pattern(self.file_pattern, nkeep=nkeep)
return replace(self, file_pattern=new_pattern)
# below here are methods that are part of recipe execution
def _set_target_chunks(self):
target_concat_dim_chunks = self.target_chunks.get(self._concat_dim)
if (self._nitems_per_input is None) and (target_concat_dim_chunks is None):
raise ValueError(
"Unable to determine target chunks. Please specify either "
"`target_chunks` or `nitems_per_input`"
)
elif target_concat_dim_chunks:
self._concat_dim_chunks = target_concat_dim_chunks
else:
self._concat_dim_chunks = self._nitems_per_input * self.inputs_per_chunk
@property # type: ignore
@closure
def prepare_target(self) -> None:
if self.target is None:
raise ValueError("target is not set.")
try:
ds = self.open_target()
logger.info("Found an existing dataset in target")
logger.debug(f"{ds}")
if self.target_chunks:
# TODO: check that target_chunks id compatibile with the
# existing chunks
pass
except (FileNotFoundError, IOError, zarr.errors.GroupNotFoundError):
logger.info("Creating a new dataset in target")
# need to rewrite this as an append loop
for chunk_key in self._init_chunks:
with self.open_chunk(chunk_key) as ds:
# ds is already chunked
# https://github.com/pydata/xarray/blob/5287c7b2546fc8848f539bb5ee66bb8d91d8496f/xarray/core/variable.py#L1069
for v in ds.variables:
if self.target_chunks:
this_var = ds[v]
chunks = {
this_var.get_axis_num(dim): chunk
for dim, chunk in self.target_chunks.items()
if dim in this_var.dims
}
encoding_chunks = tuple(
chunks.get(n, s) for n, s in enumerate(this_var.shape)
)
else:
encoding_chunks = ds[v].shape
logger.debug(f"Setting variable {v} encoding chunks to {encoding_chunks}")
ds[v].encoding["chunks"] = encoding_chunks
# load all variables that don't have the sequence dim in them
# these are usually coordinates.
# Variables that are loaded will be written even with compute=False
# TODO: make this behavior customizable
for v in ds.variables:
if self._concat_dim not in ds[v].dims:
ds[v].load()
target_mapper = self.target.get_mapper()
logger.info(f"Storing dataset in {self.target.root_path}") # type: ignore
logger.debug(f"{ds}")
with warnings.catch_warnings():
warnings.simplefilter(
"ignore"
) # suppress the warning that comes with safe_chunks
ds.to_zarr(target_mapper, mode="a", compute=False, safe_chunks=False)
# Regardless of whether there is an existing dataset or we are creating a new one,
# we need to expand the concat_dim to hold the entire expected size of the data
input_sequence_lens = self.calculate_sequence_lens()
n_sequence = sum(input_sequence_lens)
logger.info(f"Expanding target concat dim '{self._concat_dim}' to size {n_sequence}")
self.expand_target_dim(self._concat_dim, n_sequence)
if self._cache_metadata:
if self.metadata_cache is None:
raise ValueError("metadata_cache is not set")
# if nitems_per_input is not constant, we need to cache this info
recipe_meta = {"input_sequence_lens": input_sequence_lens}
self.metadata_cache[_GLOBAL_METADATA_KEY] = recipe_meta
# TODO: figure out how to make mypy happy with this convoluted structure
@property # type: ignore
@closure
def cache_input(self, input_key: InputKey) -> None: # type: ignore
if self.cache_inputs:
if self.input_cache is None:
raise ValueError("input_cache is not set.")
logger.info(f"Caching input '{input_key}'")
fname = self.file_pattern[input_key]
self.input_cache.cache_file(fname, **self.fsspec_open_kwargs)
if self._cache_metadata:
self.cache_input_metadata(input_key)
@property # type: ignore
@closure
def store_chunk(self, chunk_key: ChunkKey) -> None: # type: ignore
if self.target is None:
raise ValueError("target has not been set.")
with self.open_chunk(chunk_key) as ds_chunk:
# writing a region means that all the variables MUST have concat_dim
to_drop = [v for v in ds_chunk.variables if self._concat_dim not in ds_chunk[v].dims]
ds_chunk = ds_chunk.drop_vars(to_drop)
target_mapper = self.target.get_mapper()
write_region, conflicts = self.region_and_conflicts_for_chunk(chunk_key)
zgroup = zarr.open_group(target_mapper)
for vname, var_coded in ds_chunk.variables.items():
zarr_array = zgroup[vname]
# get encoding for variable from zarr attributes
# could this backfire some way?
var_coded.encoding.update(zarr_array.attrs)
# just delete all attributes from the var;
# they are not used anyway, and there can be conflicts
# related to xarray.coding.variables.safe_setitem
var_coded.attrs = {}
with dask.config.set(
scheduler="single-threaded"
): # make sure we don't use a scheduler
var = xr.backends.zarr.encode_zarr_variable(var_coded)
data = np.asarray(
var.data
) # TODO: can we buffer large data rather than loading it all?
zarr_region = tuple(write_region.get(dim, slice(None)) for dim in var.dims)
lock_keys = [f"{vname}-{c}" for c in conflicts]
logger.debug(f"Acquiring locks {lock_keys}")
with lock_for_conflicts(lock_keys, timeout=self.lock_timeout):
logger.info(
f"Storing variable {vname} chunk {chunk_key} "
f"to Zarr region {zarr_region}"
)
zarr_array[zarr_region] = data
@property # type: ignore
@closure
def finalize_target(self) -> None:
if self.target is None:
raise ValueError("target has not been set.")
if self.consolidate_zarr:
logger.info("Consolidating Zarr metadata")
target_mapper = self.target.get_mapper()
zarr.consolidate_metadata(target_mapper)
@contextmanager
def open_input(self, input_key: InputKey):
fname = self.file_pattern[input_key]
logger.info(f"Opening input with Xarray {input_key}: '{fname}'")
cache = self.input_cache if self.cache_inputs else None
with file_opener(fname, cache=cache, copy_to_local=self.copy_input_to_local_file) as f:
with dask.config.set(scheduler="single-threaded"): # make sure we don't use a scheduler
logger.debug(f"about to call xr.open_dataset on {f}")
kw = self.xarray_open_kwargs.copy()
if "engine" not in kw:
kw["engine"] = "h5netcdf"
ds = xr.open_dataset(f, **kw)
logger.debug("successfully opened dataset")
ds = fix_scalar_attr_encoding(ds)
if self.delete_input_encoding:
for var in ds.variables:
ds[var].encoding = {}
if self.process_input is not None:
ds = self.process_input(ds, str(fname))
logger.debug(f"{ds}")
yield ds
def cache_input_metadata(self, input_key: InputKey):
if self.metadata_cache is None:
raise ValueError("metadata_cache is not set.")
logger.info(f"Caching metadata for input '{input_key}'")
with self.open_input(input_key) as ds:
input_metadata = ds.to_dict(data=False)
self.metadata_cache[_input_metadata_fname(input_key)] = input_metadata
@contextmanager
def open_chunk(self, chunk_key: ChunkKey):
logger.info(f"Opening inputs for chunk {chunk_key}")
inputs = self._chunks_inputs[chunk_key]
# need to open an unknown number of contexts at the same time
with ExitStack() as stack:
dsets = [stack.enter_context(self.open_input(i)) for i in inputs]
# explicitly chunking prevents eager evaluation during concat
dsets = [ds.chunk() for ds in dsets]
logger.info(f"Combining inputs for chunk '{chunk_key}'")
if len(dsets) > 1:
# During concat, attributes and encoding are taken from the first dataset
# https://github.com/pydata/xarray/issues/1614
with dask.config.set(
scheduler="single-threaded"
): # make sure we don't use a scheduler
ds = xr.concat(dsets, self._concat_dim, **self.xarray_concat_kwargs)
elif len(dsets) == 1:
ds = dsets[0]
else: # pragma: no cover
assert False, "Should never happen"
if self.process_chunk is not None:
with dask.config.set(
scheduler="single-threaded"
): # make sure we don't use a scheduler
ds = self.process_chunk(ds)
with dask.config.set(scheduler="single-threaded"): # make sure we don't use a scheduler
logger.debug(f"{ds}")
# TODO: maybe do some chunking here?
yield ds
def open_target(self):
target_mapper = self.target.get_mapper()
return xr.open_zarr(target_mapper)
def expand_target_dim(self, dim, dimsize):
target_mapper = self.target.get_mapper()
zgroup = zarr.open_group(target_mapper)
ds = self.open_target()
sequence_axes = {v: ds[v].get_axis_num(dim) for v in ds.variables if dim in ds[v].dims}
for v, axis in sequence_axes.items():
arr = zgroup[v]
shape = list(arr.shape)
shape[axis] = dimsize
logger.debug(f"resizing array {v} to shape {shape}")
arr.resize(shape)
# now explicity write the sequence coordinate to avoid missing data
# when reopening
if dim in zgroup:
zgroup[dim][:] = 0
def iter_inputs(self):
for input in self._inputs_chunks:
yield input
def region_and_conflicts_for_chunk(self, chunk_key: ChunkKey):
# return a dict suitable to pass to xr.to_zarr(region=...)
# specifies where in the overall array to put this chunk's data
# also return the conflicts with other chunks
input_keys = self._chunks_inputs[chunk_key]
if self._nitems_per_input:
input_sequence_lens = (self._nitems_per_input,) * self.file_pattern.dims[
self._concat_dim # type: ignore
]
else:
if self.metadata_cache is None:
raise ValueError("metadata_cache is not set.")
global_metadata = self.metadata_cache[_GLOBAL_METADATA_KEY]
input_sequence_lens = global_metadata["input_sequence_lens"]
chunk_bounds, all_chunk_conflicts = chunk_bounds_and_conflicts(
input_sequence_lens, self._concat_dim_chunks # type: ignore
)
input_positions = [self.input_position(input_key) for input_key in input_keys]
start = chunk_bounds[min(input_positions)]
stop = chunk_bounds[max(input_positions) + 1]
this_chunk_conflicts = set()
for k in input_keys:
# for multi-variable recipes, the confilcts will usually be the same
# for each variable. using a set avoids duplicate locks
for input_conflict in all_chunk_conflicts[self.input_position(k)]:
this_chunk_conflicts.add(input_conflict)
region_slice = slice(start, stop)
return {self._concat_dim: region_slice}, this_chunk_conflicts
def iter_chunks(self):
for k in self._chunks_inputs:
yield k
def get_input_meta(self, *input_keys: Sequence[InputKey]) -> Dict:
# getitems should be async; much faster than serial calls
if self.metadata_cache is None:
raise ValueError("metadata_cache is not set.")
return self.metadata_cache.getitems([_input_metadata_fname(k) for k in input_keys])
def input_position(self, input_key):
# returns the index position of an input key wrt the concat_dim
concat_dim_axis = list(self.file_pattern.dims).index(self._concat_dim)
return input_key[concat_dim_axis]
def calculate_sequence_lens(self):
if self._nitems_per_input:
return list((self._nitems_per_input,) * self.file_pattern.dims[self._concat_dim])
# read per-input metadata; this is distinct from global metadata
# get the sequence length of every file
# this line could become problematic for large (> 10_000) lists of files
input_meta = self.get_input_meta(*self._inputs_chunks)
# use a numpy array to allow reshaping
all_lens = np.array([m["dims"][self._concat_dim] for m in input_meta.values()])
all_lens.shape = list(self.file_pattern.dims.values())
# check that all lens are the same along the concat dim
concat_dim_axis = list(self.file_pattern.dims).index(self._concat_dim)
selector = [slice(0, 1)] * len(self.file_pattern.dims)
selector[concat_dim_axis] = slice(None) # this should broadcast correctly agains all_lens
sequence_lens = all_lens[tuple(selector)]
if not (all_lens == sequence_lens).all():
raise ValueError(f"Inconsistent sequence lengths found: f{all_lens}")
return sequence_lens.squeeze().tolist()
def inputs_for_chunk(self, chunk_key: ChunkKey) -> Tuple[InputKey]:
"""Convenience function for users to introspect recipe."""
return self._chunks_inputs[chunk_key]