diff --git a/examples/sup3rcc/run_configs/wind/config_fwp.json b/examples/sup3rcc/run_configs/wind/config_fwp.json index acf0e5564..16ce97b48 100755 --- a/examples/sup3rcc/run_configs/wind/config_fwp.json +++ b/examples/sup3rcc/run_configs/wind/config_fwp.json @@ -1,13 +1,12 @@ { "file_paths": "PLACEHOLDER", "model_kwargs": { - "spatial_model_dirs": [ + "model_dirs": [ "./sup3rcc_models_202303/sup3rcc_wind_step1_5x_1x_6f/", - "./sup3rcc_models_202303/sup3rcc_wind_step2_5x_1x_6f/" - ], - "temporal_model_dirs": "./sup3rcc_models_202303/sup3rcc_wind_step3_1x_24x_6f/" + "./sup3rcc_models_202303/sup3rcc_wind_step2_5x_1x_6f/", + "./sup3rcc_models_202303/sup3rcc_wind_step3_1x_24x_6f/"] }, - "model_class": "SpatialThenTemporalGan", + "model_class": "MultiStepGan", "out_pattern": "./chunks/sup3r_chunk_{file_id}.h5", "log_pattern": "./logs/sup3r_fwp_log_{node_index}.log", "bias_correct_method": "monthly_local_linear_bc", diff --git a/sup3r/bias/bias_calc_cli.py b/sup3r/bias/bias_calc_cli.py index 9e34ee571..d20b16f09 100644 --- a/sup3r/bias/bias_calc_cli.py +++ b/sup3r/bias/bias_calc_cli.py @@ -3,15 +3,15 @@ sup3r bias correction calculation CLI entry points. """ import copy -import click import logging import os +import click + import sup3r.bias.bias_calc from sup3r.utilities import ModuleName -from sup3r.version import __version__ from sup3r.utilities.cli import BaseCLI - +from sup3r.version import __version__ logger = logging.getLogger(__name__) diff --git a/sup3r/models/__init__.py b/sup3r/models/__init__.py index 779b65cf0..1179231d9 100644 --- a/sup3r/models/__init__.py +++ b/sup3r/models/__init__.py @@ -4,16 +4,9 @@ from .conditional_moments import Sup3rCondMom from .data_centric import Sup3rGanDC from .linear import LinearInterp -from .multi_step import ( - MultiStepGan, - MultiStepSurfaceMetGan, - SolarMultiStepGan, - SpatialThenTemporalGan, - TemporalThenSpatialGan, -) +from .multi_step import MultiStepGan, MultiStepSurfaceMetGan, SolarMultiStepGan from .solar_cc import SolarCC from .surface import SurfaceSpatialMetModel -SPATIAL_FIRST_MODELS = (SpatialThenTemporalGan, - MultiStepSurfaceMetGan, +SPATIAL_FIRST_MODELS = (MultiStepSurfaceMetGan, SolarMultiStepGan) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 49d6fa177..79dd05ede 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -130,9 +130,15 @@ def t_enhance(self): @property def input_resolution(self): - """Resolution of input data. Given as a dictionary {'spatial':..., - 'temporal':...}""" - return self.meta.get('input_resolution', None) + """Resolution of input data. Given as a dictionary {'spatial': '...km', + 'temporal': '...min'}. The numbers are required to be integers in the + units specified. The units are not strict as long as the resolution + of the exogenous data, when extracting exogenous data, is specified + in the same units.""" + input_resolution = self.meta.get('input_resolution', None) + msg = 'model.input_resolution is None. This needs to be set.' + assert input_resolution is not None, msg + return input_resolution def _get_numerical_resolutions(self): """Get the input and output resolutions without units""" @@ -185,11 +191,11 @@ def _ensure_valid_enhancement_factors(self): @property def output_resolution(self): - """Resolution of output data. Given as a dictionary {'spatial':..., - 'temporal':...}""" + """Resolution of output data. Given as a dictionary + {'spatial': '...km', 'temporal': '...min'}. This is computed from the + input resolution and the enhancement factors.""" output_res = self.meta.get('output_resolution', None) if self.input_resolution is not None and output_res is None: - output_res = self.input_resolution.copy() ires_num, ores_num = self._get_numerical_resolutions() output_res = {k: v.replace(str(ires_num[k]), str(ores_num[k])) for k, v in self.input_resolution.items()} @@ -223,19 +229,20 @@ def _combine_fwp_input(self, low_res, exogenous_data=None): (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - low_res_shape = low_res.shape - check = (exogenous_data is not None - and low_res.shape[-1] < len(self.training_features)) - if check: - exo_data = {k: v for k, v in exogenous_data.items() - if k in self.training_features} - for i, (feature, entry) in enumerate(exo_data.items()): - f_idx = low_res_shape[-1] + i - training_feature = self.training_features[f_idx] - msg = ('The ordering of features in exogenous_data conflicts ' - 'with the ordering of training features. Received ' - f'{feature} instead of {training_feature}.') - assert feature == training_feature, msg + if exogenous_data is None: + return low_res + + training_features = ([] if self.training_features is None + else self.training_features) + fnum_diff = len(training_features) - low_res.shape[-1] + exo_feats = ([] if fnum_diff <= 0 + else self.training_features[-fnum_diff:]) + msg = ('Provided exogenous_data is missing some required features ' + f'({exo_feats})') + assert all(feature in exogenous_data for feature in exo_feats), msg + if exogenous_data is not None and fnum_diff > 0: + for feature in exo_feats: + entry = exogenous_data[feature] combine_types = [step['combine_type'] for step in entry['steps']] if 'input' in combine_types: @@ -272,19 +279,20 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None): (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - hi_res_shape = hi_res.shape[-1] - check = (exogenous_data is not None - and hi_res.shape[-1] < len(self.output_features)) - if check: - exo_data = {k: v for k, v in exogenous_data.items() - if k in self.training_features} - for i, (feature, entry) in enumerate(exo_data.items()): - f_idx = hi_res_shape[-1] + i - training_feature = self.training_features[f_idx] - msg = ('The ordering of features in exogenous_data conflicts ' - 'with the ordering of training features. Received ' - f'{feature} instead of {training_feature}.') - assert feature == training_feature, msg + if exogenous_data is None: + return hi_res + + output_features = ([] if self.output_features is None + else self.output_features) + fnum_diff = len(output_features) - hi_res.shape[-1] + exo_feats = ([] if fnum_diff <= 0 + else self.output_features[-fnum_diff:]) + msg = ('Provided exogenous_data is missing some required features ' + f'({exo_feats})') + assert all(feature in exogenous_data for feature in exo_feats), msg + if exogenous_data is not None and fnum_diff > 0: + for feature in exo_feats: + entry = exogenous_data[feature] combine_types = [step['combine_type'] for step in entry['steps']] if 'output' in combine_types: @@ -318,27 +326,6 @@ def _combine_loss_input(self, high_res_true, high_res_gen): high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1) return high_res_gen - def _get_exo_val_loss_input(self, high_res): - """Get exogenous feature data from high_res - - Parameters - ---------- - high_res : tf.Tensor - Ground truth high resolution spatiotemporal data. - - Returns - ------- - exo_data : dict - Dictionary of exogenous feature data used as input to tf_generate. - e.g. {'topography': np.ndarray(...)} - """ - exo_data = {} - for feature in self.exogenous_features: - f_idx = self.training_features.index(feature) - exo_fdata = high_res[..., f_idx: f_idx + 1] - exo_data[feature] = exo_fdata - return exo_data - @property def exogenous_features(self): """Get list of exogenous filter names the model uses. If the model has @@ -810,6 +797,27 @@ def load_saved_params(out_dir, verbose=True): return params + def get_exo_loss_input(self, high_res): + """Get exogenous feature data from high_res + + Parameters + ---------- + high_res : tf.Tensor + Ground truth high resolution spatiotemporal data. + + Returns + ------- + exo_data : dict + Dictionary of exogenous feature data used as input to tf_generate. + e.g. {'topography': tf.Tensor(...)} + """ + exo_data = {} + for feature in self.exogenous_features: + f_idx = self.training_features.index(feature) + exo_fdata = high_res[..., f_idx: f_idx + 1] + exo_data[feature] = exo_fdata + return exo_data + @staticmethod def get_loss_fun(loss): """Get the initialized loss function class from the sup3r loss library @@ -1290,9 +1298,7 @@ def generate(self, (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - low_res = self._combine_fwp_input(low_res, exogenous_data) - if norm_in and self._means is not None: low_res = self.norm_input(low_res) @@ -1408,15 +1414,11 @@ def get_single_grad(self, loss_details : dict Namespace of the breakdown of loss components """ - hi_res_exo = {} - for feature in self.exogenous_features: - f_idx = self.training_features.index(feature) - hi_res_exo[feature] = hi_res_true[..., f_idx: f_idx + 1] - with tf.device(device_name): with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(training_weights) + hi_res_exo = self.get_exo_loss_input(hi_res_true) hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss_out = self.calc_loss(hi_res_true, hi_res_gen, **calc_loss_kwargs) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 30d6ab22d..b6fa7bc52 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -602,7 +602,7 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - val_exo_data = self._get_exo_val_loss_input(val_batch.high_res) + val_exo_data = self.get_exo_loss_input(val_batch.high_res) high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data) _, v_loss_details = self.calc_loss( val_batch.high_res, high_res_gen, diff --git a/sup3r/models/conditional_moments.py b/sup3r/models/conditional_moments.py index af7835e9d..b1415f616 100644 --- a/sup3r/models/conditional_moments.py +++ b/sup3r/models/conditional_moments.py @@ -283,7 +283,7 @@ def calc_val_loss(self, batch_handler, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - val_exo_data = self._get_exo_val_loss_input(val_batch.high_res) + val_exo_data = self.get_exo_loss_input(val_batch.high_res) output_gen = self._tf_generate(val_batch.low_res, val_exo_data) _, v_loss_details = self.calc_loss( val_batch.output, output_gen, val_batch.mask) diff --git a/sup3r/models/data_centric.py b/sup3r/models/data_centric.py index 2afc75896..2f08fef17 100644 --- a/sup3r/models/data_centric.py +++ b/sup3r/models/data_centric.py @@ -38,7 +38,7 @@ def calc_val_loss_gen(self, batch_handler, weight_gen_advers): """ losses = [] for obs in batch_handler.val_data: - exo_data = self._get_exo_val_loss_input(obs.high_res) + exo_data = self.get_exo_loss_input(obs.high_res) gen = self._tf_generate(obs.low_res, exo_data) loss, _ = self.calc_loss(obs.high_res, gen, weight_gen_advers=weight_gen_advers, @@ -66,7 +66,7 @@ def calc_val_loss_gen_content(self, batch_handler): """ losses = [] for obs in batch_handler.val_data: - exo_data = self._get_exo_val_loss_input(obs.high_res) + exo_data = self.get_exo_loss_input(obs.high_res) gen = self._tf_generate(obs.low_res, exo_data) loss = self.calc_loss_gen_content(obs.high_res, gen) losses.append(float(loss)) diff --git a/sup3r/models/multi_step.py b/sup3r/models/multi_step.py index 58fe80f3e..c76fd47df 100644 --- a/sup3r/models/multi_step.py +++ b/sup3r/models/multi_step.py @@ -426,244 +426,7 @@ def _split_exo_dict(self, split_step, exogenous_data=None): return split_exo_1, split_exo_2 -class SpatialThenTemporalGan(SpatialThenTemporalBase): - """A two-step GAN where the first step is a spatial-only enhancement on a - 4D tensor and the second step is a (spatio)temporal enhancement on a 5D - tensor. - - NOTE: The low res input to the spatial enhancement should be a 4D tensor of - the shape (temporal, spatial_1, spatial_2, features) where temporal - (usually the observation index) is a series of sequential timesteps that - will be transposed to a 5D tensor of shape - (1, spatial_1, spatial_2, temporal, features) tensor and then fed to the - 2nd-step (spatio)temporal GAN. - """ - - @property - def models(self): - """Get an ordered tuple of the Sup3rGan models that are part of this - MultiStepGan - """ - if isinstance(self.spatial_models, MultiStepGan): - spatial_models = self.spatial_models.models - else: - spatial_models = [self.spatial_models] - if isinstance(self.temporal_models, MultiStepGan): - temporal_models = self.temporal_models.models - else: - temporal_models = [self.temporal_models] - - return (*spatial_models, *temporal_models) - - @property - def meta(self): - """Get a tuple of meta data dictionaries for all models - - Returns - ------- - tuple - """ - if isinstance(self.spatial_models, MultiStepGan): - spatial_models = self.spatial_models.meta - else: - spatial_models = [self.spatial_models.meta] - if isinstance(self.temporal_models, MultiStepGan): - temporal_models = self.temporal_models.meta - else: - temporal_models = [self.temporal_models.meta] - return (*spatial_models, *temporal_models) - - def generate(self, low_res, norm_in=True, un_norm_out=True, - exogenous_data=None): - """Use the generator model to generate high res data from low res - input. This is the public generate function. - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data to the 1st step spatial GAN, which is a - 4D array of shape: (temporal, spatial_1, spatial_2, n_features) - norm_in : bool - Flag to normalize low_res input data if the self.means, - self.stdevs attributes are available. The generator should always - received normalized data with mean=0 stdev=1. - un_norm_out : bool - Flag to un-normalize synthetically generated output data to physical - units - exogenous_data : list - List of arrays of exogenous_data with length equal to the - number of model steps. e.g. If we want to include topography as - an exogenous feature in a spatial + temporal multistep model then - we need to provide a list of length=2 with topography at the low - spatial resolution and at the high resolution. If we include more - than one exogenous feature the ordering must be consistent. - Each array in the list has 3D or 4D shape: - (spatial_1, spatial_2, n_features) - (temporal, spatial_1, spatial_2, n_features) - - Returns - ------- - hi_res : ndarray - Synthetically generated high-resolution data output from the 2nd - step (spatio)temporal GAN with a 5D array shape: - (1, spatial_1, spatial_2, n_temporal, n_features) - """ - logger.debug('Data input to the 1st step spatial-only ' - 'enhancement has shape {}'.format(low_res.shape)) - s_exo, t_exo = self._split_exo_dict( - split_step=len(self.spatial_models), exogenous_data=exogenous_data) - try: - hi_res = self.spatial_models.generate( - low_res, norm_in=norm_in, un_norm_out=True, - exogenous_data=s_exo) - except Exception as e: - msg = ('Could not run the 1st step spatial-only GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - logger.debug('Data output from the 1st step spatial-only ' - 'enhancement has shape {}'.format(hi_res.shape)) - hi_res = np.transpose(hi_res, axes=(1, 2, 0, 3)) - hi_res = np.expand_dims(hi_res, axis=0) - logger.debug('Data from the 1st step spatial-only enhancement has ' - 'been reshaped to {}'.format(hi_res.shape)) - - try: - hi_res = self.temporal_models.generate( - hi_res, norm_in=True, un_norm_out=un_norm_out, - exogenous_data=t_exo) - except Exception as e: - msg = ('Could not run the 2nd step (spatio)temporal GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - logger.debug('Final multistep GAN output has shape: {}' - .format(hi_res.shape)) - - return hi_res - - -class TemporalThenSpatialGan(SpatialThenTemporalBase): - """A two-step GAN where the first step is a spatiotemporal enhancement on a - 5D tensor and the second step is a spatial enhancement on a 4D tensor. - """ - - @property - def models(self): - """Get an ordered tuple of the Sup3rGan models that are part of this - MultiStepGan - """ - if isinstance(self.spatial_models, MultiStepGan): - spatial_models = self.spatial_models.models - else: - spatial_models = [self.spatial_models] - if isinstance(self.temporal_models, MultiStepGan): - temporal_models = self.temporal_models.models - else: - temporal_models = [self.temporal_models] - - return (*temporal_models, *spatial_models) - - @property - def meta(self): - """Get a tuple of meta data dictionaries for all models - - Returns - ------- - tuple - """ - if isinstance(self.spatial_models, MultiStepGan): - spatial_models = self.spatial_models.meta - else: - spatial_models = [self.spatial_models.meta] - if isinstance(self.temporal_models, MultiStepGan): - temporal_models = self.temporal_models.meta - else: - temporal_models = [self.temporal_models.meta] - - return (*temporal_models, *spatial_models) - - def generate(self, low_res, norm_in=True, un_norm_out=True, - exogenous_data=None): - """Use the generator model to generate high res data from low res - input. This is the public generate function. - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, a 5D array of shape: - (1, spatial_1, spatial_2, n_temporal, n_features) - norm_in : bool - Flag to normalize low_res input data if the self.means, - self.stdevs attributes are available. The generator should always - received normalized data with mean=0 stdev=1. - un_norm_out : bool - Flag to un-normalize synthetically generated output data to physical - units - exogenous_data : list - List of arrays of exogenous_data with length equal to the - number of model steps. e.g. If we want to include topography as - an exogenous feature in a temporal + spatial multistep model then - we need to provide a list of length=2 with topography at the low - spatial resolution and at the high resolution. If we include more - than one exogenous feature the ordering must be consistent. - Each array in the list has 3D or 4D shape: - (spatial_1, spatial_2, n_features) - (temporal, spatial_1, spatial_2, n_features) - - Returns - ------- - hi_res : ndarray - Synthetically generated high-resolution data output from the 2nd - step (spatio)temporal GAN with a 5D array shape: - (1, spatial_1, spatial_2, n_temporal, n_features) - """ - logger.debug('Data input to the 1st step (spatio)temporal ' - 'enhancement has shape {}'.format(low_res.shape)) - t_exo, s_exo = self._split_exo_dict( - split_step=len(self.temporal_models), - exogenous_data=exogenous_data) - - assert low_res.shape[0] == 1, 'Low res input can only have 1 obs!' - - try: - hi_res = self.temporal_models.generate( - low_res, norm_in=norm_in, un_norm_out=True, - exogenous_data=t_exo) - except Exception as e: - msg = ('Could not run the 1st step (spatio)temporal GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - logger.debug('Data output from the 1st step (spatio)temporal ' - 'enhancement has shape {}'.format(hi_res.shape)) - hi_res = np.transpose(hi_res[0], axes=(2, 0, 1, 3)) - logger.debug('Data from the 1st step (spatio)temporal enhancement has ' - 'been reshaped to {}'.format(hi_res.shape)) - - try: - hi_res = self.spatial_models.generate( - hi_res, norm_in=True, un_norm_out=un_norm_out, - exogenous_data=s_exo) - except Exception as e: - msg = ('Could not run the 2nd step spatial GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - hi_res = np.transpose(hi_res, axes=(1, 2, 0, 3)) - hi_res = np.expand_dims(hi_res, axis=0) - - logger.debug('Final multistep GAN output has shape: {}' - .format(hi_res.shape)) - - return hi_res - - -class MultiStepSurfaceMetGan(SpatialThenTemporalGan): +class MultiStepSurfaceMetGan(MultiStepGan): """A two-step GAN where the first step is a spatial-only enhancement on a 4D tensor of near-surface temperature and relative humidity data, and the second step is a (spatio)temporal enhancement on a 5D tensor. @@ -777,10 +540,12 @@ def load(cls, surface_model_class='SurfaceSpatialMetModel', t_models = TemporalModelClass.load(verbose=verbose, **temporal_model_kwargs) - return cls(s_models, t_models) + s_models = getattr(s_models, 'models', [s_models]) + t_models = getattr(t_models, 'models', [t_models]) + return cls([*s_models, *t_models]) -class SolarMultiStepGan(SpatialThenTemporalGan): +class SolarMultiStepGan(SpatialThenTemporalBase): """Special multi step model for solar clearsky ratio super resolution. This model takes in two parallel models for wind-only and solar-only @@ -796,11 +561,11 @@ def __init__(self, spatial_solar_models, spatial_wind_models, ---------- spatial_solar_models : MultiStepGan A loaded MultiStepGan object representing the one or more spatial - super resolution steps in this composite SpatialThenTemporalGan + super resolution steps in this composite MultiStepGan model that inputs and outputs clearsky_ratio spatial_wind_models : MultiStepGan A loaded MultiStepGan object representing the one or more spatial - super resolution steps in this composite SpatialThenTemporalGan + super resolution steps in this composite MultiStepGan model that inputs and outputs wind u/v features and must include U_200m + V_200m as output features. temporal_solar_models : MultiStepGan @@ -869,13 +634,13 @@ def preflight(self): @property def spatial_models(self): - """Alias for spatial_solar_models to preserve SpatialThenTemporalGan + """Alias for spatial_solar_models to preserve MultiStepGan interface.""" return self.spatial_solar_models @property def temporal_models(self): - """Alias for temporal_solar_models to preserve SpatialThenTemporalGan + """Alias for temporal_solar_models to preserve MultiStepGan interface.""" return self.temporal_solar_models diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index c5c2a9c79..4a96257be 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -7,7 +7,6 @@ import copy import logging import os -import re import warnings from concurrent.futures import as_completed from datetime import datetime as dt @@ -53,8 +52,7 @@ def __init__(self, s_enhancements, t_enhancements, spatial_pad, - temporal_pad, - ): + temporal_pad): """ Parameters ---------- @@ -610,8 +608,7 @@ def __init__(self, exo_kwargs=None, bias_correct_method=None, bias_correct_kwargs=None, - max_nodes=None, - ): + max_nodes=None): """Use these inputs to initialize data handlers on different nodes and to define the size of the data chunks that will be passed through the generator. @@ -700,7 +697,11 @@ def __init__(self, be used in the model. e.g. {'topography': {'file_paths': 'path to input files', 'source_file': 'path to exo data', 'exo_resolution': {'spatial': '1km', 'temporal': None}, 'steps': [{'model': 0, - 'combine_type': 'input'}, {'model': 0, 'combine_type': 'layer'}]} + 'combine_type': 'input'}, {'model': 0, 'combine_type': 'layer'}]}. + Each step can also include s_agg_factor/t_agg_factor to manually + set aggregation of exogenous_data. If they are not specified they + will be computed based on the model resolution and the exogenous + data resolution. bias_correct_method : str | None Optional bias correction function name that can be imported from the :mod:`sup3r.bias.bias_transforms` module. This will transform @@ -848,8 +849,7 @@ def init_handler(self): out = self.input_handler_class(self.file_paths[0], [], target=self.target, shape=self.grid_shape, - worker_kwargs=dict(ti_workers=1), - ) + worker_kwargs=dict(ti_workers=1)) self._init_handler = out return self._init_handler @@ -1093,7 +1093,7 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.max_workers = strategy.max_workers self.pass_workers = strategy.pass_workers self.output_workers = strategy.output_workers - self.exo_kwargs = self.update_exo_extract_kwargs(strategy.exo_kwargs) + self.exo_kwargs = strategy.exo_kwargs self.exo_features = ([] if not self.exo_kwargs else list(self.exo_kwargs)) self.exogenous_data = self.load_exo_data() @@ -1120,169 +1120,6 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.unpadded_input_data = self.data_handler.data[self.lr_slice[0], self.lr_slice[1]] - def _get_res_ratio(self, input_res, exo_res): - """Compute resolution ratio given input and output resolution - - Parameters - ---------- - input_res : str | None - Input resolution. e.g. '30km' or '60min' - exo_res : str | None - Exo resolution. e.g. '1km' or '5min' - - Returns - ------- - res_ratio : int | None - Ratio of input / exo resolution - """ - ires_num = (None if input_res is None - else int(re.search(r'\d+', input_res).group(0))) - eres_num = (None if exo_res is None - else int(re.search(r'\d+', exo_res).group(0))) - i_units = (None if input_res is None - else input_res.replace(str(ires_num), '')) - e_units = (None if exo_res is None - else exo_res.replace(str(eres_num), '')) - msg = 'Received conflicting units for input and exo resolution' - if e_units is not None: - assert i_units == e_units, msg - if ires_num is not None and eres_num is not None: - res_ratio = int(ires_num / eres_num) - else: - res_ratio = None - return res_ratio - - def get_agg_factors(self, input_res, exo_res): - """Compute aggregation ratio for exo data given input and output - resolution - - Parameters - ---------- - input_res : dict | None - Input resolution. e.g. {'spatial': '30km', 'temporal': '60min'} - exo_res : dict | None - Exogenous data resolution. e.g. - {'spatial': '1km', 'temporal': '5min'} - - Returns - ------- - s_agg_factor : int - Spatial aggregation factor for exogenous data extraction. - t_agg_factor : int - Temporal aggregation factor for exogenous data extraction. - """ - input_s_res = None if input_res is None else input_res['spatial'] - exo_s_res = None if exo_res is None else exo_res['spatial'] - s_res_ratio = self._get_res_ratio(input_s_res, exo_s_res) - s_agg_factor = None if s_res_ratio is None else int(s_res_ratio)**2 - input_t_res = None if input_res is None else input_res['temporal'] - exo_t_res = None if exo_res is None else exo_res['temporal'] - t_agg_factor = self._get_res_ratio(input_t_res, exo_t_res) - return s_agg_factor, t_agg_factor - - def _get_single_step_agg_and_enhancement(self, step_dict, exo_resolution): - """Compute agg and enhancement factors for exogenous data extraction - using exo_kwargs single model step. These factors are computed using - exo_resolution and the input/output resolution of each model step. - - Parameters - ---------- - step_dict : dict - Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} - exo_resolution : dict - Resolution of exogenous data. e.g. {'temporal': 15min, 'spatial': - '1km'} - - Returns - ------- - updated_step_dict : dict - Same as input dictionary with s_agg_factor, t_agg_factor, - s_enhance, t_enhance added - """ - model_step = step_dict['model'] - s_enhance = None - t_enhance = None - s_agg_factor = None - t_agg_factor = None - models = getattr(self.model, 'models', [self.model]) - msg = (f'Model index from exo_kwargs ({model_step} exceeds number ' - f'of model steps ({len(models)})') - assert len(models) > model_step, msg - model = models[model_step] - input_res = model.input_resolution - output_res = model.output_resolution - combine_type = step_dict.get('combine_type', None) - - if combine_type.lower() == 'input': - if model_step == 0: - s_enhance = 1 - t_enhance = 1 - else: - s_enhance = np.product( - self.strategy.s_enhancements[:model_step]) - t_enhance = np.product( - self.strategy.t_enhancements[:model_step]) - s_agg_factor, t_agg_factor = self.get_agg_factors( - input_res, exo_resolution) - resolution = input_res - - elif combine_type.lower() in ('output', 'layer'): - s_enhance = np.product( - self.strategy.s_enhancements[:model_step + 1]) - t_enhance = np.product( - self.strategy.t_enhancements[:model_step + 1]) - s_agg_factor, t_agg_factor = self.get_agg_factors( - output_res, exo_resolution) - resolution = output_res - - else: - msg = ('Received exo_kwargs entry without valid combine_type ' - '(input/layer/output)') - raise OSError(msg) - - updated_dict = step_dict.copy() - updated_dict.update({'s_enhance': s_enhance, 't_enhance': t_enhance, - 's_agg_factor': s_agg_factor, - 't_agg_factor': t_agg_factor, - 'resolution': resolution}) - return updated_dict - - def update_exo_extract_kwargs(self, exo_kwargs): - """Compute agg and enhancement factors for all model steps for all - features. - - Parameters - ---------- - exo_kwargs: dict - Full exo_kwargs dictionary with all feature entries. - e.g. {'topography': {'exo_resolution': {'spatial': '1km', - 'temporal': None}, 'steps': [{'model': 0, 'combine_type': 'input'}, - {'model': 0, 'combine_type': 'layer'}]}} - - Returns - ------- - updated_dict : dict - Same as input dictionary with s_agg_factor, t_agg_factor, - s_enhance, t_enhance added to each step entry for all features - """ - if exo_kwargs is not None: - for feature in exo_kwargs: - exo_resolution = exo_kwargs[feature]['exo_resolution'] - steps = exo_kwargs[feature]['steps'] - for i, step in enumerate(steps): - out = self._get_single_step_agg_and_enhancement( - step, exo_resolution) - exo_kwargs[feature]['steps'][i] = out - exo_kwargs[feature]['s_agg_factors'] = [step['s_agg_factor'] - for step in steps] - exo_kwargs[feature]['t_agg_factors'] = [step['t_agg_factor'] - for step in steps] - exo_kwargs[feature]['s_enhancements'] = [step['s_enhance'] - for step in steps] - exo_kwargs[feature]['t_enhancements'] = [step['t_enhance'] - for step in steps] - return exo_kwargs - def load_exo_data(self): """Extract exogenous data for each exo feature and store data in dictionary with key for each exo feature @@ -1304,12 +1141,13 @@ def load_exo_data(self): exo_kwargs['target'] = self.target exo_kwargs['shape'] = self.shape exo_kwargs['temporal_slice'] = self.ti_pad_slice - steps = exo_kwargs['steps'] + exo_kwargs['models'] = getattr(self.model, 'models', + [self.model]) sig = signature(ExogenousDataHandler) exo_kwargs = {k: v for k, v in exo_kwargs.items() if k in sig.parameters} data = ExogenousDataHandler(**exo_kwargs).data - for i, _ in enumerate(steps): + for i, _ in enumerate(exo_kwargs['steps']): exo_data[feature]['steps'][i]['data'] = data[i] shapes = [None if d is None else d.shape for d in data] logger.info( @@ -1602,11 +1440,41 @@ def pad_width(self): return ((pad_s1_start, pad_s1_end), (pad_s2_start, pad_s2_end), (pad_t_start, pad_t_end)) - @staticmethod - def pad_source_data(input_data, - pad_width, - exo_data, - mode='reflect'): + def _get_step_enhance(self, step): + """Get enhancement factors for a given step and combine type. + + Parameters + ---------- + step : dict + Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} + + Returns + ------- + s_enhance : int + Spatial enhancement factor for given step and combine type + t_enhance : int + Temporal enhancement factor for given step and combine type + """ + combine_type = step['combine_type'] + model_step = step['model'] + if combine_type.lower() == 'input': + if model_step == 0: + s_enhance = 1 + t_enhance = 1 + else: + s_enhance = np.product( + self.strategy.s_enhancements[:model_step]) + t_enhance = np.product( + self.strategy.t_enhancements[:model_step]) + + elif combine_type.lower() in ('output', 'layer'): + s_enhance = np.product( + self.strategy.s_enhancements[:model_step + 1]) + t_enhance = np.product( + self.strategy.t_enhancements[:model_step + 1]) + return s_enhance, t_enhance + + def pad_source_data(self, input_data, pad_width, exo_data, mode='reflect'): """Pad the edges of the source data from the data handler. Parameters @@ -1647,12 +1515,13 @@ def pad_source_data(input_data, if exo_data is not None: for feature in exo_data: for i, step in enumerate(exo_data[feature]['steps']): - exo_pad_width = ((step['s_enhance'] * pad_width[0][0], - step['s_enhance'] * pad_width[0][1]), - (step['s_enhance'] * pad_width[1][0], - step['s_enhance'] * pad_width[1][1]), - (step['t_enhance'] * pad_width[2][0], - step['t_enhance'] * pad_width[2][1]), + s_enhance, t_enhance = self._get_step_enhance(step) + exo_pad_width = ((s_enhance * pad_width[0][0], + s_enhance * pad_width[0][1]), + (s_enhance * pad_width[1][0], + s_enhance * pad_width[1][1]), + (t_enhance * pad_width[2][0], + t_enhance * pad_width[2][1]), (0, 0)) new_exo = np.pad(step['data'], exo_pad_width, mode=mode) exo_data[feature]['steps'][i]['data'] = new_exo diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index 884d4a7c4..d77906544 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -2,9 +2,12 @@ import logging import os import pickle +import re import shutil from typing import ClassVar +import numpy as np + from sup3r.preprocessing.data_handling import exo_extraction from sup3r.preprocessing.data_handling.exo_extraction import ( SzaExtract, @@ -35,10 +38,9 @@ class ExogenousDataHandler: def __init__(self, file_paths, feature, - s_enhancements, - t_enhancements, - s_agg_factors, - t_agg_factors, + steps, + models=None, + exo_resolution=None, source_file=None, target=None, shape=None, @@ -60,34 +62,19 @@ def __init__(self, sup3r resolved. feature : str Exogenous feature to extract from source_h5 - s_enhancements : list - List of factors by which the Sup3rGan model will enhance the - spatial dimensions of low resolution data from file_paths input - where the total spatial enhancement is the product of these - factors. For example, if file_paths has 100km data and there are 2 - spatial enhancement steps of 4x and 5x to a nominal resolution of - 5km, s_enhancements should be [1, 4, 5] and exo_steps should be - [0, 1, 2] so that the input to the 4x model gets exogenous data - at 100km (s_enhance=1, exo_step=0), the input to the 5x model gets - exogenous data at 25km (s_enhance=4, exo_step=1), and there is a - 20x (1*4*5) exogeneous data layer available if the second model can - receive a high-res input feature. The length of this list should be - equal to the number of s_agg_factors - t_enhancements : list - List of factors by which the Sup3rGan model will enhance the - temporal dimension of low resolution data from file_paths input - where the total temporal enhancement is the product of these - factors. - s_agg_factors : list - List of factors by which to aggregate the exo_source - data to the spatial resolution of the file_paths input enhanced by - s_enhance. The length of this list should be equal to the number of - s_enhancements - t_agg_factors : list - List of factors by which to aggregate the exo_source - data to the temporal resolution of the file_paths input enhanced by - t_enhance. The length of this list should be equal to the number of - t_enhancements + models : list + List of models used with the given steps list + steps : list + List of dictionaries containing info on which models to use for a + given step index and what type of exo data the step requires. e.g. + [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}] + Each step entry can also contain s_enhance, t_enhance, + s_agg_factor, t_agg_factor. If they are not included they will be + computed using exo_resolution and model attributes + exo_resolution : dict + Dictionary of spatiotemporal resolution for the given exo data + source. e.g. {'spatial': '4km', 'temporal': '60min'} source_file : str Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or 4km) data from which will be mapped to the enhanced grid of the @@ -128,10 +115,9 @@ def __init__(self, """ self.feature = feature - self.s_enhancements = s_enhancements - self.t_enhancements = t_enhancements - self.s_agg_factors = s_agg_factors - self.t_agg_factors = t_agg_factors + self.steps = steps + self.models = models + self.exo_res = exo_resolution self.source_file = source_file self.file_paths = file_paths self.exo_handler = exo_handler @@ -145,6 +131,13 @@ def __init__(self, self.cache_dir = cache_dir self.data = [] + self.input_check() + agg_enhance = self._get_all_agg_and_enhancement() + self.s_enhancements = agg_enhance['s_enhancements'] + self.t_enhancements = agg_enhance['t_enhancements'] + self.s_agg_factors = agg_enhance['s_agg_factors'] + self.t_agg_factors = agg_enhance['t_agg_factors'] + msg = ('Need to provide the same number of enhancement factors and ' f'agg factors. Received s_enhancements={self.s_enhancements}, ' f'and s_agg_factors={self.s_agg_factors}.') @@ -175,6 +168,202 @@ def __init__(self, f" Received {feature}.") raise NotImplementedError(msg) + def input_check(self): + """Make sure agg factors are provided or exo_resolution and models are + provided. Make sure enhancement factors are provided or models are + provided""" + agg_check = all('s_agg_factor' in v for v in self.steps) + agg_check = agg_check and all('t_agg_factor' in v for v in self.steps) + agg_check = (agg_check + or self.models is not None and self.exo_res is not None) + msg = ("ExogenousDataHandler needs s_agg_factor and t_agg_factor " + "provided in each step in steps list or models and " + "exo_resolution") + assert agg_check, msg + en_check = all('s_enhance' in v for v in self.steps) + en_check = en_check and all('t_enhance' in v for v in self.steps) + en_check = en_check or self.models is not None + msg = ("ExogenousDataHandler needs s_enhance and t_enhance " + "provided in each step in steps list or models") + assert en_check, msg + + def _get_res_ratio(self, input_res, exo_res): + """Compute resolution ratio given input and output resolution + + Parameters + ---------- + input_res : str | None + Input resolution. e.g. '30km' or '60min' + exo_res : str | None + Exo resolution. e.g. '1km' or '5min' + + Returns + ------- + res_ratio : int | None + Ratio of input / exo resolution + """ + ires_num = (None if input_res is None + else int(re.search(r'\d+', input_res).group(0))) + eres_num = (None if exo_res is None + else int(re.search(r'\d+', exo_res).group(0))) + i_units = (None if input_res is None + else input_res.replace(str(ires_num), '')) + e_units = (None if exo_res is None + else exo_res.replace(str(eres_num), '')) + msg = 'Received conflicting units for input and exo resolution' + if e_units is not None: + assert i_units == e_units, msg + if ires_num is not None and eres_num is not None: + res_ratio = int(ires_num / eres_num) + else: + res_ratio = None + return res_ratio + + def get_agg_factors(self, input_res, exo_res): + """Compute aggregation ratio for exo data given input and output + resolution + + Parameters + ---------- + input_res : dict | None + Input resolution. e.g. {'spatial': '30km', 'temporal': '60min'} + exo_res : dict | None + Exogenous data resolution. e.g. + {'spatial': '1km', 'temporal': '5min'} + + Returns + ------- + s_agg_factor : int + Spatial aggregation factor for exogenous data extraction. + t_agg_factor : int + Temporal aggregation factor for exogenous data extraction. + """ + input_s_res = None if input_res is None else input_res['spatial'] + exo_s_res = None if exo_res is None else exo_res['spatial'] + s_res_ratio = self._get_res_ratio(input_s_res, exo_s_res) + s_agg_factor = None if s_res_ratio is None else int(s_res_ratio)**2 + input_t_res = None if input_res is None else input_res['temporal'] + exo_t_res = None if exo_res is None else exo_res['temporal'] + t_agg_factor = self._get_res_ratio(input_t_res, exo_t_res) + return s_agg_factor, t_agg_factor + + def _get_single_step_agg(self, step): + """Compute agg factors for exogenous data extraction + using exo_kwargs single model step. These factors are computed using + exo_resolution and the input/output resolution of each model step. If + agg factors are already provided in step they are not overwritten. + + Parameters + ---------- + step : dict + Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} + + Returns + ------- + updated_step : dict + Same as input dictionary with s_agg_factor, t_agg_factor added + """ + if all(key in step for key in ['s_agg_factor', 't_agg_factor']): + return step + + model_step = step['model'] + combine_type = step.get('combine_type', None) + msg = (f'Model index from exo_kwargs ({model_step} exceeds number ' + f'of model steps ({len(self.models)})') + assert len(self.models) > model_step, msg + model = self.models[model_step] + input_res = model.input_resolution + output_res = model.output_resolution + if combine_type.lower() == 'input': + s_agg_factor, t_agg_factor = self.get_agg_factors( + input_res, self.exo_res) + + elif combine_type.lower() in ('output', 'layer'): + s_agg_factor, t_agg_factor = self.get_agg_factors( + output_res, self.exo_res) + + else: + msg = ('Received exo_kwargs entry without valid combine_type ' + '(input/layer/output)') + raise OSError(msg) + + step.update({'s_agg_factor': s_agg_factor, + 't_agg_factor': t_agg_factor}) + return step + + def _get_single_step_enhance(self, step): + """Get enhancement factors for exogenous data extraction + using exo_kwargs single model step. These factors are computed using + stored enhance attributes of each model and the model step provided. + If enhancement factors are already provided in step they are not + overwritten. + + Parameters + ---------- + step : dict + Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} + + Returns + ------- + updated_step : dict + Same as input dictionary with s_enhance, t_enhance added + """ + if all(key in step for key in ['s_enhance', 't_enhance']): + return step + + model_step = step['model'] + combine_type = step.get('combine_type', None) + msg = (f'Model index from exo_kwargs ({model_step} exceeds number ' + f'of model steps ({len(self.models)})') + assert len(self.models) > model_step, msg + + s_enhancements = [model.s_enhance for model in self.models] + t_enhancements = [model.t_enhance for model in self.models] + if combine_type.lower() == 'input': + if model_step == 0: + s_enhance = 1 + t_enhance = 1 + else: + s_enhance = np.product(s_enhancements[:model_step]) + t_enhance = np.product(t_enhancements[:model_step]) + + elif combine_type.lower() in ('output', 'layer'): + s_enhance = np.product(s_enhancements[:model_step + 1]) + t_enhance = np.product(t_enhancements[:model_step + 1]) + + else: + msg = ('Received exo_kwargs entry without valid combine_type ' + '(input/layer/output)') + raise OSError(msg) + + step.update({'s_enhance': s_enhance, 't_enhance': t_enhance}) + return step + + def _get_all_agg_and_enhancement(self): + """Compute agg and enhancement factors for all model steps for all + features. + + Returns + ------- + agg_enhance_dict : dict + Dictionary with list of agg and enhancement factors for each model + step + """ + agg_enhance_dict = {} + for i, step in enumerate(self.steps): + out = self._get_single_step_agg(step) + out = self._get_single_step_enhance(out) + self.steps[i] = out + agg_enhance_dict['s_agg_factors'] = [step['s_agg_factor'] + for step in self.steps] + agg_enhance_dict['t_agg_factors'] = [step['t_agg_factor'] + for step in self.steps] + agg_enhance_dict['s_enhancements'] = [step['s_enhance'] + for step in self.steps] + agg_enhance_dict['t_enhancements'] = [step['t_enhance'] + for step in self.steps] + return agg_enhance_dict + def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, t_agg_factor): """Get cache file name diff --git a/tests/data_handling/test_exo_data_handling.py b/tests/data_handling/test_exo_data_handling.py index 5bdfd264a..93091733c 100644 --- a/tests/data_handling/test_exo_data_handling.py +++ b/tests/data_handling/test_exo_data_handling.py @@ -27,13 +27,17 @@ def test_exo_cache(feature): """Test exogenous data caching and re-load""" # no cached data + steps = [] + for s_en, t_en, s_agg, t_agg in zip(S_ENHANCE, T_ENHANCE, S_AGG_FACTORS, + T_AGG_FACTORS): + steps.append({'s_enhance': s_en, + 't_enhance': t_en, + 's_agg_factor': s_agg, + 't_agg_factor': t_agg}) try: base = ExogenousDataHandler(FILE_PATHS, feature, source_file=FP_WTK, - s_enhancements=S_ENHANCE, - t_enhancements=T_ENHANCE, - s_agg_factors=S_AGG_FACTORS, - t_agg_factors=T_AGG_FACTORS, + steps=steps, target=TARGET, shape=SHAPE, input_handler='DataHandlerNCforCC') for i, arr in enumerate(base.data): @@ -51,10 +55,7 @@ def test_exo_cache(feature): try: cache = ExogenousDataHandler(FILE_PATHS, feature, source_file=FP_WTK, - s_enhancements=S_ENHANCE, - t_enhancements=T_ENHANCE, - s_agg_factors=S_AGG_FACTORS, - t_agg_factors=T_AGG_FACTORS, + steps=steps, target=TARGET, shape=SHAPE, input_handler='DataHandlerNCforCC') except Exception as e: diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 1518c6bc5..04f99081c 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -600,8 +600,7 @@ def test_fwp_multi_step_model(): t_enhance = 4 model_kwargs = { - 'spatial_model_dirs': s_out_dir, - 'temporal_model_dirs': st_out_dir + 'model_dirs': [s_out_dir, st_out_dir] } input_handler_kwargs = dict( @@ -613,7 +612,7 @@ def test_fwp_multi_step_model(): handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', + model_class='MultiStepGan', fwp_chunk_shape=fwp_chunk_shape, spatial_pad=0, temporal_pad=0, diff --git a/tests/forward_pass/test_forward_pass_exo.py b/tests/forward_pass/test_forward_pass_exo.py index f30b83852..1bf24b286 100644 --- a/tests/forward_pass/test_forward_pass_exo.py +++ b/tests/forward_pass/test_forward_pass_exo.py @@ -60,7 +60,7 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + st_model.meta['training_features'] = ['U_100m', 'V_100m'] st_model.meta['output_features'] = ['U_100m', 'V_100m'] st_model.meta['s_enhance'] = 3 st_model.meta['t_enhance'] = 4 @@ -99,8 +99,7 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): } model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] } out_files = os.path.join(td, 'out_{file_id}.h5') @@ -113,7 +112,7 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', + model_class='MultiStepGan', fwp_chunk_shape=fwp_chunk_shape, input_handler_kwargs=input_handler_kwargs, spatial_pad=0, @@ -317,8 +316,7 @@ def test_fwp_multi_step_model_topo_noskip(): } model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] } out_files = os.path.join(td, 'out_{file_id}.h5') @@ -331,7 +329,7 @@ def test_fwp_multi_step_model_topo_noskip(): handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', + model_class='MultiStepGan', fwp_chunk_shape=fwp_chunk_shape, spatial_pad=1, temporal_pad=1, @@ -637,10 +635,6 @@ def test_fwp_multi_step_wind_hi_res_topo(): } } - model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir - } model_kwargs = { 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] } @@ -807,8 +801,7 @@ def test_fwp_wind_hi_res_topo_plus_linear(): } model_kwargs = { - 'spatial_model_dirs': s_out_dir, - 'temporal_model_dirs': t_out_dir + 'model_dirs': [s_out_dir, t_out_dir] } out_files = os.path.join(td, 'out_{file_id}.h5') input_handler_kwargs = dict(target=target, @@ -820,7 +813,7 @@ def test_fwp_wind_hi_res_topo_plus_linear(): handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', + model_class='MultiStepGan', fwp_chunk_shape=(4, 4, 8), spatial_pad=1, temporal_pad=1, @@ -916,8 +909,7 @@ def test_fwp_multi_step_model_multi_exo(): } model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] } out_files = os.path.join(td, 'out_{file_id}.h5') @@ -930,7 +922,7 @@ def test_fwp_multi_step_model_multi_exo(): handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', + model_class='MultiStepGan', fwp_chunk_shape=fwp_chunk_shape, spatial_pad=1, temporal_pad=1, @@ -1173,8 +1165,7 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): } model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] } out_files = os.path.join(td, 'out_{file_id}.h5') input_handler_kwargs = dict(target=target, @@ -1186,7 +1177,7 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', + model_class='MultiStepGan', fwp_chunk_shape=(4, 4, 8), spatial_pad=1, temporal_pad=1, diff --git a/tests/forward_pass/test_multi_step.py b/tests/forward_pass/test_multi_step.py index f610829f6..3da773b44 100644 --- a/tests/forward_pass/test_multi_step.py +++ b/tests/forward_pass/test_multi_step.py @@ -11,7 +11,6 @@ LinearInterp, MultiStepGan, SolarMultiStepGan, - SpatialThenTemporalGan, Sup3rGan, ) @@ -70,6 +69,9 @@ def test_multi_step_norm(norm_option): model2.set_norm_stats([0.1, 0.8], [0.04, 0.02]) model3.set_norm_stats([0.1, 0.8], [0.04, 0.02]) + model1.meta['input_resolution'] = {'spatial': '27km', 'temporal': '64min'} + model2.meta['input_resolution'] = {'spatial': '9km', 'temporal': '16min'} + model3.meta['input_resolution'] = {'spatial': '3km', 'temporal': '4min'} model1.set_model_params(training_features=FEATURES, output_features=FEATURES) model2.set_model_params(training_features=FEATURES, @@ -118,6 +120,8 @@ def test_spatial_then_temporal_gan(): model1.set_norm_stats([0.1, 0.2], [0.04, 0.02]) model2.set_norm_stats([0.3, 0.9], [0.02, 0.07]) + model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '40min'} + model2.meta['input_resolution'] = {'spatial': '6km', 'temporal': '40min'} model1.set_model_params(training_features=FEATURES, output_features=FEATURES) model2.set_model_params(training_features=FEATURES, @@ -129,7 +133,7 @@ def test_spatial_then_temporal_gan(): model1.save(fp1) model2.save(fp2) - ms_model = SpatialThenTemporalGan.load(fp1, fp2) + ms_model = MultiStepGan.load([fp1, fp2]) x = np.ones((4, 10, 10, len(FEATURES))) out = ms_model.generate(x) @@ -150,6 +154,8 @@ def test_temporal_then_spatial_gan(): model1.set_norm_stats([0.1, 0.2], [0.04, 0.02]) model2.set_norm_stats([0.3, 0.9], [0.02, 0.07]) + model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '40min'} + model2.meta['input_resolution'] = {'spatial': '6km', 'temporal': '40min'} model1.set_model_params(training_features=FEATURES, output_features=FEATURES) model2.set_model_params(training_features=FEATURES, @@ -178,6 +184,7 @@ def test_spatial_gan_then_linear_interp(): model2 = LinearInterp(features=FEATURES, s_enhance=3, t_enhance=4) model1.set_norm_stats([0.1, 0.2], [0.04, 0.02]) + model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '60min'} model1.set_model_params(training_features=FEATURES, output_features=FEATURES) @@ -187,7 +194,7 @@ def test_spatial_gan_then_linear_interp(): model1.save(fp1) model2.save(fp2) - ms_model = SpatialThenTemporalGan.load(fp1, fp2) + ms_model = MultiStepGan.load([fp1, fp2]) x = np.ones((4, 10, 10, len(FEATURES))) out = ms_model.generate(x) @@ -204,6 +211,7 @@ def test_solar_multistep(): model1 = Sup3rGan(fp_gen, fp_disc) _ = model1.generate(np.ones((4, 10, 10, len(features1)))) model1.set_norm_stats([0.7], [0.04]) + model1.meta['input_resolution'] = {'spatial': '8km', 'temporal': '40min'} model1.set_model_params(training_features=features1, output_features=features1) @@ -213,6 +221,7 @@ def test_solar_multistep(): model2 = Sup3rGan(fp_gen, fp_disc) _ = model2.generate(np.ones((4, 10, 10, len(features2)))) model2.set_norm_stats([4.2, 5.6], [1.1, 1.3]) + model2.meta['input_resolution'] = {'spatial': '4km', 'temporal': '40min'} model2.set_model_params(training_features=features2, output_features=features2) @@ -223,6 +232,7 @@ def test_solar_multistep(): model3 = Sup3rGan(fp_gen, fp_disc) _ = model3.generate(np.ones((4, 10, 10, 3, len(features_in_3)))) model3.set_norm_stats([0.7, 4.2, 5.6], [0.04, 1.1, 1.3]) + model3.meta['input_resolution'] = {'spatial': '2km', 'temporal': '40min'} model3.set_model_params(training_features=features_in_3, output_features=features_out_3)