Skip to content

Commit

Permalink
alternative setup for ddh with inheritance from mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Nov 27, 2023
1 parent 067a928 commit 7a00a1f
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 118 deletions.
20 changes: 12 additions & 8 deletions sup3r/preprocessing/batch_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,14 +618,15 @@ def _parallel_normalization(self):
max_workers = self.norm_workers
if max_workers == 1:
for dh in self.data_handlers:
dh.normalize(self.means, self.stds)
dh.normalize(self.means, self.stds,
max_workers=dh.norm_workers)
else:
with ThreadPoolExecutor(max_workers=max_workers) as exe:
futures = {}
now = dt.now()
for idh, dh in enumerate(self.data_handlers):
future = exe.submit(dh.normalize, self.means, self.stds,
max_workers=1)
max_workers=dh.norm_workers)
futures[future] = idh

logger.info(f'Started normalizing {len(self.data_handlers)} '
Expand Down Expand Up @@ -691,7 +692,8 @@ def _get_stats(self):
future = exe.submit(dh._get_stats)
futures[future] = idh

for i, _ in enumerate(as_completed(futures)):
for i, future in enumerate(as_completed(futures)):
_ = future.result()
logger.debug(f'{i+1} out of {len(self.data_handlers)} '
'means calculated.')

Expand Down Expand Up @@ -731,10 +733,10 @@ def check_cached_stats(self):
means_check = means_check and os.path.exists(self.means_file)
if stdevs_check and means_check:
logger.info(f'Loading stdevs from {self.stdevs_file}')
with open(self.stdevs_file, 'r') as fh:
with open(self.stdevs_file) as fh:
self.stds = json.load(fh)
logger.info(f'Loading means from {self.means_file}')
with open(self.means_file, 'r') as fh:
with open(self.means_file) as fh:
self.means = json.load(fh)

msg = ('The training features and cached statistics are '
Expand Down Expand Up @@ -777,8 +779,7 @@ def _get_feature_means(self, feature):
feature : str
Feature to get mean for
"""

logger.debug(f'Calculating mean for {feature}')
logger.debug(f'Calculating multi-handler mean for {feature}')
for idh, dh in enumerate(self.data_handlers):
self.means[feature] += (self.handler_weights[idh]
* dh.means[feature])
Expand All @@ -798,7 +799,7 @@ def _get_feature_stdev(self, feature):
Feature to get stdev for
"""

logger.debug(f'Calculating stdev for {feature}')
logger.debug(f'Calculating multi-handler stdev for {feature}')
for idh, dh in enumerate(self.data_handlers):
variance = dh.stds[feature]**2
self.stds[feature] += (variance * self.handler_weights[idh])
Expand All @@ -823,6 +824,9 @@ def normalize(self, means=None, stds=None):
feature names and values: standard deviations. if None, this will
be calculated. if norm is true these will be used for data
normalization
features : list | None
Optional list of features used to index data array during
normalization. If this is None self.features will be used.
"""
if means is None or stds is None:
self.get_stats()
Expand Down
77 changes: 0 additions & 77 deletions sup3r/preprocessing/data_handling/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,17 +450,6 @@ def load_workers(self):
n_procs)
return load_workers

@property
def norm_workers(self):
"""Get upper bound on workers used for normalization."""
if self.data is not None:
norm_workers = estimate_max_workers(self._norm_workers,
2 * self.feature_mem,
self.shape[-1])
else:
norm_workers = self._norm_workers
return norm_workers

@property
def time_chunks(self):
"""Get time chunks which will be extracted from source data
Expand Down Expand Up @@ -921,72 +910,6 @@ def get_cache_file_names(self,
target,
features)

@property
def means(self):
"""Get the mean values for each feature.
Returns
-------
dict
"""
self._get_stats()
return self._means

@property
def stds(self):
"""Get the standard deviation values for each feature.
Returns
-------
dict
"""
self._get_stats()
return self._stds

def _get_stats(self):
if self._means is None or self._stds is None:
msg = (f'DataHandler has {len(self.features)} features '
f'and mismatched shape of {self.shape}')
assert len(self.features) == self.shape[-1], msg
self._stds = {}
self._means = {}
for idf, fname in enumerate(self.features):
self._means[fname] = np.nanmean(self.data[..., idf])
self._stds[fname] = np.nanstd(self.data[..., idf])

def normalize(self, means=None, stds=None, max_workers=None):
"""Normalize all data features.
Parameters
----------
means : dict | none
Dictionary of means for all features with keys: feature names and
values: mean values. If this is None, the self.means attribute will
be used. If this is not None, this DataHandler object means
attribute will be updated.
stds : dict | none
dictionary of standard deviation values for all features with keys:
feature names and values: standard deviations. If this is None, the
self.stds attribute will be used. If this is not None, this
DataHandler object stds attribute will be updated.
max_workers : None | int
Max workers to perform normalization. if None, self.norm_workers
will be used
"""
if means is not None:
self._means = means
if stds is not None:
self._stds = stds

max_workers = max_workers or self.norm_workers
if self._is_normalized:
logger.info('Skipping DataHandler, already normalized')
else:
self._normalize(self.data,
self.val_data,
max_workers=max_workers)
self._is_normalized = True

def get_next(self):
"""Get data for observation using random observation index. Loops
repeatedly over randomized time index
Expand Down
54 changes: 31 additions & 23 deletions sup3r/preprocessing/data_handling/dual_data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,30 @@ def __init__(self,
self.t_enhance = t_enhance
self.lr_dh = lr_handler
self.hr_dh = hr_handler
self._cache_pattern = cache_pattern
self._cached_features = None
self._noncached_features = None
self.overwrite_cache = overwrite_cache
self.val_split = val_split
self.current_obs_index = None
self.load_cached = load_cached
self.regrid_workers = regrid_workers
self.shuffle_time = shuffle_time
self._lr_lat_lon = None
self._hr_lat_lon = None
self._lr_input_data = None
self.hr_data = None
self.lr_val_data = None
self.hr_val_data = None
lr_data_shape = (*self.lr_required_shape, len(self.lr_dh.features))
self.lr_data = np.zeros(lr_data_shape, dtype=np.float32)
self.lr_data = np.zeros(self.shape, dtype=np.float32)
self.lr_time_index = lr_handler.time_index
self.hr_time_index = hr_handler.time_index
self.lr_val_time_index = lr_handler.val_time_index
self.hr_val_time_index = hr_handler.val_time_index
self._lr_lat_lon = None
self._hr_lat_lon = None
self._lr_input_data = None
self._cache_pattern = cache_pattern
self._cached_features = None
self._noncached_features = None
self._means = None
self._stds = None
self._is_normalized = False
self._norm_workers = self.lr_dh.norm_workers

if self.try_load and self.load_cached:
self.load_cached_data()
Expand Down Expand Up @@ -162,7 +165,7 @@ def _val_split_check(self):

def _get_stats(self):
"""Get mean/stdev stats for HR and LR data handlers"""
self.lr_dh._get_stats()
super()._get_stats(features=self.lr_dh.features)
self.hr_dh._get_stats()

@property
Expand All @@ -176,7 +179,7 @@ def means(self):
dict
"""
out = copy.deepcopy(self.hr_dh.means)
out.update(self.lr_dh.means)
out.update(super().means)
return out

@property
Expand All @@ -190,9 +193,10 @@ def stds(self):
dict
"""
out = copy.deepcopy(self.hr_dh.stds)
out.update(self.lr_dh.stds)
out.update(super().stds)
return out

# pylint: disable=unused-argument
def normalize(self, means=None, stds=None, max_workers=None):
"""Normalize low_res and high_res data
Expand All @@ -209,19 +213,22 @@ def normalize(self, means=None, stds=None, max_workers=None):
self.stds attribute will be used. If this is not None, this
DataHandler object stds attribute will be updated.
max_workers : None | int
Max workers to perform normalization. if None, self.norm_workers
will be used
Has no effect. Used to match MixIn class signature.
"""
if means is None:
means = self.means
if stds is None:
stds = self.stds
logger.info('Normalizing low resolution data features='
f'{self.lr_dh.features}')
self.lr_dh.normalize(means=means, stds=stds, max_workers=max_workers)
super().normalize(means=means, stds=stds,
features=self.lr_dh.features,
max_workers=self.lr_dh.norm_workers)
logger.info('Normalizing high resolution data features='
f'{self.hr_dh.features}')
self.hr_dh.normalize(means=means, stds=stds, max_workers=max_workers)
self.hr_dh.normalize(means=means, stds=stds,
features=self.hr_dh.features,
max_workers=self.hr_dh.norm_workers)

@property
def features(self):
Expand Down Expand Up @@ -363,9 +370,15 @@ def hr_sample_shape(self):
@property
def data(self):
"""Get low res data. Same as self.lr_data but used to match property
used by batch handler for computing means and stdevs"""
used for computing means and stdevs"""
return self.lr_data

@property
def val_data(self):
"""Get low res validation data. Same as self.lr_val_data but used to
match property used by normalization routine."""
return self.lr_val_data

@property
def lr_input_data(self):
"""Get low res data used as input to regridding routine"""
Expand Down Expand Up @@ -405,11 +418,6 @@ def lr_grid_shape(self):
"""Return grid shape for regridded low_res data"""
return (self.lr_required_shape[0], self.lr_required_shape[1])

@property
def lr_requested_shape(self):
"""Return requested shape for low_res data"""
return (*self.lr_required_shape, len(self.features))

@property
def lr_lat_lon(self):
"""Get low_res lat lon array"""
Expand Down Expand Up @@ -471,10 +479,10 @@ def load_lr_cached_data(self):
"""Load low_res cache data"""

logger.info(
f'Loading cache with requested_shape={self.lr_requested_shape}.')
f'Loading cache with requested_shape={self.shape}.')
self._load_cached_data(self.lr_data,
self.cache_files,
self.features,
self.lr_dh.features,
max_workers=self.hr_dh.load_workers)

def load_cached_data(self):
Expand Down
Loading

0 comments on commit 7a00a1f

Please sign in to comment.