From 4faaf3a4bf5be7325e30b4f07639f19273625e2f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 30 Jun 2018 22:52:39 -0700 Subject: [PATCH 01/39] WIP: xarray.backends.file_manager for managing file objects. This is intended to replace both PickleByReconstructionWrapper and DataStorePickleMixin with something more compartmentalized. xref GH2121 --- xarray/backends/file_manager.py | 155 +++++++++++++++++++++ xarray/tests/test_backends_file_manager.py | 70 ++++++++++ 2 files changed, 225 insertions(+) create mode 100644 xarray/backends/file_manager.py create mode 100644 xarray/tests/test_backends_file_manager.py diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py new file mode 100644 index 00000000000..34ad39ff885 --- /dev/null +++ b/xarray/backends/file_manager.py @@ -0,0 +1,155 @@ +import contextlib +import threading + + +class FileManager(object): + """Base class for context managers for managing file objects. + + Unlike files, FileManager objects should be safely. They must be explicitly + closed. + + Example usage: + + import functools + + manager = FileManager(functools.partial(open, filename), mode='w') + with manager.acquire() as f: + f.write(...) + manager.close() + """ + + def __init__(self, opener, mode=None): + """Initialize a FileManager. + + Parameters + ---------- + opener : callable + Callable that opens a given file when called, returning a file + object. + mode : str, optional + If provided, passed to opener as a keyword argument. + """ + raise NotImplementedError + + @contextlib.contextmanager + def acquire(self): + """Context manager for acquiring a file object. + + This method must be thread-safe: it should be safe to simultaneously + acquire a file in multiple threads at the same time (assuming that + the underlying file object is thread-safe). + + Yields + ------ + Open file object, as returned by opener(). + """ + raise NotImplementedError + + def close(self): + """Explicitly close any associated file object (if necessary).""" + raise NotImplementedError + + +_DEFAULT_MODE = object() + + +def _open(opener, mode): + return opener() if mode is _DEFAULT_MODE else opener(mode=mode) + + +class ExplicitFileManager(FileManager): + """A file manager that holds a file open until explicitly closed. + + This is mostly a reference implementation: must real use cases should use + ExplicitLazyFileContext for better performance. + """ + + def __init__(self, opener, mode=_DEFAULT_MODE): + self._opener = opener + # file has already been created, don't override when restoring + self._mode = 'a' if mode == 'w' else mode + self._file = _open(opener, mode) + + @contextlib.contextmanager + def acquire(self): + yield self._file + + def close(self): + self._file.close() + + def __getstate__(self): + return {'opener': self._opener, 'mode': self._mode} + + def __setstate__(self, state): + self.__init__(**state) + + +class LazyFileManager(FileManager): + """An explicit file manager that lazily opens files.""" + + def __init__(self, opener, mode=_DEFAULT_MODE): + self._opener = opener + self._mode = mode + self._lock = threading.Lock() + self._file = None + + @contextlib.contextmanager + def acquire(self): + with self._lock: + if self._file is None: + self._file = _open(self._opener, self._mode) + # file has already been created, don't override when restoring + if self._mode == 'w': + self._mode = 'a' + yield self._file + + def close(self): + if self._file is not None: + self._file.close() + + def __getstate__(self): + return {'opener': self._opener, 'mode': self._mode} + + def __setstate__(self, state): + self.__init__(**state) + + +class AutoclosingFileManager(FileManager): + """A FileManager that automatically opens/closes files when used.""" + + def __init__(self, opener, mode=_DEFAULT_MODE): + self._opener = opener + self._mode = mode + self._lock = threading.Lock() + self._file = None + self._references = 0 + + @contextlib.contextmanager + def acquire(self): + with self._lock: + if self._file is None: + self._file = _open(self._opener, self._mode) + # file has already been created, don't override when restoring + if self._mode == 'w': + self._mode = 'a' + self._references += 1 + + yield self._file + + with self._lock: + self._references -= 1 + if not self._references: + self._file.close() + self._file = None + + def close(self): + pass + + def __getstate__(self): + return {'opener': self._opener, 'mode': self._mode} + + def __setstate__(self, state): + self.__init__(**state) + + +# TODO: write a FileManager that makes use of an LRU cache. diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py new file mode 100644 index 00000000000..2803ed71aa9 --- /dev/null +++ b/xarray/tests/test_backends_file_manager.py @@ -0,0 +1,70 @@ +import functools +import pickle + +import pytest + +from xarray.backends.file_manager import ( + ExplicitFileManager, LazyFileManager, AutoclosingFileManager +) + +FILE_MANAGERS = [ + ExplicitFileManager, LazyFileManager, AutoclosingFileManager, +] + +@pytest.mark.parametrize('manager_type', FILE_MANAGERS) +def test_file_manager_write_consecutive(tmpdir, manager_type): + path = str(tmpdir.join('testing.txt')) + manager = manager_type(functools.partial(open, path), mode='w') + with manager.acquire() as f: + f.write('foo') + with manager.acquire() as f: + f.write('bar') + manager.close() + + with open(path, 'r') as f: + assert f.read() == 'foobar' + + +@pytest.mark.parametrize('manager_type', FILE_MANAGERS) +def test_file_manager_write_concurrent(tmpdir, manager_type): + path = str(tmpdir.join('testing.txt')) + manager = manager_type(functools.partial(open, path), mode='w') + with manager.acquire() as f1: + with manager.acquire() as f2: + f1.write('foo') + f2.write('bar') + manager.close() + + with open(path, 'r') as f: + assert f.read() == 'foobar' + + +@pytest.mark.parametrize('manager_type', FILE_MANAGERS) +def test_file_manager_write_pickle(tmpdir, manager_type): + path = str(tmpdir.join('testing.txt')) + manager = manager_type( + functools.partial(open, path), mode='w') + with manager.acquire() as f: + f.write('foo') + f.flush() + manager2 = pickle.loads(pickle.dumps(manager)) + with manager2.acquire() as f: + f.write('bar') + manager2.close() + manager.close() + + with open(path, 'r') as f: + assert f.read() == 'foobar' + + +@pytest.mark.parametrize('manager_type', FILE_MANAGERS) +def test_file_manager_read(tmpdir, manager_type): + path = str(tmpdir.join('testing.txt')) + + with open(path, 'w') as f: + f.write('foobar') + + manager = manager_type(functools.partial(open, path)) + with manager.acquire() as f: + assert f.read() == 'foobar' + manager.close() From c82a38c3528ee161eb47720e35ab7ebc2f750cf5 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 30 Jun 2018 23:14:03 -0700 Subject: [PATCH 02/39] Switch rasterio to use FileManager --- xarray/backends/common.py | 27 ----- xarray/backends/rasterio_.py | 194 ++++++++++++++++++---------------- xarray/tests/test_backends.py | 29 +---- 3 files changed, 103 insertions(+), 147 deletions(-) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index d5eccd9be52..cccdbe9e193 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -508,30 +508,3 @@ def assert_open(self): if not self._isopen: raise AssertionError('internal failure: file must be open ' 'if `autoclose=True` is used.') - - -class PickleByReconstructionWrapper(object): - - def __init__(self, opener, file, mode='r', **kwargs): - self.opener = partial(opener, file, mode=mode, **kwargs) - self.mode = mode - self._ds = None - - @property - def value(self): - self._ds = self.opener() - return self._ds - - def __getstate__(self): - state = self.__dict__.copy() - del state['_ds'] - if self.mode == 'w': - # file has already been created, don't override when restoring - state['mode'] = 'a' - return state - - def __setstate__(self, state): - self.__dict__.update(state) - - def close(self): - self._ds.close() diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 0f19a1b51be..53bc0564bf7 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -1,5 +1,6 @@ import os from collections import OrderedDict +import functools from distutils.version import LooseVersion import warnings @@ -8,7 +9,8 @@ from .. import DataArray from ..core import indexing from ..core.utils import is_scalar -from .common import BackendArray, PickleByReconstructionWrapper +from .common import BackendArray +from .file_manager import LazyFileManager try: from dask.utils import SerializableLock as Lock @@ -25,18 +27,20 @@ class RasterioArrayWrapper(BackendArray): """A wrapper around rasterio dataset objects""" - def __init__(self, riods): - self.riods = riods - self._shape = (riods.value.count, riods.value.height, - riods.value.width) - self._ndims = len(self.shape) + def __init__(self, manager, riods): + self.manager = manager - @property - def dtype(self): - dtypes = self.riods.value.dtypes + # cannot save riods as an attribute: this would break pickleability. + self._shape = (riods.count, riods.height, riods.width) + + dtypes = riods.dtypes if not np.all(np.asarray(dtypes) == dtypes[0]): raise ValueError('All bands should have the same dtype') - return np.dtype(dtypes[0]) + self._dtype = np.dtype(dtypes[0]) + + @property + def dtype(self): + return self._dtype @property def shape(self): @@ -104,8 +108,8 @@ def _get_indexer(self, key): def __getitem__(self, key): band_key, window, squeeze_axis, np_inds = self._get_indexer(key) - - out = self.riods.value.read(band_key, window=tuple(window)) + with self.manager.acquire() as riods: + out = riods.read(band_key, window=tuple(window)) if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) return indexing.NumpyIndexingAdapter(out)[np_inds] @@ -195,91 +199,97 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, import rasterio - riods = PickleByReconstructionWrapper(rasterio.open, filename, mode='r') + manager = LazyFileManager( + functools.partial(rasterio.open, filename), mode='r') if cache is None: cache = chunks is None coords = OrderedDict() - # Get bands - if riods.value.count < 1: - raise ValueError('Unknown dims') - coords['band'] = np.asarray(riods.value.indexes) - - # Get coordinates - if LooseVersion(rasterio.__version__) < '1.0': - transform = riods.value.affine - else: - transform = riods.value.transform - if transform.is_rectilinear: - # 1d coordinates - parse = True if parse_coordinates is None else parse_coordinates - if parse: - nx, ny = riods.value.width, riods.value.height - # xarray coordinates are pixel centered - x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform - _, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform - coords['y'] = y - coords['x'] = x - else: - # 2d coordinates - parse = False if (parse_coordinates is None) else parse_coordinates - if parse: - warnings.warn("The file coordinates' transformation isn't " - "rectilinear: xarray won't parse the coordinates " - "in this case. Set `parse_coordinates=False` to " - "suppress this warning.", - RuntimeWarning, stacklevel=3) - - # Attributes - attrs = dict() - # Affine transformation matrix (always available) - # This describes coefficients mapping pixel coordinates to CRS - # For serialization store as tuple of 6 floats, the last row being - # always (0, 0, 1) per definition (see https://github.com/sgillies/affine) - attrs['transform'] = tuple(transform)[:6] - if hasattr(riods.value, 'crs') and riods.value.crs: - # CRS is a dict-like object specific to rasterio - # If CRS is not None, we convert it back to a PROJ4 string using - # rasterio itself - attrs['crs'] = riods.value.crs.to_string() - if hasattr(riods.value, 'res'): - # (width, height) tuple of pixels in units of CRS - attrs['res'] = riods.value.res - if hasattr(riods.value, 'is_tiled'): - # Is the TIF tiled? (bool) - # We cast it to an int for netCDF compatibility - attrs['is_tiled'] = np.uint8(riods.value.is_tiled) - with warnings.catch_warnings(): - # casting riods.value.transform to a tuple makes this future proof - warnings.simplefilter('ignore', FutureWarning) - if hasattr(riods.value, 'transform'): - # Affine transformation matrix (tuple of floats) - # Describes coefficients mapping pixel coordinates to CRS - attrs['transform'] = tuple(riods.value.transform) - if hasattr(riods.value, 'nodatavals'): - # The nodata values for the raster bands - attrs['nodatavals'] = tuple([np.nan if nodataval is None else nodataval - for nodataval in riods.value.nodatavals]) - - # Parse extra metadata from tags, if supported - parsers = {'ENVI': _parse_envi} - - driver = riods.value.driver - if driver in parsers: - meta = parsers[driver](riods.value.tags(ns=driver)) - - for k, v in meta.items(): - # Add values as coordinates if they match the band count, - # as attributes otherwise - if (isinstance(v, (list, np.ndarray)) and - len(v) == riods.value.count): - coords[k] = ('band', np.asarray(v)) - else: - attrs[k] = v - - data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(riods)) + with manager.acquire() as riods: + # Get bands + if riods.count < 1: + raise ValueError('Unknown dims') + coords['band'] = np.asarray(riods.indexes) + + # Get coordinates + if LooseVersion(rasterio.__version__) < '1.0': + transform = riods.affine + else: + transform = riods.transform + if transform.is_rectilinear: + # 1d coordinates + parse = True if parse_coordinates is None else parse_coordinates + if parse: + nx, ny = riods.width, riods.height + # xarray coordinates are pixel centered + x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform + _, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform + coords['y'] = y + coords['x'] = x + else: + # 2d coordinates + parse = False if (parse_coordinates is None) else parse_coordinates + if parse: + warnings.warn( + "The file coordinates' transformation isn't " + "rectilinear: xarray won't parse the coordinates " + "in this case. Set `parse_coordinates=False` to " + "suppress this warning.", + RuntimeWarning, stacklevel=3) + + # Attributes + attrs = dict() + # Affine transformation matrix (always available) + # This describes coefficients mapping pixel coordinates to CRS + # For serialization store as tuple of 6 floats, the last row being + # always (0, 0, 1) per definition (see + # https://github.com/sgillies/affine) + attrs['transform'] = tuple(transform)[:6] + if hasattr(riods, 'crs') and riods.crs: + # CRS is a dict-like object specific to rasterio + # If CRS is not None, we convert it back to a PROJ4 string using + # rasterio itself + attrs['crs'] = riods.crs.to_string() + if hasattr(riods, 'res'): + # (width, height) tuple of pixels in units of CRS + attrs['res'] = riods.res + if hasattr(riods, 'is_tiled'): + # Is the TIF tiled? (bool) + # We cast it to an int for netCDF compatibility + attrs['is_tiled'] = np.uint8(riods.is_tiled) + with warnings.catch_warnings(): + # casting riods.transform to a tuple makes this future proof + warnings.simplefilter('ignore', FutureWarning) + if hasattr(riods, 'transform'): + # Affine transformation matrix (tuple of floats) + # Describes coefficients mapping pixel coordinates to CRS + attrs['transform'] = tuple(riods.transform) + if hasattr(riods, 'nodatavals'): + # The nodata values for the raster bands + attrs['nodatavals'] = tuple( + np.nan if nodataval is None else nodataval + for nodataval in riods.nodatavals) + + # Parse extra metadata from tags, if supported + parsers = {'ENVI': _parse_envi} + + driver = riods.driver + if driver in parsers: + meta = parsers[driver](riods.tags(ns=driver)) + + for k, v in meta.items(): + # Add values as coordinates if they match the band count, + # as attributes otherwise + if (isinstance(v, (list, np.ndarray)) and + len(v) == riods.count): + coords[k] = ('band', np.asarray(v)) + else: + attrs[k] = v + + data = indexing.LazilyOuterIndexedArray( + RasterioArrayWrapper(manager, riods)) # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) @@ -305,6 +315,6 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, lock=lock) # Make the file closeable - result._file_obj = riods + result._file_obj = manager return result diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index e83b80a6dd8..cb36938d7b9 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -19,8 +19,7 @@ from xarray import ( DataArray, Dataset, backends, open_dataarray, open_dataset, open_mfdataset, save_mfdataset) -from xarray.backends.common import (robust_getitem, - PickleByReconstructionWrapper) +from xarray.backends.common import robust_getitem from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding from xarray.backends.pydap_ import PydapDataStore from xarray.core import indexing @@ -3281,29 +3280,3 @@ def test_dataarray_to_netcdf_no_name_pathlib(self): with open_dataarray(tmp) as loaded_da: assert_identical(original_da, loaded_da) - - -def test_pickle_reconstructor(): - - lines = ['foo bar spam eggs'] - - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp: - with open(tmp, 'w') as f: - f.writelines(lines) - - obj = PickleByReconstructionWrapper(open, tmp) - - assert obj.value.readlines() == lines - - p_obj = pickle.dumps(obj) - obj.value.close() # for windows - obj2 = pickle.loads(p_obj) - - assert obj2.value.readlines() == lines - - # roundtrip again to make sure we can fully restore the state - p_obj2 = pickle.dumps(obj2) - obj2.value.close() # for windows - obj3 = pickle.loads(p_obj2) - - assert obj3.value.readlines() == lines From 7a55a3074ac28556fe9d10bf2af46d2a3e05f464 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 4 Jul 2018 10:09:52 -0700 Subject: [PATCH 03/39] lint fixes --- xarray/backends/file_manager.py | 2 +- xarray/tests/test_backends_file_manager.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 34ad39ff885..e02e2c677d1 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -27,7 +27,7 @@ def __init__(self, opener, mode=None): Callable that opens a given file when called, returning a file object. mode : str, optional - If provided, passed to opener as a keyword argument. + If provided, passed to opener as a keyword argument. """ raise NotImplementedError diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py index 2803ed71aa9..1693f84f44f 100644 --- a/xarray/tests/test_backends_file_manager.py +++ b/xarray/tests/test_backends_file_manager.py @@ -11,6 +11,7 @@ ExplicitFileManager, LazyFileManager, AutoclosingFileManager, ] + @pytest.mark.parametrize('manager_type', FILE_MANAGERS) def test_file_manager_write_consecutive(tmpdir, manager_type): path = str(tmpdir.join('testing.txt')) From 51463ddbc05c5bd42e41034b9082c84431dd548e Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 9 Jul 2018 15:48:19 -0700 Subject: [PATCH 04/39] WIP: rewrite FileManager to always use an LRUCache --- xarray/backends/file_manager.py | 185 +++++++++------------ xarray/backends/lru_cache.py | 78 +++++++++ xarray/backends/rasterio_.py | 5 +- xarray/tests/test_backends_file_manager.py | 35 ++-- 4 files changed, 166 insertions(+), 137 deletions(-) create mode 100644 xarray/backends/lru_cache.py diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index e02e2c677d1..ae90d775015 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -1,155 +1,118 @@ import contextlib import threading +from .lru_cache import LRUCache + + +# Global cache for storing open files. +FILE_CACHE = LRUCache(512, on_evict=lambda k, v: v.close()) + +# TODO(shoyer): add an option (xarray.set_options) for resizing the cache. + class FileManager(object): - """Base class for context managers for managing file objects. + """Wrapper for automatically opening and closing file objects. - Unlike files, FileManager objects should be safely. They must be explicitly - closed. + Unlike files, FileManager objects can be safely pickled and passed between + processes. They should be explicitly closed to release resources, but + a per-process least-recently-used cache for open files ensures that you can + safely create arbitrarily large numbers of FileManager objects. Example usage: - import functools - - manager = FileManager(functools.partial(open, filename), mode='w') + manager = FileManager(open, 'example.txt', mode='w') with manager.acquire() as f: f.write(...) manager.close() """ - def __init__(self, opener, mode=None): + def __init__(self, opener, *args, **kwargs): """Initialize a FileManager. Parameters ---------- opener : callable - Callable that opens a given file when called, returning a file - object. - mode : str, optional - If provided, passed to opener as a keyword argument. + Function that when called like ``opener(*args, **kwargs)`` returns + an open file object. The file object must implement a ``close()`` + method. + *args + Positional arguments for opener. A ``mode`` argument should be + provided as a keyword argument (see below). + **kwargs + Keyword arguments for opener. The keyword argument ``mode`` has + special handling if it is provided with a value of 'w': on all + calls after the first, it is changed to 'a' instead to avoid + overriding the newly created file. All argument values must be + hashable. """ - raise NotImplementedError + self._opener = opener + self._args = args + self._kwargs = kwargs + self._key = _make_key(opener, args, kwargs) + self._lock = threading.RLock() @contextlib.contextmanager def acquire(self): """Context manager for acquiring a file object. - This method must be thread-safe: it should be safe to simultaneously - acquire a file in multiple threads at the same time (assuming that - the underlying file object is thread-safe). + A new file is only opened if it has expired from the + least-recently-used cache. + + This method uses a reentrant lock, which ensures that it is + thread-safe. You can safely acquire a file in multiple threads at the + same time, as long as the underlying file object is thread-safe. Yields ------ - Open file object, as returned by opener(). + Open file object, as returned by ``opener(*args, **kwargs)``. """ - raise NotImplementedError - - def close(self): - """Explicitly close any associated file object (if necessary).""" - raise NotImplementedError - - -_DEFAULT_MODE = object() - - -def _open(opener, mode): - return opener() if mode is _DEFAULT_MODE else opener(mode=mode) - - -class ExplicitFileManager(FileManager): - """A file manager that holds a file open until explicitly closed. - - This is mostly a reference implementation: must real use cases should use - ExplicitLazyFileContext for better performance. - """ - - def __init__(self, opener, mode=_DEFAULT_MODE): - self._opener = opener - # file has already been created, don't override when restoring - self._mode = 'a' if mode == 'w' else mode - self._file = _open(opener, mode) - - @contextlib.contextmanager - def acquire(self): - yield self._file - - def close(self): - self._file.close() - - def __getstate__(self): - return {'opener': self._opener, 'mode': self._mode} - - def __setstate__(self, state): - self.__init__(**state) - - -class LazyFileManager(FileManager): - """An explicit file manager that lazily opens files.""" - - def __init__(self, opener, mode=_DEFAULT_MODE): - self._opener = opener - self._mode = mode - self._lock = threading.Lock() - self._file = None - - @contextlib.contextmanager - def acquire(self): with self._lock: - if self._file is None: - self._file = _open(self._opener, self._mode) - # file has already been created, don't override when restoring - if self._mode == 'w': - self._mode = 'a' - yield self._file + try: + file = FILE_CACHE[self._key] + except KeyError: + file = self._opener(*self._args, **self._kwargs) + if self._kwargs.get('mode') == 'w': + # ensure file doesn't get overriden when opened again + self._kwargs['mode'] = 'a' + self._key = _make_key( + self._opener, self._args, self._kwargs) + FILE_CACHE[self._key] = file + yield file def close(self): - if self._file is not None: - self._file.close() + """Explicitly close any associated file object (if necessary).""" + file = FILE_CACHE.pop(self._key, default=None) + if file is not None: + file.close() def __getstate__(self): - return {'opener': self._opener, 'mode': self._mode} + """State for pickling.""" + return (self._opener, self._args, self._kwargs) def __setstate__(self, state): - self.__init__(**state) + """Restore from a pickle.""" + opener, args, kwargs = state + self.__init__(opener, *args, **kwargs) -class AutoclosingFileManager(FileManager): - """A FileManager that automatically opens/closes files when used.""" +class _HashedSequence(list): + """Speedup repeated look-ups by caching hash values. - def __init__(self, opener, mode=_DEFAULT_MODE): - self._opener = opener - self._mode = mode - self._lock = threading.Lock() - self._file = None - self._references = 0 - - @contextlib.contextmanager - def acquire(self): - with self._lock: - if self._file is None: - self._file = _open(self._opener, self._mode) - # file has already been created, don't override when restoring - if self._mode == 'w': - self._mode = 'a' - self._references += 1 - - yield self._file + Based on what Python uses internally in functools.lru_cache. - with self._lock: - self._references -= 1 - if not self._references: - self._file.close() - self._file = None - - def close(self): - pass + Python doesn't perform this optimization automatically: + https://bugs.python.org/issue1462796 + """ - def __getstate__(self): - return {'opener': self._opener, 'mode': self._mode} + def __init__(self, tuple_value): + self[:] = tuple_value + self.hashvalue = hash(tuple_value) - def __setstate__(self, state): - self.__init__(**state) + def __hash__(self): + return self.hashvalue -# TODO: write a FileManager that makes use of an LRU cache. +def _make_key(opener, args, kwargs): + """Make a key for caching files in the LRU cache.""" + value = (opener, args, tuple(sorted(kwargs.items()))) + return _HashedSequence(value) diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py new file mode 100644 index 00000000000..f9b811710ec --- /dev/null +++ b/xarray/backends/lru_cache.py @@ -0,0 +1,78 @@ +import collections +import threading + + +class LRUCache(collections.MutableMapping): + """Thread-safe LRUCache based on an OrderedDict. + + All dict operations (__getitem__, __setitem__, __contains__) update the + priority of the relevant key and take O(1) time. The dict is iterated over + in order from the oldest to newest key, which means that a complete pass + over the dict should not affect the order of any entries. + + When a new item is set and the maximum size of the cache is exceeded, the + oldest item is dropped and called with ``on_evict(key, value)``. + + The ``maxsize`` property can be used to view or resize the capacity of + the cache. + """ + def __init__(self, maxsize, on_evict=None): + """ + Parameters + ---------- + maxsize : int + Integer maximum number of items to hold in the cache. + on_evict: callable, optional + Function to call like ``on_evict(key, value)`` when items are + evicted. + """ + self._maxsize = maxsize + self._on_evict = on_evict + self._cache = collections.OrderedDict() + self._lock = threading.RLock() + + def __getitem__(self, key): + # record recent use of the key by moving it to the front of the list + with self._lock: + value = self._cache[key] + self._cache.move_to_end(key) + return value + + def _shrink(self, capacity): + """Shrink the cache if necessary, evicting the oldest items.""" + while len(self._cache) > capacity: + key, value = self._cache.popitem(last=False) + if self._on_evict is not None: + self._on_evict(key, value) + + def __setitem__(self, key, value): + with self._lock: + if self._maxsize: + if key in self._cache: + self._cache.move_to_end(key) + elif len(self._cache) >= self._maxsize: + self._shrink(self._maxsize - 1) + self._cache[key] = value + + def __delitem__(self, key): + del self._cache[key] + + def __iter__(self): + # create a list, so accessing the cache during iteration cannot change + # the iteration order + return iter(list(self._cache)) + + def __len__(self): + return len(self._cache) + + @property + def maxsize(self): + """Maximum number of items can be held in the cache.""" + return self._maxsize + + @maxsize.setter + def maxsize(self, size): + """Resize the cache, evicting the oldest items if necessary.""" + with self._lock: + self._shrink(size) + self._maxsize = size diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 53bc0564bf7..40b1bd8dca2 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -10,7 +10,7 @@ from ..core import indexing from ..core.utils import is_scalar from .common import BackendArray -from .file_manager import LazyFileManager +from .file_manager import FileManager try: from dask.utils import SerializableLock as Lock @@ -199,8 +199,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, import rasterio - manager = LazyFileManager( - functools.partial(rasterio.open, filename), mode='r') + manager = FileManager(rasterio.open, filename, mode='r') if cache is None: cache = chunks is None diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py index 1693f84f44f..916d8848eb3 100644 --- a/xarray/tests/test_backends_file_manager.py +++ b/xarray/tests/test_backends_file_manager.py @@ -1,21 +1,11 @@ -import functools import pickle -import pytest +from xarray.backends.file_manager import FileManager, FILE_CACHE -from xarray.backends.file_manager import ( - ExplicitFileManager, LazyFileManager, AutoclosingFileManager -) -FILE_MANAGERS = [ - ExplicitFileManager, LazyFileManager, AutoclosingFileManager, -] - - -@pytest.mark.parametrize('manager_type', FILE_MANAGERS) -def test_file_manager_write_consecutive(tmpdir, manager_type): +def test_file_manager_write_consecutive(tmpdir): path = str(tmpdir.join('testing.txt')) - manager = manager_type(functools.partial(open, path), mode='w') + manager = FileManager(open, path, mode='w') with manager.acquire() as f: f.write('foo') with manager.acquire() as f: @@ -26,10 +16,9 @@ def test_file_manager_write_consecutive(tmpdir, manager_type): assert f.read() == 'foobar' -@pytest.mark.parametrize('manager_type', FILE_MANAGERS) -def test_file_manager_write_concurrent(tmpdir, manager_type): +def test_file_manager_write_concurrent(tmpdir): path = str(tmpdir.join('testing.txt')) - manager = manager_type(functools.partial(open, path), mode='w') + manager = FileManager(open, path, mode='w') with manager.acquire() as f1: with manager.acquire() as f2: f1.write('foo') @@ -40,11 +29,9 @@ def test_file_manager_write_concurrent(tmpdir, manager_type): assert f.read() == 'foobar' -@pytest.mark.parametrize('manager_type', FILE_MANAGERS) -def test_file_manager_write_pickle(tmpdir, manager_type): +def test_file_manager_write_pickle(tmpdir): path = str(tmpdir.join('testing.txt')) - manager = manager_type( - functools.partial(open, path), mode='w') + manager = FileManager(open, path, mode='w') with manager.acquire() as f: f.write('foo') f.flush() @@ -58,14 +45,16 @@ def test_file_manager_write_pickle(tmpdir, manager_type): assert f.read() == 'foobar' -@pytest.mark.parametrize('manager_type', FILE_MANAGERS) -def test_file_manager_read(tmpdir, manager_type): +def test_file_manager_read(tmpdir): path = str(tmpdir.join('testing.txt')) with open(path, 'w') as f: f.write('foobar') - manager = manager_type(functools.partial(open, path)) + manager = FileManager(open, path) with manager.acquire() as f: assert f.read() == 'foobar' manager.close() + + +# TODO(shoyer): add test coverage for exceeding the max size of the file cache From 23e132ff31a174ed0c158a103050e8698d55bfdc Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 10 Jul 2018 09:02:05 -0700 Subject: [PATCH 05/39] Test coverage --- xarray/backends/file_manager.py | 9 ++- xarray/backends/lru_cache.py | 22 ++++--- xarray/backends/rasterio_.py | 19 +++--- xarray/tests/test_backends_file_manager.py | 69 ++++++++++++++++++---- 4 files changed, 89 insertions(+), 30 deletions(-) diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index ae90d775015..8dc918a14e5 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -8,6 +8,7 @@ FILE_CACHE = LRUCache(512, on_evict=lambda k, v: v.close()) # TODO(shoyer): add an option (xarray.set_options) for resizing the cache. +# Note: the cache has a minimum size of one. class FileManager(object): @@ -48,9 +49,12 @@ def __init__(self, opener, *args, **kwargs): self._opener = opener self._args = args self._kwargs = kwargs - self._key = _make_key(opener, args, kwargs) + self._key = self._make_key() self._lock = threading.RLock() + def _make_key(self): + return _make_key(self._opener, self._args, self._kwargs) + @contextlib.contextmanager def acquire(self): """Context manager for acquiring a file object. @@ -74,8 +78,7 @@ def acquire(self): if self._kwargs.get('mode') == 'w': # ensure file doesn't get overriden when opened again self._kwargs['mode'] = 'a' - self._key = _make_key( - self._opener, self._args, self._kwargs) + self._key = self._make_key() FILE_CACHE[self._key] = file yield file diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py index f9b811710ec..c5ef515322c 100644 --- a/xarray/backends/lru_cache.py +++ b/xarray/backends/lru_cache.py @@ -13,8 +13,8 @@ class LRUCache(collections.MutableMapping): When a new item is set and the maximum size of the cache is exceeded, the oldest item is dropped and called with ``on_evict(key, value)``. - The ``maxsize`` property can be used to view or resize the capacity of - the cache. + The ``maxsize`` property can be used to view or adjust the capacity of + the cache, e.g., ``cache.maxsize = new_size``. """ def __init__(self, maxsize, on_evict=None): """ @@ -26,6 +26,8 @@ def __init__(self, maxsize, on_evict=None): Function to call like ``on_evict(key, value)`` when items are evicted. """ + if maxsize < 0: + raise ValueError('maxsize must be non-negative') self._maxsize = maxsize self._on_evict = on_evict self._cache = collections.OrderedDict() @@ -47,12 +49,16 @@ def _shrink(self, capacity): def __setitem__(self, key, value): with self._lock: - if self._maxsize: - if key in self._cache: - self._cache.move_to_end(key) - elif len(self._cache) >= self._maxsize: - self._shrink(self._maxsize - 1) + if key in self._cache: + self._cache.move_to_end(key) self._cache[key] = value + elif self._maxsize: + # make room if necessary + self._shrink(self._maxsize - 1) + self._cache[key] = value + elif self._on_evict is not None: + # not saving, immediately evict + self._on_evict(key, value) def __delitem__(self, key): del self._cache[key] @@ -73,6 +79,8 @@ def maxsize(self): @maxsize.setter def maxsize(self, size): """Resize the cache, evicting the oldest items if necessary.""" + if size < 0: + raise ValueError('maxsize must be non-negative') with self._lock: self._shrink(size) self._maxsize = size diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 40b1bd8dca2..13eb6e14a66 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -1,6 +1,5 @@ import os from collections import OrderedDict -import functools from distutils.version import LooseVersion import warnings @@ -27,16 +26,17 @@ class RasterioArrayWrapper(BackendArray): """A wrapper around rasterio dataset objects""" - def __init__(self, manager, riods): + def __init__(self, manager): self.manager = manager - # cannot save riods as an attribute: this would break pickleability. - self._shape = (riods.count, riods.height, riods.width) + with manager.acquire() as riods: + # cannot save riods as an attribute: this would break pickleability + self._shape = (riods.count, riods.height, riods.width) - dtypes = riods.dtypes - if not np.all(np.asarray(dtypes) == dtypes[0]): - raise ValueError('All bands should have the same dtype') - self._dtype = np.dtype(dtypes[0]) + dtypes = riods.dtypes + if not np.all(np.asarray(dtypes) == dtypes[0]): + raise ValueError('All bands should have the same dtype') + self._dtype = np.dtype(dtypes[0]) @property def dtype(self): @@ -287,8 +287,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, else: attrs[k] = v - data = indexing.LazilyOuterIndexedArray( - RasterioArrayWrapper(manager, riods)) + data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager)) # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py index 916d8848eb3..6fcca064890 100644 --- a/xarray/tests/test_backends_file_manager.py +++ b/xarray/tests/test_backends_file_manager.py @@ -1,9 +1,56 @@ import pickle +try: + from unittest import mock +except ImportError: + import mock # noqa: F401 + +import pytest from xarray.backends.file_manager import FileManager, FILE_CACHE +from xarray.core.pycompat import suppress + + +@pytest.fixture(scope='module', params=[1, 2, 3]) +def file_cache(request): + contents = FILE_CACHE.items() + maxsize = FILE_CACHE.maxsize + FILE_CACHE.clear() + FILE_CACHE.maxsize = request.param + yield FILE_CACHE + FILE_CACHE.maxsize = maxsize + FILE_CACHE.clear() + FILE_CACHE.update(contents) + + +def test_file_manager_mock_write(file_cache): + mock_file = mock.Mock() + opener = mock.Mock(return_value=mock_file) + + manager = FileManager(opener, 'filename') + with manager.acquire() as f: + f.write('contents') + manager.close() + + opener.assert_called_once_with('filename') + mock_file.write.assert_called_once_with('contents') + mock_file.close.assert_called_once_with() -def test_file_manager_write_consecutive(tmpdir): +def test_file_manager_mock_error(file_cache): + mock_file = mock.Mock() + opener = mock.Mock(return_value=mock_file) + + manager = FileManager(opener, 'mydata') + with suppress(ValueError): + with manager.acquire(): + raise ValueError + manager.close() + + opener.assert_called_once_with('mydata') + mock_file.close.assert_called_once_with() + + +def test_file_manager_write_consecutive(tmpdir, file_cache): path = str(tmpdir.join('testing.txt')) manager = FileManager(open, path, mode='w') with manager.acquire() as f: @@ -16,20 +63,25 @@ def test_file_manager_write_consecutive(tmpdir): assert f.read() == 'foobar' -def test_file_manager_write_concurrent(tmpdir): +def test_file_manager_write_concurrent(tmpdir, file_cache): path = str(tmpdir.join('testing.txt')) manager = FileManager(open, path, mode='w') with manager.acquire() as f1: with manager.acquire() as f2: - f1.write('foo') - f2.write('bar') + with manager.acquire() as f3: + f1.write('foo') + f1.flush() + f2.write('bar') + f2.flush() + f3.write('baz') + f3.flush() manager.close() with open(path, 'r') as f: - assert f.read() == 'foobar' + assert f.read() == 'foobarbaz' -def test_file_manager_write_pickle(tmpdir): +def test_file_manager_write_pickle(tmpdir, file_cache): path = str(tmpdir.join('testing.txt')) manager = FileManager(open, path, mode='w') with manager.acquire() as f: @@ -45,7 +97,7 @@ def test_file_manager_write_pickle(tmpdir): assert f.read() == 'foobar' -def test_file_manager_read(tmpdir): +def test_file_manager_read(tmpdir, file_cache): path = str(tmpdir.join('testing.txt')) with open(path, 'w') as f: @@ -55,6 +107,3 @@ def test_file_manager_read(tmpdir): with manager.acquire() as f: assert f.read() == 'foobar' manager.close() - - -# TODO(shoyer): add test coverage for exceeding the max size of the file cache From 8fc81837f613b5e2cab64fd411838d1e5e0343c9 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 10 Jul 2018 11:17:43 -0700 Subject: [PATCH 06/39] Don't use move_to_end --- xarray/backends/lru_cache.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py index c5ef515322c..4b180644929 100644 --- a/xarray/backends/lru_cache.py +++ b/xarray/backends/lru_cache.py @@ -37,7 +37,9 @@ def __getitem__(self, key): # record recent use of the key by moving it to the front of the list with self._lock: value = self._cache[key] - self._cache.move_to_end(key) + # On Python 3, could just use: self._cache.move_to_end(key) + del self._cache[key] + self._cache[key] = value return value def _shrink(self, capacity): @@ -50,7 +52,7 @@ def _shrink(self, capacity): def __setitem__(self, key, value): with self._lock: if key in self._cache: - self._cache.move_to_end(key) + del self._cache[key] self._cache[key] = value elif self._maxsize: # make room if necessary From 422944bce5d9aef6ef091ce99f7678e2aa9163ed Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 10 Jul 2018 11:25:21 -0700 Subject: [PATCH 07/39] minor clarification --- xarray/backends/lru_cache.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py index 4b180644929..02411408abd 100644 --- a/xarray/backends/lru_cache.py +++ b/xarray/backends/lru_cache.py @@ -42,7 +42,7 @@ def __getitem__(self, key): self._cache[key] = value return value - def _shrink(self, capacity): + def _maybe_shrink(self, capacity): """Shrink the cache if necessary, evicting the oldest items.""" while len(self._cache) > capacity: key, value = self._cache.popitem(last=False) @@ -52,11 +52,12 @@ def _shrink(self, capacity): def __setitem__(self, key, value): with self._lock: if key in self._cache: + # insert the new value at the end del self._cache[key] self._cache[key] = value elif self._maxsize: # make room if necessary - self._shrink(self._maxsize - 1) + self._maybe_shrink(self._maxsize - 1) self._cache[key] = value elif self._on_evict is not None: # not saving, immediately evict @@ -84,5 +85,5 @@ def maxsize(self, size): if size < 0: raise ValueError('maxsize must be non-negative') with self._lock: - self._shrink(size) + self._maybe_shrink(size) self._maxsize = size From aea0a1a58ff7f1fd9e3a9b01fbca79e4691d1511 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 10 Jul 2018 19:39:18 -0700 Subject: [PATCH 08/39] Switch FileManager.acquire() to a method --- xarray/backends/file_manager.py | 105 +++++++----- xarray/backends/lru_cache.py | 6 +- xarray/backends/rasterio_.py | 183 +++++++++++---------- xarray/tests/test_backends_file_manager.py | 109 ++++++------ 4 files changed, 215 insertions(+), 188 deletions(-) diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 8dc918a14e5..196b49c5cd5 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -1,6 +1,6 @@ -import contextlib import threading +from ..core import utils from .lru_cache import LRUCache @@ -11,6 +11,9 @@ # Note: the cache has a minimum size of one. +_DEFAULT_MODE = utils.ReprObject('') + + class FileManager(object): """Wrapper for automatically opening and closing file objects. @@ -19,15 +22,23 @@ class FileManager(object): a per-process least-recently-used cache for open files ensures that you can safely create arbitrarily large numbers of FileManager objects. + Don't directly close files acquired from a FileManager. Instead, call + FileManager.close(), which ensures that closed files are removed from the + cache as well. + Example usage: manager = FileManager(open, 'example.txt', mode='w') - with manager.acquire() as f: - f.write(...) - manager.close() + f = manager.acquire() + f.write(...) + manager.close() # ensures file is closed """ - def __init__(self, opener, *args, **kwargs): + def __init__(self, opener, *args, + mode=_DEFAULT_MODE, + kwargs=None, + lock=None, + cache=FILE_CACHE): """Initialize a FileManager. Parameters @@ -38,26 +49,46 @@ def __init__(self, opener, *args, **kwargs): method. *args Positional arguments for opener. A ``mode`` argument should be - provided as a keyword argument (see below). - **kwargs - Keyword arguments for opener. The keyword argument ``mode`` has - special handling if it is provided with a value of 'w': on all - calls after the first, it is changed to 'a' instead to avoid - overriding the newly created file. All argument values must be + provided as a keyword argument (see below). All arguments must be hashable. + mode : optional + If provided, passed as a keyword argument to ``opener`` along with + ``**kwargs``. ``mode='w' `` has special treatment: after the first + call it is replaced by ``mode='a'`` in all subsequent function to + avoid overriding the newly created file. + kwargs : dict, optional + Keyword arguments for opener, excluding ``mode``. All values must + be hashable. + lock : duck-compatible threading.Lock, optional + Lock to use when modifying the cache inside acquire() and close(). + By default, uses a new threading.Lock() object. If set, this object + should be pickleable. + cache : MutableMapping, optional + Mapping to use as a cache for open files. By default, uses xarray's + global LRU file cache. Because ``cache`` typically points to a + global variable and contains non-picklable file objects, an + unpickled FileManager objects will be restored with the default + cache. """ self._opener = opener self._args = args - self._kwargs = kwargs + self._mode = mode + self._kwargs = {} if kwargs is None else dict(kwargs) + self._default_lock = lock is None + self._lock = threading.Lock() if self._default_lock else lock + self._cache = cache self._key = self._make_key() - self._lock = threading.RLock() def _make_key(self): - return _make_key(self._opener, self._args, self._kwargs) + """Make a key for caching files in the LRU cache.""" + value = (self._opener, + self._args, + self._mode, + tuple(sorted(self._kwargs.items()))) + return _HashedSequence(value) - @contextlib.contextmanager def acquire(self): - """Context manager for acquiring a file object. + """Acquiring a file object from the manager. A new file is only opened if it has expired from the least-recently-used cache. @@ -66,36 +97,42 @@ def acquire(self): thread-safe. You can safely acquire a file in multiple threads at the same time, as long as the underlying file object is thread-safe. - Yields - ------ - Open file object, as returned by ``opener(*args, **kwargs)``. + Returns + ------- + An open file object, as returned by ``opener(*args, **kwargs)``. """ with self._lock: try: - file = FILE_CACHE[self._key] + file = self._cache[self._key] except KeyError: - file = self._opener(*self._args, **self._kwargs) - if self._kwargs.get('mode') == 'w': + kwargs = self._kwargs + if self._mode is not _DEFAULT_MODE: + kwargs = kwargs.copy() + kwargs['mode'] = self._mode + file = self._opener(*self._args, **kwargs) + if self._mode == 'w': # ensure file doesn't get overriden when opened again - self._kwargs['mode'] = 'a' + self._mode = 'a' self._key = self._make_key() - FILE_CACHE[self._key] = file - yield file + self._cache[self._key] = file + return file def close(self): """Explicitly close any associated file object (if necessary).""" - file = FILE_CACHE.pop(self._key, default=None) - if file is not None: - file.close() + with self._lock: + file = self._cache.pop(self._key, default=None) + if file is not None: + file.close() def __getstate__(self): """State for pickling.""" - return (self._opener, self._args, self._kwargs) + lock = None if self._default_lock else self._lock + return (self._opener, self._args, self._mode, self._kwargs, lock) def __setstate__(self, state): """Restore from a pickle.""" - opener, args, kwargs = state - self.__init__(opener, *args, **kwargs) + opener, args, mode, kwargs, lock = state + self.__init__(opener, *args, mode=mode, kwargs=kwargs, lock=lock) class _HashedSequence(list): @@ -113,9 +150,3 @@ def __init__(self, tuple_value): def __hash__(self): return self.hashvalue - - -def _make_key(opener, args, kwargs): - """Make a key for caching files in the LRU cache.""" - value = (opener, args, tuple(sorted(kwargs.items()))) - return _HashedSequence(value) diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py index 02411408abd..59f896c0101 100644 --- a/xarray/backends/lru_cache.py +++ b/xarray/backends/lru_cache.py @@ -42,7 +42,7 @@ def __getitem__(self, key): self._cache[key] = value return value - def _maybe_shrink(self, capacity): + def _enforce_size_limit(self, capacity): """Shrink the cache if necessary, evicting the oldest items.""" while len(self._cache) > capacity: key, value = self._cache.popitem(last=False) @@ -57,7 +57,7 @@ def __setitem__(self, key, value): self._cache[key] = value elif self._maxsize: # make room if necessary - self._maybe_shrink(self._maxsize - 1) + self._enforce_size_limit(self._maxsize - 1) self._cache[key] = value elif self._on_evict is not None: # not saving, immediately evict @@ -85,5 +85,5 @@ def maxsize(self, size): if size < 0: raise ValueError('maxsize must be non-negative') with self._lock: - self._maybe_shrink(size) + self._enforce_size_limit(size) self._maxsize = size diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 13eb6e14a66..d037cad27bd 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -29,14 +29,15 @@ class RasterioArrayWrapper(BackendArray): def __init__(self, manager): self.manager = manager - with manager.acquire() as riods: - # cannot save riods as an attribute: this would break pickleability - self._shape = (riods.count, riods.height, riods.width) + # cannot save riods as an attribute: this would break pickleability + riods = manager.acquire() - dtypes = riods.dtypes - if not np.all(np.asarray(dtypes) == dtypes[0]): - raise ValueError('All bands should have the same dtype') - self._dtype = np.dtype(dtypes[0]) + self._shape = (riods.count, riods.height, riods.width) + + dtypes = riods.dtypes + if not np.all(np.asarray(dtypes) == dtypes[0]): + raise ValueError('All bands should have the same dtype') + self._dtype = np.dtype(dtypes[0]) @property def dtype(self): @@ -108,8 +109,8 @@ def _get_indexer(self, key): def __getitem__(self, key): band_key, window, squeeze_axis, np_inds = self._get_indexer(key) - with self.manager.acquire() as riods: - out = riods.read(band_key, window=tuple(window)) + riods = self.manager.acquire() + out = riods.read(band_key, window=tuple(window)) if squeeze_axis: out = np.squeeze(out, axis=squeeze_axis) return indexing.NumpyIndexingAdapter(out)[np_inds] @@ -200,94 +201,94 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, import rasterio manager = FileManager(rasterio.open, filename, mode='r') + riods = manager.acquire() if cache is None: cache = chunks is None coords = OrderedDict() - with manager.acquire() as riods: - # Get bands - if riods.count < 1: - raise ValueError('Unknown dims') - coords['band'] = np.asarray(riods.indexes) - - # Get coordinates - if LooseVersion(rasterio.__version__) < '1.0': - transform = riods.affine - else: - transform = riods.transform - if transform.is_rectilinear: - # 1d coordinates - parse = True if parse_coordinates is None else parse_coordinates - if parse: - nx, ny = riods.width, riods.height - # xarray coordinates are pixel centered - x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform - _, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform - coords['y'] = y - coords['x'] = x - else: - # 2d coordinates - parse = False if (parse_coordinates is None) else parse_coordinates - if parse: - warnings.warn( - "The file coordinates' transformation isn't " - "rectilinear: xarray won't parse the coordinates " - "in this case. Set `parse_coordinates=False` to " - "suppress this warning.", - RuntimeWarning, stacklevel=3) - - # Attributes - attrs = dict() - # Affine transformation matrix (always available) - # This describes coefficients mapping pixel coordinates to CRS - # For serialization store as tuple of 6 floats, the last row being - # always (0, 0, 1) per definition (see - # https://github.com/sgillies/affine) - attrs['transform'] = tuple(transform)[:6] - if hasattr(riods, 'crs') and riods.crs: - # CRS is a dict-like object specific to rasterio - # If CRS is not None, we convert it back to a PROJ4 string using - # rasterio itself - attrs['crs'] = riods.crs.to_string() - if hasattr(riods, 'res'): - # (width, height) tuple of pixels in units of CRS - attrs['res'] = riods.res - if hasattr(riods, 'is_tiled'): - # Is the TIF tiled? (bool) - # We cast it to an int for netCDF compatibility - attrs['is_tiled'] = np.uint8(riods.is_tiled) - with warnings.catch_warnings(): - # casting riods.transform to a tuple makes this future proof - warnings.simplefilter('ignore', FutureWarning) - if hasattr(riods, 'transform'): - # Affine transformation matrix (tuple of floats) - # Describes coefficients mapping pixel coordinates to CRS - attrs['transform'] = tuple(riods.transform) - if hasattr(riods, 'nodatavals'): - # The nodata values for the raster bands - attrs['nodatavals'] = tuple( - np.nan if nodataval is None else nodataval - for nodataval in riods.nodatavals) - - # Parse extra metadata from tags, if supported - parsers = {'ENVI': _parse_envi} - - driver = riods.driver - if driver in parsers: - meta = parsers[driver](riods.tags(ns=driver)) - - for k, v in meta.items(): - # Add values as coordinates if they match the band count, - # as attributes otherwise - if (isinstance(v, (list, np.ndarray)) and - len(v) == riods.count): - coords[k] = ('band', np.asarray(v)) - else: - attrs[k] = v - - data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager)) + # Get bands + if riods.count < 1: + raise ValueError('Unknown dims') + coords['band'] = np.asarray(riods.indexes) + + # Get coordinates + if LooseVersion(rasterio.__version__) < '1.0': + transform = riods.affine + else: + transform = riods.transform + if transform.is_rectilinear: + # 1d coordinates + parse = True if parse_coordinates is None else parse_coordinates + if parse: + nx, ny = riods.width, riods.height + # xarray coordinates are pixel centered + x, _ = (np.arange(nx) + 0.5, np.zeros(nx) + 0.5) * transform + _, y = (np.zeros(ny) + 0.5, np.arange(ny) + 0.5) * transform + coords['y'] = y + coords['x'] = x + else: + # 2d coordinates + parse = False if (parse_coordinates is None) else parse_coordinates + if parse: + warnings.warn( + "The file coordinates' transformation isn't " + "rectilinear: xarray won't parse the coordinates " + "in this case. Set `parse_coordinates=False` to " + "suppress this warning.", + RuntimeWarning, stacklevel=3) + + # Attributes + attrs = dict() + # Affine transformation matrix (always available) + # This describes coefficients mapping pixel coordinates to CRS + # For serialization store as tuple of 6 floats, the last row being + # always (0, 0, 1) per definition (see + # https://github.com/sgillies/affine) + attrs['transform'] = tuple(transform)[:6] + if hasattr(riods, 'crs') and riods.crs: + # CRS is a dict-like object specific to rasterio + # If CRS is not None, we convert it back to a PROJ4 string using + # rasterio itself + attrs['crs'] = riods.crs.to_string() + if hasattr(riods, 'res'): + # (width, height) tuple of pixels in units of CRS + attrs['res'] = riods.res + if hasattr(riods, 'is_tiled'): + # Is the TIF tiled? (bool) + # We cast it to an int for netCDF compatibility + attrs['is_tiled'] = np.uint8(riods.is_tiled) + with warnings.catch_warnings(): + # casting riods.transform to a tuple makes this future proof + warnings.simplefilter('ignore', FutureWarning) + if hasattr(riods, 'transform'): + # Affine transformation matrix (tuple of floats) + # Describes coefficients mapping pixel coordinates to CRS + attrs['transform'] = tuple(riods.transform) + if hasattr(riods, 'nodatavals'): + # The nodata values for the raster bands + attrs['nodatavals'] = tuple( + np.nan if nodataval is None else nodataval + for nodataval in riods.nodatavals) + + # Parse extra metadata from tags, if supported + parsers = {'ENVI': _parse_envi} + + driver = riods.driver + if driver in parsers: + meta = parsers[driver](riods.tags(ns=driver)) + + for k, v in meta.items(): + # Add values as coordinates if they match the band count, + # as attributes otherwise + if (isinstance(v, (list, np.ndarray)) and + len(v) == riods.count): + coords[k] = ('band', np.asarray(v)) + else: + attrs[k] = v + + data = indexing.LazilyOuterIndexedArray(RasterioArrayWrapper(manager)) # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py index 6fcca064890..2246a1015e0 100644 --- a/xarray/tests/test_backends_file_manager.py +++ b/xarray/tests/test_backends_file_manager.py @@ -1,4 +1,5 @@ import pickle +import threading try: from unittest import mock except ImportError: @@ -6,75 +7,69 @@ import pytest -from xarray.backends.file_manager import FileManager, FILE_CACHE -from xarray.core.pycompat import suppress +from xarray.backends.file_manager import FileManager +from xarray.backends.lru_cache import LRUCache -@pytest.fixture(scope='module', params=[1, 2, 3]) +@pytest.fixture(params=[1, 2, 3]) def file_cache(request): - contents = FILE_CACHE.items() - maxsize = FILE_CACHE.maxsize - FILE_CACHE.clear() - FILE_CACHE.maxsize = request.param - yield FILE_CACHE - FILE_CACHE.maxsize = maxsize - FILE_CACHE.clear() - FILE_CACHE.update(contents) + yield LRUCache(maxsize=request.param) def test_file_manager_mock_write(file_cache): mock_file = mock.Mock() - opener = mock.Mock(return_value=mock_file) + opener = mock.Mock(spec=open, return_value=mock_file) + lock = mock.MagicMock(spec=threading.Lock()) - manager = FileManager(opener, 'filename') - with manager.acquire() as f: - f.write('contents') + manager = FileManager(opener, 'filename', lock=lock, cache=file_cache) + f = manager.acquire() + f.write('contents') manager.close() + assert not file_cache opener.assert_called_once_with('filename') mock_file.write.assert_called_once_with('contents') mock_file.close.assert_called_once_with() - - -def test_file_manager_mock_error(file_cache): - mock_file = mock.Mock() - opener = mock.Mock(return_value=mock_file) - - manager = FileManager(opener, 'mydata') - with suppress(ValueError): - with manager.acquire(): - raise ValueError - manager.close() - - opener.assert_called_once_with('mydata') - mock_file.close.assert_called_once_with() + lock.__enter__.assert_has_calls([mock.call(), mock.call()]) def test_file_manager_write_consecutive(tmpdir, file_cache): - path = str(tmpdir.join('testing.txt')) - manager = FileManager(open, path, mode='w') - with manager.acquire() as f: - f.write('foo') - with manager.acquire() as f: - f.write('bar') - manager.close() + path1 = str(tmpdir.join('testing1.txt')) + path2 = str(tmpdir.join('testing2.txt')) + manager1 = FileManager(open, path1, mode='w', cache=file_cache) + manager2 = FileManager(open, path2, mode='w', cache=file_cache) + f1a = manager1.acquire() + f1a.write('foo') + f1a.flush() + f2 = manager2.acquire() + f2.write('bar') + f2.flush() + f1b = manager1.acquire() + f1b.write('baz') + assert (file_cache.maxsize > 1) == (f1a is f1b) + manager1.close() + manager2.close() - with open(path, 'r') as f: - assert f.read() == 'foobar' + with open(path1, 'r') as f: + assert f.read() == 'foobaz' + with open(path2, 'r') as f: + assert f.read() == 'bar' def test_file_manager_write_concurrent(tmpdir, file_cache): path = str(tmpdir.join('testing.txt')) - manager = FileManager(open, path, mode='w') - with manager.acquire() as f1: - with manager.acquire() as f2: - with manager.acquire() as f3: - f1.write('foo') - f1.flush() - f2.write('bar') - f2.flush() - f3.write('baz') - f3.flush() + manager = FileManager(open, path, mode='w', cache=file_cache) + f1 = manager.acquire() + f2 = manager.acquire() + f3 = manager.acquire() + assert f1 is f2 + assert f2 is f3 + f1.write('foo') + f1.flush() + f2.write('bar') + f2.flush() + f3.write('baz') + f3.flush() manager.close() with open(path, 'r') as f: @@ -83,13 +78,13 @@ def test_file_manager_write_concurrent(tmpdir, file_cache): def test_file_manager_write_pickle(tmpdir, file_cache): path = str(tmpdir.join('testing.txt')) - manager = FileManager(open, path, mode='w') - with manager.acquire() as f: - f.write('foo') - f.flush() + manager = FileManager(open, path, mode='w', cache=file_cache) + f = manager.acquire() + f.write('foo') + f.flush() manager2 = pickle.loads(pickle.dumps(manager)) - with manager2.acquire() as f: - f.write('bar') + f2 = manager2.acquire() + f2.write('bar') manager2.close() manager.close() @@ -103,7 +98,7 @@ def test_file_manager_read(tmpdir, file_cache): with open(path, 'w') as f: f.write('foobar') - manager = FileManager(open, path) - with manager.acquire() as f: - assert f.read() == 'foobar' + manager = FileManager(open, path, cache=file_cache) + f = manager.acquire() + assert f.read() == 'foobar' manager.close() From 4366c0b17c9805b0f86af55ac24b22c89e28ab5f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 10 Jul 2018 19:53:46 -0700 Subject: [PATCH 09/39] Python 2 compat --- xarray/backends/file_manager.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 196b49c5cd5..850d415a3bf 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -34,11 +34,7 @@ class FileManager(object): manager.close() # ensures file is closed """ - def __init__(self, opener, *args, - mode=_DEFAULT_MODE, - kwargs=None, - lock=None, - cache=FILE_CACHE): + def __init__(self, opener, *args, **keywords): """Initialize a FileManager. Parameters @@ -70,6 +66,12 @@ def __init__(self, opener, *args, unpickled FileManager objects will be restored with the default cache. """ + # TODO: replace with real keyword arguments when we drop Python 2 + # support + mode = keywords.pop('mode', _DEFAULT_MODE) + kwargs = keywords.pop('kwargs', None) + lock = keywords.pop('lock', None) + cache = keywords.pop('cache', FILE_CACHE) self._opener = opener self._args = args self._mode = mode From f35b7e73bea5c9d20f6a344ead18d3380de6e515 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 11 Jul 2018 10:08:34 -0500 Subject: [PATCH 10/39] Update xarray.set_options() to add file_cache_maxsize and validation --- xarray/backends/file_manager.py | 6 ++-- xarray/core/options.py | 62 ++++++++++++++++++++++++++++----- xarray/tests/test_options.py | 33 ++++++++++++++++++ 3 files changed, 89 insertions(+), 12 deletions(-) diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 850d415a3bf..a57c4603951 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -1,14 +1,14 @@ import threading from ..core import utils +from ..core.options import OPTIONS from .lru_cache import LRUCache # Global cache for storing open files. -FILE_CACHE = LRUCache(512, on_evict=lambda k, v: v.close()) - -# TODO(shoyer): add an option (xarray.set_options) for resizing the cache. # Note: the cache has a minimum size of one. +FILE_CACHE = LRUCache( + OPTIONS['file_cache_maxsize'], on_evict=lambda k, v: v.close()) _DEFAULT_MODE = utils.ReprObject('') diff --git a/xarray/core/options.py b/xarray/core/options.py index 48d4567fc99..47d6128a4eb 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -1,9 +1,39 @@ from __future__ import absolute_import, division, print_function +DISPLAY_WIDTH = 'display_width' +ARITHMETIC_JOIN = 'arithmetic_join' +ENABLE_CFTIMEINDEX = 'enable_cftimeindex' +FILE_CACHE_MAXSIZE = 'file_cache_maxsize' + OPTIONS = { - 'display_width': 80, - 'arithmetic_join': 'inner', - 'enable_cftimeindex': False + DISPLAY_WIDTH: 80, + ARITHMETIC_JOIN: 'inner', + ENABLE_CFTIMEINDEX: False, + FILE_CACHE_MAXSIZE: 512, +} + +_JOIN_OPTIONS = frozenset(['inner', 'outer', 'left', 'right', 'exact']) + + +def _positive_integer(value): + return isinstance(value, int) and value > 0 + + +_VALIDATORS = { + DISPLAY_WIDTH: _positive_integer, + ARITHMETIC_JOIN: _JOIN_OPTIONS.__contains__, + ENABLE_CFTIMEINDEX: lambda value: isinstance(value, bool), + FILE_CACHE_MAXSIZE: _positive_integer, +} + + +def _set_file_cache_maxsize(value): + from ..backends.file_manager import FILE_CACHE + FILE_CACHE.maxsize = value + + +_SETTERS = { + FILE_CACHE_MAXSIZE: _set_file_cache_maxsize, } @@ -19,6 +49,10 @@ class set_options(object): - ``enable_cftimeindex``: flag to enable using a ``CFTimeIndex`` for time indexes with non-standard calendars or dates outside the Timestamp-valid range. Default: ``False``. + - ``file_cache_maxsize``: maximum number of open files to hold in xarray's + global least-recently-usage cached. This should be smaller than your + system's per-process file descriptor limit, e.g., ``ulimit -n`` on Linux. + Default: 512. You can use ``set_options`` either as a context manager: @@ -38,16 +72,26 @@ class set_options(object): """ def __init__(self, **kwargs): - invalid_options = {k for k in kwargs if k not in OPTIONS} - if invalid_options: - raise ValueError('argument names %r are not in the set of valid ' - 'options %r' % (invalid_options, set(OPTIONS))) self.old = OPTIONS.copy() - OPTIONS.update(kwargs) + for k, v in kwargs.items(): + if k not in OPTIONS: + raise ValueError( + 'argument name %r is not in the set of valid options %r' + % (k, set(OPTIONS))) + if k in _VALIDATORS and not _VALIDATORS[k](v): + raise ValueError( + 'option %r given an invalid value: %r' % (k, v)) + self._apply_update(kwargs) + + def _apply_update(self, options_dict): + for k, v in options_dict.items(): + if k in _SETTERS: + _SETTERS[k](v) + OPTIONS.update(options_dict) def __enter__(self): return def __exit__(self, type, value, traceback): OPTIONS.clear() - OPTIONS.update(self.old) + self._apply_update(self.old) diff --git a/xarray/tests/test_options.py b/xarray/tests/test_options.py index aed96f1acb6..4441375a1b1 100644 --- a/xarray/tests/test_options.py +++ b/xarray/tests/test_options.py @@ -4,6 +4,7 @@ import xarray from xarray.core.options import OPTIONS +from xarray.backends.file_manager import FILE_CACHE def test_invalid_option_raises(): @@ -11,6 +12,38 @@ def test_invalid_option_raises(): xarray.set_options(not_a_valid_options=True) +def test_display_width(): + with pytest.raises(ValueError): + xarray.set_options(display_width=0) + with pytest.raises(ValueError): + xarray.set_options(display_width=-10) + with pytest.raises(ValueError): + xarray.set_options(display_width=3.5) + + +def test_arithmetic_join(): + with pytest.raises(ValueError): + xarray.set_options(arithmetic_join='invalid') + with xarray.set_options(arithmetic_join='exact'): + assert OPTIONS['arithmetic_join'] == 'exact' + + +def test_enable_cftimeindex(): + with pytest.raises(ValueError): + xarray.set_options(enable_cftimeindex=None) + with xarray.set_options(enable_cftimeindex=True): + assert OPTIONS['enable_cftimeindex'] + + +def test_file_cache_maxsize(): + with pytest.raises(ValueError): + xarray.set_options(file_cache_maxsize=0) + original_size = FILE_CACHE.maxsize + with xarray.set_options(file_cache_maxsize=123): + assert FILE_CACHE.maxsize == 123 + assert FILE_CACHE.maxsize == original_size + + def test_nested_options(): original = OPTIONS['display_width'] with xarray.set_options(display_width=1): From 057cad21ff2ebc93f9627a99bc2809a1a628cbea Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 11 Jul 2018 10:41:35 -0500 Subject: [PATCH 11/39] Add assert for FILE_CACHE.maxsize --- xarray/backends/file_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index a57c4603951..cd09dadc53a 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -6,9 +6,9 @@ # Global cache for storing open files. -# Note: the cache has a minimum size of one. FILE_CACHE = LRUCache( OPTIONS['file_cache_maxsize'], on_evict=lambda k, v: v.close()) +assert FILE_CACHE.maxsize, 'file cache must be at least size one' _DEFAULT_MODE = utils.ReprObject('') From 0f3e656c69f6d686d5fb15d2b1989e36d89f4fee Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 11 Jul 2018 11:00:44 -0500 Subject: [PATCH 12/39] More docstring for FileManager --- xarray/backends/file_manager.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index cd09dadc53a..d62025a7ec9 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -32,6 +32,14 @@ class FileManager(object): f = manager.acquire() f.write(...) manager.close() # ensures file is closed + + Note that as long as previous files are still cached, acquiring a file + multiple times from the same FileManager is essentially free: + + f1 = manager.acquire() + f2 = manager.acquire() + assert f1 is f2 + """ def __init__(self, opener, *args, **keywords): From 1a0cc104c55d19b25aa07f1e4d87a28ef1d173af Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 11 Jul 2018 11:02:47 -0500 Subject: [PATCH 13/39] Add accidentally omited tests for LRUCache --- xarray/tests/test_backends_lru_cache.py | 89 +++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 xarray/tests/test_backends_lru_cache.py diff --git a/xarray/tests/test_backends_lru_cache.py b/xarray/tests/test_backends_lru_cache.py new file mode 100644 index 00000000000..e8302c55c9e --- /dev/null +++ b/xarray/tests/test_backends_lru_cache.py @@ -0,0 +1,89 @@ +try: + from unittest import mock +except ImportError: + import mock # noqa: F401 + +import pytest + +from xarray.backends.lru_cache import LRUCache + + +def test_simple(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + + assert cache['x'] == 1 + assert cache['y'] == 2 + assert len(cache) == 2 + assert dict(cache) == {'x': 1, 'y': 2} + assert list(cache.keys()) == ['x', 'y'] + assert list(cache.items()) == [('x', 1), ('y', 2)] + + cache['z'] = 3 + assert len(cache) == 2 + assert list(cache.items()) == [('y', 2), ('z', 3)] + + +def test_trivial(): + cache = LRUCache(maxsize=0) + cache['x'] = 1 + assert len(cache) == 0 + + +def test_invalid(): + with pytest.raises(ValueError): + LRUCache(maxsize=-1) + + +def test_update_priority(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + assert list(cache) == ['x', 'y'] + assert 'x' in cache # contains + assert list(cache) == ['y', 'x'] + assert cache['y'] == 2 # getitem + assert list(cache) == ['x', 'y'] + cache['x'] = 3 # setitem + assert list(cache.items()) == [('y', 2), ('x', 3)] + + +def test_del(): + cache = LRUCache(maxsize=2) + cache['x'] = 1 + cache['y'] = 2 + del cache['x'] + assert dict(cache) == {'y': 2} + + +def test_on_evict(): + on_evict = mock.Mock() + cache = LRUCache(maxsize=1, on_evict=on_evict) + cache['x'] = 1 + cache['y'] = 2 + on_evict.assert_called_once_with('x', 1) + + +def test_on_evict_trivial(): + on_evict = mock.Mock() + cache = LRUCache(maxsize=0, on_evict=on_evict) + cache['x'] = 1 + on_evict.assert_called_once_with('x', 1) + + +def test_resize(): + cache = LRUCache(maxsize=2) + assert cache.maxsize == 2 + cache['w'] = 0 + cache['x'] = 1 + cache['y'] = 2 + assert list(cache.items()) == [('x', 1), ('y', 2)] + cache.maxsize = 10 + cache['z'] = 3 + assert list(cache.items()) == [('x', 1), ('y', 2), ('z', 3)] + cache.maxsize = 1 + assert list(cache.items()) == [('z', 3)] + + with pytest.raises(ValueError): + cache.maxsize = -1 From 83d9b10297e0a1a0398a668d7300830f8e0ef822 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 28 Jul 2018 16:57:35 -0700 Subject: [PATCH 14/39] Adapt scipy backend to use FileManager --- xarray/backends/common.py | 4 +- xarray/backends/scipy_.py | 116 ++++++++++++----------------- xarray/tests/test_backends.py | 134 ++++++++++++---------------------- 3 files changed, 94 insertions(+), 160 deletions(-) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index ab6b3fac46b..04be2feef37 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -176,7 +176,7 @@ def __array__(self, dtype=None): class AbstractDataStore(Mapping): _autoclose = None _ds = None - _isopen = False + _isopen = None def __iter__(self): return iter(self.variables) @@ -330,7 +330,7 @@ def set_variable(self, k, v): # pragma: no cover raise NotImplementedError def sync(self, compute=True): - if self._isopen and self._autoclose: + if self._isopen is not None and self._isopen and self._autoclose: # datastore will be reopened during write self.close() self.delayed_store = self.writer.sync(compute=compute) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index cd84431f6b7..bb1fcada1bf 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -1,6 +1,5 @@ from __future__ import absolute_import, division, print_function -import functools import warnings from distutils.version import LooseVersion from io import BytesIO @@ -11,7 +10,8 @@ from ..core.indexing import NumpyIndexingAdapter from ..core.pycompat import OrderedDict, basestring, iteritems from ..core.utils import Frozen, FrozenOrderedDict -from .common import BackendArray, DataStorePickleMixin, WritableCFDataStore +from .common import BackendArray, WritableCFDataStore +from .file_manager import FileManager from .netcdf3 import ( encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name) @@ -40,31 +40,28 @@ def __init__(self, variable_name, datastore): str(array.dtype.itemsize)) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name].data def __getitem__(self, key): - with self.datastore.ensure_open(autoclose=True): - data = NumpyIndexingAdapter(self.get_array())[key] - # Copy data if the source file is mmapped. - # This makes things consistent - # with the netCDF4 library by ensuring - # we can safely read arrays even - # after closing associated files. - copy = self.datastore.ds.use_mmap - return np.array(data, dtype=self.dtype, copy=copy) + data = NumpyIndexingAdapter(self.get_array())[key] + # Copy data if the source file is mmapped. + # This makes things consistent + # with the netCDF4 library by ensuring + # we can safely read arrays even + # after closing associated files. + copy = self.datastore.ds.use_mmap + return np.array(data, dtype=self.dtype, copy=copy) def __setitem__(self, key, value): - with self.datastore.ensure_open(autoclose=True): - data = self.datastore.ds.variables[self.variable_name] - try: - data[key] = value - except TypeError: - if key is Ellipsis: - # workaround for GH: scipy/scipy#6880 - data[:] = value - else: - raise + data = self.datastore.ds.variables[self.variable_name] + try: + data[key] = value + except TypeError: + if key is Ellipsis: + # workaround for GH: scipy/scipy#6880 + data[:] = value + else: + raise def _open_scipy_netcdf(filename, mode, mmap, version): @@ -106,7 +103,7 @@ def _open_scipy_netcdf(filename, mode, mmap, version): raise -class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): +class ScipyDataStore(WritableCFDataStore): """Store for reading and writing data via scipy.io.netcdf. This store has the advantage of being able to be initialized with a @@ -116,7 +113,7 @@ class ScipyDataStore(WritableCFDataStore, DataStorePickleMixin): """ def __init__(self, filename_or_obj, mode='r', format=None, group=None, - writer=None, mmap=None, autoclose=False, lock=None): + writer=None, mmap=None, autoclose=None, lock=None): import scipy import scipy.io @@ -140,34 +137,28 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) - opener = functools.partial(_open_scipy_netcdf, - filename=filename_or_obj, - mode=mode, mmap=mmap, version=version) - self._ds = opener() - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode - + self._manager = FileManager( + _open_scipy_netcdf, filename_or_obj, mode=mode, + kwargs=dict(mmap=mmap, version=version)) super(ScipyDataStore, self).__init__(writer, lock=lock) + @property + def ds(self): + return self._manager.acquire() + def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - return Variable(var.dimensions, ScipyArrayWrapper(name, self), - _decode_attrs(var._attributes)) + return Variable(var.dimensions, ScipyArrayWrapper(name, self), + _decode_attrs(var._attributes)) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(_decode_attrs(self.ds._attributes)) + return Frozen(_decode_attrs(self.ds._attributes)) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -176,22 +167,20 @@ def get_encoding(self): return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - if name in self.ds.dimensions: - raise ValueError('%s does not support modifying dimensions' - % type(self).__name__) - dim_length = length if not is_unlimited else None - self.ds.createDimension(name, dim_length) + if name in self.ds.dimensions: + raise ValueError('%s does not support modifying dimensions' + % type(self).__name__) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, dim_length) def _validate_attr_key(self, key): if not is_valid_nc3_name(key): raise ValueError("Not a valid attribute name") def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - self._validate_attr_key(key) - value = encode_nc3_attr_value(value) - setattr(self.ds, key, value) + self._validate_attr_key(key) + value = encode_nc3_attr_value(value) + setattr(self.ds, key, value) def encode_variable(self, variable): variable = encode_nc3_variable(variable) @@ -220,26 +209,11 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, data def sync(self, compute=True): - if not compute: - raise NotImplementedError( - 'compute=False is not supported for the scipy backend yet') - with self.ensure_open(autoclose=True): - super(ScipyDataStore, self).sync(compute=compute) - self.ds.flush() + super(ScipyDataStore, self).sync(compute=compute) + self.ds.flush() def close(self): - self.ds.close() - self._isopen = False + self._manager.close() def __exit__(self, type, value, tb): self.close() - - def __setstate__(self, state): - filename = state['_opener'].keywords['filename'] - if hasattr(filename, 'seek'): - # it's a file-like object - # seek to the start of the file so scipy can read it - filename.seek(0) - super(ScipyDataStore, self).__setstate__(state) - self._ds = None - self._isopen = False diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 4336b2a9e72..9674cef0d64 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1565,10 +1565,9 @@ def test_to_netcdf_explicit_engine(self): # regression test for GH1321 Dataset({'foo': 42}).to_netcdf(engine='scipy') - @pytest.mark.skipif(PY2, reason='cannot pickle BytesIO on Python 2') - def test_bytesio_pickle(self): + def test_bytes_pickle(self): data = Dataset({'foo': ('x', [1, 2, 3])}) - fobj = BytesIO(data.to_netcdf()) + fobj = data.to_netcdf() with open_dataset(fobj, autoclose=self.autoclose) as ds: unpickled = pickle.loads(pickle.dumps(ds)) assert_identical(unpickled, data) @@ -1897,6 +1896,10 @@ def test_dump_encodings_h5py(self): self.assertEqual(actual.x.encoding['compression'], 'lzf') self.assertEqual(actual.x.encoding['compression_opts'], None) + @pytest.mark.xfail(reason="won't work until we use FileManager here") + def test_roundtrip_bytes_with_fill_value(self): + super(H5NetCDFDataTest, self).test_roundtrip_bytes_with_fill_value() + # tests pending h5netcdf fix @unittest.skip @@ -2348,20 +2351,6 @@ def test_dataarray_compute(self): self.assertTrue(computed._in_memory) assert_allclose(actual, computed, decode_bytes=False) - def test_to_netcdf_compute_false_roundtrip(self): - from dask.delayed import Delayed - - original = create_test_data().chunk() - - with create_tmp_file() as tmp_file: - # dataset, path, **kwargs): - delayed_obj = self.save(original, tmp_file, compute=False) - assert isinstance(delayed_obj, Delayed) - delayed_obj.compute() - - with self.open(tmp_file) as actual: - assert_identical(original, actual) - def test_save_mfdataset_compute_false_roundtrip(self): from dask.delayed import Delayed @@ -2651,28 +2640,27 @@ def test_uamiv_format_read(self): """ with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=UserWarning, - message=('IOAPI_ISPH is assumed to be ' + - '6370000.; consistent with WRF')) + message='IOAPI_ISPH') camxfile = open_example_dataset('example.uamiv', engine='pseudonetcdf', autoclose=True, backend_kwargs={'format': 'uamiv'}) - data = np.arange(20, dtype='f').reshape(1, 1, 4, 5) - expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, - dict(units='ppm', long_name='O3'.ljust(16), - var_desc='O3'.ljust(80))) - actual = camxfile.variables['O3'] - assert_allclose(expected, actual) - - data = np.array(['2002-06-03'], 'datetime64[ns]') - expected = xr.Variable(('TSTEP',), data, - dict(bounds='time_bounds', - long_name=('synthesized time coordinate ' + - 'from SDATE, STIME, STEP ' + - 'global attributes'))) - actual = camxfile.variables['time'] - assert_allclose(expected, actual) - camxfile.close() + data = np.arange(20, dtype='f').reshape(1, 1, 4, 5) + expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, + dict(units='ppm', long_name='O3'.ljust(16), + var_desc='O3'.ljust(80))) + actual = camxfile.variables['O3'] + assert_allclose(expected, actual) + + data = np.array(['2002-06-03'], 'datetime64[ns]') + attrs = dict(bounds='time_bounds', + long_name=('synthesized time coordinate ' + + 'from SDATE, STIME, STEP ' + + 'global attributes')) + expected = xr.Variable(('TSTEP',), data, attrs) + actual = camxfile.variables['time'] + assert_allclose(expected, actual) + camxfile.close() def test_uamiv_format_mfread(self): """ @@ -2680,8 +2668,7 @@ def test_uamiv_format_mfread(self): """ with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=UserWarning, - message=('IOAPI_ISPH is assumed to be ' + - '6370000.; consistent with WRF')) + message='IOAPI_ISPH') camxfile = open_example_mfdataset( ['example.uamiv', 'example.uamiv'], @@ -2690,39 +2677,38 @@ def test_uamiv_format_mfread(self): concat_dim='TSTEP', backend_kwargs={'format': 'uamiv'}) - data1 = np.arange(20, dtype='f').reshape(1, 1, 4, 5) - data = np.concatenate([data1] * 2, axis=0) - expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, - dict(units='ppm', long_name='O3'.ljust(16), - var_desc='O3'.ljust(80))) - actual = camxfile.variables['O3'] - assert_allclose(expected, actual) - - data1 = np.array(['2002-06-03'], 'datetime64[ns]') - data = np.concatenate([data1] * 2, axis=0) - expected = xr.Variable(('TSTEP',), data, - dict(bounds='time_bounds', - long_name=('synthesized time coordinate ' + - 'from SDATE, STIME, STEP ' + - 'global attributes'))) - actual = camxfile.variables['time'] - assert_allclose(expected, actual) - camxfile.close() + data1 = np.arange(20, dtype='f').reshape(1, 1, 4, 5) + data = np.concatenate([data1] * 2, axis=0) + expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, + dict(units='ppm', long_name='O3'.ljust(16), + var_desc='O3'.ljust(80))) + actual = camxfile.variables['O3'] + assert_allclose(expected, actual) + + data1 = np.array(['2002-06-03'], 'datetime64[ns]') + data = np.concatenate([data1] * 2, axis=0) + attrs = dict(bounds='time_bounds', + long_name=('synthesized time coordinate ' + + 'from SDATE, STIME, STEP ' + + 'global attributes')) + expected = xr.Variable(('TSTEP',), data, attrs) + actual = camxfile.variables['time'] + assert_allclose(expected, actual) + camxfile.close() def test_uamiv_format_write(self): fmtkw = {'format': 'uamiv'} with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=UserWarning, - message=('IOAPI_ISPH is assumed to be ' + - '6370000.; consistent with WRF')) + message='IOAPI_ISPH') expected = open_example_dataset('example.uamiv', engine='pseudonetcdf', autoclose=False, backend_kwargs=fmtkw) - with self.roundtrip(expected, - save_kwargs=fmtkw, - open_kwargs={'backend_kwargs': fmtkw}) as actual: - assert_identical(expected, actual) + with self.roundtrip(expected, + save_kwargs=fmtkw, + open_kwargs={'backend_kwargs': fmtkw}) as actual: + assert_identical(expected, actual) def save(self, dataset, path, **save_kwargs): import PseudoNetCDF as pnc @@ -3302,32 +3288,6 @@ def test_dataarray_to_netcdf_no_name_pathlib(self): assert_identical(original_da, loaded_da) -def test_pickle_reconstructor(): - - lines = ['foo bar spam eggs'] - - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp: - with open(tmp, 'w') as f: - f.writelines(lines) - - obj = PickleByReconstructionWrapper(open, tmp) - - assert obj.value.readlines() == lines - - p_obj = pickle.dumps(obj) - obj.value.close() # for windows - obj2 = pickle.loads(p_obj) - - assert obj2.value.readlines() == lines - - # roundtrip again to make sure we can fully restore the state - p_obj2 = pickle.dumps(obj2) - obj2.value.close() # for windows - obj3 = pickle.loads(p_obj2) - - assert obj3.value.readlines() == lines - - @requires_scipy_or_netCDF4 def test_no_warning_from_dask_effective_get(): with create_tmp_file() as tmpfile: From a0074ffcea46d242b45ab6be96dddfdfaa207b1c Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 28 Jul 2018 16:59:11 -0700 Subject: [PATCH 15/39] Stickler fix --- xarray/tests/test_backends.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 9674cef0d64..099d2a312f0 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2705,9 +2705,10 @@ def test_uamiv_format_write(self): engine='pseudonetcdf', autoclose=False, backend_kwargs=fmtkw) - with self.roundtrip(expected, - save_kwargs=fmtkw, - open_kwargs={'backend_kwargs': fmtkw}) as actual: + with self.roundtrip( + expected, + save_kwargs=fmtkw, + open_kwargs={'backend_kwargs': fmtkw}) as actual: assert_identical(expected, actual) def save(self, dataset, path, **save_kwargs): From 062ba969dbaeb7a0ff43a80fdb1dc295bba93da1 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 29 Jul 2018 11:36:03 -0700 Subject: [PATCH 16/39] Fix failure on Python 2.7 --- xarray/tests/test_backends.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 099d2a312f0..a598537fbde 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1568,7 +1568,7 @@ def test_to_netcdf_explicit_engine(self): def test_bytes_pickle(self): data = Dataset({'foo': ('x', [1, 2, 3])}) fobj = data.to_netcdf() - with open_dataset(fobj, autoclose=self.autoclose) as ds: + with self.open(fobj) as ds: unpickled = pickle.loads(pickle.dumps(ds)) assert_identical(unpickled, data) From 2d41b296c3b03ddfd41582382e48e607bf1a0e67 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 29 Jul 2018 16:47:46 -0700 Subject: [PATCH 17/39] Finish adjusting backends to use FileManager --- doc/whats-new.rst | 16 +- xarray/backends/__init__.py | 4 + xarray/backends/api.py | 59 +++--- xarray/backends/common.py | 63 ------- xarray/backends/file_manager.py | 45 ++++- xarray/backends/h5netcdf_.py | 146 +++++++-------- xarray/backends/lru_cache.py | 8 +- xarray/backends/netCDF4_.py | 198 ++++++++++----------- xarray/backends/pseudonetcdf_.py | 66 +++---- xarray/backends/pynio_.py | 41 ++--- xarray/backends/rasterio_.py | 4 +- xarray/backends/scipy_.py | 6 +- xarray/core/pycompat.py | 9 +- xarray/tests/test_backends.py | 144 +++++---------- xarray/tests/test_backends_file_manager.py | 30 ++-- xarray/tests/test_backends_lru_cache.py | 2 + 16 files changed, 358 insertions(+), 483 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index af485015094..59a27745400 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,11 +25,23 @@ What's New - `Python 3 Statement `__ - `Tips on porting to Python 3 `__ -.. _whats-new.0.10.9: +.. _whats-new.0.11.0: -v0.10.9 (unreleased) +v0.11.0 (unreleased) -------------------- +Breaking changes +~~~~~~~~~~~~~~~~ + +- Xarray's storage backends now automatically open and close files when + necessary, rather than requiring opening a file with ``autoclose=True``. A + global least-recently-used cache is used to store open files; the default + limit of 512 open files should suffice in most cases, but can be adjusted if + necessary with + ``xarray.set_options(file_cache_maxsize=...)``. + + TODO: Add some note about performance benefits. + Documentation ~~~~~~~~~~~~~ diff --git a/xarray/backends/__init__.py b/xarray/backends/__init__.py index 47a2011a3af..a2f0d79a6d1 100644 --- a/xarray/backends/__init__.py +++ b/xarray/backends/__init__.py @@ -4,6 +4,7 @@ formats. They should not be used directly, but rather through Dataset objects. """ from .common import AbstractDataStore +from .file_manager import FileManager, CachingFileManager, DummyFileManager from .memory import InMemoryDataStore from .netCDF4_ import NetCDF4DataStore from .pydap_ import PydapDataStore @@ -15,6 +16,9 @@ __all__ = [ 'AbstractDataStore', + 'FileManager', + 'CachingFileManager', + 'DummyFileManager', 'InMemoryDataStore', 'NetCDF4DataStore', 'PydapDataStore', diff --git a/xarray/backends/api.py b/xarray/backends/api.py index b2c0df7b01b..0df98bb4c27 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -4,6 +4,7 @@ from glob import glob from io import BytesIO from numbers import Number +import warnings import numpy as np @@ -152,7 +153,7 @@ def _finalize_store(write, store): def open_dataset(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=None, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=None, concat_characters=True, decode_coords=True, engine=None, chunks=None, lock=None, cache=None, drop_variables=None, backend_kwargs=None): @@ -235,6 +236,14 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, -------- open_mfdataset """ + if autoclose is not None: + warnings.warn( + 'The autoclose argument is no longer used by ' + 'xarray.open_dataset() and is now ignored; it will be removed in ' + 'xarray v0.12. If necessary, you can control the maximum number ' + 'of simultaneous open files with ' + 'xarray.set_options(file_cache_maxsize=...).', + FutureWarning, stacklevel=2) if mask_and_scale is None: mask_and_scale = not engine == 'pseudonetcdf' @@ -278,12 +287,6 @@ def maybe_decode_store(store, lock=False): else: ds2 = ds - # protect so that dataset store isn't necessarily closed, e.g., - # streams like BytesIO can't be reopened - # datastore backend is responsible for determining this capability - if store._autoclose: - store.close() - return ds2 if isinstance(filename_or_obj, path_type): @@ -314,28 +317,21 @@ def maybe_decode_store(store, lock=False): engine = _get_default_engine(filename_or_obj, allow_remote=True) if engine == 'netcdf4': - store = backends.NetCDF4DataStore.open(filename_or_obj, - group=group, - autoclose=autoclose, - **backend_kwargs) + store = backends.NetCDF4DataStore.open( + filename_or_obj, group=group, **backend_kwargs) elif engine == 'scipy': - store = backends.ScipyDataStore(filename_or_obj, - autoclose=autoclose, - **backend_kwargs) + store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) elif engine == 'pydap': - store = backends.PydapDataStore.open(filename_or_obj, - **backend_kwargs) + store = backends.PydapDataStore.open( + filename_or_obj, **backend_kwargs) elif engine == 'h5netcdf': - store = backends.H5NetCDFStore(filename_or_obj, group=group, - autoclose=autoclose, - **backend_kwargs) + store = backends.H5NetCDFStore( + filename_or_obj, group=group, **backend_kwargs) elif engine == 'pynio': - store = backends.NioDataStore(filename_or_obj, - autoclose=autoclose, - **backend_kwargs) + store = backends.NioDataStore(filename_or_obj, **backend_kwargs) elif engine == 'pseudonetcdf': store = backends.PseudoNetCDFDataStore.open( - filename_or_obj, autoclose=autoclose, **backend_kwargs) + filename_or_obj, **backend_kwargs) else: raise ValueError('unrecognized engine for open_dataset: %r' % engine) @@ -355,7 +351,7 @@ def maybe_decode_store(store, lock=False): def open_dataarray(filename_or_obj, group=None, decode_cf=True, - mask_and_scale=None, decode_times=True, autoclose=False, + mask_and_scale=None, decode_times=True, autoclose=None, concat_characters=True, decode_coords=True, engine=None, chunks=None, lock=None, cache=None, drop_variables=None, backend_kwargs=None): @@ -390,10 +386,6 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, decode_times : bool, optional If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers. - autoclose : bool, optional - If True, automatically close files to avoid OS Error of too many files - being open. However, this option doesn't work with streams, e.g., - BytesIO. concat_characters : bool, optional If True, concatenate along the last dimension of character arrays to form string arrays. Dimensions will only be concatenated over (and @@ -490,7 +482,7 @@ def close(self): def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, compat='no_conflicts', preprocess=None, engine=None, lock=None, data_vars='all', coords='different', - autoclose=False, parallel=False, **kwargs): + autoclose=None, parallel=False, **kwargs): """Open multiple files as a single dataset. Requires dask to be installed. See documentation for details on dask [1]. @@ -537,10 +529,6 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. - autoclose : bool, optional - If True, automatically close files to avoid OS Error of too many files - being open. However, this option doesn't work with streams, e.g., - BytesIO. lock : False, True or threading.Lock, optional This argument is passed on to :py:func:`dask.array.from_array`. By default, a per-variable lock is used when reading data from netCDF @@ -707,12 +695,9 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, "is not currently supported with dask's %s " "scheduler" % (engine, scheduler)) lock = _get_lock(engine, scheduler, format, path_or_file) - autoclose = (have_chunks and - scheduler in ['distributed', 'multiprocessing']) target = path_or_file if path_or_file is not None else BytesIO() - store = store_open(target, mode, format, group, writer, - autoclose=autoclose, lock=lock) + store = store_open(target, mode, format, group, writer, lock=lock) if unlimited_dims is None: unlimited_dims = dataset.encoding.get('unlimited_dims', None) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 04be2feef37..77fda5fc5b1 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -174,9 +174,6 @@ def __array__(self, dtype=None): class AbstractDataStore(Mapping): - _autoclose = None - _ds = None - _isopen = None def __iter__(self): return iter(self.variables) @@ -330,9 +327,6 @@ def set_variable(self, k, v): # pragma: no cover raise NotImplementedError def sync(self, compute=True): - if self._isopen is not None and self._isopen and self._autoclose: - # datastore will be reopened during write - self.close() self.delayed_store = self.writer.sync(compute=compute) def store_dataset(self, dataset): @@ -457,60 +451,3 @@ def encode(self, variables, attributes): attributes = OrderedDict([(k, self.encode_attribute(v)) for k, v in attributes.items()]) return variables, attributes - - -class DataStorePickleMixin(object): - """Subclasses must define `ds`, `_opener` and `_mode` attributes. - - Do not subclass this class: it is not part of xarray's external API. - """ - - def __getstate__(self): - state = self.__dict__.copy() - del state['_ds'] - del state['_isopen'] - if self._mode == 'w': - # file has already been created, don't override when restoring - state['_mode'] = 'a' - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self._ds = None - self._isopen = False - - @property - def ds(self): - if self._ds is not None and self._isopen: - return self._ds - ds = self._opener(mode=self._mode) - self._isopen = True - return ds - - @contextlib.contextmanager - def ensure_open(self, autoclose=None): - """ - Helper function to make sure datasets are closed and opened - at appropriate times to avoid too many open file errors. - - Use requires `autoclose=True` argument to `open_mfdataset`. - """ - - if autoclose is None: - autoclose = self._autoclose - - if not self._isopen: - try: - self._ds = self._opener() - self._isopen = True - yield - finally: - if autoclose: - self.close() - else: - yield - - def assert_open(self): - if not self._isopen: - raise AssertionError('internal failure: file must be open ' - 'if `autoclose=True` is used.') diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index d62025a7ec9..9317d85b5f6 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -15,12 +15,29 @@ class FileManager(object): + """Manager for acquiring and closing a file object. + + Use FileManager subclasses (CachingFileManager in particular) on backend + storage classes to automatically handle issues related to keeping track of + many open files and transferring them between multiple processes. + """ + + def acquire(self): + """Acquire the file object from this manager.""" + raise NotImplementedError + + def close(self): + """Close the file object associated with this manager, if needed.""" + raise NotImplementedError + + +class CachingFileManager(FileManager): """Wrapper for automatically opening and closing file objects. - Unlike files, FileManager objects can be safely pickled and passed between - processes. They should be explicitly closed to release resources, but - a per-process least-recently-used cache for open files ensures that you can - safely create arbitrarily large numbers of FileManager objects. + Unlike files, CachingFileManager objects can be safely pickled and passed + between processes. They should be explicitly closed to release resources, + but a per-process least-recently-used cache for open files ensures that you + can safely create arbitrarily large numbers of FileManager objects. Don't directly close files acquired from a FileManager. Instead, call FileManager.close(), which ensures that closed files are removed from the @@ -80,6 +97,10 @@ def __init__(self, opener, *args, **keywords): kwargs = keywords.pop('kwargs', None) lock = keywords.pop('lock', None) cache = keywords.pop('cache', FILE_CACHE) + if keywords: + raise TypeError('FileManager() got unexpected keyword arguments: ' + '%s' % list(keywords)) + self._opener = opener self._args = args self._mode = mode @@ -130,7 +151,8 @@ def acquire(self): def close(self): """Explicitly close any associated file object (if necessary).""" with self._lock: - file = self._cache.pop(self._key, default=None) + default = None + file = self._cache.pop(self._key, default) if file is not None: file.close() @@ -160,3 +182,16 @@ def __init__(self, tuple_value): def __hash__(self): return self.hashvalue + + +class DummyFileManager(FileManager): + """FileManager that simply wraps an open file in the FileManager interface. + """ + def __init__(self, value): + self._value = value + + def acquire(self): + return self._value + + def close(self): + self._value.close() diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 959cd221734..be27d2814b0 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -9,10 +9,11 @@ from ..core.pycompat import OrderedDict, bytes_type, iteritems, unicode_type from ..core.utils import FrozenOrderedDict, close_on_error from .common import ( - HDF5_LOCK, DataStorePickleMixin, WritableCFDataStore, find_root) + HDF5_LOCK, WritableCFDataStore, find_root) +from .file_manager import CachingFileManager from .netCDF4_ import ( - BaseNetCDF4Array, _encode_nc4_variable, _extract_nc4_variable_encoding, - _get_datatype, _nc4_require_group) + BaseNetCDF4Array, GroupWrapper, _encode_nc4_variable, + _extract_nc4_variable_encoding, _get_datatype, _nc4_require_group) class H5NetCDFArrayWrapper(BaseNetCDF4Array): @@ -25,8 +26,7 @@ def _getitem(self, key): # h5py requires using lists for fancy indexing: # https://github.com/h5py/h5py/issues/992 key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in key) - with self.datastore.ensure_open(autoclose=True): - return self.get_array()[key] + return self.get_array()[key] def maybe_decode_bytes(txt): @@ -61,11 +61,12 @@ def _open_h5netcdf_group(filename, mode, group): import h5netcdf ds = h5netcdf.File(filename, mode=mode) with close_on_error(ds): - return _nc4_require_group( + ds = _nc4_require_group( ds, group, mode, create_group=_h5netcdf_create_group) + return GroupWrapper(ds) -class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin): +class H5NetCDFStore(WritableCFDataStore): """Store for reading and writing data via h5netcdf """ @@ -73,92 +74,82 @@ def __init__(self, filename, mode='r', format=None, group=None, writer=None, autoclose=False, lock=HDF5_LOCK): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') - opener = functools.partial(_open_h5netcdf_group, filename, mode=mode, - group=group) - self._ds = opener() - if autoclose: - raise NotImplementedError('autoclose=True is not implemented ' - 'for the h5netcdf backend pending ' - 'further exploration, e.g., bug fixes ' - '(in h5netcdf?)') - self._autoclose = False - self._isopen = True + self._manager = CachingFileManager( + _open_h5netcdf_group, filename, mode=mode, + kwargs=dict(group=group)) + self.format = format - self._opener = opener self._filename = filename self._mode = mode super(H5NetCDFStore, self).__init__(writer, lock=lock) + @property + def ds(self): + return self._manager.acquire().value + def open_store_variable(self, name, var): import h5py - with self.ensure_open(autoclose=False): - dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray( - H5NetCDFArrayWrapper(name, self)) - attrs = _read_attributes(var) - - # netCDF4 specific encoding - encoding = { - 'chunksizes': var.chunks, - 'fletcher32': var.fletcher32, - 'shuffle': var.shuffle, - } - # Convert h5py-style compression options to NetCDF4-Python - # style, if possible - if var.compression == 'gzip': - encoding['zlib'] = True - encoding['complevel'] = var.compression_opts - elif var.compression is not None: - encoding['compression'] = var.compression - encoding['compression_opts'] = var.compression_opts - - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape - - vlen_dtype = h5py.check_dtype(vlen=var.dtype) - if vlen_dtype is unicode_type: - encoding['dtype'] = str - elif vlen_dtype is not None: # pragma: no cover - # xarray doesn't support writing arbitrary vlen dtypes yet. - pass - else: - encoding['dtype'] = var.dtype + dimensions = var.dimensions + data = indexing.LazilyOuterIndexedArray( + H5NetCDFArrayWrapper(name, self)) + attrs = _read_attributes(var) + + # netCDF4 specific encoding + encoding = { + 'chunksizes': var.chunks, + 'fletcher32': var.fletcher32, + 'shuffle': var.shuffle, + } + # Convert h5py-style compression options to NetCDF4-Python + # style, if possible + if var.compression == 'gzip': + encoding['zlib'] = True + encoding['complevel'] = var.compression_opts + elif var.compression is not None: + encoding['compression'] = var.compression + encoding['compression_opts'] = var.compression_opts + + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape + + vlen_dtype = h5py.check_dtype(vlen=var.dtype) + if vlen_dtype is unicode_type: + encoding['dtype'] = str + elif vlen_dtype is not None: # pragma: no cover + # xarray doesn't support writing arbitrary vlen dtypes yet. + pass + else: + encoding['dtype'] = var.dtype return Variable(dimensions, data, attrs, encoding) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in iteritems(self.ds.variables)) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in iteritems(self.ds.variables)) def get_attrs(self): - with self.ensure_open(autoclose=True): - return FrozenOrderedDict(_read_attributes(self.ds)) + return FrozenOrderedDict(_read_attributes(self.ds)) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return self.ds.dimensions + return self.ds.dimensions def get_encoding(self): - with self.ensure_open(autoclose=True): - encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v is None} + encoding = {} + encoding['unlimited_dims'] = { + k for k, v in self.ds.dimensions.items() if v is None} return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - if is_unlimited: - self.ds.dimensions[name] = None - self.ds.resize_dimension(name, length) - else: - self.ds.dimensions[name] = length + if is_unlimited: + self.ds.dimensions[name] = None + self.ds.resize_dimension(name, length) + else: + self.ds.dimensions[name] = length def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - self.ds.attrs[key] = value + self.ds.attrs[key] = value def encode_variable(self, variable): return _encode_nc4_variable(variable) @@ -227,17 +218,8 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data def sync(self, compute=True): - if not compute: - raise NotImplementedError( - 'compute=False is not supported for the h5netcdf backend yet') - with self.ensure_open(autoclose=True): - super(H5NetCDFStore, self).sync(compute=compute) - self.ds.sync() + super(H5NetCDFStore, self).sync(compute=compute) + self.ds.sync() def close(self): - if self._isopen: - # netCDF4 only allows closing the root group - ds = find_root(self.ds) - if not ds._closed: - ds.close() - self._isopen = False + self._manager.close() diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py index 59f896c0101..356a1b3edc7 100644 --- a/xarray/backends/lru_cache.py +++ b/xarray/backends/lru_cache.py @@ -1,6 +1,8 @@ import collections import threading +from ..pycompat import move_to_end + class LRUCache(collections.MutableMapping): """Thread-safe LRUCache based on an OrderedDict. @@ -26,6 +28,8 @@ def __init__(self, maxsize, on_evict=None): Function to call like ``on_evict(key, value)`` when items are evicted. """ + if not isinstance(maxsize, int): + raise TypeError('maxsize must be an integer') if maxsize < 0: raise ValueError('maxsize must be non-negative') self._maxsize = maxsize @@ -37,9 +41,7 @@ def __getitem__(self, key): # record recent use of the key by moving it to the front of the list with self._lock: value = self._cache[key] - # On Python 3, could just use: self._cache.move_to_end(key) - del self._cache[key] - self._cache[key] = value + move_to_end(self._cache, key) return value def _enforce_size_limit(self, capacity): diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 5c6d82fd126..6f60f36bde5 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -14,8 +14,9 @@ PY3, OrderedDict, basestring, iteritems, suppress) from ..core.utils import FrozenOrderedDict, close_on_error, is_remote_uri from .common import ( - HDF5_LOCK, BackendArray, DataStorePickleMixin, WritableCFDataStore, + HDF5_LOCK, BackendArray, WritableCFDataStore, find_root, robust_getitem) +from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable # This lookup table maps from dtype.byteorder to a readable endian @@ -43,12 +44,10 @@ def __init__(self, variable_name, datastore): self.dtype = dtype def __setitem__(self, key, value): - with self.datastore.ensure_open(autoclose=True): - data = self.get_array() - data[key] = value + data = self.get_array() + data[key] = value def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] @@ -64,20 +63,19 @@ def _getitem(self, key): else: getitem = operator.getitem - with self.datastore.ensure_open(autoclose=True): - try: - array = getitem(self.get_array(), key) - except IndexError: - # Catch IndexError in netCDF4 and return a more informative - # error message. This is most often called when an unsorted - # indexer is used before the data is loaded from disk. - msg = ('The indexing operation you are attempting to perform ' - 'is not valid on netCDF4.Variable object. Try loading ' - 'your data into memory first by calling .load().') - if not PY3: - import traceback - msg += '\n\nOriginal traceback:\n' + traceback.format_exc() - raise IndexError(msg) + try: + array = getitem(self.get_array(), key) + except IndexError: + # Catch IndexError in netCDF4 and return a more informative + # error message. This is most often called when an unsorted + # indexer is used before the data is loaded from disk. + msg = ('The indexing operation you are attempting to perform ' + 'is not valid on netCDF4.Variable object. Try loading ' + 'your data into memory first by calling .load().') + if not PY3: + import traceback + msg += '\n\nOriginal traceback:\n' + traceback.format_exc() + raise IndexError(msg) return array @@ -224,6 +222,15 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, return encoding +class GroupWrapper(object): + def __init__(self, value): + self.value = value + + def close(self): + # netCDF4 only allows closing the root group + find_root(self.value).close() + + def _open_netcdf4_group(filename, mode, group=None, **kwargs): import netCDF4 as nc4 @@ -234,7 +241,7 @@ def _open_netcdf4_group(filename, mode, group=None, **kwargs): _disable_auto_decode_group(ds) - return ds + return GroupWrapper(ds) def _disable_auto_decode_variable(var): @@ -280,40 +287,32 @@ def _set_nc_attribute(obj, key, value): obj.setncattr(key, value) -class NetCDF4DataStore(WritableCFDataStore, DataStorePickleMixin): +class NetCDF4DataStore(WritableCFDataStore): """Store for reading and writing data via the Python-NetCDF4 library. This store supports NetCDF3, NetCDF4 and OpenDAP datasets. """ - def __init__(self, netcdf4_dataset, mode='r', writer=None, opener=None, - autoclose=False, lock=HDF5_LOCK): + def __init__(self, manager, writer=None, lock=HDF5_LOCK): + import netCDF4 - if autoclose and opener is None: - raise ValueError('autoclose requires an opener') + if isinstance(manager, netCDF4.Dataset): + _disable_auto_decode_group(manager) + manager = DummyFileManager(GroupWrapper(manager)) - _disable_auto_decode_group(netcdf4_dataset) - - self._ds = netcdf4_dataset - self._autoclose = autoclose - self._isopen = True + self._manager = manager self.format = self.ds.data_model self._filename = self.ds.filepath() self.is_remote = is_remote_uri(self._filename) - self._mode = mode = 'a' if mode == 'w' else mode - if opener: - self._opener = functools.partial(opener, mode=self._mode) - else: - self._opener = opener super(NetCDF4DataStore, self).__init__(writer, lock=lock) @classmethod def open(cls, filename, mode='r', format='NETCDF4', group=None, writer=None, clobber=True, diskless=False, persist=False, - autoclose=False, lock=HDF5_LOCK): - import netCDF4 as nc4 + lock=HDF5_LOCK): + import netCDF4 if (len(filename) == 88 and - LooseVersion(nc4.__version__) < "1.3.1"): + LooseVersion(netCDF4.__version__) < "1.3.1"): warnings.warn( 'A segmentation fault may occur when the ' 'file path has exactly 88 characters as it does ' @@ -324,86 +323,77 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None, 'https://github.com/pydata/xarray/issues/1745') if format is None: format = 'NETCDF4' - opener = functools.partial(_open_netcdf4_group, filename, mode=mode, - group=group, clobber=clobber, - diskless=diskless, persist=persist, - format=format) - ds = opener() - return cls(ds, mode=mode, writer=writer, opener=opener, - autoclose=autoclose, lock=lock) + manager = CachingFileManager( + _open_netcdf4_group, filename, mode=mode, + kwargs=dict(group=group, clobber=clobber, diskless=diskless, + persist=persist, format=format)) + return cls(manager, writer=writer, lock=lock) + + @property + def ds(self): + return self._manager.acquire().value def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - dimensions = var.dimensions - data = indexing.LazilyOuterIndexedArray( - NetCDF4ArrayWrapper(name, self)) - attributes = OrderedDict((k, var.getncattr(k)) - for k in var.ncattrs()) - _ensure_fill_value_valid(data, attributes) - # netCDF4 specific encoding; save _FillValue for later - encoding = {} - filters = var.filters() - if filters is not None: - encoding.update(filters) - chunking = var.chunking() - if chunking is not None: - if chunking == 'contiguous': - encoding['contiguous'] = True - encoding['chunksizes'] = None - else: - encoding['contiguous'] = False - encoding['chunksizes'] = tuple(chunking) - # TODO: figure out how to round-trip "endian-ness" without raising - # warnings from netCDF4 - # encoding['endian'] = var.endian() - pop_to(attributes, encoding, 'least_significant_digit') - # save source so __repr__ can detect if it's local or not - encoding['source'] = self._filename - encoding['original_shape'] = var.shape - encoding['dtype'] = var.dtype + dimensions = var.dimensions + data = indexing.LazilyOuterIndexedArray( + NetCDF4ArrayWrapper(name, self)) + attributes = OrderedDict((k, var.getncattr(k)) + for k in var.ncattrs()) + _ensure_fill_value_valid(data, attributes) + # netCDF4 specific encoding; save _FillValue for later + encoding = {} + filters = var.filters() + if filters is not None: + encoding.update(filters) + chunking = var.chunking() + if chunking is not None: + if chunking == 'contiguous': + encoding['contiguous'] = True + encoding['chunksizes'] = None + else: + encoding['contiguous'] = False + encoding['chunksizes'] = tuple(chunking) + # TODO: figure out how to round-trip "endian-ness" without raising + # warnings from netCDF4 + # encoding['endian'] = var.endian() + pop_to(attributes, encoding, 'least_significant_digit') + # save source so __repr__ can detect if it's local or not + encoding['source'] = self._filename + encoding['original_shape'] = var.shape + encoding['dtype'] = var.dtype return Variable(dimensions, data, attributes, encoding) def get_variables(self): - with self.ensure_open(autoclose=False): - dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in - iteritems(self.ds.variables)) + dsvars = FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in + iteritems(self.ds.variables)) return dsvars def get_attrs(self): - with self.ensure_open(autoclose=True): - attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) - for k in self.ds.ncattrs()) + attrs = FrozenOrderedDict((k, self.ds.getncattr(k)) + for k in self.ds.ncattrs()) return attrs def get_dimensions(self): - with self.ensure_open(autoclose=True): - dims = FrozenOrderedDict((k, len(v)) - for k, v in iteritems(self.ds.dimensions)) + dims = FrozenOrderedDict((k, len(v)) + for k, v in iteritems(self.ds.dimensions)) return dims def get_encoding(self): - with self.ensure_open(autoclose=True): - encoding = {} - encoding['unlimited_dims'] = { - k for k, v in self.ds.dimensions.items() if v.isunlimited()} + encoding = {} + encoding['unlimited_dims'] = { + k for k, v in self.ds.dimensions.items() if v.isunlimited()} return encoding def set_dimension(self, name, length, is_unlimited=False): - with self.ensure_open(autoclose=False): - dim_length = length if not is_unlimited else None - self.ds.createDimension(name, size=dim_length) + dim_length = length if not is_unlimited else None + self.ds.createDimension(name, size=dim_length) def set_attribute(self, key, value): - with self.ensure_open(autoclose=False): - if self.format != 'NETCDF4': - value = encode_nc3_attr_value(value) - _set_nc_attribute(self.ds, key, value) - - def set_variables(self, *args, **kwargs): - with self.ensure_open(autoclose=False): - super(NetCDF4DataStore, self).set_variables(*args, **kwargs) + if self.format != 'NETCDF4': + value = encode_nc3_attr_value(value) + _set_nc_attribute(self.ds, key, value) def encode_variable(self, variable): variable = _force_native_endianness(variable) @@ -462,14 +452,8 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data def sync(self, compute=True): - with self.ensure_open(autoclose=True): - super(NetCDF4DataStore, self).sync(compute=compute) - self.ds.sync() + super(NetCDF4DataStore, self).sync(compute=compute) + self.ds.sync() def close(self): - if self._isopen: - # netCDF4 only allows closing the root group - ds = find_root(self.ds) - if ds._isopen: - ds.close() - self._isopen = False + self._manager.close() diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index d946c6fa927..dc2867b6ca1 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -2,8 +2,6 @@ from __future__ import division from __future__ import print_function -import functools - import numpy as np from .. import Variable @@ -11,7 +9,8 @@ from ..core.utils import (FrozenOrderedDict, Frozen) from ..core import indexing -from .common import AbstractDataStore, DataStorePickleMixin, BackendArray +from .common import AbstractDataStore, BackendArray +from .file_manager import CachingFileManager class PncArrayWrapper(BackendArray): @@ -24,7 +23,6 @@ def __init__(self, variable_name, datastore): self.dtype = np.dtype(array.dtype) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] def __getitem__(self, key): @@ -33,57 +31,49 @@ def __getitem__(self, key): self._getitem) def _getitem(self, key): - with self.datastore.ensure_open(autoclose=True): - return self.get_array()[key] + return self.get_array()[key] -class PseudoNetCDFDataStore(AbstractDataStore, DataStorePickleMixin): +class PseudoNetCDFDataStore(AbstractDataStore): """Store for accessing datasets via PseudoNetCDF """ @classmethod - def open(cls, filename, format=None, writer=None, - autoclose=False, **format_kwds): + def open(cls, filename, **format_kwds): from PseudoNetCDF import pncopen - opener = functools.partial(pncopen, filename, **format_kwds) - ds = opener() - mode = format_kwds.get('mode', 'r') - return cls(ds, mode=mode, writer=writer, opener=opener, - autoclose=autoclose) - def __init__(self, pnc_dataset, mode='r', writer=None, opener=None, - autoclose=False): + keywords = dict(kwargs=format_kwds) + # only include mode if explicitly passed + mode = format_kwds.pop('mode', None) + if mode is not None: + keywords['mode'] = mode + + manager = CachingFileManager(pncopen, filename, **keywords) + return cls(manager) - if autoclose and opener is None: - raise ValueError('autoclose requires an opener') + def __init__(self, manager): + self._manager = manager - self._ds = pnc_dataset - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode - super(PseudoNetCDFDataStore, self).__init__() + @property + def ds(self): + return self._manager.acquire() def open_store_variable(self, name, var): - with self.ensure_open(autoclose=False): - data = indexing.LazilyOuterIndexedArray( - PncArrayWrapper(name, self) - ) + data = indexing.LazilyOuterIndexedArray( + PncArrayWrapper(name, self) + ) attrs = OrderedDict((k, getattr(var, k)) for k in var.ncattrs()) return Variable(var.dimensions, data, attrs) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(dict([(k, getattr(self.ds, k)) - for k in self.ds.ncattrs()])) + return Frozen(dict([(k, getattr(self.ds, k)) + for k in self.ds.ncattrs()])) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -93,6 +83,4 @@ def get_encoding(self): return encoding def close(self): - if self._isopen: - self.ds.close() - self._isopen = False + self._manager.close() diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 98b76928597..d7692565b3c 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -1,13 +1,12 @@ from __future__ import absolute_import, division, print_function -import functools - import numpy as np from .. import Variable from ..core import indexing from ..core.utils import Frozen, FrozenOrderedDict -from .common import AbstractDataStore, BackendArray, DataStorePickleMixin +from .common import AbstractDataStore, BackendArray +from .file_manager import CachingFileManager class NioArrayWrapper(BackendArray): @@ -20,7 +19,6 @@ def __init__(self, variable_name, datastore): self.dtype = np.dtype(array.typecode()) def get_array(self): - self.datastore.assert_open() return self.datastore.ds.variables[self.variable_name] def __getitem__(self, key): @@ -28,26 +26,20 @@ def __getitem__(self, key): key, self.shape, indexing.IndexingSupport.BASIC, self._getitem) def _getitem(self, key): - with self.datastore.ensure_open(autoclose=True): - array = self.get_array() - if key == () and self.ndim == 0: - return array.get_value() + array = self.get_array() + if key == () and self.ndim == 0: + return array.get_value() - return array[key] + return array[key] -class NioDataStore(AbstractDataStore, DataStorePickleMixin): +class NioDataStore(AbstractDataStore): """Store for accessing datasets via PyNIO """ - def __init__(self, filename, mode='r', autoclose=False): + def __init__(self, filename, mode='r'): import Nio - opener = functools.partial(Nio.open_file, filename, mode=mode) - self._ds = opener() - self._autoclose = autoclose - self._isopen = True - self._opener = opener - self._mode = mode + self._manager = CachingFileManager(Nio.open_file, filename, mode=mode) # xarray provides its own support for FillValue, # so turn off PyNIO's support for the same. self.ds.set_option('MaskedArrayMode', 'MaskedNever') @@ -57,17 +49,14 @@ def open_store_variable(self, name, var): return Variable(var.dimensions, data, var.attributes) def get_variables(self): - with self.ensure_open(autoclose=False): - return FrozenOrderedDict((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return FrozenOrderedDict((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) def get_attrs(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.attributes) + return Frozen(self.ds.attributes) def get_dimensions(self): - with self.ensure_open(autoclose=True): - return Frozen(self.ds.dimensions) + return Frozen(self.ds.dimensions) def get_encoding(self): encoding = {} @@ -76,6 +65,4 @@ def get_encoding(self): return encoding def close(self): - if self._isopen: - self.ds.close() - self._isopen = False + self._manager.close() diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 2c01c6e8cfa..36ef45a2ff3 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -9,7 +9,7 @@ from ..core import indexing from ..core.utils import is_scalar from .common import BackendArray -from .file_manager import FileManager +from .file_manager import CachingFileManager try: from dask.utils import SerializableLock as Lock @@ -209,7 +209,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, import rasterio - manager = FileManager(rasterio.open, filename, mode='r') + manager = CachingFileManager(rasterio.open, filename, mode='r') riods = manager.acquire() if cache is None: diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index bb1fcada1bf..628aa64516c 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -11,7 +11,7 @@ from ..core.pycompat import OrderedDict, basestring, iteritems from ..core.utils import Frozen, FrozenOrderedDict from .common import BackendArray, WritableCFDataStore -from .file_manager import FileManager +from .file_manager import CachingFileManager from .netcdf3 import ( encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name) @@ -113,7 +113,7 @@ class ScipyDataStore(WritableCFDataStore): """ def __init__(self, filename_or_obj, mode='r', format=None, group=None, - writer=None, mmap=None, autoclose=None, lock=None): + writer=None, mmap=None, lock=None): import scipy import scipy.io @@ -137,7 +137,7 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) - self._manager = FileManager( + self._manager = CachingFileManager( _open_scipy_netcdf, filename_or_obj, mode=mode, kwargs=dict(mmap=mmap, version=version)) super(ScipyDataStore, self).__init__(writer, lock=lock) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 78c26f1e92f..b980bc279b0 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -28,6 +28,9 @@ def itervalues(d): import builtins from urllib.request import urlretrieve from inspect import getfullargspec as getargspec + + def move_to_end(ordered_dict, key): + ordered_dict.move_to_end(key) else: # pragma: no cover # Python 2 basestring = basestring # noqa @@ -50,6 +53,11 @@ def itervalues(d): from urllib import urlretrieve from inspect import getargspec + def move_to_end(ordered_dict, key): + value = ordered_dict[key] + del ordered_dict[key] + ordered_dict[key] = value + integer_types = native_int_types + (np.integer,) try: @@ -76,7 +84,6 @@ def itervalues(d): except ImportError as e: path_type = () - try: from contextlib import suppress except ImportError: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a598537fbde..6b771a7a3c4 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -7,7 +7,6 @@ import shutil import sys import tempfile -import unittest import warnings from io import BytesIO @@ -24,7 +23,8 @@ from xarray.backends.pydap_ import PydapDataStore from xarray.core import indexing from xarray.core.pycompat import ( - PY2, ExitStack, basestring, dask_array_type, iteritems) + ExitStack, basestring, dask_array_type, iteritems) +from xarray.core.options import set_options from xarray.tests import mock from . import ( @@ -137,7 +137,6 @@ class NetCDF3Only(object): class DatasetIOTestCases(object): - autoclose = False engine = None file_format = None @@ -171,8 +170,7 @@ def save(self, dataset, path, **kwargs): @contextlib.contextmanager def open(self, path, **kwargs): - with open_dataset(path, engine=self.engine, autoclose=self.autoclose, - **kwargs) as ds: + with open_dataset(path, engine=self.engine, **kwargs) as ds: yield ds def test_zero_dimensional_variable(self): @@ -1161,10 +1159,10 @@ def test_already_open_dataset(self): v[...] = 42 nc = nc4.Dataset(tmp_file, mode='r') - with backends.NetCDF4DataStore(nc, autoclose=False) as store: - with open_dataset(store) as ds: - expected = Dataset({'x': ((), 42)}) - assert_identical(expected, ds) + store = backends.NetCDF4DataStore(nc) + with open_dataset(store) as ds: + expected = Dataset({'x': ((), 42)}) + assert_identical(expected, ds) def test_read_variable_len_strings(self): with create_tmp_file() as tmp_file: @@ -1183,8 +1181,6 @@ def test_read_variable_len_strings(self): @requires_netCDF4 class NetCDF4DataTest(BaseNetCDF4Test, TestCase): - autoclose = False - @contextlib.contextmanager def create_store(self): with create_tmp_file() as tmp_file: @@ -1249,9 +1245,13 @@ def test_setncattr_string(self): totest.attrs['bar']) assert one_string == totest.attrs['baz'] - -class NetCDF4DataStoreAutocloseTrue(NetCDF4DataTest): - autoclose = True + def test_autoclose_future_warning(self): + data = create_test_data() + with create_tmp_file() as tmp_file: + self.save(data, tmp_file) + with pytest.warns(FutureWarning): + with self.open(tmp_file, autoclose=True) as actual: + assert_identical(data, actual) @requires_netCDF4 @@ -1292,10 +1292,6 @@ def test_write_inconsistent_chunks(self): assert actual['y'].encoding['chunksizes'] == (100, 50) -class NetCDF4ViaDaskDataTestAutocloseTrue(NetCDF4ViaDaskDataTest): - autoclose = True - - @requires_zarr class BaseZarrTest(CFEncodedDataTest): @@ -1573,10 +1569,6 @@ def test_bytes_pickle(self): assert_identical(unpickled, data) -class ScipyInMemoryDataTestAutocloseTrue(ScipyInMemoryDataTest): - autoclose = True - - @requires_scipy class ScipyFileObjectTest(ScipyWriteTest, TestCase): engine = 'scipy' @@ -1642,10 +1634,6 @@ def test_nc4_scipy(self): open_dataset(tmp_file, engine='scipy') -class ScipyFilePathTestAutocloseTrue(ScipyFilePathTest): - autoclose = True - - @requires_netCDF4 class NetCDF3ViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only, TestCase): engine = 'netcdf4' @@ -1666,10 +1654,6 @@ def test_encoding_kwarg_vlen_string(self): pass -class NetCDF3ViaNetCDF4DataTestAutocloseTrue(NetCDF3ViaNetCDF4DataTest): - autoclose = True - - @requires_netCDF4 class NetCDF4ClassicViaNetCDF4DataTest(CFEncodedDataTest, NetCDF3Only, TestCase): @@ -1684,11 +1668,6 @@ def create_store(self): yield store -class NetCDF4ClassicViaNetCDF4DataTestAutocloseTrue( - NetCDF4ClassicViaNetCDF4DataTest): - autoclose = True - - @requires_scipy_or_netCDF4 class GenericNetCDFDataTest(CFEncodedDataTest, NetCDF3Only, TestCase): # verify that we can read and write netCDF3 files as long as we have scipy @@ -1765,10 +1744,6 @@ def test_encoding_unlimited_dims(self): assert_equal(ds, actual) -class GenericNetCDFDataTestAutocloseTrue(GenericNetCDFDataTest): - autoclose = True - - @requires_h5netcdf @requires_netCDF4 class H5NetCDFDataTest(BaseNetCDF4Test, TestCase): @@ -1896,30 +1871,25 @@ def test_dump_encodings_h5py(self): self.assertEqual(actual.x.encoding['compression'], 'lzf') self.assertEqual(actual.x.encoding['compression_opts'], None) - @pytest.mark.xfail(reason="won't work until we use FileManager here") - def test_roundtrip_bytes_with_fill_value(self): - super(H5NetCDFDataTest, self).test_roundtrip_bytes_with_fill_value() - - -# tests pending h5netcdf fix -@unittest.skip -class H5NetCDFDataTestAutocloseTrue(H5NetCDFDataTest): - autoclose = True - @pytest.fixture(params=['scipy', 'netcdf4', 'h5netcdf', 'pynio']) def readengine(request): return request.param -@pytest.fixture(params=[1, 100]) +@pytest.fixture(params=[1, 20]) def nfiles(request): return request.param -@pytest.fixture(params=[True, False]) -def autoclose(request): - return request.param +@pytest.fixture(params=[5, None]) +def file_cache_maxsize(request): + maxsize = request.param + if maxsize is not None: + with set_options(file_cache_maxsize=maxsize): + yield maxsize + else: + yield maxsize @pytest.fixture(params=[True, False]) @@ -1942,8 +1912,8 @@ def skip_if_not_engine(engine): pytest.importorskip(engine) -def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, - chunks): +def test_open_mfdataset_manyfiles(readengine, nfiles, parallel, chunks, + file_cache_maxsize): # skip certain combinations skip_if_not_engine(readengine) @@ -1951,9 +1921,6 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, if not has_dask and parallel: pytest.skip('parallel requires dask') - if readengine == 'h5netcdf' and autoclose: - pytest.skip('h5netcdf does not support autoclose yet') - if ON_WINDOWS: pytest.skip('Skipping on Windows') @@ -1969,7 +1936,7 @@ def test_open_mfdataset_manyfiles(readengine, nfiles, autoclose, parallel, # check that calculation on opened datasets works properly actual = open_mfdataset(tmpfiles, engine=readengine, parallel=parallel, - autoclose=autoclose, chunks=chunks) + chunks=chunks) # check that using open_mfdataset returns dask arrays for variables assert isinstance(actual['foo'].data, dask_array_type) @@ -2168,22 +2135,20 @@ def test_open_mfdataset(self): with create_tmp_file() as tmp2: original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: self.assertIsInstance(actual.foo.variable.data, da.Array) self.assertEqual(actual.foo.variable.data.chunks, ((5, 5),)) assert_identical(original, actual) - with open_mfdataset([tmp1, tmp2], chunks={'x': 3}, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2], chunks={'x': 3}) as actual: self.assertEqual(actual.foo.variable.data.chunks, ((3, 2, 3, 2),)) with raises_regex(IOError, 'no files to open'): - open_mfdataset('foo-bar-baz-*.nc', autoclose=self.autoclose) + open_mfdataset('foo-bar-baz-*.nc') with raises_regex(ValueError, 'wild-card'): - open_mfdataset('http://some/remote/uri', autoclose=self.autoclose) + open_mfdataset('http://some/remote/uri') @requires_pathlib def test_open_mfdataset_pathlib(self): @@ -2194,8 +2159,7 @@ def test_open_mfdataset_pathlib(self): tmp2 = Path(tmp2) original.isel(x=slice(5)).to_netcdf(tmp1) original.isel(x=slice(5, 10)).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(original, actual) def test_attrs_mfdataset(self): @@ -2226,8 +2190,7 @@ def preprocess(ds): return ds.assign_coords(z=0) expected = preprocess(original) - with open_mfdataset(tmp, preprocess=preprocess, - autoclose=self.autoclose) as actual: + with open_mfdataset(tmp, preprocess=preprocess) as actual: assert_identical(expected, actual) def test_save_mfdataset_roundtrip(self): @@ -2237,8 +2200,7 @@ def test_save_mfdataset_roundtrip(self): with create_tmp_file() as tmp1: with create_tmp_file() as tmp2: save_mfdataset(datasets, [tmp1, tmp2]) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) def test_save_mfdataset_invalid(self): @@ -2264,15 +2226,14 @@ def test_save_mfdataset_pathlib_roundtrip(self): tmp1 = Path(tmp1) tmp2 = Path(tmp2) save_mfdataset(datasets, [tmp1, tmp2]) - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) def test_open_and_do_math(self): original = Dataset({'foo': ('x', np.random.randn(10))}) with create_tmp_file() as tmp: original.to_netcdf(tmp) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: actual = 1.0 * ds assert_allclose(original, actual, decode_bytes=False) @@ -2282,8 +2243,7 @@ def test_open_mfdataset_concat_dim_none(self): data = Dataset({'x': 0}) data.to_netcdf(tmp1) Dataset({'x': np.nan}).to_netcdf(tmp2) - with open_mfdataset([tmp1, tmp2], concat_dim=None, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2], concat_dim=None) as actual: assert_identical(data, actual) def test_open_dataset(self): @@ -2310,8 +2270,7 @@ def test_open_single_dataset(self): {'baz': [100]}) with create_tmp_file() as tmp: original.to_netcdf(tmp) - with open_mfdataset([tmp], concat_dim=dim, - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp], concat_dim=dim) as actual: assert_identical(expected, actual) def test_dask_roundtrip(self): @@ -2330,10 +2289,10 @@ def test_deterministic_names(self): with create_tmp_file() as tmp: data = create_test_data() data.to_netcdf(tmp) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: original_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) - with open_mfdataset(tmp, autoclose=self.autoclose) as ds: + with open_mfdataset(tmp) as ds: repeat_names = dict((k, v.data.name) for k, v in ds.data_vars.items()) for var_name, dask_name in original_names.items(): @@ -2363,15 +2322,10 @@ def test_save_mfdataset_compute_false_roundtrip(self): engine=self.engine, compute=False) assert isinstance(delayed_obj, Delayed) delayed_obj.compute() - with open_mfdataset([tmp1, tmp2], - autoclose=self.autoclose) as actual: + with open_mfdataset([tmp1, tmp2]) as actual: assert_identical(actual, original) -class DaskTestAutocloseTrue(DaskTest): - autoclose = True - - @requires_scipy_or_netCDF4 @requires_pydap class PydapTest(TestCase): @@ -2483,8 +2437,7 @@ def test_write_store(self): @contextlib.contextmanager def open(self, path, **kwargs): - with open_dataset(path, engine='pynio', autoclose=self.autoclose, - **kwargs) as ds: + with open_dataset(path, engine='pynio', **kwargs) as ds: yield ds def save(self, dataset, path, **kwargs): @@ -2502,18 +2455,10 @@ def test_weakrefs(self): assert_identical(actual, expected) -class PyNioTestAutocloseTrue(PyNioTest): - autoclose = True - - @requires_pseudonetcdf class PseudoNetCDFFormatTest(TestCase): - autoclose = True - def open(self, path, **kwargs): - return open_dataset(path, engine='pseudonetcdf', - autoclose=self.autoclose, - **kwargs) + return open_dataset(path, engine='pseudonetcdf', **kwargs) @contextlib.contextmanager def roundtrip(self, data, save_kwargs={}, open_kwargs={}, @@ -2530,7 +2475,6 @@ def test_ict_format(self): """ ictfile = open_example_dataset('example.ict', engine='pseudonetcdf', - autoclose=False, backend_kwargs={'format': 'ffi1001'}) stdattr = { 'fill_value': -9999.0, @@ -2628,7 +2572,6 @@ def test_ict_format_write(self): fmtkw = {'format': 'ffi1001'} expected = open_example_dataset('example.ict', engine='pseudonetcdf', - autoclose=False, backend_kwargs=fmtkw) with self.roundtrip(expected, save_kwargs=fmtkw, open_kwargs={'backend_kwargs': fmtkw}) as actual: @@ -2643,7 +2586,6 @@ def test_uamiv_format_read(self): message='IOAPI_ISPH') camxfile = open_example_dataset('example.uamiv', engine='pseudonetcdf', - autoclose=True, backend_kwargs={'format': 'uamiv'}) data = np.arange(20, dtype='f').reshape(1, 1, 4, 5) expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, @@ -2673,7 +2615,6 @@ def test_uamiv_format_mfread(self): ['example.uamiv', 'example.uamiv'], engine='pseudonetcdf', - autoclose=True, concat_dim='TSTEP', backend_kwargs={'format': 'uamiv'}) @@ -2703,7 +2644,6 @@ def test_uamiv_format_write(self): message='IOAPI_ISPH') expected = open_example_dataset('example.uamiv', engine='pseudonetcdf', - autoclose=False, backend_kwargs=fmtkw) with self.roundtrip( expected, diff --git a/xarray/tests/test_backends_file_manager.py b/xarray/tests/test_backends_file_manager.py index 2246a1015e0..591c981cd45 100644 --- a/xarray/tests/test_backends_file_manager.py +++ b/xarray/tests/test_backends_file_manager.py @@ -7,13 +7,17 @@ import pytest -from xarray.backends.file_manager import FileManager +from xarray.backends.file_manager import CachingFileManager from xarray.backends.lru_cache import LRUCache -@pytest.fixture(params=[1, 2, 3]) +@pytest.fixture(params=[1, 2, 3, None]) def file_cache(request): - yield LRUCache(maxsize=request.param) + maxsize = request.param + if maxsize is None: + yield {} + else: + yield LRUCache(maxsize) def test_file_manager_mock_write(file_cache): @@ -21,7 +25,8 @@ def test_file_manager_mock_write(file_cache): opener = mock.Mock(spec=open, return_value=mock_file) lock = mock.MagicMock(spec=threading.Lock()) - manager = FileManager(opener, 'filename', lock=lock, cache=file_cache) + manager = CachingFileManager( + opener, 'filename', lock=lock, cache=file_cache) f = manager.acquire() f.write('contents') manager.close() @@ -36,8 +41,8 @@ def test_file_manager_mock_write(file_cache): def test_file_manager_write_consecutive(tmpdir, file_cache): path1 = str(tmpdir.join('testing1.txt')) path2 = str(tmpdir.join('testing2.txt')) - manager1 = FileManager(open, path1, mode='w', cache=file_cache) - manager2 = FileManager(open, path2, mode='w', cache=file_cache) + manager1 = CachingFileManager(open, path1, mode='w', cache=file_cache) + manager2 = CachingFileManager(open, path2, mode='w', cache=file_cache) f1a = manager1.acquire() f1a.write('foo') f1a.flush() @@ -46,7 +51,7 @@ def test_file_manager_write_consecutive(tmpdir, file_cache): f2.flush() f1b = manager1.acquire() f1b.write('baz') - assert (file_cache.maxsize > 1) == (f1a is f1b) + assert (getattr(file_cache, 'maxsize', float('inf')) > 1) == (f1a is f1b) manager1.close() manager2.close() @@ -58,7 +63,7 @@ def test_file_manager_write_consecutive(tmpdir, file_cache): def test_file_manager_write_concurrent(tmpdir, file_cache): path = str(tmpdir.join('testing.txt')) - manager = FileManager(open, path, mode='w', cache=file_cache) + manager = CachingFileManager(open, path, mode='w', cache=file_cache) f1 = manager.acquire() f2 = manager.acquire() f3 = manager.acquire() @@ -78,7 +83,7 @@ def test_file_manager_write_concurrent(tmpdir, file_cache): def test_file_manager_write_pickle(tmpdir, file_cache): path = str(tmpdir.join('testing.txt')) - manager = FileManager(open, path, mode='w', cache=file_cache) + manager = CachingFileManager(open, path, mode='w', cache=file_cache) f = manager.acquire() f.write('foo') f.flush() @@ -98,7 +103,12 @@ def test_file_manager_read(tmpdir, file_cache): with open(path, 'w') as f: f.write('foobar') - manager = FileManager(open, path, cache=file_cache) + manager = CachingFileManager(open, path, cache=file_cache) f = manager.acquire() assert f.read() == 'foobar' manager.close() + + +def test_file_manager_invalid_kwargs(): + with pytest.raises(TypeError): + CachingFileManager(open, 'dummy', mode='w', invalid=True) diff --git a/xarray/tests/test_backends_lru_cache.py b/xarray/tests/test_backends_lru_cache.py index e8302c55c9e..03eb6dcf208 100644 --- a/xarray/tests/test_backends_lru_cache.py +++ b/xarray/tests/test_backends_lru_cache.py @@ -32,6 +32,8 @@ def test_trivial(): def test_invalid(): + with pytest.raises(TypeError): + LRUCache(maxsize=None) with pytest.raises(ValueError): LRUCache(maxsize=-1) From 2adf486fc745c1b414384503c793d6140c26b1ca Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 29 Jul 2018 17:00:19 -0700 Subject: [PATCH 18/39] Fix bad import --- xarray/backends/lru_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/backends/lru_cache.py b/xarray/backends/lru_cache.py index 356a1b3edc7..321a1ca4da4 100644 --- a/xarray/backends/lru_cache.py +++ b/xarray/backends/lru_cache.py @@ -1,7 +1,7 @@ import collections import threading -from ..pycompat import move_to_end +from ..core.pycompat import move_to_end class LRUCache(collections.MutableMapping): From 76f151ca9276776bbcff50b7b609a45ae96becf0 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 31 Jul 2018 18:09:26 -0700 Subject: [PATCH 19/39] WIP on distributed --- xarray/backends/api.py | 66 +++++++++++++-------- xarray/backends/common.py | 2 - xarray/backends/file_manager.py | 7 ++- xarray/backends/h5netcdf_.py | 3 +- xarray/backends/netCDF4_.py | 21 +++++-- xarray/backends/scipy_.py | 5 ++ xarray/tests/test_distributed.py | 98 +++++++++++++++----------------- 7 files changed, 112 insertions(+), 90 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 0df98bb4c27..92ad9ee8e2b 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -53,16 +53,16 @@ def _normalize_path(path): return os.path.abspath(os.path.expanduser(path)) -def _default_lock(filename, engine): +def _default_read_lock(filename, engine): if filename.endswith('.gz'): - lock = False + lock = None else: if engine is None: engine = _get_default_engine(filename, allow_remote=True) if engine == 'netcdf4': if is_remote_uri(filename): - lock = False + lock = None else: # TODO: identify netcdf3 files and don't use the global lock # for them @@ -70,10 +70,25 @@ def _default_lock(filename, engine): elif engine in {'h5netcdf', 'pynio'}: lock = HDF5_LOCK else: - lock = False + lock = None return lock +def _get_write_lock(engine, scheduler, format, path_or_file): + """ Get the lock(s) that apply to a particular scheduler/engine/format""" + + locks = [] + + if (engine == 'h5netcdf' or engine == 'netcdf4' and + (format is None or format.startswith('NETCDF4'))): + locks.append(HDF5_LOCK) + + locks.append(_get_scheduler_lock(scheduler, path_or_file)) + + return CombinedLock(locks) + + + def _validate_dataset_names(dataset): """DataArray.name and Dataset keys must be a string or None""" def check_name(name): @@ -131,18 +146,19 @@ def _protect_dataset_variables_inplace(dataset, cache): variable.data = data -def _get_lock(engine, scheduler, format, path_or_file): +def _get_write_lock(engine, scheduler, format, path_or_file): """ Get the lock(s) that apply to a particular scheduler/engine/format""" locks = [] - if format in ['NETCDF4', None] and engine in ['h5netcdf', 'netcdf4']: + # if (engine == 'h5netcdf' or engine == 'netcdf4' and + # (format is None or format.startswith('NETCDF4'))): + # locks.append(HDF5_LOCK) + if (engine == 'h5netcdf' or engine == 'netcdf4'): locks.append(HDF5_LOCK) - locks.append(_get_scheduler_lock(scheduler, path_or_file)) - # When we have more than one lock, use the CombinedLock wrapper class - lock = CombinedLock(locks) if len(locks) > 1 else locks[0] + locks.append(_get_scheduler_lock(scheduler, path_or_file)) - return lock + return CombinedLock(locks) def _finalize_store(write, store): @@ -281,8 +297,7 @@ def maybe_decode_store(store, lock=False): mask_and_scale, decode_times, concat_characters, decode_coords, engine, chunks, drop_variables) name_prefix = 'open_dataset-%s' % token - ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token, - lock=lock) + ds2 = ds.chunk(chunks, name_prefix=name_prefix, token=token) ds2._file_obj = ds._file_obj else: ds2 = ds @@ -313,12 +328,15 @@ def maybe_decode_store(store, lock=False): else: engine = 'scipy' + if lock is None: + lock = _default_read_lock(filename_or_obj, engine) + if engine is None: engine = _get_default_engine(filename_or_obj, allow_remote=True) if engine == 'netcdf4': store = backends.NetCDF4DataStore.open( - filename_or_obj, group=group, **backend_kwargs) + filename_or_obj, group=group, lock=lock, **backend_kwargs) elif engine == 'scipy': store = backends.ScipyDataStore(filename_or_obj, **backend_kwargs) elif engine == 'pydap': @@ -326,7 +344,7 @@ def maybe_decode_store(store, lock=False): filename_or_obj, **backend_kwargs) elif engine == 'h5netcdf': store = backends.H5NetCDFStore( - filename_or_obj, group=group, **backend_kwargs) + filename_or_obj, group=group, lock=lock, **backend_kwargs) elif engine == 'pynio': store = backends.NioDataStore(filename_or_obj, **backend_kwargs) elif engine == 'pseudonetcdf': @@ -336,10 +354,8 @@ def maybe_decode_store(store, lock=False): raise ValueError('unrecognized engine for open_dataset: %r' % engine) - if lock is None: - lock = _default_lock(filename_or_obj, engine) with close_on_error(store): - return maybe_decode_store(store, lock) + return maybe_decode_store(store) else: if engine is not None and engine != 'scipy': raise ValueError('can only read file-like objects with ' @@ -593,7 +609,7 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, raise IOError('no files to open') if lock is None: - lock = _default_lock(paths[0], engine) + lock = _default_read_lock(paths[0], engine) open_kwargs = dict(engine=engine, chunks=chunks or {}, lock=lock, autoclose=autoclose, **kwargs) @@ -688,13 +704,13 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, # handle scheduler specific logic scheduler = _get_scheduler() - have_chunks = any(v.chunks for v in dataset.variables.values()) - if (have_chunks and scheduler in ['distributed', 'multiprocessing'] and - engine != 'netcdf4'): - raise NotImplementedError("Writing netCDF files with the %s backend " - "is not currently supported with dask's %s " - "scheduler" % (engine, scheduler)) - lock = _get_lock(engine, scheduler, format, path_or_file) + # have_chunks = any(v.chunks for v in dataset.variables.values()) + # if (have_chunks and scheduler in ['distributed', 'multiprocessing'] and + # engine != 'netcdf4'): + # raise NotImplementedError("Writing netCDF files with the %s backend " + # "is not currently supported with dask's %s " + # "scheduler" % (engine, scheduler)) + lock = _get_write_lock(engine, scheduler, format, path_or_file) target = path_or_file if path_or_file is not None else BytesIO() store = store_open(target, mode, format, group, writer, lock=lock) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 77fda5fc5b1..3514d69f0ae 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,6 +1,5 @@ from __future__ import absolute_import, division, print_function -import contextlib import logging import multiprocessing import threading @@ -8,7 +7,6 @@ import traceback import warnings from collections import Mapping, OrderedDict -from functools import partial import numpy as np diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 9317d85b5f6..f5cd592bdb7 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -82,8 +82,8 @@ def __init__(self, opener, *args, **keywords): be hashable. lock : duck-compatible threading.Lock, optional Lock to use when modifying the cache inside acquire() and close(). - By default, uses a new threading.Lock() object. If set, this object - should be pickleable. + Must be reentrant. By default, uses a new threading.RLock() object. + If set, this object should be pickleable. cache : MutableMapping, optional Mapping to use as a cache for open files. By default, uses xarray's global LRU file cache. Because ``cache`` typically points to a @@ -106,7 +106,7 @@ def __init__(self, opener, *args, **keywords): self._mode = mode self._kwargs = {} if kwargs is None else dict(kwargs) self._default_lock = lock is None - self._lock = threading.Lock() if self._default_lock else lock + self._lock = threading.RLock() if self._default_lock else lock self._cache = cache self._key = self._make_key() @@ -140,6 +140,7 @@ def acquire(self): if self._mode is not _DEFAULT_MODE: kwargs = kwargs.copy() kwargs['mode'] = self._mode + print("OPENING", id(self), self._opener, self._args, kwargs) file = self._opener(*self._args, **kwargs) if self._mode == 'w': # ensure file doesn't get overriden when opened again diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index be27d2814b0..54133da1171 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -71,7 +71,7 @@ class H5NetCDFStore(WritableCFDataStore): """ def __init__(self, filename, mode='r', format=None, group=None, - writer=None, autoclose=False, lock=HDF5_LOCK): + writer=None, lock=HDF5_LOCK): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') self._manager = CachingFileManager( @@ -81,6 +81,7 @@ def __init__(self, filename, mode='r', format=None, group=None, self.format = format self._filename = filename self._mode = mode + self._lock = lock super(H5NetCDFStore, self).__init__(writer, lock=lock) @property diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 6f60f36bde5..6cf6e4b47fd 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -64,7 +64,9 @@ def _getitem(self, key): getitem = operator.getitem try: - array = getitem(self.get_array(), key) + original_array = self.get_array() + with self.datastore._lock: + array = getitem(original_array, key) except IndexError: # Catch IndexError in netCDF4 and return a more informative # error message. This is most often called when an unsorted @@ -223,15 +225,21 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, class GroupWrapper(object): - def __init__(self, value): + def __init__(self, value, lock=HDF5_LOCK): self.value = value + self.lock = lock + + def sync(self): + with self.lock: + self.value.sync() def close(self): # netCDF4 only allows closing the root group - find_root(self.value).close() + with self.lock: + find_root(self.value).close() -def _open_netcdf4_group(filename, mode, group=None, **kwargs): +def _open_netcdf4_group(filename, lock, mode, group=None, **kwargs): import netCDF4 as nc4 ds = nc4.Dataset(filename, mode=mode, **kwargs) @@ -241,7 +249,7 @@ def _open_netcdf4_group(filename, mode, group=None, **kwargs): _disable_auto_decode_group(ds) - return GroupWrapper(ds) + return GroupWrapper(ds, lock) def _disable_auto_decode_variable(var): @@ -304,6 +312,7 @@ def __init__(self, manager, writer=None, lock=HDF5_LOCK): self.format = self.ds.data_model self._filename = self.ds.filepath() self.is_remote = is_remote_uri(self._filename) + self._lock = lock super(NetCDF4DataStore, self).__init__(writer, lock=lock) @classmethod @@ -324,7 +333,7 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None, if format is None: format = 'NETCDF4' manager = CachingFileManager( - _open_netcdf4_group, filename, mode=mode, + _open_netcdf4_group, filename, lock, mode=mode, kwargs=dict(group=group, clobber=clobber, diskless=diskless, persist=persist, format=format)) return cls(manager, writer=writer, lock=lock) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 628aa64516c..3281f392355 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -140,6 +140,7 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, self._manager = CachingFileManager( _open_scipy_netcdf, filename_or_obj, mode=mode, kwargs=dict(mmap=mmap, version=version)) + print('LOCKING', lock) super(ScipyDataStore, self).__init__(writer, lock=lock) @property @@ -217,3 +218,7 @@ def close(self): def __exit__(self, type, value, tb): self.close() + + def __getstate__(self): + self.sync() + return super(ScipyDataStore, self).__getstate__() diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 32035afdc57..5e09eacd69a 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -33,6 +33,11 @@ da = pytest.importorskip('dask.array') +@pytest.fixture +def tmp_netcdf_filename(tmpdir): + return str(tmpdir.join('testfile.nc')) + + ENGINES = [] if has_scipy: ENGINES.append('scipy') @@ -45,81 +50,68 @@ 'NETCDF3_64BIT_DATA', 'NETCDF4_CLASSIC', 'NETCDF4'], 'scipy': ['NETCDF3_CLASSIC', 'NETCDF3_64BIT'], 'h5netcdf': ['NETCDF4']} -TEST_FORMATS = ['NETCDF3_CLASSIC', 'NETCDF4_CLASSIC', 'NETCDF4'] + +ENGINES_AND_FORMATS = [ + ('netcdf4', 'NETCDF3_CLASSIC'), + ('netcdf4', 'NETCDF4_CLASSIC'), + ('netcdf4', 'NETCDF4'), + ('h5netcdf', 'NETCDF4'), + ('scipy', 'NETCDF3_64BIT'), +] @pytest.mark.xfail(sys.platform == 'win32', reason='https://github.com/pydata/xarray/issues/1738') -@pytest.mark.parametrize('engine', ['netcdf4']) -@pytest.mark.parametrize('autoclose', [True, False]) -@pytest.mark.parametrize('nc_format', TEST_FORMATS) -def test_dask_distributed_netcdf_roundtrip(monkeypatch, loop, - engine, autoclose, nc_format): +@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) +def test_dask_distributed_netcdf_roundtrip( + monkeypatch, loop, tmp_netcdf_filename, engine, nc_format): + + if engine not in ENGINES: + pytest.skip('engine not available') monkeypatch.setenv('HDF5_USE_FILE_LOCKING', 'FALSE') chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: - original = create_test_data().chunk(chunks) - original.to_netcdf(filename, engine=engine, format=nc_format) + original = create_test_data().chunk(chunks) + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) - with xr.open_dataset(filename, - chunks=chunks, - engine=engine, - autoclose=autoclose) as restored: - assert isinstance(restored.var1.data, da.Array) - computed = restored.compute() - assert_allclose(original, computed) + with xr.open_dataset(tmp_netcdf_filename, + chunks=chunks, engine=engine) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) @pytest.mark.xfail(sys.platform == 'win32', reason='https://github.com/pydata/xarray/issues/1738') -@pytest.mark.parametrize('engine', ENGINES) -@pytest.mark.parametrize('autoclose', [True, False]) -@pytest.mark.parametrize('nc_format', TEST_FORMATS) -def test_dask_distributed_read_netcdf_integration_test(loop, engine, autoclose, - nc_format): - - if engine == 'h5netcdf' and autoclose: - pytest.skip('h5netcdf does not support autoclose') +@pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) +def test_dask_distributed_read_netcdf_integration_test( + loop, tmp_netcdf_filename, engine, nc_format): - if nc_format not in NC_FORMATS[engine]: - pytest.skip('invalid format for engine') + if engine not in ENGINES: + pytest.skip('engine not available') chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: - - original = create_test_data() - original.to_netcdf(filename, engine=engine, format=nc_format) - - with xr.open_dataset(filename, - chunks=chunks, - engine=engine, - autoclose=autoclose) as restored: - assert isinstance(restored.var1.data, da.Array) - computed = restored.compute() - assert_allclose(original, computed) - - -@pytest.mark.parametrize('engine', ['h5netcdf', 'scipy']) -def test_dask_distributed_netcdf_integration_test_not_implemented(loop, engine): - chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} + with cluster() as (s, [a, b]): + with Client(s['address'], loop=loop) as c: - with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as filename: - with cluster() as (s, [a, b]): - with Client(s['address'], loop=loop) as c: + original = create_test_data() + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) - original = create_test_data().chunk(chunks) + with xr.open_dataset(tmp_netcdf_filename, + chunks=chunks, + engine=engine) as restored: + assert isinstance(restored.var1.data, da.Array) + computed = restored.compute() + assert_allclose(original, computed) - with raises_regex(NotImplementedError, 'distributed'): - original.to_netcdf(filename, engine=engine) @requires_zarr From 769f0790397fe661bc91c1a2e9aa73e3d23b386a Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 5 Aug 2018 19:37:09 -0700 Subject: [PATCH 20/39] More WIP --- xarray/backends/scipy_.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 3281f392355..628aa64516c 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -140,7 +140,6 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, self._manager = CachingFileManager( _open_scipy_netcdf, filename_or_obj, mode=mode, kwargs=dict(mmap=mmap, version=version)) - print('LOCKING', lock) super(ScipyDataStore, self).__init__(writer, lock=lock) @property @@ -218,7 +217,3 @@ def close(self): def __exit__(self, type, value, tb): self.close() - - def __getstate__(self): - self.sync() - return super(ScipyDataStore, self).__getstate__() From 5e67efed7b88b3e6babb683659869639ace92686 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 19 Aug 2018 14:52:49 -0700 Subject: [PATCH 21/39] Fix distributed write tests --- xarray/backends/api.py | 40 ++++++++++----------------- xarray/backends/common.py | 24 ++++++++++++++--- xarray/backends/file_manager.py | 30 +++++++++++++-------- xarray/backends/h5netcdf_.py | 21 ++++++++------- xarray/backends/netCDF4_.py | 46 ++++++++++++++++---------------- xarray/backends/scipy_.py | 10 +++---- xarray/core/dataset.py | 2 ++ xarray/tests/test_backends.py | 7 +++-- xarray/tests/test_distributed.py | 15 ++++++++--- 9 files changed, 111 insertions(+), 84 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 92ad9ee8e2b..f8dffdbbc23 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -14,7 +14,7 @@ from ..core.pycompat import basestring, path_type from ..core.utils import close_on_error, is_remote_uri from .common import ( - HDF5_LOCK, ArrayWriter, CombinedLock, _get_scheduler, _get_scheduler_lock) + HDF5_LOCK, ArrayWriter, combine_locks, _get_scheduler, _get_scheduler_lock) DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' @@ -85,7 +85,7 @@ def _get_write_lock(engine, scheduler, format, path_or_file): locks.append(_get_scheduler_lock(scheduler, path_or_file)) - return CombinedLock(locks) + return combine_locks(locks) @@ -146,21 +146,6 @@ def _protect_dataset_variables_inplace(dataset, cache): variable.data = data -def _get_write_lock(engine, scheduler, format, path_or_file): - """ Get the lock(s) that apply to a particular scheduler/engine/format""" - - locks = [] - # if (engine == 'h5netcdf' or engine == 'netcdf4' and - # (format is None or format.startswith('NETCDF4'))): - # locks.append(HDF5_LOCK) - if (engine == 'h5netcdf' or engine == 'netcdf4'): - locks.append(HDF5_LOCK) - - locks.append(_get_scheduler_lock(scheduler, path_or_file)) - - return CombinedLock(locks) - - def _finalize_store(write, store): """ Finalize this store by explicitly syncing and closing""" del write # ensure writing is done first @@ -704,16 +689,19 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, # handle scheduler specific logic scheduler = _get_scheduler() - # have_chunks = any(v.chunks for v in dataset.variables.values()) - # if (have_chunks and scheduler in ['distributed', 'multiprocessing'] and - # engine != 'netcdf4'): - # raise NotImplementedError("Writing netCDF files with the %s backend " - # "is not currently supported with dask's %s " - # "scheduler" % (engine, scheduler)) + have_chunks = any(v.chunks for v in dataset.variables.values()) + + autoclose = have_chunks and scheduler in ['distributed', 'multiprocessing'] + if autoclose and engine == 'scipy': + raise NotImplementedError("Writing netCDF files with the %s backend " + "is not currently supported with dask's %s " + "scheduler" % (engine, scheduler)) lock = _get_write_lock(engine, scheduler, format, path_or_file) target = path_or_file if path_or_file is not None else BytesIO() - store = store_open(target, mode, format, group, writer, lock=lock) + kwargs = dict(autoclose=True) if autoclose else {} + store = store_open( + target, mode, format, group, writer, lock=lock, **kwargs) if unlimited_dims is None: unlimited_dims = dataset.encoding.get('unlimited_dims', None) @@ -730,12 +718,12 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, store.close() if not compute: - import dask - return dask.delayed(_finalize_store)(store.delayed_store, store) + return store.delayed_store if not sync: return store + def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, engine=None, compute=True): """Write multiple datasets to disk as netCDF files simultaneously. diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 3514d69f0ae..bd07b84b0e6 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -164,6 +164,24 @@ def __repr__(self): return "CombinedLock(%r)" % list(self.locks) +def combine_locks(locks): + """Combine one or more locks into a CombinedLock.""" + all_locks = [] + for lock in locks: + if isinstance(lock, CombinedLock): + all_locks.extend(lock.locks) + elif lock is not None: + all_locks.append(lock) + + num_locks = len(all_locks) + if num_locks > 1: + return CombinedLock(all_locks) + elif num_locks == 1: + return all_locks[0] + else: + return None + + class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): def __array__(self, dtype=None): @@ -254,7 +272,7 @@ def __exit__(self, exception_type, exception_value, traceback): class ArrayWriter(object): - def __init__(self, lock=HDF5_LOCK): + def __init__(self, lock=None): self.sources = [] self.targets = [] self.lock = lock @@ -278,9 +296,9 @@ def sync(self, compute=True): class AbstractWritableDataStore(AbstractDataStore): - def __init__(self, writer=None, lock=HDF5_LOCK): + def __init__(self, writer=None): if writer is None: - writer = ArrayWriter(lock=lock) + writer = ArrayWriter() self.writer = writer self.delayed_store = None diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index f5cd592bdb7..5bad91ba2d6 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -82,8 +82,8 @@ def __init__(self, opener, *args, **keywords): be hashable. lock : duck-compatible threading.Lock, optional Lock to use when modifying the cache inside acquire() and close(). - Must be reentrant. By default, uses a new threading.RLock() object. - If set, this object should be pickleable. + By default, uses a new threading.Lock() object. If set, this object + should be pickleable. cache : MutableMapping, optional Mapping to use as a cache for open files. By default, uses xarray's global LRU file cache. Because ``cache`` typically points to a @@ -106,7 +106,7 @@ def __init__(self, opener, *args, **keywords): self._mode = mode self._kwargs = {} if kwargs is None else dict(kwargs) self._default_lock = lock is None - self._lock = threading.RLock() if self._default_lock else lock + self._lock = threading.Lock() if self._default_lock else lock self._cache = cache self._key = self._make_key() @@ -140,7 +140,6 @@ def acquire(self): if self._mode is not _DEFAULT_MODE: kwargs = kwargs.copy() kwargs['mode'] = self._mode - print("OPENING", id(self), self._opener, self._args, kwargs) file = self._opener(*self._args, **kwargs) if self._mode == 'w': # ensure file doesn't get overriden when opened again @@ -149,13 +148,21 @@ def acquire(self): self._cache[self._key] = file return file - def close(self): + def _close(self): + default = None + file = self._cache.pop(self._key, default) + if file is not None: + file.close() + + def close(self, needs_lock=True): """Explicitly close any associated file object (if necessary).""" - with self._lock: - default = None - file = self._cache.pop(self._key, default) - if file is not None: - file.close() + # TODO: remove needs_lock if/when we have a reentrant lock in + # dask.distributed: https://github.com/dask/dask/issues/3832 + if needs_lock: + with self._lock: + self._close() + else: + self._close() def __getstate__(self): """State for pickling.""" @@ -194,5 +201,6 @@ def __init__(self, value): def acquire(self): return self._value - def close(self): + def close(self, needs_lock=True): + del needs_lock # ignored self._value.close() diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 54133da1171..de41b241fc1 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -8,8 +8,7 @@ from ..core import indexing from ..core.pycompat import OrderedDict, bytes_type, iteritems, unicode_type from ..core.utils import FrozenOrderedDict, close_on_error -from .common import ( - HDF5_LOCK, WritableCFDataStore, find_root) +from .common import HDF5_LOCK, WritableCFDataStore from .file_manager import CachingFileManager from .netCDF4_ import ( BaseNetCDF4Array, GroupWrapper, _encode_nc4_variable, @@ -26,7 +25,8 @@ def _getitem(self, key): # h5py requires using lists for fancy indexing: # https://github.com/h5py/h5py/issues/992 key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in key) - return self.get_array()[key] + with self.datastore.lock: + return self.get_array()[key] def maybe_decode_bytes(txt): @@ -71,7 +71,7 @@ class H5NetCDFStore(WritableCFDataStore): """ def __init__(self, filename, mode='r', format=None, group=None, - writer=None, lock=HDF5_LOCK): + writer=None, lock=HDF5_LOCK, autoclose=False): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') self._manager = CachingFileManager( @@ -81,8 +81,9 @@ def __init__(self, filename, mode='r', format=None, group=None, self.format = format self._filename = filename self._mode = mode - self._lock = lock - super(H5NetCDFStore, self).__init__(writer, lock=lock) + self.lock = lock + self.autoclose = autoclose + super(H5NetCDFStore, self).__init__(writer) @property def ds(self): @@ -219,8 +220,10 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data def sync(self, compute=True): - super(H5NetCDFStore, self).sync(compute=compute) self.ds.sync() + if self.autoclose: + self.close() + super(H5NetCDFStore, self).sync(compute=compute) - def close(self): - self._manager.close() + def close(self, **kwargs): + self._manager.close(**kwargs) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 6cf6e4b47fd..a4fb7f49b28 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -14,8 +14,7 @@ PY3, OrderedDict, basestring, iteritems, suppress) from ..core.utils import FrozenOrderedDict, close_on_error, is_remote_uri from .common import ( - HDF5_LOCK, BackendArray, WritableCFDataStore, - find_root, robust_getitem) + HDF5_LOCK, BackendArray, WritableCFDataStore, find_root, robust_getitem) from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable @@ -44,8 +43,11 @@ def __init__(self, variable_name, datastore): self.dtype = dtype def __setitem__(self, key, value): - data = self.get_array() - data[key] = value + with self.datastore.lock: + data = self.get_array() + data[key] = value + if self.datastore.autoclose: + self.datastore.close(needs_lock=False) def get_array(self): return self.datastore.ds.variables[self.variable_name] @@ -64,8 +66,8 @@ def _getitem(self, key): getitem = operator.getitem try: - original_array = self.get_array() - with self.datastore._lock: + with self.datastore.lock: + original_array = self.get_array() array = getitem(original_array, key) except IndexError: # Catch IndexError in netCDF4 and return a more informative @@ -225,18 +227,13 @@ def _extract_nc4_variable_encoding(variable, raise_on_invalid=False, class GroupWrapper(object): - def __init__(self, value, lock=HDF5_LOCK): + """Wrap netCDF4.Group objects so closing them closes the root group.""" + def __init__(self, value): self.value = value - self.lock = lock - - def sync(self): - with self.lock: - self.value.sync() def close(self): # netCDF4 only allows closing the root group - with self.lock: - find_root(self.value).close() + find_root(self.value).close() def _open_netcdf4_group(filename, lock, mode, group=None, **kwargs): @@ -249,7 +246,7 @@ def _open_netcdf4_group(filename, lock, mode, group=None, **kwargs): _disable_auto_decode_group(ds) - return GroupWrapper(ds, lock) + return GroupWrapper(ds) def _disable_auto_decode_variable(var): @@ -301,7 +298,7 @@ class NetCDF4DataStore(WritableCFDataStore): This store supports NetCDF3, NetCDF4 and OpenDAP datasets. """ - def __init__(self, manager, writer=None, lock=HDF5_LOCK): + def __init__(self, manager, writer=None, lock=HDF5_LOCK, autoclose=False): import netCDF4 if isinstance(manager, netCDF4.Dataset): @@ -312,13 +309,14 @@ def __init__(self, manager, writer=None, lock=HDF5_LOCK): self.format = self.ds.data_model self._filename = self.ds.filepath() self.is_remote = is_remote_uri(self._filename) - self._lock = lock - super(NetCDF4DataStore, self).__init__(writer, lock=lock) + self.lock = lock + self.autoclose = autoclose + super(NetCDF4DataStore, self).__init__(writer) @classmethod def open(cls, filename, mode='r', format='NETCDF4', group=None, writer=None, clobber=True, diskless=False, persist=False, - lock=HDF5_LOCK): + lock=HDF5_LOCK, autoclose=False): import netCDF4 if (len(filename) == 88 and LooseVersion(netCDF4.__version__) < "1.3.1"): @@ -336,7 +334,7 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None, _open_netcdf4_group, filename, lock, mode=mode, kwargs=dict(group=group, clobber=clobber, diskless=diskless, persist=persist, format=format)) - return cls(manager, writer=writer, lock=lock) + return cls(manager, writer=writer, lock=lock, autoclose=autoclose) @property def ds(self): @@ -461,8 +459,10 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data def sync(self, compute=True): - super(NetCDF4DataStore, self).sync(compute=compute) self.ds.sync() + if self.autoclose: + self.close() + super(NetCDF4DataStore, self).sync(compute=compute) - def close(self): - self._manager.close() + def close(self, **kwargs): + self._manager.close(**kwargs) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 628aa64516c..18754a008e0 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -44,10 +44,8 @@ def get_array(self): def __getitem__(self, key): data = NumpyIndexingAdapter(self.get_array())[key] - # Copy data if the source file is mmapped. - # This makes things consistent - # with the netCDF4 library by ensuring - # we can safely read arrays even + # Copy data if the source file is mmapped. This makes things consistent + # with the netCDF4 library by ensuring we can safely read arrays even # after closing associated files. copy = self.datastore.ds.use_mmap return np.array(data, dtype=self.dtype, copy=copy) @@ -140,7 +138,7 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, self._manager = CachingFileManager( _open_scipy_netcdf, filename_or_obj, mode=mode, kwargs=dict(mmap=mmap, version=version)) - super(ScipyDataStore, self).__init__(writer, lock=lock) + super(ScipyDataStore, self).__init__(writer) @property def ds(self): @@ -210,7 +208,7 @@ def prepare_variable(self, name, variable, check_encoding=False, def sync(self, compute=True): super(ScipyDataStore, self).sync(compute=compute) - self.ds.flush() + self.ds.sync() def close(self): self._manager.close() diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 4b52178ad0e..4724422d948 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1074,7 +1074,9 @@ def dump_to_store(self, store, encoder=None, sync=True, encoding=None, store.store(variables, attrs, check_encoding, unlimited_dims=unlimited_dims) if sync: + print("COMPUTE START") store.sync(compute=compute) + print("COMPUTE FINISHED") def to_netcdf(self, path=None, mode='w', format=None, group=None, engine=None, encoding=None, unlimited_dims=None, diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 6b771a7a3c4..606c359f801 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1756,8 +1756,11 @@ def create_store(self): def test_complex(self): expected = Dataset({'x': ('y', np.ones(5) + 1j * np.ones(5))}) - with self.roundtrip(expected) as actual: - assert_equal(expected, actual) + with pytest.warns(FutureWarning): + # TODO: make it possible to write invalid netCDF files from xarray + # without a warning + with self.roundtrip(expected) as actual: + assert_equal(expected, actual) @pytest.mark.xfail(reason='https://github.com/pydata/xarray/issues/535') def test_cross_engine_read_write_netcdf4(self): diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 5e09eacd69a..02ce73d11e3 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -60,8 +60,8 @@ def tmp_netcdf_filename(tmpdir): ] -@pytest.mark.xfail(sys.platform == 'win32', - reason='https://github.com/pydata/xarray/issues/1738') +# @pytest.mark.xfail(sys.platform == 'win32', +# reason='https://github.com/pydata/xarray/issues/1738') @pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) def test_dask_distributed_netcdf_roundtrip( monkeypatch, loop, tmp_netcdf_filename, engine, nc_format): @@ -77,6 +77,13 @@ def test_dask_distributed_netcdf_roundtrip( with Client(s['address'], loop=loop) as c: original = create_test_data().chunk(chunks) + + if engine == 'scipy': + with pytest.raises(NotImplementedError): + original.to_netcdf(tmp_netcdf_filename, + engine=engine, format=nc_format) + return + original.to_netcdf(tmp_netcdf_filename, engine=engine, format=nc_format) @@ -87,8 +94,8 @@ def test_dask_distributed_netcdf_roundtrip( assert_allclose(original, computed) -@pytest.mark.xfail(sys.platform == 'win32', - reason='https://github.com/pydata/xarray/issues/1738') +# @pytest.mark.xfail(sys.platform == 'win32', +# reason='https://github.com/pydata/xarray/issues/1738') @pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) def test_dask_distributed_read_netcdf_integration_test( loop, tmp_netcdf_filename, engine, nc_format): From 1d38335414a3c33d4be6c6111331feb466a2b782 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 19 Aug 2018 15:24:37 -0700 Subject: [PATCH 22/39] Fixes --- xarray/backends/pynio_.py | 4 ++++ xarray/core/dataset.py | 2 -- xarray/tests/test_backends.py | 4 ++-- xarray/tests/test_distributed.py | 4 ---- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index d7692565b3c..a9b73132b57 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -44,6 +44,10 @@ def __init__(self, filename, mode='r'): # so turn off PyNIO's support for the same. self.ds.set_option('MaskedArrayMode', 'MaskedNever') + @property + def ds(self): + return self._manager.acquire() + def open_store_variable(self, name, var): data = indexing.LazilyOuterIndexedArray(NioArrayWrapper(name, self)) return Variable(var.dimensions, data, var.attributes) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index dfbdbd69436..37544aca372 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1074,9 +1074,7 @@ def dump_to_store(self, store, encoder=None, sync=True, encoding=None, store.store(variables, attrs, check_encoding, unlimited_dims=unlimited_dims) if sync: - print("COMPUTE START") store.sync(compute=compute) - print("COMPUTE FINISHED") def to_netcdf(self, path=None, mode='w', format=None, group=None, engine=None, encoding=None, unlimited_dims=None, diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index d2d611fd1d7..c5e39898d35 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2328,8 +2328,8 @@ def test_save_mfdataset_compute_false_roundtrip(self): original = Dataset({'foo': ('x', np.random.randn(10))}).chunk() datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))] - with create_tmp_file() as tmp1: - with create_tmp_file() as tmp2: + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp1: + with create_tmp_file(allow_cleanup_failure=ON_WINDOWS) as tmp2: delayed_obj = save_mfdataset(datasets, [tmp1, tmp2], engine=self.engine, compute=False) assert isinstance(delayed_obj, Delayed) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index 02ce73d11e3..bf8952f7436 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -60,8 +60,6 @@ def tmp_netcdf_filename(tmpdir): ] -# @pytest.mark.xfail(sys.platform == 'win32', -# reason='https://github.com/pydata/xarray/issues/1738') @pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) def test_dask_distributed_netcdf_roundtrip( monkeypatch, loop, tmp_netcdf_filename, engine, nc_format): @@ -94,8 +92,6 @@ def test_dask_distributed_netcdf_roundtrip( assert_allclose(original, computed) -# @pytest.mark.xfail(sys.platform == 'win32', -# reason='https://github.com/pydata/xarray/issues/1738') @pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) def test_dask_distributed_read_netcdf_integration_test( loop, tmp_netcdf_filename, engine, nc_format): From 6350ca6f3fb8ae3e1a276e4d41eaa53d55253cfa Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 20 Aug 2018 09:34:39 -0700 Subject: [PATCH 23/39] Minor fixup --- xarray/backends/api.py | 3 ++- xarray/backends/file_manager.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 370ca999539..53cca570107 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -718,7 +718,8 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, store.close() if not compute: - return store.delayed_store + import dask + return dask.delayed(_finalize_store)(store.delayed_store, store) if not sync: return store diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 5bad91ba2d6..2ecae8581e9 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -26,7 +26,7 @@ def acquire(self): """Acquire the file object from this manager.""" raise NotImplementedError - def close(self): + def close(self, needs_lock=True): """Close the file object associated with this manager, if needed.""" raise NotImplementedError From 4aa0df7cc6dcee1ac76268d19c8b08a8802f50a4 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 29 Aug 2018 20:15:34 -0700 Subject: [PATCH 24/39] whats new --- doc/api.rst | 3 +++ doc/whats-new.rst | 5 ++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/doc/api.rst b/doc/api.rst index 927c0aa072c..6b1f8d3be00 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -615,3 +615,6 @@ arguments for the ``from_store`` and ``dump_to_store`` Dataset methods: backends.H5NetCDFStore backends.PydapDataStore backends.ScipyDataStore + backends.FileManager + backends.CachingFileManager + backends.DummyFileManager diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 33b8336b51b..daa3b1545e8 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,7 +40,10 @@ Breaking changes necessary with ``xarray.set_options(file_cache_maxsize=...)``. - TODO: Add some note about performance benefits. + This change significantly simplies the work required to write a new backend + class, and also should improves performance for reading and writing + netCDF files when using dask. + By `Stephan Hoyer `_ Documentation ~~~~~~~~~~~~~ From 67377c7047cbec7a3d6f5e21b36aa7f96294401b Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 31 Aug 2018 08:12:50 -0700 Subject: [PATCH 25/39] More refactoring: remove state from backends entirely --- asv_bench/asv.conf.json | 1 + asv_bench/benchmarks/dataset_io.py | 41 +++++++++ xarray/backends/api.py | 142 ++++++++++++++++++++--------- xarray/backends/common.py | 58 +++++++++--- xarray/backends/h5netcdf_.py | 11 +-- xarray/backends/netCDF4_.py | 12 +-- xarray/backends/pseudonetcdf_.py | 10 +- xarray/backends/pynio_.py | 15 +-- xarray/backends/scipy_.py | 26 +++--- xarray/backends/zarr.py | 30 +++--- xarray/core/dataset.py | 25 +---- 11 files changed, 243 insertions(+), 128 deletions(-) diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index b5953436387..e3933b400e6 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -64,6 +64,7 @@ "scipy": [""], "bottleneck": ["", null], "dask": [""], + "distributed": [""], }, diff --git a/asv_bench/benchmarks/dataset_io.py b/asv_bench/benchmarks/dataset_io.py index 54ed9ac9fa2..01c6d6d2d16 100644 --- a/asv_bench/benchmarks/dataset_io.py +++ b/asv_bench/benchmarks/dataset_io.py @@ -1,5 +1,7 @@ from __future__ import absolute_import, division, print_function +import os + import numpy as np import pandas as pd @@ -14,6 +16,9 @@ pass +os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' + + class IOSingleNetCDF(object): """ A few examples that benchmark reading/writing a single netCDF file with @@ -405,3 +410,39 @@ def time_open_dataset_scipy_with_time_chunks(self): with dask.set_options(get=dask.multiprocessing.get): xr.open_mfdataset(self.filenames_list, engine='scipy', chunks=self.time_chunks) + + +def create_delayed_write(): + import dask.array as da + vals = da.random.random(300, chunks=(1,)) + ds = xr.Dataset({'vals': (['a'], vals)}) + return ds.to_netcdf('file.nc', engine='netcdf4', compute=False) + + +class IOWriteNetCDFDask(object): + timeout = 60 + repeat = 1 + number = 5 + + def setup(self): + requires_dask() + self.write = create_delayed_write() + + def time_write(self): + self.write.compute() + + +class IOWriteNetCDFDaskDistributed(object): + def setup(self): + try: + import distributed + except ImportError: + raise NotImplementedError + self.client = distributed.Client() + self.write = create_delayed_write() + + def cleanup(self): + self.client.shutdown() + + def time_write(self): + self.write.compute() diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 53cca570107..5fc7d707d88 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -14,7 +14,8 @@ from ..core.pycompat import basestring, path_type from ..core.utils import close_on_error, is_remote_uri from .common import ( - HDF5_LOCK, ArrayWriter, combine_locks, _get_scheduler, _get_scheduler_lock) + HDF5_LOCK, NETCDFC_LOCK, ArrayWriter, combine_locks, _get_scheduler, + _get_scheduler_lock) DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' @@ -55,23 +56,31 @@ def _normalize_path(path): def _default_read_lock(filename, engine): if filename.endswith('.gz'): - lock = None + locks = [] else: if engine is None: engine = _get_default_engine(filename, allow_remote=True) if engine == 'netcdf4': if is_remote_uri(filename): - lock = None + locks = [NETCDFC_LOCK] else: - # TODO: identify netcdf3 files and don't use the global lock - # for them - lock = HDF5_LOCK - elif engine in {'h5netcdf', 'pynio'}: - lock = HDF5_LOCK + # TODO: identify netcdf3 files and don't use the global HDF5 + # lock for them + locks = [NETCDFC_LOCK, HDF5_LOCK] + elif engine == 'h5netcdf': + locks = [HDF5_LOCK] + elif engine == 'pynio': + # pynio can invoke netCDF libraries internally + locks = [HDF5_LOCK, NETCDFC_LOCK, PYNIO_LOCK] + elif engine == 'psuedonetcdf': + # psuedonetcdf can invoke netCDF libraries internally + locks = [HDF5_LOCK, NETCDFC_LOCK] else: - lock = None - return lock + # no locking needed by default, e.g., for scipy or pynio + locks = [] + + return combine_locks(locks) def _get_write_lock(engine, scheduler, format, path_or_file): @@ -83,6 +92,9 @@ def _get_write_lock(engine, scheduler, format, path_or_file): (format is None or format.startswith('NETCDF4'))): locks.append(HDF5_LOCK) + if engine == 'netcdf4': + locks.append(NETCDFC_LOCK) + locks.append(_get_scheduler_lock(scheduler, path_or_file)) return combine_locks(locks) @@ -149,7 +161,6 @@ def _protect_dataset_variables_inplace(dataset, cache): def _finalize_store(write, store): """ Finalize this store by explicitly syncing and closing""" del write # ensure writing is done first - store.sync() store.close() @@ -331,10 +342,11 @@ def maybe_decode_store(store, lock=False): store = backends.H5NetCDFStore( filename_or_obj, group=group, lock=lock, **backend_kwargs) elif engine == 'pynio': - store = backends.NioDataStore(filename_or_obj, **backend_kwargs) + store = backends.NioDataStore( + filename_or_obj, lock=lock, **backend_kwargs) elif engine == 'pseudonetcdf': store = backends.PseudoNetCDFDataStore.open( - filename_or_obj, **backend_kwargs) + filename_or_obj, lock=lock, **backend_kwargs) else: raise ValueError('unrecognized engine for open_dataset: %r' % engine) @@ -645,19 +657,21 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, - engine=None, writer=None, encoding=None, unlimited_dims=None, - compute=True): + engine=None, encoding=None, unlimited_dims=None, compute=True, + multifile=False): """This function creates an appropriate datastore for writing a dataset to disk as a netCDF file See `Dataset.to_netcdf` for full API docs. - The ``writer`` argument is only for the private use of save_mfdataset. + The ``multifile`` argument is only for the private use of save_mfdataset. """ if isinstance(path_or_file, path_type): path_or_file = str(path_or_file) + if encoding is None: encoding = {} + if path_or_file is None: if engine is None: engine = 'scipy' @@ -665,6 +679,10 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, raise ValueError('invalid engine for creating bytes with ' 'to_netcdf: %r. Only the default engine ' "or engine='scipy' is supported" % engine) + if not compute: + raise NotImplementedError( + 'to_netcdf() with compute=False is not yet implemented when ' + 'returning bytes') elif isinstance(path_or_file, basestring): if engine is None: engine = _get_default_engine(path_or_file) @@ -672,6 +690,9 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, else: # file-like object engine = 'scipy' + if path_or_file is None and not compute: + raise ValueError + # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) _validate_attrs(dataset) @@ -684,9 +705,6 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, if format is not None: format = format.upper() - # if a writer is provided, store asynchronously - sync = writer is None - # handle scheduler specific logic scheduler = _get_scheduler() have_chunks = any(v.chunks for v in dataset.variables.values()) @@ -701,28 +719,65 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, target = path_or_file if path_or_file is not None else BytesIO() kwargs = dict(autoclose=True) if autoclose else {} store = store_open( - target, mode, format, group, writer, lock=lock, **kwargs) + target, mode, format, group, lock=lock, **kwargs) if unlimited_dims is None: unlimited_dims = dataset.encoding.get('unlimited_dims', None) if isinstance(unlimited_dims, basestring): unlimited_dims = [unlimited_dims] + writer = ArrayWriter() + + # TODO: figure out how to refactor this logic (here and in save_mfdataset) + # to avoid this mess of conditionals try: - dataset.dump_to_store(store, sync=sync, encoding=encoding, - unlimited_dims=unlimited_dims, compute=compute) + # TODO: allow this work (setting up the file for writing array data) + # to be parallelized with dask + dump_to_store(dataset, store, writer, encoding=encoding, + unlimited_dims=unlimited_dims) + if autoclose: + store.close() + + if multifile: + return writer, store + + writes = writer.sync(compute=compute) + if path_or_file is None: + store.sync() return target.getvalue() finally: - if sync and isinstance(path_or_file, basestring): + if not multifile and compute: store.close() if not compute: import dask - return dask.delayed(_finalize_store)(store.delayed_store, store) + return dask.delayed(_finalize_store)(writes, store) - if not sync: - return store + +def dump_to_store(dataset, store, writer=None, encoder=None, + encoding=None, unlimited_dims=None): + """Store dataset contents to a backends.*DataStore object.""" + if writer is None: + writer = ArrayWriter() + + if encoding is None: + encoding = {} + + variables, attrs = conventions.encode_dataset_coordinates(dataset) + + check_encoding = set() + for k, enc in encoding.items(): + # no need to shallow copy the variable again; that already happened + # in encode_dataset_coordinates + variables[k].encoding = enc + check_encoding.add(k) + + if encoder: + variables, attrs = encoder(variables, attrs) + + store.store(variables, attrs, check_encoding, writer, + unlimited_dims=unlimited_dims) def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, @@ -806,22 +861,22 @@ def save_mfdataset(datasets, paths, mode='w', format=None, groups=None, 'datasets, paths and groups arguments to ' 'save_mfdataset') - writer = ArrayWriter() if compute else None - stores = [to_netcdf(ds, path, mode, format, group, engine, writer, - compute=compute) - for ds, path, group in zip(datasets, paths, groups)] - - if not compute: - import dask - return dask.delayed(stores) + writers, stores = zip(*[ + to_netcdf(ds, path, mode, format, group, engine, compute=compute, + multifile=True) + for ds, path, group in zip(datasets, paths, groups)]) try: - delayed = writer.sync(compute=compute) - for store in stores: - store.sync() + writes = [w.sync(compute=compute) for w in writers] finally: - for store in stores: - store.close() + if compute: + for store in stores: + store.close() + + if not compute: + import dask + return dask.delayed([dask.delayed(_finalize_store)(w, s) + for w, s in zip(writes, stores)]) def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, @@ -842,13 +897,14 @@ def to_zarr(dataset, store=None, mode='w-', synchronizer=None, group=None, store = backends.ZarrStore.open_group(store=store, mode=mode, synchronizer=synchronizer, - group=group, writer=None) + group=group) - # I think zarr stores should always be sync'd immediately + writer = ArrayWriter() # TODO: figure out how to properly handle unlimited_dims - dataset.dump_to_store(store, sync=True, encoding=encoding, compute=compute) + dump_to_store(dataset, store, writer, encoding=encoding) + writes = writer.sync(compute=compute) if not compute: import dask - return dask.delayed(_finalize_store)(store.delayed_store, store) + return dask.delayed(_finalize_store)(writes, store) return store diff --git a/xarray/backends/common.py b/xarray/backends/common.py index bd07b84b0e6..f9f59741e87 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -15,12 +15,11 @@ from ..core.pycompat import dask_array_type, iteritems from ..core.utils import FrozenOrderedDict, NdimSizeLenMixin -# Import default lock try: from dask.utils import SerializableLock - HDF5_LOCK = SerializableLock() except ImportError: - HDF5_LOCK = threading.Lock() + # no need to worry about serializing the lock + SerializableLock = threading.Lock # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -164,6 +163,26 @@ def __repr__(self): return "CombinedLock(%r)" % list(self.locks) +class DummyLock(object): + """DummyLock provides the lock API without any actual locking.""" + + def acquire(self, *args): + pass + + def release(self, *args): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + @property + def locked(self): + return False + + def combine_locks(locks): """Combine one or more locks into a CombinedLock.""" all_locks = [] @@ -179,7 +198,14 @@ def combine_locks(locks): elif num_locks == 1: return all_locks[0] else: - return None + return DummyLock() + + +# Neither HDF5 nor the netCDF-C library are thread-safe. +HDF5_LOCK = SerializableLock() +NETCDFC_LOCK = SerializableLock() +# TODO: determine if we need a separate lock for PyNIO or not +PYNIO_LOCK = SerializableLock() class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): @@ -287,6 +313,9 @@ def add(self, source, target): def sync(self, compute=True): if self.sources: import dask.array as da + # TODO: consider wrapping targets with dask.delayed, if this makes + # for any discernable difference in perforance, e.g., + # targets = [dask.delayed(t) for t in self.targets] delayed_store = da.store(self.sources, self.targets, lock=self.lock, compute=compute, flush=True) @@ -296,11 +325,6 @@ def sync(self, compute=True): class AbstractWritableDataStore(AbstractDataStore): - def __init__(self, writer=None): - if writer is None: - writer = ArrayWriter() - self.writer = writer - self.delayed_store = None def encode(self, variables, attributes): """ @@ -342,8 +366,8 @@ def set_attribute(self, k, v): # pragma: no cover def set_variable(self, k, v): # pragma: no cover raise NotImplementedError - def sync(self, compute=True): - self.delayed_store = self.writer.sync(compute=compute) + # def sync(self, compute=True): + # self.delayed_store = self.writer.sync(compute=compute) def store_dataset(self, dataset): """ @@ -355,7 +379,7 @@ def store_dataset(self, dataset): self.store(dataset, dataset.attrs) def store(self, variables, attributes, check_encoding_set=frozenset(), - unlimited_dims=None): + writer=None, unlimited_dims=None): """ Top level method for putting data on this store, this method: - encodes variables/attributes @@ -371,16 +395,19 @@ def store(self, variables, attributes, check_encoding_set=frozenset(), check_encoding_set : list-like List of variables that should be checked for invalid encoding values + writer : ArrayWriter unlimited_dims : list-like List of dimension names that should be treated as unlimited dimensions. """ + if writer is None: + writer = ArrayWriter() variables, attributes = self.encode(variables, attributes) self.set_attributes(attributes) self.set_dimensions(variables, unlimited_dims=unlimited_dims) - self.set_variables(variables, check_encoding_set, + self.set_variables(variables, check_encoding_set, writer, unlimited_dims=unlimited_dims) def set_attributes(self, attributes): @@ -396,7 +423,7 @@ def set_attributes(self, attributes): for k, v in iteritems(attributes): self.set_attribute(k, v) - def set_variables(self, variables, check_encoding_set, + def set_variables(self, variables, check_encoding_set, writer, unlimited_dims=None): """ This provides a centralized method to set the variables on the data @@ -409,6 +436,7 @@ def set_variables(self, variables, check_encoding_set, check_encoding_set : list-like List of variables that should be checked for invalid encoding values + writer : ArrayWriter unlimited_dims : list-like List of dimension names that should be treated as unlimited dimensions. @@ -420,7 +448,7 @@ def set_variables(self, variables, check_encoding_set, target, source = self.prepare_variable( name, v, check, unlimited_dims=unlimited_dims) - self.writer.add(source, target) + writer.add(source, target) def set_dimensions(self, variables, unlimited_dims=None): """ diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index de41b241fc1..2a11ccb1086 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -71,7 +71,7 @@ class H5NetCDFStore(WritableCFDataStore): """ def __init__(self, filename, mode='r', format=None, group=None, - writer=None, lock=HDF5_LOCK, autoclose=False): + lock=HDF5_LOCK, autoclose=False): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') self._manager = CachingFileManager( @@ -83,7 +83,6 @@ def __init__(self, filename, mode='r', format=None, group=None, self._mode = mode self.lock = lock self.autoclose = autoclose - super(H5NetCDFStore, self).__init__(writer) @property def ds(self): @@ -219,11 +218,11 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data - def sync(self, compute=True): + def sync(self): self.ds.sync() - if self.autoclose: - self.close() - super(H5NetCDFStore, self).sync(compute=compute) + # if self.autoclose: + # self.close() + # super(H5NetCDFStore, self).sync(compute=compute) def close(self, **kwargs): self._manager.close(**kwargs) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index a4fb7f49b28..36a06d1d9e1 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -298,7 +298,7 @@ class NetCDF4DataStore(WritableCFDataStore): This store supports NetCDF3, NetCDF4 and OpenDAP datasets. """ - def __init__(self, manager, writer=None, lock=HDF5_LOCK, autoclose=False): + def __init__(self, manager, lock=HDF5_LOCK, autoclose=False): import netCDF4 if isinstance(manager, netCDF4.Dataset): @@ -311,11 +311,10 @@ def __init__(self, manager, writer=None, lock=HDF5_LOCK, autoclose=False): self.is_remote = is_remote_uri(self._filename) self.lock = lock self.autoclose = autoclose - super(NetCDF4DataStore, self).__init__(writer) @classmethod def open(cls, filename, mode='r', format='NETCDF4', group=None, - writer=None, clobber=True, diskless=False, persist=False, + clobber=True, diskless=False, persist=False, lock=HDF5_LOCK, autoclose=False): import netCDF4 if (len(filename) == 88 and @@ -334,7 +333,7 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None, _open_netcdf4_group, filename, lock, mode=mode, kwargs=dict(group=group, clobber=clobber, diskless=diskless, persist=persist, format=format)) - return cls(manager, writer=writer, lock=lock, autoclose=autoclose) + return cls(manager, lock=lock, autoclose=autoclose) @property def ds(self): @@ -458,11 +457,8 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, variable.data - def sync(self, compute=True): + def sync(self): self.ds.sync() - if self.autoclose: - self.close() - super(NetCDF4DataStore, self).sync(compute=compute) def close(self, **kwargs): self._manager.close(**kwargs) diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index dc2867b6ca1..e5efc1b4efb 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -9,7 +9,7 @@ from ..core.utils import (FrozenOrderedDict, Frozen) from ..core import indexing -from .common import AbstractDataStore, BackendArray +from .common import AbstractDataStore, BackendArray, DummyLock from .file_manager import CachingFileManager @@ -31,14 +31,15 @@ def __getitem__(self, key): self._getitem) def _getitem(self, key): - return self.get_array()[key] + with self.datastore.lock: + return self.get_array()[key] class PseudoNetCDFDataStore(AbstractDataStore): """Store for accessing datasets via PseudoNetCDF """ @classmethod - def open(cls, filename, **format_kwds): + def open(cls, filename, lock=None, **format_kwds): from PseudoNetCDF import pncopen keywords = dict(kwargs=format_kwds) @@ -50,8 +51,9 @@ def open(cls, filename, **format_kwds): manager = CachingFileManager(pncopen, filename, **keywords) return cls(manager) - def __init__(self, manager): + def __init__(self, manager, lock=None): self._manager = manager + self.lock = DummyLock() if lock is None else lock @property def ds(self): diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index a9b73132b57..de291f04571 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -5,7 +5,7 @@ from .. import Variable from ..core import indexing from ..core.utils import Frozen, FrozenOrderedDict -from .common import AbstractDataStore, BackendArray +from .common import PYNIO_LOCK, AbstractDataStore, BackendArray from .file_manager import CachingFileManager @@ -26,19 +26,20 @@ def __getitem__(self, key): key, self.shape, indexing.IndexingSupport.BASIC, self._getitem) def _getitem(self, key): - array = self.get_array() - if key == () and self.ndim == 0: - return array.get_value() - - return array[key] + with self.datastore.lock: + array = self.get_array() + if key == () and self.ndim == 0: + return array.get_value() + return array[key] class NioDataStore(AbstractDataStore): """Store for accessing datasets via PyNIO """ - def __init__(self, filename, mode='r'): + def __init__(self, filename, mode='r', lock=PYNIO_LOCK): import Nio + self.lock = lock self._manager = CachingFileManager(Nio.open_file, filename, mode=mode) # xarray provides its own support for FillValue, # so turn off PyNIO's support for the same. diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 18754a008e0..b5bf3834f09 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -11,7 +11,7 @@ from ..core.pycompat import OrderedDict, basestring, iteritems from ..core.utils import Frozen, FrozenOrderedDict from .common import BackendArray, WritableCFDataStore -from .file_manager import CachingFileManager +from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import ( encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name) @@ -111,7 +111,7 @@ class ScipyDataStore(WritableCFDataStore): """ def __init__(self, filename_or_obj, mode='r', format=None, group=None, - writer=None, mmap=None, lock=None): + mmap=None, lock=None): import scipy import scipy.io @@ -135,10 +135,16 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) - self._manager = CachingFileManager( - _open_scipy_netcdf, filename_or_obj, mode=mode, - kwargs=dict(mmap=mmap, version=version)) - super(ScipyDataStore, self).__init__(writer) + if isinstance(filename_or_obj, basestring): + manager = CachingFileManager( + _open_scipy_netcdf, filename_or_obj, mode=mode, + kwargs=dict(mmap=mmap, version=version)) + else: + scipy_dataset = _open_scipy_netcdf( + filename_or_obj, mode=mode, mmap=mmap, version=version) + manager = DummyFileManager(scipy_dataset) + + self._manager = manager @property def ds(self): @@ -206,12 +212,10 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, data - def sync(self, compute=True): - super(ScipyDataStore, self).sync(compute=compute) + def sync(self): + # super(ScipyDataStore, self).sync(compute=compute) self.ds.sync() def close(self): + # self.sync() self._manager.close() - - def __exit__(self, type, value, tb): - self.close() diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 47b90c8a617..5f1188f2f47 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -217,8 +217,7 @@ class ZarrStore(AbstractWritableDataStore): """ @classmethod - def open_group(cls, store, mode='r', synchronizer=None, group=None, - writer=None): + def open_group(cls, store, mode='r', synchronizer=None, group=None): import zarr min_zarr = '2.2' @@ -230,23 +229,23 @@ def open_group(cls, store, mode='r', synchronizer=None, group=None, "#installation" % min_zarr) zarr_group = zarr.open_group(store=store, mode=mode, synchronizer=synchronizer, path=group) - return cls(zarr_group, writer=writer) + return cls(zarr_group) - def __init__(self, zarr_group, writer=None): + def __init__(self, zarr_group): self.ds = zarr_group self._read_only = self.ds.read_only self._synchronizer = self.ds.synchronizer self._group = self.ds.path - if writer is None: - # by default, we should not need a lock for writing zarr because - # we do not (yet) allow overlapping chunks during write - zarr_writer = ArrayWriter(lock=False) - else: - zarr_writer = writer + # if writer is None: + # # by default, we should not need a lock for writing zarr because + # # we do not (yet) allow overlapping chunks during write + # zarr_writer = ArrayWriter(lock=False) + # else: + # zarr_writer = writer - # do we need to define attributes for all of the opener keyword args? - super(ZarrStore, self).__init__(zarr_writer) + # # do we need to define attributes for all of the opener keyword args? + # super(ZarrStore, self).__init__(zarr_writer) def open_store_variable(self, name, zarr_array): data = indexing.LazilyOuterIndexedArray(ZarrArrayWrapper(name, self)) @@ -334,8 +333,11 @@ def store(self, variables, attributes, *args, **kwargs): AbstractWritableDataStore.store(self, variables, attributes, *args, **kwargs) - def sync(self, compute=True): - self.delayed_store = self.writer.sync(compute=compute) + def sync(self): + pass + + # def sync(self, compute=True): + # self.delayed_store = self.writer.sync(compute=compute) def open_zarr(store, group=None, synchronizer=None, auto_chunk=True, diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 37544aca372..53286890ceb 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1054,27 +1054,12 @@ def reset_coords(self, names=None, drop=False, inplace=False): del obj._variables[name] return obj - def dump_to_store(self, store, encoder=None, sync=True, encoding=None, - unlimited_dims=None, compute=True): + def dump_to_store(self, store, **kwargs): """Store dataset contents to a backends.*DataStore object.""" - if encoding is None: - encoding = {} - variables, attrs = conventions.encode_dataset_coordinates(self) - - check_encoding = set() - for k, enc in encoding.items(): - # no need to shallow copy the variable again; that already happened - # in encode_dataset_coordinates - variables[k].encoding = enc - check_encoding.add(k) - - if encoder: - variables, attrs = encoder(variables, attrs) - - store.store(variables, attrs, check_encoding, - unlimited_dims=unlimited_dims) - if sync: - store.sync(compute=compute) + from ..backends.api import dump_to_store + # TODO: rename and/or cleanup this method to make it more consistent + # with to_netcdf() + return dump_to_store(self, store, **kwargs) def to_netcdf(self, path=None, mode='w', format=None, group=None, engine=None, encoding=None, unlimited_dims=None, From 2a5d1f02a6c4cd562272033e2d44a11dc7828b48 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 6 Sep 2018 09:40:57 -0700 Subject: [PATCH 26/39] Cleanup --- xarray/backends/common.py | 5 +---- xarray/backends/scipy_.py | 2 -- xarray/backends/zarr.py | 13 ------------- 3 files changed, 1 insertion(+), 19 deletions(-) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index f9f59741e87..31b2ed80fbc 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -180,7 +180,7 @@ def __exit__(self, *args): @property def locked(self): - return False + return False def combine_locks(locks): @@ -366,9 +366,6 @@ def set_attribute(self, k, v): # pragma: no cover def set_variable(self, k, v): # pragma: no cover raise NotImplementedError - # def sync(self, compute=True): - # self.delayed_store = self.writer.sync(compute=compute) - def store_dataset(self, dataset): """ in stores, variables are all variables AND coordinates diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index b5bf3834f09..c54a50c384c 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -213,9 +213,7 @@ def prepare_variable(self, name, variable, check_encoding=False, return target, data def sync(self): - # super(ScipyDataStore, self).sync(compute=compute) self.ds.sync() def close(self): - # self.sync() self._manager.close() diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 5f1188f2f47..5f19c826289 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -237,16 +237,6 @@ def __init__(self, zarr_group): self._synchronizer = self.ds.synchronizer self._group = self.ds.path - # if writer is None: - # # by default, we should not need a lock for writing zarr because - # # we do not (yet) allow overlapping chunks during write - # zarr_writer = ArrayWriter(lock=False) - # else: - # zarr_writer = writer - - # # do we need to define attributes for all of the opener keyword args? - # super(ZarrStore, self).__init__(zarr_writer) - def open_store_variable(self, name, zarr_array): data = indexing.LazilyOuterIndexedArray(ZarrArrayWrapper(name, self)) dimensions, attributes = _get_zarr_dims_and_attrs(zarr_array, @@ -336,9 +326,6 @@ def store(self, variables, attributes, *args, **kwargs): def sync(self): pass - # def sync(self, compute=True): - # self.delayed_store = self.writer.sync(compute=compute) - def open_zarr(store, group=None, synchronizer=None, auto_chunk=True, decode_cf=True, mask_and_scale=True, decode_times=True, From a6c170bd7713100e7d80c2e198b999bc440c740f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 6 Sep 2018 11:12:47 -0700 Subject: [PATCH 27/39] Fix failing in-memory datastore tests --- xarray/backends/memory.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xarray/backends/memory.py b/xarray/backends/memory.py index dcf092557b8..195d4647534 100644 --- a/xarray/backends/memory.py +++ b/xarray/backends/memory.py @@ -17,10 +17,9 @@ class InMemoryDataStore(AbstractWritableDataStore): This store exists purely for internal testing purposes. """ - def __init__(self, variables=None, attributes=None, writer=None): + def __init__(self, variables=None, attributes=None): self._variables = OrderedDict() if variables is None else variables self._attributes = OrderedDict() if attributes is None else attributes - super(InMemoryDataStore, self).__init__(writer) def get_attrs(self): return self._attributes From 009e30d6155a698812d7099e92e3028e5b2bda3f Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 6 Sep 2018 11:14:43 -0700 Subject: [PATCH 28/39] Fix inaccessible datastore --- xarray/tests/test_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index d22d8470dc6..c83568bcd59 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -63,8 +63,8 @@ def create_test_multiindex(): class InaccessibleVariableDataStore(backends.InMemoryDataStore): - def __init__(self, writer=None): - super(InaccessibleVariableDataStore, self).__init__(writer) + def __init__(self): + super(InaccessibleVariableDataStore, self).__init__() self._indexvars = set() def store(self, variables, *args, **kwargs): From 14118ead14dce5b6a407ddf3a546d6d2da49e699 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 6 Sep 2018 11:19:31 -0700 Subject: [PATCH 29/39] fix autoclose warnings --- xarray/tests/test_backends.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index bdaaaed56ac..a8d5b128ea2 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2597,7 +2597,6 @@ def test_uamiv_format_read(self): camxfile = open_example_dataset('example.uamiv', engine='pseudonetcdf', - autoclose=True, backend_kwargs={'format': 'uamiv'}) data = np.arange(20, dtype='f').reshape(1, 1, 4, 5) expected = xr.Variable(('TSTEP', 'LAY', 'ROW', 'COL'), data, @@ -2625,7 +2624,6 @@ def test_uamiv_format_mfread(self): ['example.uamiv', 'example.uamiv'], engine='pseudonetcdf', - autoclose=True, concat_dim='TSTEP', backend_kwargs={'format': 'uamiv'}) @@ -2653,7 +2651,6 @@ def test_uamiv_format_write(self): expected = open_example_dataset('example.uamiv', engine='pseudonetcdf', - autoclose=False, backend_kwargs=fmtkw) with self.roundtrip(expected, save_kwargs=fmtkw, From c778488667e69c466d35304a347f5ee8361bbe4d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 6 Sep 2018 12:32:00 -0700 Subject: [PATCH 30/39] Fix PyNIO failures --- xarray/backends/api.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 5fc7d707d88..ac975997ac0 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -14,8 +14,8 @@ from ..core.pycompat import basestring, path_type from ..core.utils import close_on_error, is_remote_uri from .common import ( - HDF5_LOCK, NETCDFC_LOCK, ArrayWriter, combine_locks, _get_scheduler, - _get_scheduler_lock) + HDF5_LOCK, NETCDFC_LOCK, PYNIO_LOCK, ArrayWriter, combine_locks, + _get_scheduler, _get_scheduler_lock) DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' @@ -55,6 +55,7 @@ def _normalize_path(path): def _default_read_lock(filename, engine): + # TODO: move this logic to the data store classes if filename.endswith('.gz'): locks = [] else: @@ -85,6 +86,7 @@ def _default_read_lock(filename, engine): def _get_write_lock(engine, scheduler, format, path_or_file): """ Get the lock(s) that apply to a particular scheduler/engine/format""" + # TODO: move this logic to the data store classes locks = [] From fe14ebfcc1e295abddbc9943f6f2ed4a0a3ffb91 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 6 Sep 2018 20:13:39 -0700 Subject: [PATCH 31/39] No longer disable HDF5 file locking We longer need to explicitly HDF5_USE_FILE_LOCKING='FALSE' because we properly close open files. --- xarray/tests/test_distributed.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index bf8952f7436..c8b99aaf37d 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -15,6 +15,7 @@ from distributed.utils_test import cluster, gen_cluster from distributed.utils_test import loop # flake8: noqa from distributed.client import futures_of +import numpy as np import xarray as xr from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file, @@ -62,13 +63,11 @@ def tmp_netcdf_filename(tmpdir): @pytest.mark.parametrize('engine,nc_format', ENGINES_AND_FORMATS) def test_dask_distributed_netcdf_roundtrip( - monkeypatch, loop, tmp_netcdf_filename, engine, nc_format): + loop, tmp_netcdf_filename, engine, nc_format): if engine not in ENGINES: pytest.skip('engine not available') - monkeypatch.setenv('HDF5_USE_FILE_LOCKING', 'FALSE') - chunks = {'dim1': 4, 'dim2': 3, 'dim3': 6} with cluster() as (s, [a, b]): From f1026ce6ac15598647dd020604e573174edd0a4a Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 6 Sep 2018 20:21:50 -0700 Subject: [PATCH 32/39] whats new and default file cache size --- doc/whats-new.rst | 11 ++++++----- xarray/core/options.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 55df39d13bf..18515ea73c9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,14 +36,15 @@ Breaking changes - Xarray's storage backends now automatically open and close files when necessary, rather than requiring opening a file with ``autoclose=True``. A global least-recently-used cache is used to store open files; the default - limit of 512 open files should suffice in most cases, but can be adjusted if + limit of 128 open files should suffice in most cases, but can be adjusted if necessary with - ``xarray.set_options(file_cache_maxsize=...)``. + ``xarray.set_options(file_cache_maxsize=...)``. The ``autoclose`` argument + to ``open_dataset`` has been deprecated, because it is no longer necessary. This change significantly simplies the work required to write a new backend - class, and also should improves performance for reading and writing - netCDF files when using dask. - By `Stephan Hoyer `_ + class, and should significantly improve performance when reading and writing + netCDF files with dask, especially when reading/writing many files or using + dask-distributed. By `Stephan Hoyer `_ Documentation ~~~~~~~~~~~~~ diff --git a/xarray/core/options.py b/xarray/core/options.py index d4b4f425666..04ea0be7172 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -11,7 +11,7 @@ DISPLAY_WIDTH: 80, ARITHMETIC_JOIN: 'inner', ENABLE_CFTIMEINDEX: False, - FILE_CACHE_MAXSIZE: 512, + FILE_CACHE_MAXSIZE: 128, CMAP_SEQUENTIAL: 'viridis', CMAP_DIVERGENT: 'RdBu_r', } @@ -56,7 +56,7 @@ class set_options(object): - ``file_cache_maxsize``: maximum number of open files to hold in xarray's global least-recently-usage cached. This should be smaller than your system's per-process file descriptor limit, e.g., ``ulimit -n`` on Linux. - Default: 512. + Default: 128. - ``cmap_sequential``: colormap to use for nondivergent data plots. Default: ``viridis``. If string, must be matplotlib built-in colormap. Can also be a Colormap object (e.g. mpl.cm.magma) From e13406be381274c64180d380152c7f509261b992 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 7 Sep 2018 08:24:38 -0700 Subject: [PATCH 33/39] Whats new tweak --- doc/whats-new.rst | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 18515ea73c9..a09a90d6744 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -39,12 +39,13 @@ Breaking changes limit of 128 open files should suffice in most cases, but can be adjusted if necessary with ``xarray.set_options(file_cache_maxsize=...)``. The ``autoclose`` argument - to ``open_dataset`` has been deprecated, because it is no longer necessary. + to ``open_dataset`` and related functions has been deprecated and is now a + no-op. - This change significantly simplies the work required to write a new backend - class, and should significantly improve performance when reading and writing - netCDF files with dask, especially when reading/writing many files or using - dask-distributed. By `Stephan Hoyer `_ + This change, along with an internal refactor of xarray's storage backends, + should significantly improve performance when reading and writing + netCDF files with Dask, especially when working with many files or using + Dask Distributed. By `Stephan Hoyer `_ Documentation ~~~~~~~~~~~~~ From 465dfaed4e430845c40586b13b5b8db66b6fe59d Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 9 Sep 2018 22:31:43 -0700 Subject: [PATCH 34/39] Refactor default lock logic to backend classes --- xarray/backends/api.py | 97 +++----------- xarray/backends/common.py | 140 +------------------- xarray/backends/file_manager.py | 2 +- xarray/backends/h5netcdf_.py | 13 +- xarray/backends/locks.py | 191 ++++++++++++++++++++++++++++ xarray/backends/netCDF4_.py | 27 +++- xarray/backends/pseudonetcdf_.py | 16 ++- xarray/backends/pynio_.py | 19 ++- xarray/backends/rasterio_.py | 10 +- xarray/backends/scipy_.py | 7 +- xarray/tests/test_backends_locks.py | 13 ++ xarray/tests/test_distributed.py | 2 +- 12 files changed, 296 insertions(+), 241 deletions(-) create mode 100644 xarray/backends/locks.py create mode 100644 xarray/tests/test_backends_locks.py diff --git a/xarray/backends/api.py b/xarray/backends/api.py index ac975997ac0..f20061b48b6 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -13,9 +13,9 @@ from ..core.combine import auto_combine from ..core.pycompat import basestring, path_type from ..core.utils import close_on_error, is_remote_uri -from .common import ( - HDF5_LOCK, NETCDFC_LOCK, PYNIO_LOCK, ArrayWriter, combine_locks, - _get_scheduler, _get_scheduler_lock) +from .common import ArrayWriter +from .locks import _get_scheduler + DATAARRAY_NAME = '__xarray_dataarray_name__' DATAARRAY_VARIABLE = '__xarray_dataarray_variable__' @@ -54,55 +54,6 @@ def _normalize_path(path): return os.path.abspath(os.path.expanduser(path)) -def _default_read_lock(filename, engine): - # TODO: move this logic to the data store classes - if filename.endswith('.gz'): - locks = [] - else: - if engine is None: - engine = _get_default_engine(filename, allow_remote=True) - - if engine == 'netcdf4': - if is_remote_uri(filename): - locks = [NETCDFC_LOCK] - else: - # TODO: identify netcdf3 files and don't use the global HDF5 - # lock for them - locks = [NETCDFC_LOCK, HDF5_LOCK] - elif engine == 'h5netcdf': - locks = [HDF5_LOCK] - elif engine == 'pynio': - # pynio can invoke netCDF libraries internally - locks = [HDF5_LOCK, NETCDFC_LOCK, PYNIO_LOCK] - elif engine == 'psuedonetcdf': - # psuedonetcdf can invoke netCDF libraries internally - locks = [HDF5_LOCK, NETCDFC_LOCK] - else: - # no locking needed by default, e.g., for scipy or pynio - locks = [] - - return combine_locks(locks) - - -def _get_write_lock(engine, scheduler, format, path_or_file): - """ Get the lock(s) that apply to a particular scheduler/engine/format""" - # TODO: move this logic to the data store classes - - locks = [] - - if (engine == 'h5netcdf' or engine == 'netcdf4' and - (format is None or format.startswith('NETCDF4'))): - locks.append(HDF5_LOCK) - - if engine == 'netcdf4': - locks.append(NETCDFC_LOCK) - - locks.append(_get_scheduler_lock(scheduler, path_or_file)) - - return combine_locks(locks) - - - def _validate_dataset_names(dataset): """DataArray.name and Dataset keys must be a string or None""" def check_name(name): @@ -219,12 +170,11 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True, If chunks is provided, it used to load the new dataset into dask arrays. ``chunks={}`` loads the dataset with dask using a single chunk for all arrays. - lock : False, True or threading.Lock, optional - If chunks is provided, this argument is passed on to - :py:func:`dask.array.from_array`. By default, a global lock is - used when reading data from netCDF files with the netcdf4 and h5netcdf - engines to avoid issues with concurrent access when using dask's - multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -326,9 +276,6 @@ def maybe_decode_store(store, lock=False): else: engine = 'scipy' - if lock is None: - lock = _default_read_lock(filename_or_obj, engine) - if engine is None: engine = _get_default_engine(filename_or_obj, allow_remote=True) @@ -416,12 +363,11 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True, chunks : int or dict, optional If chunks is provided, it used to load the new dataset into dask arrays. - lock : False, True or threading.Lock, optional - If chunks is provided, this argument is passed on to - :py:func:`dask.array.from_array`. By default, a global lock is - used when reading data from netCDF files with the netcdf4 and h5netcdf - engines to avoid issues with concurrent access when using dask's - multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. cache : bool, optional If True, cache data loaded from the underlying datastore in memory as NumPy arrays when accessed to avoid reading from the underlying data- @@ -544,11 +490,11 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, Engine to use when reading files. If not provided, the default engine is chosen based on available dependencies, with a preference for 'netcdf4'. - lock : False, True or threading.Lock, optional - This argument is passed on to :py:func:`dask.array.from_array`. By - default, a per-variable lock is used when reading data from netCDF - files with the netcdf4 and h5netcdf engines to avoid issues with - concurrent access when using dask's multithreaded backend. + lock : False or duck threading.Lock, optional + Resource lock to use when reading data from disk. Only relevant when + using dask or another form of parallelism. By default, appropriate + locks are chosen to safely read and write files with the currently + active dask scheduler. data_vars : {'minimal', 'different', 'all' or list of str}, optional These data variables will be concatenated together: * 'minimal': Only data variables in which the dimension already @@ -607,9 +553,6 @@ def open_mfdataset(paths, chunks=None, concat_dim=_CONCAT_DIM_DEFAULT, if not paths: raise IOError('no files to open') - if lock is None: - lock = _default_read_lock(paths[0], engine) - open_kwargs = dict(engine=engine, chunks=chunks or {}, lock=lock, autoclose=autoclose, **kwargs) @@ -716,12 +659,10 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, raise NotImplementedError("Writing netCDF files with the %s backend " "is not currently supported with dask's %s " "scheduler" % (engine, scheduler)) - lock = _get_write_lock(engine, scheduler, format, path_or_file) target = path_or_file if path_or_file is not None else BytesIO() kwargs = dict(autoclose=True) if autoclose else {} - store = store_open( - target, mode, format, group, lock=lock, **kwargs) + store = store_open(target, mode, format, group, **kwargs) if unlimited_dims is None: unlimited_dims = dataset.encoding.get('unlimited_dims', None) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 31b2ed80fbc..555eb2bfe8b 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -6,6 +6,7 @@ import time import traceback import warnings +import weakref from collections import Mapping, OrderedDict import numpy as np @@ -15,12 +16,6 @@ from ..core.pycompat import dask_array_type, iteritems from ..core.utils import FrozenOrderedDict, NdimSizeLenMixin -try: - from dask.utils import SerializableLock -except ImportError: - # no need to worry about serializing the lock - SerializableLock = threading.Lock - # Create a logger object, but don't add any handlers. Leave that to user code. logger = logging.getLogger(__name__) @@ -28,61 +23,6 @@ NONE_VAR_NAME = '__values__' -def _get_scheduler(get=None, collection=None): - """ Determine the dask scheduler that is being used. - - None is returned if not dask scheduler is active. - - See also - -------- - dask.base.get_scheduler - """ - try: - # dask 0.18.1 and later - from dask.base import get_scheduler - actual_get = get_scheduler(get, collection) - except ImportError: - try: - from dask.utils import effective_get - actual_get = effective_get(get, collection) - except ImportError: - return None - - try: - from dask.distributed import Client - if isinstance(actual_get.__self__, Client): - return 'distributed' - except (ImportError, AttributeError): - try: - import dask.multiprocessing - if actual_get == dask.multiprocessing.get: - return 'multiprocessing' - else: - return 'threaded' - except ImportError: - return 'threaded' - - -def _get_scheduler_lock(scheduler, path_or_file=None): - """ Get the appropriate lock for a certain situation based onthe dask - scheduler used. - - See Also - -------- - dask.utils.get_scheduler_lock - """ - - if scheduler == 'distributed': - from dask.distributed import Lock - return Lock(path_or_file) - elif scheduler == 'multiprocessing': - return multiprocessing.Lock() - elif scheduler == 'threaded': - from dask.utils import SerializableLock - return SerializableLock() - else: - return threading.Lock() - def _encode_variable_name(name): if name is None: @@ -130,84 +70,6 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, time.sleep(1e-3 * next_delay) -class CombinedLock(object): - """A combination of multiple locks. - - Like a locked door, a CombinedLock is locked if any of its constituent - locks are locked. - """ - - def __init__(self, locks): - self.locks = tuple(set(locks)) # remove duplicates - - def acquire(self, *args): - return all(lock.acquire(*args) for lock in self.locks) - - def release(self, *args): - for lock in self.locks: - lock.release(*args) - - def __enter__(self): - for lock in self.locks: - lock.__enter__() - - def __exit__(self, *args): - for lock in self.locks: - lock.__exit__(*args) - - @property - def locked(self): - return any(lock.locked for lock in self.locks) - - def __repr__(self): - return "CombinedLock(%r)" % list(self.locks) - - -class DummyLock(object): - """DummyLock provides the lock API without any actual locking.""" - - def acquire(self, *args): - pass - - def release(self, *args): - pass - - def __enter__(self): - pass - - def __exit__(self, *args): - pass - - @property - def locked(self): - return False - - -def combine_locks(locks): - """Combine one or more locks into a CombinedLock.""" - all_locks = [] - for lock in locks: - if isinstance(lock, CombinedLock): - all_locks.extend(lock.locks) - elif lock is not None: - all_locks.append(lock) - - num_locks = len(all_locks) - if num_locks > 1: - return CombinedLock(all_locks) - elif num_locks == 1: - return all_locks[0] - else: - return DummyLock() - - -# Neither HDF5 nor the netCDF-C library are thread-safe. -HDF5_LOCK = SerializableLock() -NETCDFC_LOCK = SerializableLock() -# TODO: determine if we need a separate lock for PyNIO or not -PYNIO_LOCK = SerializableLock() - - class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed): def __array__(self, dtype=None): diff --git a/xarray/backends/file_manager.py b/xarray/backends/file_manager.py index 2ecae8581e9..a93285370b2 100644 --- a/xarray/backends/file_manager.py +++ b/xarray/backends/file_manager.py @@ -105,7 +105,7 @@ def __init__(self, opener, *args, **keywords): self._args = args self._mode = mode self._kwargs = {} if kwargs is None else dict(kwargs) - self._default_lock = lock is None + self._default_lock = lock is None or lock is False self._lock = threading.Lock() if self._default_lock else lock self._cache = cache self._key = self._make_key() diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 2a11ccb1086..faa7d94405b 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -8,8 +8,9 @@ from ..core import indexing from ..core.pycompat import OrderedDict, bytes_type, iteritems, unicode_type from ..core.utils import FrozenOrderedDict, close_on_error -from .common import HDF5_LOCK, WritableCFDataStore +from .common import WritableCFDataStore from .file_manager import CachingFileManager +from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_resource_lock from .netCDF4_ import ( BaseNetCDF4Array, GroupWrapper, _encode_nc4_variable, _extract_nc4_variable_encoding, _get_datatype, _nc4_require_group) @@ -71,17 +72,23 @@ class H5NetCDFStore(WritableCFDataStore): """ def __init__(self, filename, mode='r', format=None, group=None, - lock=HDF5_LOCK, autoclose=False): + lock=None, autoclose=False): if format not in [None, 'NETCDF4']: raise ValueError('invalid format for h5netcdf backend') self._manager = CachingFileManager( _open_h5netcdf_group, filename, mode=mode, kwargs=dict(group=group)) + if lock is None: + if mode == 'r': + lock = HDF5_LOCK + else: + lock = combine_locks([HDF5_LOCK, get_resource_lock(filename)]) + self.format = format self._filename = filename self._mode = mode - self.lock = lock + self.lock = ensure_lock(lock) self.autoclose = autoclose @property diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py new file mode 100644 index 00000000000..a35680c13fe --- /dev/null +++ b/xarray/backends/locks.py @@ -0,0 +1,191 @@ +import multiprocessing +import threading +import weakref + +try: + from dask.utils import SerializableLock +except ImportError: + # no need to worry about serializing the lock + SerializableLock = threading.Lock + + +# Locks used by multiple backends. +# Neither HDF5 nor the netCDF-C library are thread-safe. +HDF5_LOCK = SerializableLock() +NETCDFC_LOCK = SerializableLock() + + +_FILE_LOCKS = weakref.WeakValueDictionary() + + +def _get_threaded_lock(key): + try: + lock = _FILE_LOCKS[key] + except KeyError: + lock = _FILE_LOCKS[key] = threading.Lock() + return lock + + +def _get_multiprocessing_lock(key): + # TODO: make use of the key -- maybe use locket.py? + # https://github.com/mwilliamson/locket.py + del key # unused + return multiprocessing.Lock() + + +def _get_distributed_lock(key): + from dask.distributed import Lock + return Lock(key) + + +_LOCK_MAKERS = { + None: _get_threaded_lock, + 'threaded': _get_threaded_lock, + 'multiprocessing': _get_multiprocessing_lock, + 'distributed': _get_distributed_lock, +} + + +def _get_lock_maker(scheduler=None): + """Returns an appropriate function for creating resource locks. + + Parameters + ---------- + scheduler : str or None + Dask scheduler being used. + + See Also + -------- + dask.utils.get_scheduler_lock + """ + return _LOCK_MAKERS[scheduler] + + +def _get_scheduler(get=None, collection=None): + """Determine the dask scheduler that is being used. + + None is returned if no dask scheduler is active. + + See also + -------- + dask.base.get_scheduler + """ + try: + # dask 0.18.1 and later + from dask.base import get_scheduler + actual_get = get_scheduler(get, collection) + except ImportError: + try: + from dask.utils import effective_get + actual_get = effective_get(get, collection) + except ImportError: + return None + + try: + from dask.distributed import Client + if isinstance(actual_get.__self__, Client): + return 'distributed' + except (ImportError, AttributeError): + try: + import dask.multiprocessing + if actual_get == dask.multiprocessing.get: + return 'multiprocessing' + else: + return 'threaded' + except ImportError: + return 'threaded' + + +def get_resource_lock(key): + """Get a scheduler appropriate lock for writing to the given resource. + + Parameters + ---------- + key : str + Name of the resource for which to acquire a lock. Typically a filename. + + Returns + ------- + Lock object that can be used like a threading.Lock object. + """ + scheduler = _get_scheduler() + lock_maker = _get_lock_maker(scheduler) + return lock_maker(key) + + +class CombinedLock(object): + """A combination of multiple locks. + + Like a locked door, a CombinedLock is locked if any of its constituent + locks are locked. + """ + + def __init__(self, locks): + self.locks = tuple(set(locks)) # remove duplicates + + def acquire(self, *args): + return all(lock.acquire(*args) for lock in self.locks) + + def release(self, *args): + for lock in self.locks: + lock.release(*args) + + def __enter__(self): + for lock in self.locks: + lock.__enter__() + + def __exit__(self, *args): + for lock in self.locks: + lock.__exit__(*args) + + @property + def locked(self): + return any(lock.locked for lock in self.locks) + + def __repr__(self): + return "CombinedLock(%r)" % list(self.locks) + + +class DummyLock(object): + """DummyLock provides the lock API without any actual locking.""" + + def acquire(self, *args): + pass + + def release(self, *args): + pass + + def __enter__(self): + pass + + def __exit__(self, *args): + pass + + @property + def locked(self): + return False + + +def combine_locks(locks): + """Combine a sequence of locks into a single lock.""" + all_locks = [] + for lock in locks: + if isinstance(lock, CombinedLock): + all_locks.extend(lock.locks) + elif lock is not None: + all_locks.append(lock) + + num_locks = len(all_locks) + if num_locks > 1: + return CombinedLock(all_locks) + elif num_locks == 1: + return all_locks[0] + else: + return DummyLock() + + +def ensure_lock(lock): + """Ensure that the given object is a lock.""" + if lock is None or lock is False: + return DummyLock() + return lock diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index 36a06d1d9e1..daf2538323f 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -14,7 +14,9 @@ PY3, OrderedDict, basestring, iteritems, suppress) from ..core.utils import FrozenOrderedDict, close_on_error, is_remote_uri from .common import ( - HDF5_LOCK, BackendArray, WritableCFDataStore, find_root, robust_getitem) + BackendArray, WritableCFDataStore, find_root, robust_getitem) +from .locks import (NETCDFC_LOCK, HDF5_LOCK, + combine_locks, ensure_lock, get_resource_lock) from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable @@ -26,6 +28,9 @@ '|': 'native'} +NETCDF4_PYTHON_LOCK = combine_locks([NETCDFC_LOCK, HDF5_LOCK]) + + class BaseNetCDF4Array(BackendArray): def __init__(self, variable_name, datastore): self.datastore = datastore @@ -298,7 +303,7 @@ class NetCDF4DataStore(WritableCFDataStore): This store supports NetCDF3, NetCDF4 and OpenDAP datasets. """ - def __init__(self, manager, lock=HDF5_LOCK, autoclose=False): + def __init__(self, manager, lock=NETCDF4_PYTHON_LOCK, autoclose=False): import netCDF4 if isinstance(manager, netCDF4.Dataset): @@ -309,13 +314,13 @@ def __init__(self, manager, lock=HDF5_LOCK, autoclose=False): self.format = self.ds.data_model self._filename = self.ds.filepath() self.is_remote = is_remote_uri(self._filename) - self.lock = lock + self.lock = ensure_lock(lock) self.autoclose = autoclose @classmethod def open(cls, filename, mode='r', format='NETCDF4', group=None, clobber=True, diskless=False, persist=False, - lock=HDF5_LOCK, autoclose=False): + lock=None, lock_maker=None, autoclose=False): import netCDF4 if (len(filename) == 88 and LooseVersion(netCDF4.__version__) < "1.3.1"): @@ -329,6 +334,20 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None, 'https://github.com/pydata/xarray/issues/1745') if format is None: format = 'NETCDF4' + + if lock is None: + if mode == 'r': + if is_remote_uri(filename): + lock = NETCDFC_LOCK + else: + lock = NETCDF4_PYTHON_LOCK + else: + if format is None or format.startswith('NETCDF4'): + base_lock = NETCDF4_PYTHON_LOCK + else: + base_lock = NETCDFC_LOCK + lock = combine_locks([base_lock, get_resource_lock(filename)]) + manager = CachingFileManager( _open_netcdf4_group, filename, lock, mode=mode, kwargs=dict(group=group, clobber=clobber, diskless=diskless, diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index e5efc1b4efb..9ca3549ad49 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -9,8 +9,13 @@ from ..core.utils import (FrozenOrderedDict, Frozen) from ..core import indexing -from .common import AbstractDataStore, BackendArray, DummyLock +from .common import AbstractDataStore, BackendArray from .file_manager import CachingFileManager +from .locks import NETCDFC_LOCK, HDF5_LOCK, combine_locks, ensure_lock + + +# psuedonetcdf can invoke netCDF libraries internally +PNETCDF_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK]) class PncArrayWrapper(BackendArray): @@ -48,12 +53,15 @@ def open(cls, filename, lock=None, **format_kwds): if mode is not None: keywords['mode'] = mode - manager = CachingFileManager(pncopen, filename, **keywords) - return cls(manager) + if lock is None: + lock = PNETCDF_LOCK + + manager = CachingFileManager(pncopen, filename, lock=lock, **keywords) + return cls(manager, lock) def __init__(self, manager, lock=None): self._manager = manager - self.lock = DummyLock() if lock is None else lock + self.lock = ensure_lock(lock) @property def ds(self): diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index de291f04571..107840046fe 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -5,8 +5,16 @@ from .. import Variable from ..core import indexing from ..core.utils import Frozen, FrozenOrderedDict -from .common import PYNIO_LOCK, AbstractDataStore, BackendArray +from .common import AbstractDataStore, BackendArray from .file_manager import CachingFileManager +from .locks import ( + HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock, SerializableLock) + + +# PyNIO can invoke netCDF libraries internally +# Add a dedicated lock just in case NCL as well isn't thread-safe. +NCL_LOCK = SerializableLock() +PYNIO_LOCK = combine_locks([HDF5_LOCK, NETCDFC_LOCK, NCL_LOCK]) class NioArrayWrapper(BackendArray): @@ -37,10 +45,13 @@ class NioDataStore(AbstractDataStore): """Store for accessing datasets via PyNIO """ - def __init__(self, filename, mode='r', lock=PYNIO_LOCK): + def __init__(self, filename, mode='r', lock=None): import Nio - self.lock = lock - self._manager = CachingFileManager(Nio.open_file, filename, mode=mode) + if lock is None: + lock = PYNIO_LOCK + self.lock = ensure_lock(lock) + self._manager = CachingFileManager( + Nio.open_file, filename, lock=lock, mode=mode) # xarray provides its own support for FillValue, # so turn off PyNIO's support for the same. self.ds.set_option('MaskedArrayMode', 'MaskedNever') diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index 3622aab86a2..840ec96a87d 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -10,13 +10,11 @@ from ..core.utils import is_scalar from .common import BackendArray from .file_manager import CachingFileManager +from .locks import SerializableLock -try: - from dask.utils import SerializableLock as Lock -except ImportError: - from threading import Lock -RASTERIO_LOCK = Lock() +# TODO: should this be GDAL_LOCK instead? +RASTERIO_LOCK = SerializableLock() _ERROR_MSG = ('The kind of indexing operation you are trying to do is not ' 'valid on rasterio files. Try to load your data with ds.load()' @@ -294,7 +292,7 @@ def open_rasterio(filename, parse_coordinates=None, chunks=None, cache=None, # this lets you write arrays loaded with rasterio data = indexing.CopyOnWriteArray(data) - if cache and (chunks is None): + if cache and chunks is None: data = indexing.MemoryCachedArray(data) result = DataArray(data=data, dims=('band', 'y', 'x'), diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index c54a50c384c..f542a4638a7 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -11,6 +11,7 @@ from ..core.pycompat import OrderedDict, basestring, iteritems from ..core.utils import Frozen, FrozenOrderedDict from .common import BackendArray, WritableCFDataStore +from .locks import get_resource_lock from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import ( encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name) @@ -135,9 +136,13 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) + if (lock is None and mode != 'r' + and isinstance(filename_or_obj, basestring)): + lock = get_resource_lock(filename_or_obj) + if isinstance(filename_or_obj, basestring): manager = CachingFileManager( - _open_scipy_netcdf, filename_or_obj, mode=mode, + _open_scipy_netcdf, filename_or_obj, mode=mode, lock=lock, kwargs=dict(mmap=mmap, version=version)) else: scipy_dataset = _open_scipy_netcdf( diff --git a/xarray/tests/test_backends_locks.py b/xarray/tests/test_backends_locks.py new file mode 100644 index 00000000000..5f83321802e --- /dev/null +++ b/xarray/tests/test_backends_locks.py @@ -0,0 +1,13 @@ +import threading + +from xarray.backends import locks + + +def test_threaded_lock(): + lock1 = locks._get_threaded_lock('foo') + assert isinstance(lock1, type(threading.Lock())) + lock2 = locks._get_threaded_lock('foo') + assert lock1 is lock2 + + lock3 = locks._get_threaded_lock('bar') + assert lock1 is not lock3 diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index c8b99aaf37d..7c77a62d3c9 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -18,10 +18,10 @@ import numpy as np import xarray as xr +from xarray.backends.locks import HDF5_LOCK, CombinedLock from xarray.tests.test_backends import (ON_WINDOWS, create_tmp_file, create_tmp_geotiff) from xarray.tests.test_dataset import create_test_data -from xarray.backends.common import HDF5_LOCK, CombinedLock from . import ( assert_allclose, has_h5netcdf, has_netCDF4, requires_rasterio, has_scipy, From 55d35c85ed38bad153c57d1e1d8c14236d04b8e4 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 10 Sep 2018 08:22:29 -0700 Subject: [PATCH 35/39] Rename get_resource_lock -> get_write_lock --- xarray/backends/common.py | 4 ---- xarray/backends/h5netcdf_.py | 4 ++-- xarray/backends/locks.py | 2 +- xarray/backends/netCDF4_.py | 4 ++-- xarray/backends/scipy_.py | 8 ++++---- 5 files changed, 9 insertions(+), 13 deletions(-) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index 555eb2bfe8b..405d989f4af 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -1,12 +1,9 @@ from __future__ import absolute_import, division, print_function import logging -import multiprocessing -import threading import time import traceback import warnings -import weakref from collections import Mapping, OrderedDict import numpy as np @@ -23,7 +20,6 @@ NONE_VAR_NAME = '__values__' - def _encode_variable_name(name): if name is None: name = NONE_VAR_NAME diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index faa7d94405b..a8f29b416b7 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -10,7 +10,7 @@ from ..core.utils import FrozenOrderedDict, close_on_error from .common import WritableCFDataStore from .file_manager import CachingFileManager -from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_resource_lock +from .locks import HDF5_LOCK, combine_locks, ensure_lock, get_write_lock from .netCDF4_ import ( BaseNetCDF4Array, GroupWrapper, _encode_nc4_variable, _extract_nc4_variable_encoding, _get_datatype, _nc4_require_group) @@ -83,7 +83,7 @@ def __init__(self, filename, mode='r', format=None, group=None, if mode == 'r': lock = HDF5_LOCK else: - lock = combine_locks([HDF5_LOCK, get_resource_lock(filename)]) + lock = combine_locks([HDF5_LOCK, get_write_lock(filename)]) self.format = format self._filename = filename diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index a35680c13fe..f633280ef1d 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -96,7 +96,7 @@ def _get_scheduler(get=None, collection=None): return 'threaded' -def get_resource_lock(key): +def get_write_lock(key): """Get a scheduler appropriate lock for writing to the given resource. Parameters diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index daf2538323f..c393fda9255 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -16,7 +16,7 @@ from .common import ( BackendArray, WritableCFDataStore, find_root, robust_getitem) from .locks import (NETCDFC_LOCK, HDF5_LOCK, - combine_locks, ensure_lock, get_resource_lock) + combine_locks, ensure_lock, get_write_lock) from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import encode_nc3_attr_value, encode_nc3_variable @@ -346,7 +346,7 @@ def open(cls, filename, mode='r', format='NETCDF4', group=None, base_lock = NETCDF4_PYTHON_LOCK else: base_lock = NETCDFC_LOCK - lock = combine_locks([base_lock, get_resource_lock(filename)]) + lock = combine_locks([base_lock, get_write_lock(filename)]) manager = CachingFileManager( _open_netcdf4_group, filename, lock, mode=mode, diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index f542a4638a7..b009342efb6 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -11,7 +11,7 @@ from ..core.pycompat import OrderedDict, basestring, iteritems from ..core.utils import Frozen, FrozenOrderedDict from .common import BackendArray, WritableCFDataStore -from .locks import get_resource_lock +from .locks import get_write_lock from .file_manager import CachingFileManager, DummyFileManager from .netcdf3 import ( encode_nc3_attr_value, encode_nc3_variable, is_valid_nc3_name) @@ -136,9 +136,9 @@ def __init__(self, filename_or_obj, mode='r', format=None, group=None, raise ValueError('invalid format for scipy.io.netcdf backend: %r' % format) - if (lock is None and mode != 'r' - and isinstance(filename_or_obj, basestring)): - lock = get_resource_lock(filename_or_obj) + if (lock is None and mode != 'r' and + isinstance(filename_or_obj, basestring)): + lock = get_write_lock(filename_or_obj) if isinstance(filename_or_obj, basestring): manager = CachingFileManager( From c8fbadcc1f9074fbf1b385a4df9dbf65e58abc00 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 10 Sep 2018 11:44:22 -0700 Subject: [PATCH 36/39] Don't acquire unnecessary locks in __getitem__ --- xarray/backends/h5netcdf_.py | 3 ++- xarray/backends/netCDF4_.py | 3 ++- xarray/backends/pseudonetcdf_.py | 3 ++- xarray/backends/pynio_.py | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index a8f29b416b7..59cd4e84793 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -26,8 +26,9 @@ def _getitem(self, key): # h5py requires using lists for fancy indexing: # https://github.com/h5py/h5py/issues/992 key = tuple(list(k) if isinstance(k, np.ndarray) else k for k in key) + array = self.get_array() with self.datastore.lock: - return self.get_array()[key] + return array[key] def maybe_decode_bytes(txt): diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index c393fda9255..2ba63d64333 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -70,9 +70,10 @@ def _getitem(self, key): else: getitem = operator.getitem + original_array = self.get_array() + try: with self.datastore.lock: - original_array = self.get_array() array = getitem(original_array, key) except IndexError: # Catch IndexError in netCDF4 and return a more informative diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index 9ca3549ad49..026fea21610 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -36,8 +36,9 @@ def __getitem__(self, key): self._getitem) def _getitem(self, key): + array = self.get_array() with self.datastore.lock: - return self.get_array()[key] + return array[key] class PseudoNetCDFDataStore(AbstractDataStore): diff --git a/xarray/backends/pynio_.py b/xarray/backends/pynio_.py index 107840046fe..574fff744e3 100644 --- a/xarray/backends/pynio_.py +++ b/xarray/backends/pynio_.py @@ -34,8 +34,8 @@ def __getitem__(self, key): key, self.shape, indexing.IndexingSupport.BASIC, self._getitem) def _getitem(self, key): + array = self.get_array() with self.datastore.lock: - array = self.get_array() if key == () and self.ndim == 0: return array.get_value() return array[key] From 36f1156253f15e5932c96f970fb29d013022c977 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 8 Oct 2018 17:29:06 -0700 Subject: [PATCH 37/39] Fix bad merge --- xarray/backends/pseudonetcdf_.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index 017aff4714c..dc70a8588ad 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -4,6 +4,7 @@ from .. import Variable from ..core import indexing +from ..core.pycompat import Frozen, OrderedDict from .common import AbstractDataStore, BackendArray from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock @@ -71,9 +72,8 @@ def open_store_variable(self, name, var): return Variable(var.dimensions, data, attrs) def get_variables(self): - return - ((k, self.open_store_variable(k, v)) - for k, v in self.ds.variables.items()) + return ((k, self.open_store_variable(k, v)) + for k, v in self.ds.variables.items()) def get_attrs(self): return Frozen(dict([(k, getattr(self.ds, k)) From c6f43ddbd70910f63ff24608f706f5f71d6c6888 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 8 Oct 2018 17:58:37 -0700 Subject: [PATCH 38/39] Fix import --- xarray/backends/pseudonetcdf_.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xarray/backends/pseudonetcdf_.py b/xarray/backends/pseudonetcdf_.py index dc70a8588ad..e4691d1f7e1 100644 --- a/xarray/backends/pseudonetcdf_.py +++ b/xarray/backends/pseudonetcdf_.py @@ -4,7 +4,8 @@ from .. import Variable from ..core import indexing -from ..core.pycompat import Frozen, OrderedDict +from ..core.pycompat import OrderedDict +from ..core.utils import Frozen from .common import AbstractDataStore, BackendArray from .file_manager import CachingFileManager from .locks import HDF5_LOCK, NETCDFC_LOCK, combine_locks, ensure_lock From 8916bc7540b6ef2a8e3b2ea513470d1918e42ad7 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 8 Oct 2018 19:30:35 -0700 Subject: [PATCH 39/39] Remove unreachable code --- xarray/backends/api.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index f20061b48b6..65112527045 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -635,9 +635,6 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None, else: # file-like object engine = 'scipy' - if path_or_file is None and not compute: - raise ValueError - # validate Dataset keys, DataArray names, and attr keys/values _validate_dataset_names(dataset) _validate_attrs(dataset)