diff --git a/sup3r/postprocessing/file_handling.py b/sup3r/postprocessing/file_handling.py index cebb87176..c2c241b4c 100644 --- a/sup3r/postprocessing/file_handling.py +++ b/sup3r/postprocessing/file_handling.py @@ -92,6 +92,11 @@ 'units': 's-2', 'dtype': 'int16', 'chunks': (2000, 500)}, + 'pr': {'scale_factor': 1, + 'units': 'kg m-2 s-1', + 'dtype': 'float32', + 'min': 0, + 'chunks': (2000, 250)}, } @@ -293,8 +298,14 @@ def enforce_limits(features, data): maxs = [] mins = [] for fn in features: - max = H5_ATTRS[Feature.get_basename(fn)].get('max', np.inf) - min = H5_ATTRS[Feature.get_basename(fn)].get('min', -np.inf) + dset_name = Feature.get_basename(fn) + if dset_name not in H5_ATTRS: + msg = ('Could not find "{dset_name}" in H5_ATTRS dict!') + logger.error(msg) + raise KeyError(msg) + + max = H5_ATTRS[dset_name].get('max', np.inf) + min = H5_ATTRS[dset_name].get('min', -np.inf) logger.debug(f'Enforcing range of ({max}, {min} for "{fn}")') maxs.append(max) mins.append(min) diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index b0700c017..cf965199c 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -148,6 +148,9 @@ def __init__(self, self.shape = shape self.res_kwargs = res_kwargs + # for subclasses + self._source_handler = None + if input_handler is None: in_type = get_source_type(file_paths) if in_type == 'nc': @@ -566,23 +569,20 @@ def get_cache_file(self, feature, s_enhance, t_enhance, t_agg_factor): class TopoExtractNC(TopoExtractH5): """TopoExtract for netCDF files""" - def __init__(self, *args, **kwargs): - """Parameters - ---------- - args : list - Same positional arguments as TopoExtract - kwargs : dict - Same keyword arguments as TopoExtract - """ - super().__init__(*args, **kwargs) - logger.info('Getting topography for full domain from ' - f'{self._exo_source}') - self.source_handler = DataHandlerNC( - self._exo_source, - features=['topography'], - worker_kwargs={'ti_workers': self.ti_workers}, - val_split=0.0, - ) + @property + def source_handler(self): + """Get the DataHandlerNC object that handles the .nc source topography + data file.""" + if self._source_handler is None: + logger.info('Getting topography for full domain from ' + f'{self._exo_source}') + self._source_handler = DataHandlerNC( + self._exo_source, + features=['topography'], + worker_kwargs={'ti_workers': self.ti_workers}, + val_split=0.0, + ) + return self._source_handler @property def source_data(self):