-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathxarray_zarr.py
722 lines (611 loc) · 31 KB
/
xarray_zarr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
"""
A Pangeo Forge Recipe
"""
from __future__ import annotations
import itertools
import logging
import os
import warnings
from contextlib import ExitStack, contextmanager
from dataclasses import dataclass, field, replace
from itertools import chain, product
from math import ceil
from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple
import dask
import numpy as np
import xarray as xr
import zarr
from ..chunk_grid import ChunkGrid
from ..executors.base import Pipeline, Stage
from ..patterns import CombineOp, DimIndex, FilePattern, Index
from ..storage import CacheFSSpecTarget, FSSpecTarget, MetadataTarget, file_opener
from ..utils import calc_subsets, fix_scalar_attr_encoding, lock_for_conflicts
from .base import BaseRecipe, FilePatternMixin
# use this filename to store global recipe metadata in the metadata_cache
# it will be written once (by prepare_target) and read many times (by store_chunk)
_GLOBAL_METADATA_KEY = "pangeo-forge-recipe-metadata.json"
_ARRAY_DIMENSIONS = "_ARRAY_DIMENSIONS"
MAX_MEMORY = (
int(os.getenv("PANGEO_FORGE_MAX_MEMORY")) # type: ignore
if os.getenv("PANGEO_FORGE_MAX_MEMORY")
else 500_000_000
)
logger = logging.getLogger(__name__)
# Some types that help us keep things organized
InputKey = Index # input keys are the same as the file pattern keys
ChunkKey = Index
SubsetSpec = Dict[str, int]
# SubsetSpec is a dictionary mapping dimension names to the number of subsets along that dimension
# (e.g. {'time': 5, 'depth': 2})
def _input_metadata_fname(input_key):
key_str = "-".join([f"{k.name}_{k.index}" for k in input_key])
return "input-meta-" + key_str + ".json"
def inputs_for_chunk(
chunk_key: ChunkKey, inputs_per_chunk: int, ninputs: int
) -> Sequence[InputKey]:
"""For a chunk key, figure out which inputs belong to it.
Returns at least one InputKey."""
merge_dims = [dim_idx for dim_idx in chunk_key if dim_idx.operation == CombineOp.MERGE]
concat_dims = [dim_idx for dim_idx in chunk_key if dim_idx.operation == CombineOp.CONCAT]
# Ignore subset dims, we don't need them here
# subset_dims = [dim_idx for dim_idx in chunk_key if dim_idx.operation == CombineOp.SUBSET]
assert len(merge_dims) <= 1
if len(merge_dims) == 1:
merge_dim = merge_dims[0] # type: Optional[DimIndex]
else:
merge_dim = None
assert len(concat_dims) == 1
concat_dim = concat_dims[0]
input_keys = []
for n in range(inputs_per_chunk):
input_index = (inputs_per_chunk * concat_dim.index) + n
if input_index >= ninputs:
break
input_concat_dim = DimIndex(concat_dim.name, input_index, ninputs, CombineOp.CONCAT)
input_key = [input_concat_dim]
if merge_dim is not None:
input_key.append(merge_dim)
input_keys.append(Index(input_key))
return input_keys
def expand_target_dim(target: FSSpecTarget, concat_dim: Optional[str], dimsize: int) -> None:
target_mapper = target.get_mapper()
zgroup = zarr.open_group(target_mapper)
ds = open_target(target)
sequence_axes = {
v: ds[v].get_axis_num(concat_dim) for v in ds.variables if concat_dim in ds[v].dims
}
for v, axis in sequence_axes.items():
arr = zgroup[v]
shape = list(arr.shape)
shape[axis] = dimsize
logger.debug(f"resizing array {v} to shape {shape}")
arr.resize(shape)
# now explicity write the sequence coordinate to avoid missing data
# when reopening
if concat_dim in zgroup:
zgroup[concat_dim][:] = 0
def open_target(target: FSSpecTarget) -> xr.Dataset:
return xr.open_zarr(target.get_mapper())
def input_position(input_key: InputKey) -> int:
"""Return the position of the input within the input sequence."""
for dim_idx in input_key:
# assumes there is one and only one concat dim
if dim_idx.operation == CombineOp.CONCAT:
return dim_idx.index
return -1 # make mypy happy
def chunk_position(chunk_key: ChunkKey) -> int:
"""Return the position of the input within the input sequence."""
concat_idx = -1
for dim_idx in chunk_key:
# assumes there is one and only one concat dim
if dim_idx.operation == CombineOp.CONCAT:
concat_idx = dim_idx.index
concat_dim = dim_idx.name
if concat_idx == -1:
raise ValueError("Couldn't find concat_dim")
subset_idx = 0
subset_factor = 1
for dim_idx in chunk_key:
if dim_idx.operation == CombineOp.SUBSET:
if dim_idx.name == concat_dim:
subset_idx = dim_idx.index
subset_factor = dim_idx.sequence_len
return subset_factor * concat_idx + subset_idx
def cache_input(input_key: InputKey, *, config: XarrayZarrRecipe) -> None:
if config.cache_inputs:
if config.file_pattern.is_opendap:
raise ValueError("Can't cache opendap inputs")
if config.input_cache is None:
raise ValueError("input_cache is not set.")
logger.info(f"Caching input '{input_key!s}'")
fname = config.file_pattern[input_key]
config.input_cache.cache_file(
fname,
config.file_pattern.query_string_secrets,
**config.file_pattern.fsspec_open_kwargs,
)
if config.cache_metadata:
if config.metadata_cache is None:
raise ValueError("metadata_cache is not set.")
logger.info(f"Caching metadata for input '{input_key!s}'")
with open_input(input_key, config=config) as ds:
input_metadata = ds.to_dict(data=False)
config.metadata_cache[_input_metadata_fname(input_key)] = input_metadata
def region_and_conflicts_for_chunk(
config: XarrayZarrRecipe, chunk_key: ChunkKey
) -> Tuple[Dict[str, slice], Dict[str, Set[int]]]:
# return a dict suitable to pass to xr.to_zarr(region=...)
# specifies where in the overall array to put this chunk's data
# also return the conflicts with other chunks
if config.nitems_per_input:
input_sequence_lens = (config.nitems_per_input,) * config.file_pattern.dims[
config.concat_dim
] # type: ignore
else:
assert config.metadata_cache is not None # for mypy
global_metadata = config.metadata_cache[_GLOBAL_METADATA_KEY]
input_sequence_lens = global_metadata["input_sequence_lens"]
total_len = sum(input_sequence_lens)
# for now this will just have one key since we only allow one concat_dim
# but it could expand to accomodate multiple concat dims
chunk_index = {config.concat_dim: chunk_position(chunk_key)}
input_chunk_grid = ChunkGrid({config.concat_dim: input_sequence_lens})
if config.subset_inputs and config.concat_dim in config.subset_inputs:
assert (
config.inputs_per_chunk == 1
), "Doesn't make sense to have multiple inputs per chunk plus subsetting"
chunk_grid = input_chunk_grid.subset(config.subset_inputs)
elif config.inputs_per_chunk > 1:
chunk_grid = input_chunk_grid.consolidate({config.concat_dim: config.inputs_per_chunk})
else:
chunk_grid = input_chunk_grid
assert chunk_grid.shape[config.concat_dim] == total_len
region = chunk_grid.chunk_index_to_array_slice(chunk_index)
assert config.concat_dim_chunks is not None
target_grid = ChunkGrid.from_uniform_grid(
{config.concat_dim: (config.concat_dim_chunks, total_len)}
)
conflicts = chunk_grid.chunk_conflicts(chunk_index, target_grid)
return region, conflicts
@contextmanager
def open_input(input_key: InputKey, *, config: XarrayZarrRecipe) -> xr.Dataset:
fname = config.file_pattern[input_key]
logger.info(f"Opening input with Xarray {input_key!s}: '{fname}'")
if config.file_pattern.is_opendap:
if config.input_cache:
raise ValueError("Can't cache opendap inputs")
if config.copy_input_to_local_file:
raise ValueError("Can't copy opendap inputs to local file")
cache = config.input_cache if config.cache_inputs else None
with file_opener(
fname,
cache=cache,
copy_to_local=config.copy_input_to_local_file,
bypass_open=config.file_pattern.is_opendap,
secrets=config.file_pattern.query_string_secrets,
**config.file_pattern.fsspec_open_kwargs,
) as f:
with dask.config.set(scheduler="single-threaded"): # make sure we don't use a scheduler
kw = config.xarray_open_kwargs.copy()
if "engine" not in kw:
kw["engine"] = "h5netcdf"
logger.debug(f"about to enter xr.open_dataset context on {f}")
with xr.open_dataset(f, **kw) as ds:
logger.debug("successfully opened dataset")
ds = fix_scalar_attr_encoding(ds)
if config.delete_input_encoding:
for var in ds.variables:
ds[var].encoding = {}
if config.process_input is not None:
ds = config.process_input(ds, str(fname))
logger.debug(f"{ds}")
yield ds
def subset_dataset(ds: xr.Dataset, subset_spec: DimIndex) -> xr.Dataset:
assert subset_spec.operation == CombineOp.SUBSET
dim = subset_spec.name
dim_len = ds.dims[dim]
subset_lens = calc_subsets(dim_len, subset_spec.sequence_len)
start = sum(subset_lens[: subset_spec.index])
stop = sum(subset_lens[: (subset_spec.index + 1)])
subset_slice = slice(start, stop)
indexer = {dim: subset_slice}
logger.debug(f"Subsetting dataset with indexer {indexer}")
return ds.isel(**indexer)
@contextmanager
def open_chunk(chunk_key: ChunkKey, *, config: XarrayZarrRecipe) -> xr.Dataset:
logger.info(f"Opening inputs for chunk {chunk_key!s}")
ninputs = config.file_pattern.dims[config.file_pattern.concat_dims[0]]
inputs = inputs_for_chunk(chunk_key, config.inputs_per_chunk, ninputs)
# need to open an unknown number of contexts at the same time
with ExitStack() as stack:
dsets = [stack.enter_context(open_input(input_key, config=config)) for input_key in inputs]
# subset before chunking; hopefully lazy
subset_dims = [dim_idx for dim_idx in chunk_key if dim_idx.operation == CombineOp.SUBSET]
for subset_dim in subset_dims:
logger.info(f"Subsetting input according to {subset_dim}")
dsets = [subset_dataset(ds, subset_dim) for ds in dsets]
# explicitly chunking prevents eager evaluation during concat
dsets = [ds.chunk() for ds in dsets]
logger.info(f"Combining inputs for chunk '{chunk_key!s}'")
if len(dsets) > 1:
# During concat, attributes and encoding are taken from the first dataset
# https://github.com/pydata/xarray/issues/1614
with dask.config.set(scheduler="single-threaded"): # make sure we don't use a scheduler
ds = xr.concat(dsets, config.concat_dim, **config.xarray_concat_kwargs)
elif len(dsets) == 1:
ds = dsets[0]
else: # pragma: no cover
assert False, "Should never happen"
if config.process_chunk is not None:
with dask.config.set(scheduler="single-threaded"): # make sure we don't use a scheduler
ds = config.process_chunk(ds)
with dask.config.set(scheduler="single-threaded"): # make sure we don't use a scheduler
logger.debug(f"{ds}")
if config.target_chunks:
# The input may be too large to process in memory at once, so
# rechunk it to the target chunks.
ds = ds.chunk(config.target_chunks)
yield ds
def get_input_meta(metadata_cache: Optional[MetadataTarget], *input_keys: InputKey,) -> Dict:
# getitems should be async; much faster than serial calls
if metadata_cache is None:
raise ValueError("metadata_cache is not set.")
return metadata_cache.getitems([_input_metadata_fname(k) for k in input_keys])
def calculate_sequence_lens(
nitems_per_input: Optional[int],
file_pattern: FilePattern,
metadata_cache: Optional[MetadataTarget],
) -> List[int]:
assert len(file_pattern.concat_dims) == 1
concat_dim = file_pattern.concat_dims[0]
if nitems_per_input:
concat_dim = file_pattern.concat_dims[0]
return list((nitems_per_input,) * file_pattern.dims[concat_dim])
# read per-input metadata; this is distinct from global metadata
# get the sequence length of every file
# this line could become problematic for large (> 10_000) lists of files
input_meta = get_input_meta(metadata_cache, *file_pattern)
# use a numpy array to allow reshaping
all_lens = np.array([m["dims"][concat_dim] for m in input_meta.values()])
all_lens.shape = list(file_pattern.dims.values())
# check that all lens are the same along the concat dim
concat_dim_axis = list(file_pattern.dims).index(concat_dim)
selector = [slice(0, 1)] * len(file_pattern.dims)
selector[concat_dim_axis] = slice(None) # this should broadcast correctly agains all_lens
sequence_lens = all_lens[tuple(selector)]
if not (all_lens == sequence_lens).all():
raise ValueError(f"Inconsistent sequence lengths found: f{all_lens}")
return np.atleast_1d(sequence_lens.squeeze()).tolist()
def prepare_target(*, config: XarrayZarrRecipe) -> None:
if config.target is None:
raise ValueError("Cannot proceed without a target")
try:
ds = open_target(config.target)
logger.info("Found an existing dataset in target")
logger.debug(f"{ds}")
if config.target_chunks:
# TODO: check that target_chunks id compatibile with the
# existing chunks
pass
except (FileNotFoundError, IOError, zarr.errors.GroupNotFoundError):
logger.info("Creating a new dataset in target")
# need to rewrite this as an append loop
def filter_init_chunks(chunk_key):
for dim_idx in chunk_key:
if (dim_idx.operation != CombineOp.MERGE) and (dim_idx.index > 0):
return False
return True
init_chunks = list(filter(filter_init_chunks, config.iter_chunks()))
for chunk_key in init_chunks:
with open_chunk(chunk_key, config=config) as ds:
# ds is already chunked
# https://github.com/pydata/xarray/blob/5287c7b2546fc8848f539bb5ee66bb8d91d8496f/xarray/core/variable.py#L1069
for v in ds.variables:
if config.target_chunks:
this_var = ds[v]
chunks = {
this_var.get_axis_num(dim): chunk
for dim, chunk in config.target_chunks.items()
if dim in this_var.dims
}
encoding_chunks = tuple(
chunks.get(n, s) for n, s in enumerate(this_var.shape)
)
else:
encoding_chunks = ds[v].shape
logger.debug(f"Setting variable {v} encoding chunks to {encoding_chunks}")
ds[v].encoding["chunks"] = encoding_chunks
# load all variables that don't have the sequence dim in them
# these are usually coordinates.
# Variables that are loaded will be written even with compute=False
# TODO: make this behavior customizable
for v in ds.variables:
if config.concat_dim not in ds[v].dims:
ds[v].load()
target_mapper = config.target.get_mapper()
logger.info(f"Storing dataset in {config.target.root_path}")
logger.debug(f"{ds}")
with warnings.catch_warnings():
warnings.simplefilter(
"ignore"
) # suppress the warning that comes with safe_chunks
ds.to_zarr(target_mapper, mode="a", compute=False, safe_chunks=False)
# Regardless of whether there is an existing dataset or we are creating a new one,
# we need to expand the concat_dim to hold the entire expected size of the data
input_sequence_lens = calculate_sequence_lens(
config.nitems_per_input, config.file_pattern, config.metadata_cache,
)
n_sequence = sum(input_sequence_lens)
logger.info(f"Expanding target concat dim '{config.concat_dim}' to size {n_sequence}")
expand_target_dim(config.target, config.concat_dim, n_sequence)
# TODO: handle possible subsetting
# The init chunks might not cover the whole dataset along multiple dimensions!
if config.cache_metadata:
# if nitems_per_input is not constant, we need to cache this info
assert config.metadata_cache is not None # for mypy
recipe_meta = {"input_sequence_lens": input_sequence_lens}
config.metadata_cache[_GLOBAL_METADATA_KEY] = recipe_meta
def store_chunk(chunk_key: ChunkKey, *, config: XarrayZarrRecipe) -> None:
if config.target is None:
raise ValueError("target has not been set.")
with open_chunk(chunk_key, config=config) as ds_chunk:
# writing a region means that all the variables MUST have concat_dim
to_drop = [v for v in ds_chunk.variables if config.concat_dim not in ds_chunk[v].dims]
ds_chunk = ds_chunk.drop_vars(to_drop)
target_mapper = config.target.get_mapper()
write_region, conflicts = region_and_conflicts_for_chunk(config, chunk_key)
zgroup = zarr.open_group(target_mapper)
for vname, var_coded in ds_chunk.variables.items():
zarr_array = zgroup[vname]
# get encoding for variable from zarr attributes
# could this backfire some way?
var_coded.encoding.update(zarr_array.attrs)
# just delete all attributes from the var;
# they are not used anyway, and there can be conflicts
# related to xarray.coding.variables.safe_setitem
var_coded.attrs = {}
with dask.config.set(scheduler="single-threaded"): # make sure we don't use a scheduler
var = xr.backends.zarr.encode_zarr_variable(var_coded)
logger.debug(
f"Converting variable {vname} of {var.data.nbytes} bytes to `numpy.ndarray`"
)
if var.data.nbytes > MAX_MEMORY:
factor = round((var.data.nbytes / MAX_MEMORY), 2)
cdim = config.concat_dim
logger.warning(
f"Variable {vname} of {var.data.nbytes} bytes is {factor} times larger "
f"than specified maximum variable array size of {MAX_MEMORY} bytes. "
f'Consider re-instantiating recipe with `subset_inputs = {{"{cdim}": '
f'{ceil(factor)}}}`. If `len(ds["{cdim}"])` < {ceil(factor)}, '
f'substitute "{cdim}" for any name in ds["{vname}"].dims with length '
f">= {ceil(factor)} or consider subsetting along multiple dimensions."
" Setting PANGEO_FORGE_MAX_MEMORY env variable changes the variable array"
" size which will trigger this warning."
)
data = np.asarray(
var.data
) # TODO: can we buffer large data rather than loading it all?
zarr_region = tuple(write_region.get(dim, slice(None)) for dim in var.dims)
lock_keys = [
f"{vname}-{dim}-{c}"
for dim, dim_conflicts in conflicts.items()
for c in dim_conflicts
]
logger.debug(f"Acquiring locks {lock_keys}")
with lock_for_conflicts(lock_keys, timeout=config.lock_timeout):
logger.info(
f"Storing variable {vname} chunk {chunk_key!s} " f"to Zarr region {zarr_region}"
)
zarr_array[zarr_region] = data
def _gather_coordinate_dimensions(group: zarr.Group) -> List[str]:
return list(
set(itertools.chain(*(group[var].attrs.get(_ARRAY_DIMENSIONS, []) for var in group)))
)
def finalize_target(*, config: XarrayZarrRecipe) -> None:
if config.target is None:
raise ValueError("target has not been set.")
if config.consolidate_dimension_coordinates:
logger.info("Consolidating dimension coordinate arrays")
target_mapper = config.target.get_mapper()
group = zarr.open(target_mapper, mode="a")
# https://github.com/pangeo-forge/pangeo-forge-recipes/issues/214
# intersect the dims from the array metadata with the Zarr group
# to handle coordinateless dimensions.
dims = set(_gather_coordinate_dimensions(group)) & set(group)
for dim in dims:
arr = group[dim]
attrs = dict(arr.attrs)
new = group.array(
dim,
arr[:],
chunks=arr.shape,
dtype=arr.dtype,
compressor=arr.compressor,
fill_value=arr.fill_value,
order=arr.order,
filters=arr.filters,
overwrite=True,
)
new.attrs.update(attrs)
if config.consolidate_zarr:
logger.info("Consolidating Zarr metadata")
target_mapper = config.target.get_mapper()
zarr.consolidate_metadata(target_mapper)
def xarray_zarr_recipe_compiler(recipe: XarrayZarrRecipe) -> Pipeline:
stages = [
Stage(name="cache_input", function=cache_input, mappable=list(recipe.iter_inputs())),
Stage(name="prepare_target", function=prepare_target),
Stage(name="store_chunk", function=store_chunk, mappable=list(recipe.iter_chunks())),
Stage(name="finalize_target", function=finalize_target),
]
return Pipeline(stages=stages, config=recipe)
# Notes about dataclasses:
# - https://www.python.org/dev/peps/pep-0557/#inheritance
# - https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses
@dataclass
class XarrayZarrRecipe(BaseRecipe, FilePatternMixin):
"""This configuration represents a dataset composed of many individual NetCDF files.
This class uses Xarray to read and write data and writes its output to Zarr.
The organization of the source files is described by the ``file_pattern``.
Currently this recipe supports at most one ``MergeDim`` and one ``ConcatDim``
in the File Pattern.
:param file_pattern: An object which describes the organization of the input files.
:param inputs_per_chunk: The number of inputs to use in each chunk along the concat dim.
Must be an integer >= 1.
:param target_chunks: Desired chunk structure for the targret dataset. This is a dictionary
mapping dimension names to chunk size. When using a :class:`patterns.FilePattern` with
a :class:`patterns.ConcatDim` that specifies ``n_items_per_file``, then you don't need
to include the concat dim in ``target_chunks``.
:param target: A location in which to put the dataset. Can also be assigned at run time.
:param input_cache: A location in which to cache temporary data.
:param metadata_cache: A location in which to cache metadata for inputs and chunks.
Required if ``nitems_per_file=None`` on concat dim in file pattern.
:param cache_inputs: If ``True``, inputs are copied to ``input_cache`` before
opening. If ``False``, try to open inputs directly from their source location.
:param copy_input_to_local_file: Whether to copy the inputs to a temporary
local file. In this case, a path (rather than file object) is passed to
``xr.open_dataset``. This is required for engines that can't open
file-like objects (e.g. pynio).
:param consolidate_zarr: Whether to consolidate the resulting Zarr dataset.
:param consolidate_dimension_coordinates: Whether to rewrite coordinate variables as a
single chunk. We recommend consolidating coordinate variables to avoid
many small read requests to get the coordinates in xarray.
:param xarray_open_kwargs: Extra options for opening the inputs with Xarray.
:param xarray_concat_kwargs: Extra options to pass to Xarray when concatenating
the inputs to form a chunk.
:param delete_input_encoding: Whether to remove Xarray encoding from variables
in the input dataset
:param process_input: Function to call on each opened input, with signature
`(ds: xr.Dataset, filename: str) -> ds: xr.Dataset`.
:param process_chunk: Function to call on each concatenated chunk, with signature
`(ds: xr.Dataset) -> ds: xr.Dataset`.
:param lock_timeout: The default timeout for acquiring a chunk lock.
:param subset_inputs: If set, break each input file up into multiple chunks
along dimension according to the specified mapping. For example,
``{'time': 5}`` would split each input file into 5 chunks along the
time dimension. Multiple dimensions are allowed.
"""
_compiler = xarray_zarr_recipe_compiler
inputs_per_chunk: int = 1
target_chunks: Dict[str, int] = field(default_factory=dict)
target: Optional[FSSpecTarget] = None
input_cache: Optional[CacheFSSpecTarget] = None
metadata_cache: Optional[MetadataTarget] = None
cache_inputs: Optional[bool] = None
copy_input_to_local_file: bool = False
consolidate_zarr: bool = True
consolidate_dimension_coordinates: bool = True
xarray_open_kwargs: dict = field(default_factory=dict)
xarray_concat_kwargs: dict = field(default_factory=dict)
delete_input_encoding: bool = True
process_input: Optional[Callable[[xr.Dataset, str], xr.Dataset]] = None
process_chunk: Optional[Callable[[xr.Dataset], xr.Dataset]] = None
lock_timeout: Optional[int] = None
subset_inputs: SubsetSpec = field(default_factory=dict)
# internal attributes not meant to be seen or accessed by user
concat_dim: str = field(default_factory=str, repr=False, init=False)
"""The concatenation dimension name."""
concat_dim_chunks: Optional[int] = field(default=None, repr=False, init=False)
"""The desired chunking along the sequence dimension."""
init_chunks: List[ChunkKey] = field(default_factory=list, repr=False, init=False)
"""List of chunks needed to initialize the recipe."""
cache_metadata: bool = field(default=False, repr=False, init=False)
"""Whether metadata caching is needed."""
nitems_per_input: Optional[int] = field(default=None, repr=False, init=False)
"""How many items per input along concat_dim."""
def __post_init__(self):
self._validate_file_pattern()
# from here on we know there is at most one merge dim and one concat dim
self.concat_dim = self.file_pattern.concat_dims[0]
self.cache_metadata = any(
[v is None for v in self.file_pattern.concat_sequence_lens.values()]
)
self.nitems_per_input = self.file_pattern.nitems_per_input[self.concat_dim]
if self.file_pattern.is_opendap:
if self.cache_inputs:
raise ValueError("Can't cache opendap inputs.")
else:
self.cache_inputs = False
if "engine" in self.xarray_open_kwargs:
if self.xarray_open_kwargs["engine"] != "netcdf4":
raise ValueError(
"Opendap inputs only work with `xarray_open_kwargs['engine'] == 'netcdf4'`"
)
else:
new_kw = self.xarray_open_kwargs.copy()
new_kw["engine"] = "netcdf4"
self.xarray_open_kwargs = new_kw
elif self.cache_inputs is None:
self.cache_inputs = True # old defult
self._validate_input_and_chunk_keys()
# set concat_dim_chunks
target_concat_dim_chunks = self.target_chunks.get(self.concat_dim)
if (self.nitems_per_input is None) and (target_concat_dim_chunks is None):
raise ValueError(
"Unable to determine target chunks. Please specify either "
"`target_chunks` or `nitems_per_input`"
)
elif target_concat_dim_chunks:
self.concat_dim_chunks = target_concat_dim_chunks
else:
self.concat_dim_chunks = self.nitems_per_input * self.inputs_per_chunk
def _validate_file_pattern(self):
if len(self.file_pattern.merge_dims) > 1:
raise NotImplementedError("This Recipe class can't handle more than one merge dim.")
if len(self.file_pattern.concat_dims) > 1:
raise NotImplementedError("This Recipe class can't handle more than one concat dim.")
def _validate_input_and_chunk_keys(self):
all_input_keys = set(self.iter_inputs())
ninputs = self.file_pattern.dims[self.file_pattern.concat_dims[0]]
all_inputs_for_chunks = set(
list(
chain(
*(
inputs_for_chunk(chunk_key, self.inputs_per_chunk, ninputs)
for chunk_key in self.iter_chunks()
)
)
)
)
if all_input_keys != all_inputs_for_chunks:
chunk_key = next(iter(self.iter_chunks()))
print("First chunk", chunk_key)
print("Inputs_for_chunk", inputs_for_chunk(chunk_key, self.inputs_per_chunk, ninputs))
raise ValueError("Inputs and chunks are inconsistent")
def iter_inputs(self) -> Iterator[InputKey]:
yield from self.file_pattern
def iter_chunks(self) -> Iterator[ChunkKey]:
for input_key in self.iter_inputs():
concat_dims = [
dim_idx for dim_idx in input_key if dim_idx.operation == CombineOp.CONCAT
]
assert len(concat_dims) == 1
concat_dim = concat_dims[0]
input_concat_index = concat_dim.index
if input_concat_index % self.inputs_per_chunk > 0:
continue # don't emit a chunk
chunk_concat_index = input_concat_index // self.inputs_per_chunk
chunk_sequence_len = ceil(concat_dim.sequence_len / self.inputs_per_chunk)
chunk_concat_dim = replace(
concat_dim, index=chunk_concat_index, sequence_len=chunk_sequence_len
)
chunk_key_base = [
chunk_concat_dim if dim_idx.operation == CombineOp.CONCAT else dim_idx
for dim_idx in input_key
]
if len(self.subset_inputs) == 0:
yield Index(chunk_key_base)
# no subsets
continue
subset_iterators = [range(v) for k, v in self.subset_inputs.items()]
for i in product(*subset_iterators):
# TODO: remove redundant name
subset_dims = [
DimIndex(*args, CombineOp.SUBSET)
for args in zip(self.subset_inputs.keys(), i, self.subset_inputs.values())
]
yield Index((chunk_key_base + subset_dims))
def inputs_for_chunk(self, chunk_key: ChunkKey) -> Sequence[InputKey]:
"""Convenience function for users to introspect recipe."""
ninputs = self.file_pattern.dims[self.file_pattern.concat_dims[0]]
return inputs_for_chunk(chunk_key, self.inputs_per_chunk, ninputs)