diff --git a/sup3r/bias/abstract.py b/sup3r/bias/abstract.py new file mode 100644 index 000000000..8fc5487fd --- /dev/null +++ b/sup3r/bias/abstract.py @@ -0,0 +1,144 @@ +"""Bias correction class interface.""" + +import logging +from abc import ABC, abstractmethod +from concurrent.futures import ProcessPoolExecutor, as_completed + +import numpy as np + +from sup3r.preprocessing import DataHandler + +logger = logging.getLogger(__name__) + + +class AbstractBiasCorrection(ABC): + """Minimal interface for bias correction classes""" + + @abstractmethod + def _get_run_kwargs(self, **run_single_kwargs): + """Get dictionary of kwarg dictionaries to use for calls to + ``_run_single``. Each key-value pair is a bias_gid with the associated + ``_run_single`` kwargs dict for that gid""" + + def _run( + self, + out, + max_workers=None, + fill_extend=True, + smooth_extend=0, + smooth_interior=0, + **run_single_kwargs, + ): + """Run correction factor calculations for every site in the bias + dataset + + Parameters + ---------- + out : dict + Dictionary of arrays to fill with bias correction factors. + max_workers : int + Number of workers to run in parallel. 1 is serial and None is all + available. + daily_reduction : None | str + Option to do a reduction of the hourly+ source base data to daily + data. Can be None (no reduction, keep source time frequency), "avg" + (daily average), "max" (daily max), "min" (daily min), + "sum" (daily sum/total) + fill_extend : bool + Flag to fill data past distance_upper_bound using spatial nearest + neighbor. If False, the extended domain will be left as NaN. + smooth_extend : float + Option to smooth the scalar/adder data outside of the spatial + domain set by the distance_upper_bound input. This alleviates the + weird seams far from the domain of interest. This value is the + standard deviation for the gaussian_filter kernel + smooth_interior : float + Option to smooth the scalar/adder data within the valid spatial + domain. This can reduce the affect of extreme values within + aggregations over large number of pixels. + run_single_kwargs: dict + Additional kwargs that get sent to ``_run_single`` e.g. + daily_reduction='avg', zero_rate_threshold=1.157e-7 + + Returns + ------- + out : dict + Dictionary of values defining the mean/std of the bias + base data + and correction factors to correct the biased data like: bias_data * + scalar + adder. Each value is of shape (lat, lon, time). + """ + self.bad_bias_gids = [] + + task_kwargs = self._get_run_kwargs(**run_single_kwargs) + + # sup3r DataHandler opening base files will load all data in parallel + # during the init and should not be passed in parallel to workers + if isinstance(self.base_dh, DataHandler): + max_workers = 1 + + if max_workers == 1: + logger.debug('Running serial calculation.') + results = { + bias_gid: self._run_single(**kwargs, base_dh_inst=self.base_dh) + for bias_gid, kwargs in task_kwargs.items() + } + else: + logger.info( + 'Running parallel calculation with %s workers.', max_workers + ) + results = {} + with ProcessPoolExecutor(max_workers=max_workers) as exe: + futures = { + exe.submit(self._run_single, **kwargs): bias_gid + for bias_gid, kwargs in task_kwargs.items() + } + for future in as_completed(futures): + bias_gid = futures[future] + results[bias_gid] = future.result() + + for i, (bias_gid, single_out) in enumerate(results.items()): + raster_loc = np.where(self.bias_gid_raster == bias_gid) + for key, arr in single_out.items(): + out[key][raster_loc] = arr + logger.info( + 'Completed bias calculations for %s out of %s sites', + i + 1, + len(results), + ) + + logger.info('Finished calculating bias correction factors.') + + return self.fill_and_smooth( + out, fill_extend, smooth_extend, smooth_interior + ) + + @abstractmethod + def run( + self, + fp_out=None, + max_workers=None, + daily_reduction='avg', + fill_extend=True, + smooth_extend=0, + smooth_interior=0, + ): + """Run correction factor calculations for every site in the bias + dataset""" + + @classmethod + @abstractmethod + def _run_single( + cls, + bias_data, + base_fps, + bias_feature, + base_dset, + base_gid, + base_handler, + daily_reduction, + bias_ti, + decimals, + base_dh_inst=None, + match_zero_rate=False, + ): + """Run correction factor calculations for a single site""" diff --git a/sup3r/bias/base.py b/sup3r/bias/base.py index a184f37ec..d7761a876 100644 --- a/sup3r/bias/base.py +++ b/sup3r/bias/base.py @@ -20,8 +20,6 @@ from sup3r.utilities import VERSION_RECORD, ModuleName from sup3r.utilities.cli import BaseCLI -from .utilities import run_in_parallel - logger = logging.getLogger(__name__) @@ -779,115 +777,3 @@ def _reduce_base_data( assert base_data.shape == daily_ti.shape, msg return base_data, daily_ti - - def _get_run_kwargs(self, **kwargs_extras): - """Get dictionary of kwarg dictionaries to use for calls to - ``_run_single``. Each key-value pair is a bias_gid with the associated - ``_run_single`` arguments for that gid""" - task_kwargs = {} - for bias_gid in self.bias_meta.index: - _, base_gid = self.get_base_gid(bias_gid) - - if not base_gid.any(): - self.bad_bias_gids.append(bias_gid) - else: - bias_data = self.get_bias_data(bias_gid) - task_kwargs[bias_gid] = { - 'bias_data': bias_data, - 'base_fps': self.base_fps, - 'bias_feature': self.bias_feature, - 'base_dset': self.base_dset, - 'base_gid': base_gid, - 'base_handler': self.base_handler, - 'bias_ti': self.bias_ti, - 'decimals': self.decimals, - 'match_zero_rate': self.match_zero_rate, - **kwargs_extras - } - return task_kwargs - - def _run( - self, - max_workers=None, - fill_extend=True, - smooth_extend=0, - smooth_interior=0, - **kwargs_extras - ): - """Run correction factor calculations for every site in the bias - dataset - - Parameters - ---------- - fp_out : str | None - Optional .h5 output file to write scalar and adder arrays. - max_workers : int - Number of workers to run in parallel. 1 is serial and None is all - available. - daily_reduction : None | str - Option to do a reduction of the hourly+ source base data to daily - data. Can be None (no reduction, keep source time frequency), "avg" - (daily average), "max" (daily max), "min" (daily min), - "sum" (daily sum/total) - fill_extend : bool - Flag to fill data past distance_upper_bound using spatial nearest - neighbor. If False, the extended domain will be left as NaN. - smooth_extend : float - Option to smooth the scalar/adder data outside of the spatial - domain set by the distance_upper_bound input. This alleviates the - weird seams far from the domain of interest. This value is the - standard deviation for the gaussian_filter kernel - smooth_interior : float - Option to smooth the scalar/adder data within the valid spatial - domain. This can reduce the affect of extreme values within - aggregations over large number of pixels. - kwargs_extras: dict - Additional kwargs that get sent to ``_run_single`` e.g. - daily_reduction='avg', zero_rate_threshold=1.157e-7 - - Returns - ------- - out : dict - Dictionary of values defining the mean/std of the bias + base - data and the scalar + adder factors to correct the biased data - like: bias_data * scalar + adder. Each value is of shape - (lat, lon, time). - """ - self.bad_bias_gids = [] - - task_kwargs = self._get_run_kwargs(**kwargs_extras) - # sup3r DataHandler opening base files will load all data in parallel - # during the init and should not be passed in parallel to workers - if isinstance(self.base_dh, DataHandler): - max_workers = 1 - - if max_workers == 1: - logger.debug('Running serial calculation.') - results = { - bias_gid: self._run_single(**kwargs, base_dh_inst=self.base_dh) - for bias_gid, kwargs in task_kwargs.items() - } - else: - logger.info( - 'Running parallel calculation with %s workers.', max_workers - ) - results = run_in_parallel( - self._run_single, task_kwargs, max_workers=max_workers - ) - for i, (bias_gid, single_out) in enumerate(results.items()): - raster_loc = np.where(self.bias_gid_raster == bias_gid) - for key, arr in single_out.items(): - self.out[key][raster_loc] = arr - logger.info( - 'Completed bias calculations for %s out of %s sites', - i + 1, - len(results), - ) - - logger.info('Finished calculating bias correction factors.') - - self.out = self.fill_and_smooth( - self.out, fill_extend, smooth_extend, smooth_interior - ) - - return self.out diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index 13bf18dcc..4e3f40f72 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -11,13 +11,16 @@ import numpy as np from scipy import stats +from .abstract import AbstractBiasCorrection from .base import DataRetrievalBase from .mixins import FillAndSmoothMixin logger = logging.getLogger(__name__) -class LinearCorrection(FillAndSmoothMixin, DataRetrievalBase): +class LinearCorrection( + AbstractBiasCorrection, FillAndSmoothMixin, DataRetrievalBase +): """Calculate linear correction *scalar +adder factors to bias correct data This calculation operates on single bias sites for the full time series of @@ -159,6 +162,32 @@ def write_outputs(self, fp_out, out): 'Wrote scalar adder factors to file: {}'.format(fp_out) ) + def _get_run_kwargs(self, **kwargs_extras): + """Get dictionary of kwarg dictionaries to use for calls to + ``_run_single``. Each key-value pair is a bias_gid with the associated + ``_run_single`` arguments for that gid""" + task_kwargs = {} + for bias_gid in self.bias_meta.index: + _, base_gid = self.get_base_gid(bias_gid) + + if not base_gid.any(): + self.bad_bias_gids.append(bias_gid) + else: + bias_data = self.get_bias_data(bias_gid) + task_kwargs[bias_gid] = { + 'bias_data': bias_data, + 'base_fps': self.base_fps, + 'bias_feature': self.bias_feature, + 'base_dset': self.base_dset, + 'base_gid': base_gid, + 'base_handler': self.base_handler, + 'bias_ti': self.bias_ti, + 'decimals': self.decimals, + 'match_zero_rate': self.match_zero_rate, + **kwargs_extras, + } + return task_kwargs + def run( self, fp_out=None, @@ -212,6 +241,7 @@ def run( ) ) self.out = self._run( + out=self.out, max_workers=max_workers, daily_reduction=daily_reduction, fill_extend=fill_extend, diff --git a/sup3r/bias/presrat.py b/sup3r/bias/presrat.py index b9694ac97..da09ad478 100644 --- a/sup3r/bias/presrat.py +++ b/sup3r/bias/presrat.py @@ -359,41 +359,6 @@ def _run_single( return out - def _get_run_kwargs(self, **kwargs_extras): - """Get dictionary of kwarg dictionaries to use for calls to - ``_run_single``. Each key-value pair is a bias_gid with the associated - ``_run_single`` arguments for that gid""" - task_kwargs = {} - for bias_gid in self.bias_meta.index: - _, base_gid = self.get_base_gid(bias_gid) - - if not base_gid.any(): - self.bad_bias_gids.append(bias_gid) - else: - bias_data = self.get_bias_data(bias_gid) - bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh) - task_kwargs[bias_gid] = { - 'bias_data': bias_data, - 'bias_fut_data': bias_fut_data, - 'base_fps': self.base_fps, - 'bias_feature': self.bias_feature, - 'base_dset': self.base_dset, - 'base_gid': base_gid, - 'base_handler': self.base_handler, - 'bias_ti': self.bias_dh.time_index, - 'bias_fut_ti': self.bias_fut_dh.time_index, - 'decimals': self.decimals, - 'dist': self.dist, - 'relative': self.relative, - 'sampling': self.sampling, - 'n_samples': self.n_quantiles, - 'log_base': self.log_base, - 'n_time_steps': self.n_time_steps, - 'window_size': self.window_size, - **kwargs_extras, - } - return task_kwargs - def run( self, fp_out=None, @@ -454,6 +419,7 @@ def run( ) ) self.out = self._run( + out=self.out, max_workers=max_workers, daily_reduction=daily_reduction, fill_extend=fill_extend, diff --git a/sup3r/bias/qdm.py b/sup3r/bias/qdm.py index a1d2da23d..a635ee811 100644 --- a/sup3r/bias/qdm.py +++ b/sup3r/bias/qdm.py @@ -19,13 +19,16 @@ from sup3r.preprocessing.utilities import expand_paths +from .abstract import AbstractBiasCorrection from .base import DataRetrievalBase from .mixins import FillAndSmoothMixin logger = logging.getLogger(__name__) -class QuantileDeltaMappingCorrection(FillAndSmoothMixin, DataRetrievalBase): +class QuantileDeltaMappingCorrection( + AbstractBiasCorrection, FillAndSmoothMixin, DataRetrievalBase +): """Estimate probability distributions required by Quantile Delta Mapping The main purpose of this class is to estimate the probability @@ -565,6 +568,7 @@ def run( ) self.out = self._run( + out=self.out, max_workers=max_workers, daily_reduction=daily_reduction, fill_extend=fill_extend, diff --git a/sup3r/bias/utilities.py b/sup3r/bias/utilities.py index 4c2e759e0..817620dde 100644 --- a/sup3r/bias/utilities.py +++ b/sup3r/bias/utilities.py @@ -2,7 +2,6 @@ import logging import os -from concurrent.futures import ProcessPoolExecutor, as_completed from inspect import signature from warnings import warn @@ -19,38 +18,6 @@ logger = logging.getLogger(__name__) -def run_in_parallel(task_function, task_kwargs, max_workers=None): - """ - Execute a list of tasks in parallel using ``ProcessPoolExecutor``. - - Parameters - ---------- - task_function : callable - The function to execute in parallel. - task_kwargs : dictionary - A dictionary of keyword argument dictionaries for a single call to - ``task_function``. - max_workers : int, optional - The maximum number of workers to use. If None, it uses all available. - - Returns - ------- - results : dictionary - A dictionary of results from the executed tasks with the same keys as - ``task_kwargs``. - """ - results = {} - with ProcessPoolExecutor(max_workers=max_workers) as exe: - futures = { - exe.submit(task_function, **kwargs): bias_gid - for bias_gid, kwargs in task_kwargs.items() - } - for future in as_completed(futures): - bias_gid = futures[future] - results[bias_gid] = future.result() - return results - - def lin_bc(handler, bc_files, bias_feature=None, threshold=0.1): """Bias correct the data in this DataHandler in place using linear bias correction factors from files output by MonthlyLinearCorrection or diff --git a/tests/bias/test_presrat_bias_correction.py b/tests/bias/test_presrat_bias_correction.py index 7137965ef..a9af3e8ff 100644 --- a/tests/bias/test_presrat_bias_correction.py +++ b/tests/bias/test_presrat_bias_correction.py @@ -82,27 +82,17 @@ def fp_resource(tmpdir_factory): """ fn = tmpdir_factory.mktemp('data').join('precip_oh.h5') - # Reproducing FP_NSRDB before I can change it. time = pd.date_range( - '2018-01-01 00:00:00+0000', '2018-03-26 23:30:00+0000', freq='30m' - ) - time = pd.DatetimeIndex( - np.arange( - np.datetime64('2018-01-01 00:00:00+00:00'), - np.datetime64('2019-01-01 00:00:00+00:00'), - np.timedelta64(6, 'h'), - ) + '2018-01-01 00:00:00', '2019-01-01 00:00:00', freq='6h' ) lat = np.arange(39.77, 39.00, -0.04) lon = np.arange(-105.14, -104.37, 0.04) - rng = np.random.default_rng() - ghi = rng.lognormal(0.0, 1.0, (time.size, lat.size, lon.size)) + ghi = RANDOM_GENERATOR.lognormal(0.0, 1.0, (time.size, lat.size, lon.size)) ds = xr.Dataset( data_vars={'ghi': (['time', 'lat', 'lon'], ghi)}, coords={ 'time': ('time', time), - # "time_bnds": (["time", "bnds"], time_bnds), 'lat': ('lat', lat), 'lon': ('lon', lon), }, @@ -142,40 +132,16 @@ def fp_resource(tmpdir_factory): @pytest.fixture(scope='module') def precip(): - """Synthetic historical modeled dataset - - Note - ---- - There are different expected patterns in different components of the - processing. For instance, lon might be expected as 0-360 in some places - but -180 to 180 in others, and expect a certain order that does not - necessarily match latitutde. So changes in the coordinates shall be - done carefullly. - """ - # first value must conform with TARGET[0] - # n values must conform with SHAPE[0] - # dlat = -0.70175216 + """Synthetic historical modeled dataset""" lat = np.array( [40.3507847105177, 39.649032596592, 38.9472804370071, 38.2455282337738] ) - # assert np.allclose(lat[0], TARGET[0]) - # assert lat.size == SHAPE[0] - - # lon = np.linspace(254.4, 255.1, 10) - # first value must conform with TARGET[1] - # n values must conform with SHAPE[1] lon = np.array([254.53125, 255.234375, 255.9375, 256.640625]) - # assert np.allclose(lat[1], 360 + TARGET[0]) - # assert lon.size == SHAPE[0] - t0 = np.datetime64('2015-01-01T12:00:00') - time = t0 + np.arange( - 0, SAMPLE_TIME_DURATION, SAMPLE_TIME_RESOLUTION, dtype='timedelta64[D]' + time = pd.date_range( + '2015-01-01T12:00:00', '2016-12-31T12:00:00', freq='D' ) - # bnds = (-np.timedelta64(12, 'h'), np.timedelta64(12, 'h')) - # time_bnds = time[:, np.newaxis] + bnds - rng = np.random.default_rng() - pr = rng.lognormal(0.0, 1.0, (time.size, lat.size, lon.size)) + pr = RANDOM_GENERATOR.lognormal(0.0, 1.0, (time.size, lat.size, lon.size)) # Transform the upper tail into negligible to guarantee some 'zero # precipiation days'. @@ -190,7 +156,7 @@ def precip(): data=pr, dims=['time', 'lat', 'lon'], coords={ - 'time': ('time', time), + 'time': ('time', pd.DatetimeIndex(time)), 'lat': ('lat', lat), 'lon': ('lon', lon), },