diff --git a/sup3r/preprocessing/batch_handling.py b/sup3r/preprocessing/batch_handling.py index 9c10b85f5..292bb25d7 100644 --- a/sup3r/preprocessing/batch_handling.py +++ b/sup3r/preprocessing/batch_handling.py @@ -618,14 +618,15 @@ def _parallel_normalization(self): max_workers = self.norm_workers if max_workers == 1: for dh in self.data_handlers: - dh.normalize(self.means, self.stds) + dh.normalize(self.means, self.stds, + max_workers=dh.norm_workers) else: with ThreadPoolExecutor(max_workers=max_workers) as exe: futures = {} now = dt.now() for idh, dh in enumerate(self.data_handlers): future = exe.submit(dh.normalize, self.means, self.stds, - max_workers=1) + max_workers=dh.norm_workers) futures[future] = idh logger.info(f'Started normalizing {len(self.data_handlers)} ' @@ -691,7 +692,8 @@ def _get_stats(self): future = exe.submit(dh._get_stats) futures[future] = idh - for i, _ in enumerate(as_completed(futures)): + for i, future in enumerate(as_completed(futures)): + _ = future.result() logger.debug(f'{i+1} out of {len(self.data_handlers)} ' 'means calculated.') @@ -731,10 +733,10 @@ def check_cached_stats(self): means_check = means_check and os.path.exists(self.means_file) if stdevs_check and means_check: logger.info(f'Loading stdevs from {self.stdevs_file}') - with open(self.stdevs_file, 'r') as fh: + with open(self.stdevs_file) as fh: self.stds = json.load(fh) logger.info(f'Loading means from {self.means_file}') - with open(self.means_file, 'r') as fh: + with open(self.means_file) as fh: self.means = json.load(fh) msg = ('The training features and cached statistics are ' @@ -777,8 +779,7 @@ def _get_feature_means(self, feature): feature : str Feature to get mean for """ - - logger.debug(f'Calculating mean for {feature}') + logger.debug(f'Calculating multi-handler mean for {feature}') for idh, dh in enumerate(self.data_handlers): self.means[feature] += (self.handler_weights[idh] * dh.means[feature]) @@ -798,7 +799,7 @@ def _get_feature_stdev(self, feature): Feature to get stdev for """ - logger.debug(f'Calculating stdev for {feature}') + logger.debug(f'Calculating multi-handler stdev for {feature}') for idh, dh in enumerate(self.data_handlers): variance = dh.stds[feature]**2 self.stds[feature] += (variance * self.handler_weights[idh]) @@ -823,6 +824,9 @@ def normalize(self, means=None, stds=None): feature names and values: standard deviations. if None, this will be calculated. if norm is true these will be used for data normalization + features : list | None + Optional list of features used to index data array during + normalization. If this is None self.features will be used. """ if means is None or stds is None: self.get_stats() diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index e6c51ef74..17bfbd8dd 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -450,17 +450,6 @@ def load_workers(self): n_procs) return load_workers - @property - def norm_workers(self): - """Get upper bound on workers used for normalization.""" - if self.data is not None: - norm_workers = estimate_max_workers(self._norm_workers, - 2 * self.feature_mem, - self.shape[-1]) - else: - norm_workers = self._norm_workers - return norm_workers - @property def time_chunks(self): """Get time chunks which will be extracted from source data @@ -921,72 +910,6 @@ def get_cache_file_names(self, target, features) - @property - def means(self): - """Get the mean values for each feature. - - Returns - ------- - dict - """ - self._get_stats() - return self._means - - @property - def stds(self): - """Get the standard deviation values for each feature. - - Returns - ------- - dict - """ - self._get_stats() - return self._stds - - def _get_stats(self): - if self._means is None or self._stds is None: - msg = (f'DataHandler has {len(self.features)} features ' - f'and mismatched shape of {self.shape}') - assert len(self.features) == self.shape[-1], msg - self._stds = {} - self._means = {} - for idf, fname in enumerate(self.features): - self._means[fname] = np.nanmean(self.data[..., idf]) - self._stds[fname] = np.nanstd(self.data[..., idf]) - - def normalize(self, means=None, stds=None, max_workers=None): - """Normalize all data features. - - Parameters - ---------- - means : dict | none - Dictionary of means for all features with keys: feature names and - values: mean values. If this is None, the self.means attribute will - be used. If this is not None, this DataHandler object means - attribute will be updated. - stds : dict | none - dictionary of standard deviation values for all features with keys: - feature names and values: standard deviations. If this is None, the - self.stds attribute will be used. If this is not None, this - DataHandler object stds attribute will be updated. - max_workers : None | int - Max workers to perform normalization. if None, self.norm_workers - will be used - """ - if means is not None: - self._means = means - if stds is not None: - self._stds = stds - - max_workers = max_workers or self.norm_workers - if self._is_normalized: - logger.info('Skipping DataHandler, already normalized') - else: - self._normalize(self.data, - self.val_data, - max_workers=max_workers) - self._is_normalized = True - def get_next(self): """Get data for observation using random observation index. Loops repeatedly over randomized time index diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index 232786e02..affcc95dc 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -72,27 +72,30 @@ def __init__(self, self.t_enhance = t_enhance self.lr_dh = lr_handler self.hr_dh = hr_handler - self._cache_pattern = cache_pattern - self._cached_features = None - self._noncached_features = None self.overwrite_cache = overwrite_cache self.val_split = val_split self.current_obs_index = None self.load_cached = load_cached self.regrid_workers = regrid_workers self.shuffle_time = shuffle_time - self._lr_lat_lon = None - self._hr_lat_lon = None - self._lr_input_data = None self.hr_data = None self.lr_val_data = None self.hr_val_data = None - lr_data_shape = (*self.lr_required_shape, len(self.lr_dh.features)) - self.lr_data = np.zeros(lr_data_shape, dtype=np.float32) + self.lr_data = np.zeros(self.shape, dtype=np.float32) self.lr_time_index = lr_handler.time_index self.hr_time_index = hr_handler.time_index self.lr_val_time_index = lr_handler.val_time_index self.hr_val_time_index = hr_handler.val_time_index + self._lr_lat_lon = None + self._hr_lat_lon = None + self._lr_input_data = None + self._cache_pattern = cache_pattern + self._cached_features = None + self._noncached_features = None + self._means = None + self._stds = None + self._is_normalized = False + self._norm_workers = self.lr_dh.norm_workers if self.try_load and self.load_cached: self.load_cached_data() @@ -162,7 +165,7 @@ def _val_split_check(self): def _get_stats(self): """Get mean/stdev stats for HR and LR data handlers""" - self.lr_dh._get_stats() + super()._get_stats(features=self.lr_dh.features) self.hr_dh._get_stats() @property @@ -176,7 +179,7 @@ def means(self): dict """ out = copy.deepcopy(self.hr_dh.means) - out.update(self.lr_dh.means) + out.update(super().means) return out @property @@ -190,9 +193,10 @@ def stds(self): dict """ out = copy.deepcopy(self.hr_dh.stds) - out.update(self.lr_dh.stds) + out.update(super().stds) return out + # pylint: disable=unused-argument def normalize(self, means=None, stds=None, max_workers=None): """Normalize low_res and high_res data @@ -209,8 +213,7 @@ def normalize(self, means=None, stds=None, max_workers=None): self.stds attribute will be used. If this is not None, this DataHandler object stds attribute will be updated. max_workers : None | int - Max workers to perform normalization. if None, self.norm_workers - will be used + Has no effect. Used to match MixIn class signature. """ if means is None: means = self.means @@ -218,10 +221,14 @@ def normalize(self, means=None, stds=None, max_workers=None): stds = self.stds logger.info('Normalizing low resolution data features=' f'{self.lr_dh.features}') - self.lr_dh.normalize(means=means, stds=stds, max_workers=max_workers) + super().normalize(means=means, stds=stds, + features=self.lr_dh.features, + max_workers=self.lr_dh.norm_workers) logger.info('Normalizing high resolution data features=' f'{self.hr_dh.features}') - self.hr_dh.normalize(means=means, stds=stds, max_workers=max_workers) + self.hr_dh.normalize(means=means, stds=stds, + features=self.hr_dh.features, + max_workers=self.hr_dh.norm_workers) @property def features(self): @@ -363,9 +370,15 @@ def hr_sample_shape(self): @property def data(self): """Get low res data. Same as self.lr_data but used to match property - used by batch handler for computing means and stdevs""" + used for computing means and stdevs""" return self.lr_data + @property + def val_data(self): + """Get low res validation data. Same as self.lr_val_data but used to + match property used by normalization routine.""" + return self.lr_val_data + @property def lr_input_data(self): """Get low res data used as input to regridding routine""" @@ -405,11 +418,6 @@ def lr_grid_shape(self): """Return grid shape for regridded low_res data""" return (self.lr_required_shape[0], self.lr_required_shape[1]) - @property - def lr_requested_shape(self): - """Return requested shape for low_res data""" - return (*self.lr_required_shape, len(self.features)) - @property def lr_lat_lon(self): """Get low_res lat lon array""" @@ -471,10 +479,10 @@ def load_lr_cached_data(self): """Load low_res cache data""" logger.info( - f'Loading cache with requested_shape={self.lr_requested_shape}.') + f'Loading cache with requested_shape={self.shape}.') self._load_cached_data(self.lr_data, self.cache_files, - self.features, + self.lr_dh.features, max_workers=self.hr_dh.load_workers) def load_cached_data(self): diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index 3d2149e37..2d3037c5d 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -17,6 +17,7 @@ from scipy.stats import mode from sup3r.utilities.utilities import ( + estimate_max_workers, get_source_type, ignore_case_path_fetch, uniform_box_sampler, @@ -924,8 +925,14 @@ class TrainingPrepMixIn: def __init__(self): """Initialize common attributes""" self.features = None - self.means = None - self.stds = None + self.data = None + self.val_data = None + self.feature_mem = None + self.shape = None + self._means = None + self._stds = None + self._is_normalized = False + self._norm_workers = None @classmethod def _split_data_indices(cls, @@ -1031,7 +1038,7 @@ def _normalize_data(self, data, val_data, feature_index, mean, std): logger.debug(f'Finished normalizing {self.features[feature_index]} ' f'with mean {mean:.3e} and std {std:.3e}.') - def _normalize(self, data, val_data, max_workers=None): + def _normalize(self, data, val_data, features=None, max_workers=None): """Normalize all data features Parameters @@ -1042,27 +1049,31 @@ def _normalize(self, data, val_data, max_workers=None): val_data : np.ndarray Array of validation data. (spatial_1, spatial_2, temporal, n_features) + features : list | None + List of features used for indexing data array during normalization. max_workers : int | None Number of workers to use in thread pool for nomalization. """ + if features is None: + features = self.features - msg1 = (f'Not all feature names {self.features} were found in ' + msg1 = (f'Not all feature names {features} were found in ' f'self.means: {list(self.means.keys())}') - msg2 = (f'Not all feature names {self.features} were found in ' + msg2 = (f'Not all feature names {features} were found in ' f'self.stds: {list(self.stds.keys())}') - assert all(fn in self.means for fn in self.features), msg1 - assert all(fn in self.stds for fn in self.features), msg2 + assert all(fn in self.means for fn in features), msg1 + assert all(fn in self.stds for fn in features), msg2 - logger.info(f'Normalizing {data.shape[-1]} features: {self.features}') + logger.info(f'Normalizing {data.shape[-1]} features: {features}') if max_workers == 1: - for idf, feature in enumerate(self.features): + for idf, feature in enumerate(features): self._normalize_data(data, val_data, idf, self.means[feature], self.stds[feature]) else: with ThreadPoolExecutor(max_workers=max_workers) as exe: futures = [] - for idf, feature in enumerate(self.features): + for idf, feature in enumerate(features): future = exe.submit(self._normalize_data, data, val_data, idf, self.means[feature], @@ -1077,3 +1088,88 @@ def _normalize(self, data, val_data, max_workers=None): f'{futures[future]}.') logger.exception(msg) raise RuntimeError(msg) from e + + @property + def means(self): + """Get the mean values for each feature. + + Returns + ------- + dict + """ + self._get_stats() + return self._means + + @property + def stds(self): + """Get the standard deviation values for each feature. + + Returns + ------- + dict + """ + self._get_stats() + return self._stds + + def _get_stats(self, features=None): + """Get the mean/stdev for each feature in the data handler.""" + if features is None: + features = self.features + if self._means is None or self._stds is None: + msg = (f'DataHandler has {len(features)} features ' + f'and mismatched shape of {self.shape}') + assert len(features) == self.shape[-1], msg + self._stds = {} + self._means = {} + for idf, fname in enumerate(features): + self._means[fname] = np.nanmean( + self.data[..., idf].astype(np.float64)) + self._stds[fname] = np.nanstd( + self.data[..., idf].astype(np.float64)) + + def normalize(self, means=None, stds=None, features=None, + max_workers=None): + """Normalize all data features. + + Parameters + ---------- + means : dict | none + Dictionary of means for all features with keys: feature names and + values: mean values. If this is None, the self.means attribute will + be used. If this is not None, this DataHandler object means + attribute will be updated. + stds : dict | none + dictionary of standard deviation values for all features with keys: + feature names and values: standard deviations. If this is None, the + self.stds attribute will be used. If this is not None, this + DataHandler object stds attribute will be updated. + features : list | None + List of features used for indexing data array during normalization. + max_workers : None | int + Max workers to perform normalization. if None, self.norm_workers + will be used + """ + if means is not None: + self._means = means + if stds is not None: + self._stds = stds + + if self._is_normalized: + logger.info('Skipping DataHandler, already normalized') + else: + self._normalize(self.data, + self.val_data, + features=features, + max_workers=max_workers) + self._is_normalized = True + + @property + def norm_workers(self): + """Get upper bound on workers used for normalization.""" + if self.data is not None: + norm_workers = estimate_max_workers(self._norm_workers, + 2 * self.feature_mem, + self.shape[-1]) + else: + norm_workers = self._norm_workers + return norm_workers