Skip to content

Commit

Permalink
additional bias refact: _run base method and _get_run_kwargs
Browse files Browse the repository at this point in the history
…method.
  • Loading branch information
bnb32 committed Jan 4, 2025
1 parent 585a321 commit 7f92049
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 221 deletions.
118 changes: 116 additions & 2 deletions sup3r/bias/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from sup3r.utilities import VERSION_RECORD, ModuleName
from sup3r.utilities.cli import BaseCLI

from .utilities import run_in_parallel

logger = logging.getLogger(__name__)


Expand All @@ -43,7 +45,7 @@ def __init__(
bias_handler_kwargs=None,
decimals=None,
match_zero_rate=False,
pre_load=True
pre_load=True,
):
"""
Parameters
Expand Down Expand Up @@ -178,7 +180,7 @@ class is used, all data will be loaded in this class'

self.nn_dist, self.nn_ind = self.bias_tree.query(
self.base_meta[['latitude', 'longitude']],
distance_upper_bound=self.distance_upper_bound
distance_upper_bound=self.distance_upper_bound,
)

if pre_load:
Expand Down Expand Up @@ -777,3 +779,115 @@ def _reduce_base_data(
assert base_data.shape == daily_ti.shape, msg

return base_data, daily_ti

def _get_run_kwargs(self, **kwargs_extras):
"""Get dictionary of kwarg dictionaries to use for calls to
``_run_single``. Each key-value pair is a bias_gid with the associated
``_run_single`` arguments for that gid"""
task_kwargs = {}
for bias_gid in self.bias_meta.index:
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)
task_kwargs[bias_gid] = {
'bias_data': bias_data,
'base_fps': self.base_fps,
'bias_feature': self.bias_feature,
'base_dset': self.base_dset,
'base_gid': base_gid,
'base_handler': self.base_handler,
'bias_ti': self.bias_ti,
'decimals': self.decimals,
'match_zero_rate': self.match_zero_rate,
**kwargs_extras
}
return task_kwargs

def _run(
self,
max_workers=None,
fill_extend=True,
smooth_extend=0,
smooth_interior=0,
**kwargs_extras
):
"""Run correction factor calculations for every site in the bias
dataset
Parameters
----------
fp_out : str | None
Optional .h5 output file to write scalar and adder arrays.
max_workers : int
Number of workers to run in parallel. 1 is serial and None is all
available.
daily_reduction : None | str
Option to do a reduction of the hourly+ source base data to daily
data. Can be None (no reduction, keep source time frequency), "avg"
(daily average), "max" (daily max), "min" (daily min),
"sum" (daily sum/total)
fill_extend : bool
Flag to fill data past distance_upper_bound using spatial nearest
neighbor. If False, the extended domain will be left as NaN.
smooth_extend : float
Option to smooth the scalar/adder data outside of the spatial
domain set by the distance_upper_bound input. This alleviates the
weird seams far from the domain of interest. This value is the
standard deviation for the gaussian_filter kernel
smooth_interior : float
Option to smooth the scalar/adder data within the valid spatial
domain. This can reduce the affect of extreme values within
aggregations over large number of pixels.
kwargs_extras: dict
Additional kwargs that get sent to ``_run_single`` e.g.
daily_reduction='avg', zero_rate_threshold=1.157e-7
Returns
-------
out : dict
Dictionary of values defining the mean/std of the bias + base
data and the scalar + adder factors to correct the biased data
like: bias_data * scalar + adder. Each value is of shape
(lat, lon, time).
"""
self.bad_bias_gids = []

task_kwargs = self._get_run_kwargs(**kwargs_extras)
# sup3r DataHandler opening base files will load all data in parallel
# during the init and should not be passed in parallel to workers
if isinstance(self.base_dh, DataHandler):
max_workers = 1

if max_workers == 1:
logger.debug('Running serial calculation.')
results = {
bias_gid: self._run_single(**kwargs, base_dh_inst=self.base_dh)
for bias_gid, kwargs in task_kwargs.items()
}
else:
logger.info(
'Running parallel calculation with %s workers.', max_workers
)
results = run_in_parallel(
self._run_single, task_kwargs, max_workers=max_workers
)
for i, (bias_gid, single_out) in enumerate(results.items()):
raster_loc = np.where(self.bias_gid_raster == bias_gid)
for key, arr in single_out.items():
self.out[key][raster_loc] = arr
logger.info(
'Completed bias calculations for %s out of %s sites',
i + 1,
len(results),
)

logger.info('Finished calculating bias correction factors.')

self.out = self.fill_and_smooth(
self.out, fill_extend, smooth_extend, smooth_interior
)

return self.out
70 changes: 6 additions & 64 deletions sup3r/bias/bias_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,8 @@
import numpy as np
from scipy import stats

from sup3r.preprocessing import DataHandler

from .base import DataRetrievalBase
from .mixins import FillAndSmoothMixin
from .utilities import run_in_parallel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -214,67 +211,12 @@ def run(
self.bias_gid_raster.shape
)
)

self.bad_bias_gids = []

# sup3r DataHandler opening base files will load all data in parallel
# during the init and should not be passed in parallel to workers
if isinstance(self.base_dh, DataHandler):
max_workers = 1

task_kwargs_list = []
bias_gids = []
for bias_gid in self.bias_meta.index:
raster_loc = np.where(self.bias_gid_raster == bias_gid)
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)
bias_gids.append(bias_gid)
task_kwargs_list.append(
{
'bias_data': bias_data,
'base_fps': self.base_fps,
'bias_feature': self.bias_feature,
'base_dset': self.base_dset,
'base_gid': base_gid,
'base_handler': self.base_handler,
'daily_reduction': daily_reduction,
'bias_ti': self.bias_ti,
'decimals': self.decimals,
'match_zero_rate': self.match_zero_rate,
}
)

if max_workers == 1:
logger.debug('Running serial calculation.')
results = [
self._run_single(**kwargs, base_dh_inst=self.base_dh)
for kwargs in task_kwargs_list
]
else:
logger.info(
'Running parallel calculation with %s workers.', max_workers
)
results = run_in_parallel(
self._run_single, task_kwargs_list, max_workers=max_workers
)
for i, single_out in enumerate(results):
raster_loc = np.where(self.bias_gid_raster == bias_gids[i])
for key, arr in single_out.items():
self.out[key][raster_loc] = arr
logger.info(
'Completed bias calculations for %s out of %s sites',
i + 1,
len(results),
)

logger.info('Finished calculating bias correction factors.')

self.out = self.fill_and_smooth(
self.out, fill_extend, smooth_extend, smooth_interior
self.out = self._run(
max_workers=max_workers,
daily_reduction=daily_reduction,
fill_extend=fill_extend,
smooth_extend=smooth_extend,
smooth_interior=smooth_interior,
)
self.write_outputs(fp_out, self.out)

Expand Down
116 changes: 42 additions & 74 deletions sup3r/bias/presrat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
QuantileDeltaMapping,
)

from sup3r.preprocessing import DataHandler

from .mixins import ZeroRateMixin
from .qdm import QuantileDeltaMappingCorrection
from .utilities import run_in_parallel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -362,6 +359,41 @@ def _run_single(

return out

def _get_run_kwargs(self, **kwargs_extras):
"""Get dictionary of kwarg dictionaries to use for calls to
``_run_single``. Each key-value pair is a bias_gid with the associated
``_run_single`` arguments for that gid"""
task_kwargs = {}
for bias_gid in self.bias_meta.index:
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)
bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh)
task_kwargs[bias_gid] = {
'bias_data': bias_data,
'bias_fut_data': bias_fut_data,
'base_fps': self.base_fps,
'bias_feature': self.bias_feature,
'base_dset': self.base_dset,
'base_gid': base_gid,
'base_handler': self.base_handler,
'bias_ti': self.bias_dh.time_index,
'bias_fut_ti': self.bias_fut_dh.time_index,
'decimals': self.decimals,
'dist': self.dist,
'relative': self.relative,
'sampling': self.sampling,
'n_samples': self.n_quantiles,
'log_base': self.log_base,
'n_time_steps': self.n_time_steps,
'window_size': self.window_size,
**kwargs_extras,
}
return task_kwargs

def run(
self,
fp_out=None,
Expand Down Expand Up @@ -421,77 +453,13 @@ def run(
self.bias_gid_raster.shape
)
)
self.bad_bias_gids = []

# sup3r DataHandler opening base files will load all data in parallel
# during the init and should not be passed in parallel to workers
if isinstance(self.base_dh, DataHandler):
max_workers = 1

task_kwargs_list = []
bias_gids = []
for bias_gid in self.bias_meta.index:
raster_loc = np.where(self.bias_gid_raster == bias_gid)
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_gids.append(bias_gid)
bias_data = self.get_bias_data(bias_gid)
bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh)
task_kwargs_list.append(
{
'bias_data': bias_data,
'bias_fut_data': bias_fut_data,
'base_fps': self.base_fps,
'bias_feature': self.bias_feature,
'base_dset': self.base_dset,
'base_gid': base_gid,
'base_handler': self.base_handler,
'daily_reduction': daily_reduction,
'bias_ti': self.bias_dh.time_index,
'bias_fut_ti': self.bias_fut_dh.time_index,
'decimals': self.decimals,
'dist': self.dist,
'relative': self.relative,
'sampling': self.sampling,
'n_samples': self.n_quantiles,
'log_base': self.log_base,
'n_time_steps': self.n_time_steps,
'window_size': self.window_size,
'zero_rate_threshold': zero_rate_threshold,
}
)

if max_workers == 1:
logger.debug('Running serial calculation.')
results = [
self._run_single(**kwargs) for kwargs in task_kwargs_list
]
else:
logger.debug(
'Running parallel calculation with %s workers.', max_workers
)
results = run_in_parallel(
self._run_single, task_kwargs_list, max_workers=max_workers
)

for i, single_out in enumerate(results):
raster_loc = np.where(self.bias_gid_raster == bias_gids[i])
for key, arr in single_out.items():
self.out[key][raster_loc] = arr

logger.info(
'Completed bias calculations for %s out of %s sites',
i + 1,
len(results),
)

logger.info('Finished calculating bias correction factors.')

self.out = self.fill_and_smooth(
self.out, fill_extend, smooth_extend, smooth_interior
self.out = self._run(
max_workers=max_workers,
daily_reduction=daily_reduction,
fill_extend=fill_extend,
smooth_extend=smooth_extend,
smooth_interior=smooth_interior,
zero_rate_threshold=zero_rate_threshold,
)

extra_attrs = {
Expand Down
Loading

0 comments on commit 7f92049

Please sign in to comment.