Skip to content

Commit

Permalink
bias refact: run_in_parallel function to remove duplicate calls t…
Browse files Browse the repository at this point in the history
…o ``ProcessPoolExecutor``
  • Loading branch information
bnb32 committed Jan 4, 2025
1 parent e1ee79a commit 0832a29
Show file tree
Hide file tree
Showing 4 changed files with 417 additions and 430 deletions.
107 changes: 42 additions & 65 deletions sup3r/bias/bias_calc.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
"""Utilities to calculate the bias correction factors for biased data that is
going to be fed into the sup3r downscaling models. This is typically used to
bias correct GCM data vs. some historical record like the WTK or NSRDB.
TODO: Generalize the ``with ProcessPoolExecutor() as exe: ...`` so we don't
need to duplicate this wherever we kickoff a process or thread pool
"""
bias correct GCM data vs. some historical record like the WTK or NSRDB."""

import copy
import json
import logging
import os
from concurrent.futures import ProcessPoolExecutor, as_completed

import h5py
import numpy as np
Expand All @@ -20,6 +15,7 @@

from .base import DataRetrievalBase
from .mixins import FillAndSmoothMixin
from .utilities import run_in_parallel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -226,17 +222,17 @@ def run(
if isinstance(self.base_dh, DataHandler):
max_workers = 1

if max_workers == 1:
logger.debug('Running serial calculation.')
for i, bias_gid in enumerate(self.bias_meta.index):
raster_loc = np.where(self.bias_gid_raster == bias_gid)
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)
single_out = self._run_single(
task_args_list = []
for bias_gid in self.bias_meta.index:
raster_loc = np.where(self.bias_gid_raster == bias_gid)
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)
task_args_list.append(
(
bias_data,
self.base_fps,
self.bias_feature,
Expand All @@ -246,66 +242,47 @@ def run(
daily_reduction,
self.bias_ti,
self.decimals,
base_dh_inst=self.base_dh,
match_zero_rate=self.match_zero_rate,
self.base_dh,
self.match_zero_rate,
)
for key, arr in single_out.items():
self.out[key][raster_loc] = arr
)

if max_workers == 1:
logger.debug('Running serial calculation.')
for i, args in enumerate(task_args_list):
single_out = self._run_single(*args)
raster_loc = np.where(self.bias_gid_raster == args[0])
for key, arr in single_out.items():
self.out[key][raster_loc] = arr
logger.info(
'Completed bias calculations for {} out of {} '
'sites'.format(i + 1, len(self.bias_meta))
'Completed bias calculations for %s out of %s sites',
i + 1,
len(task_args_list),
)

else:
logger.debug(
'Running parallel calculation with {} workers.'.format(
max_workers
)
logger.info(
'Running parallel calculation with %s workers.', max_workers
)
with ProcessPoolExecutor(max_workers=max_workers) as exe:
futures = {}
for bias_gid in self.bias_meta.index:
raster_loc = np.where(self.bias_gid_raster == bias_gid)
_, base_gid = self.get_base_gid(bias_gid)

if not base_gid.any():
self.bad_bias_gids.append(bias_gid)
else:
bias_data = self.get_bias_data(bias_gid)
future = exe.submit(
self._run_single,
bias_data,
self.base_fps,
self.bias_feature,
self.base_dset,
base_gid,
self.base_handler,
daily_reduction,
self.bias_ti,
self.decimals,
match_zero_rate=self.match_zero_rate,
)
futures[future] = raster_loc

logger.debug('Finished launching futures.')
for i, future in enumerate(as_completed(futures)):
raster_loc = futures[future]
single_out = future.result()
for key, arr in single_out.items():
self.out[key][raster_loc] = arr

logger.info(
'Completed bias calculations for {} out of {} '
'sites'.format(i + 1, len(futures))
)
results = run_in_parallel(
self._run_single, task_args_list, max_workers=max_workers
)
for i, single_out in enumerate(results):
raster_loc = np.where(
self.bias_gid_raster == task_args_list[i][0]
)
for key, arr in single_out.items():
self.out[key][raster_loc] = arr
logger.info(
'Completed bias calculations for %s out of %s sites',
i + 1,
len(results),
)

logger.info('Finished calculating bias correction factors.')

self.out = self.fill_and_smooth(
self.out, fill_extend, smooth_extend, smooth_interior
)

self.write_outputs(fp_out, self.out)

return copy.deepcopy(self.out)
Expand Down
Loading

0 comments on commit 0832a29

Please sign in to comment.