Skip to content

Commit

Permalink
bias test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Jan 4, 2025
1 parent aff87b6 commit 585a321
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 97 deletions.
67 changes: 30 additions & 37 deletions sup3r/bias/bias_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def run(
if isinstance(self.base_dh, DataHandler):
max_workers = 1

task_args_list = []
task_kwargs_list = []
bias_gids = []
for bias_gid in self.bias_meta.index:
raster_loc = np.where(self.bias_gid_raster == bias_gid)
_, base_gid = self.get_base_gid(bias_gid)
Expand All @@ -231,52 +232,44 @@ def run(
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)
task_args_list.append(
(
bias_data,
self.base_fps,
self.bias_feature,
self.base_dset,
base_gid,
self.base_handler,
daily_reduction,
self.bias_ti,
self.decimals,
self.base_dh,
self.match_zero_rate,
)
bias_gids.append(bias_gid)
task_kwargs_list.append(
{
'bias_data': bias_data,
'base_fps': self.base_fps,
'bias_feature': self.bias_feature,
'base_dset': self.base_dset,
'base_gid': base_gid,
'base_handler': self.base_handler,
'daily_reduction': daily_reduction,
'bias_ti': self.bias_ti,
'decimals': self.decimals,
'match_zero_rate': self.match_zero_rate,
}
)

if max_workers == 1:
logger.debug('Running serial calculation.')
for i, args in enumerate(task_args_list):
single_out = self._run_single(*args)
raster_loc = np.where(self.bias_gid_raster == args[0])
for key, arr in single_out.items():
self.out[key][raster_loc] = arr
logger.info(
'Completed bias calculations for %s out of %s sites',
i + 1,
len(task_args_list),
)
results = [
self._run_single(**kwargs, base_dh_inst=self.base_dh)
for kwargs in task_kwargs_list
]
else:
logger.info(
'Running parallel calculation with %s workers.', max_workers
)
results = run_in_parallel(
self._run_single, task_args_list, max_workers=max_workers
self._run_single, task_kwargs_list, max_workers=max_workers
)
for i, single_out in enumerate(results):
raster_loc = np.where(self.bias_gid_raster == bias_gids[i])
for key, arr in single_out.items():
self.out[key][raster_loc] = arr
logger.info(
'Completed bias calculations for %s out of %s sites',
i + 1,
len(results),
)
for i, single_out in enumerate(results):
raster_loc = np.where(
self.bias_gid_raster == task_args_list[i][0]
)
for key, arr in single_out.items():
self.out[key][raster_loc] = arr
logger.info(
'Completed bias calculations for %s out of %s sites',
i + 1,
len(results),
)

logger.info('Finished calculating bias correction factors.')

Expand Down
56 changes: 30 additions & 26 deletions sup3r/bias/presrat.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,53 +428,57 @@ def run(
if isinstance(self.base_dh, DataHandler):
max_workers = 1

task_args_list = []
task_kwargs_list = []
bias_gids = []
for bias_gid in self.bias_meta.index:
raster_loc = np.where(self.bias_gid_raster == bias_gid)
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_gids.append(bias_gid)
bias_data = self.get_bias_data(bias_gid)
bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh)
task_args_list.append(
(
bias_data,
bias_fut_data,
self.base_fps,
self.bias_feature,
self.base_dset,
base_gid,
self.base_handler,
daily_reduction,
self.bias_dh.time_index,
self.bias_fut_dh.time_index,
self.decimals,
self.dist,
self.relative,
self.sampling,
self.n_quantiles,
self.log_base,
self.n_time_steps,
self.window_size,
zero_rate_threshold,
)
task_kwargs_list.append(
{
'bias_data': bias_data,
'bias_fut_data': bias_fut_data,
'base_fps': self.base_fps,
'bias_feature': self.bias_feature,
'base_dset': self.base_dset,
'base_gid': base_gid,
'base_handler': self.base_handler,
'daily_reduction': daily_reduction,
'bias_ti': self.bias_dh.time_index,
'bias_fut_ti': self.bias_fut_dh.time_index,
'decimals': self.decimals,
'dist': self.dist,
'relative': self.relative,
'sampling': self.sampling,
'n_samples': self.n_quantiles,
'log_base': self.log_base,
'n_time_steps': self.n_time_steps,
'window_size': self.window_size,
'zero_rate_threshold': zero_rate_threshold,
}
)

if max_workers == 1:
logger.debug('Running serial calculation.')
results = [self._run_single(*args) for args in task_args_list]
results = [
self._run_single(**kwargs) for kwargs in task_kwargs_list
]
else:
logger.debug(
'Running parallel calculation with %s workers.', max_workers
)
results = run_in_parallel(
self._run_single, task_args_list, max_workers=max_workers
self._run_single, task_kwargs_list, max_workers=max_workers
)

for i, single_out in enumerate(results):
raster_loc = np.where(self.bias_gid_raster == task_args_list[i][0])
raster_loc = np.where(self.bias_gid_raster == bias_gids[i])
for key, arr in single_out.items():
self.out[key][raster_loc] = arr

Expand Down
55 changes: 29 additions & 26 deletions sup3r/bias/qdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,53 +537,56 @@ def run(
if isinstance(self.base_dh, DataHandler):
max_workers = 1

task_args_list = []
task_kwargs_list = []
bias_gids = []
for bias_gid in self.bias_meta.index:
raster_loc = np.where(self.bias_gid_raster == bias_gid)
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_gids.append(bias_gid)
bias_data = self.get_bias_data(bias_gid)
bias_fut_data = self.get_bias_data(bias_gid, self.bias_fut_dh)
task_args_list.append(
(
bias_data,
bias_fut_data,
self.base_fps,
self.bias_feature,
self.base_dset,
base_gid,
self.base_handler,
daily_reduction,
self.bias_dh.time_index,
self.bias_fut_dh.time_index,
self.decimals,
self.dist,
self.relative,
self.sampling,
self.n_quantiles,
self.log_base,
self.n_time_steps,
self.window_size,
self.base_dh,
)
task_kwargs_list.append(
{
'bias_data': bias_data,
'bias_fut_data': bias_fut_data,
'base_fps': self.base_fps,
'bias_feature': self.bias_feature,
'base_dset': self.base_dset,
'base_gid': base_gid,
'base_handler': self.base_handler,
'daily_reduction': daily_reduction,
'bias_ti': self.bias_dh.time_index,
'bias_fut_ti': self.bias_fut_dh.time_index,
'decimals': self.decimals,
'dist': self.dist,
'relative': self.relative,
'sampling': self.sampling,
'n_samples': self.n_quantiles,
'log_base': self.log_base,
'n_time_steps': self.n_time_steps,
'window_size': self.window_size,
}
)

if max_workers == 1:
logger.debug('Running serial calculation.')
results = [self._run_single(*args) for args in task_args_list]
results = [
self._run_single(**kwargs) for kwargs in task_kwargs_list
]
else:
logger.debug(
'Running parallel calculation with %s workers.', max_workers
)
results = run_in_parallel(
self._run_single, task_args_list, max_workers=max_workers
self._run_single, task_kwargs_list, max_workers=max_workers
)

for i, single_out in enumerate(results):
raster_loc = np.where(self.bias_gid_raster == task_args_list[i][0])
raster_loc = np.where(self.bias_gid_raster == bias_gids[i])
for key, arr in single_out.items():
self.out[key][raster_loc] = arr

Expand Down
10 changes: 5 additions & 5 deletions sup3r/bias/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
logger = logging.getLogger(__name__)


def run_in_parallel(task_function, task_args_list, max_workers=None):
def run_in_parallel(task_function, task_kwargs_list, max_workers=None):
"""
Execute a list of tasks in parallel using ``ProcessPoolExecutor``.
Expand All @@ -28,8 +28,8 @@ def run_in_parallel(task_function, task_args_list, max_workers=None):
task_function : callable
The function to execute in parallel.
task_args_list : list
A list of argument tuples, where each tuple contains the arguments
for a single call to ``task_function``.
A list of keyword argument dictionaries for a single call to
``task_function``.
max_workers : int, optional
The maximum number of workers to use. If None, it uses all available.
Expand All @@ -41,8 +41,8 @@ def run_in_parallel(task_function, task_args_list, max_workers=None):
results = []
with ProcessPoolExecutor(max_workers=max_workers) as exe:
futures = {
exe.submit(task_function, *args): args
for args in task_args_list
exe.submit(task_function, **kwargs): kwargs
for kwargs in task_kwargs_list
}
for future in as_completed(futures):
result = future.result()
Expand Down
7 changes: 4 additions & 3 deletions sup3r/models/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def run(self):
kwargs=self.kwargs,
)
try:
logger.info('Starting training session.')
self.batch_handler.start()
logger.info(
'Starting training session. Training for %s epochs',
self.kwargs['n_epoch'],
)
model_thread.start()
except KeyboardInterrupt:
logger.info('Ending training session.')
Expand All @@ -53,7 +55,6 @@ def run(self):
sys.exit()

logger.info('Finished training')
self.batch_handler.stop()
model_thread.join()


Expand Down

0 comments on commit 585a321

Please sign in to comment.