From d0575a7060fc1855bac179d70b350c06335f6b13 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 12 Sep 2023 13:39:28 -0600 Subject: [PATCH 01/15] generalized TopoExtract and renamed ExoExtract. changed exo_kwargs in forward pass to use dictionary of features. added multi exo tests in test_forward_pass.py --- sup3r/models/abstract.py | 215 +-- sup3r/models/multi_step.py | 28 +- sup3r/pipeline/forward_pass.py | 267 ++-- .../data_handling/exo_extraction.py} | 295 +++-- .../data_handling/exogenous_data_handling.py | 231 +++- .../data_handling/nc_data_handling.py | 59 +- tests/forward_pass/test_forward_pass.py | 1172 ++++++++++++----- 7 files changed, 1450 insertions(+), 817 deletions(-) rename sup3r/{utilities/topo.py => preprocessing/data_handling/exo_extraction.py} (52%) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 26d7f4be7..70a23ea78 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -14,10 +14,10 @@ import numpy as np import tensorflow as tf -from tensorflow.keras import optimizers from phygnn import CustomNetwork from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat from rex.utilities.utilities import safe_json_load +from tensorflow.keras import optimizers import sup3r.utilities.loss_metrics from sup3r.utilities import VERSION_RECORD @@ -118,6 +118,56 @@ def t_enhance(self): model training during high res coarsening""" return self.meta.get('t_enhance', None) + @property + def needs_hr_exo(self): + """Determine whether or not the sup3r model needs hi-res exogenous data + + Returns + ------- + needs_hr_exo : bool + True if the model requires high-resolution exogenous data, + typically because of the use of Sup3rAdder or Sup3rConcat layers. + """ + # pylint: disable=E1101 + return (hasattr(self, '_gen') and any( + isinstance(layer, (Sup3rAdder, Sup3rConcat)) + for layer in self._gen.layers)) + + def _needs_lr_exo(self, low_res): + """Determine whether or not the sup3r model needs low-res exogenous + data + + Parameters + ---------- + low_res : np.ndarray + Low-resolution input data, usually a 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + + Returns + ------- + needs_lr_exo : bool + True if the model requires low-resolution exogenous data. + """ + return low_res.shape[-1] < len(self.training_features) + + @property + def exogenous_features(self): + """Get list of exogenous filter names the model uses. If the model has + N concat or add layers this list will be the last N features in the + training features list. The ordering is assumed to be the same as the + order of concat or add layers. If training features is [..., topo, + sza], and the model has 2 concat or add layers, exo features will be + [topo, sza]. Topo will then be used in the first concat layer and sza + will be used in the second""" + # pylint: disable=E1101 + layer_count = 0 + if hasattr(self, '_gen'): + layer_count = sum( + isinstance(layer, (Sup3rAdder, Sup3rConcat)) + for layer in self._gen.layers) + return self.training_features[-layer_count:] + @property @abstractmethod def meta(self): @@ -189,8 +239,9 @@ def set_model_params(self, **kwargs): self.meta[var] = kwargs[var] elif val != kwargs[var]: msg = ('Model was previously trained with {var}={} but ' - 'received new {var}={}' - .format(val, kwargs[var], var=var)) + 'received new {var}={}'.format(val, + kwargs[var], + var=var)) logger.warning(msg) warn(msg) @@ -255,16 +306,14 @@ def load_network(self, model, name): self._meta[f'config_{name}'] = model if 'hidden_layers' in model: model = model['hidden_layers'] - elif ('meta' in model - and f'config_{name}' in model['meta'] + elif ('meta' in model and f'config_{name}' in model['meta'] and 'hidden_layers' in model['meta'][f'config_{name}']): model = model['meta'][f'config_{name}']['hidden_layers'] else: msg = ('Could not load model from json config, need ' '"hidden_layers" key or ' f'"meta/config_{name}/hidden_layers" ' - ' at top level but only found: {}' - .format(model.keys())) + ' at top level but only found: {}'.format(model.keys())) logger.error(msg) raise KeyError(msg) @@ -277,8 +326,8 @@ def load_network(self, model, name): if not isinstance(model, CustomNetwork): msg = ('Something went wrong. Tried to load a custom network ' - 'but ended up with a model of type "{}"' - .format(type(model))) + 'but ended up with a model of type "{}"'.format( + type(model))) logger.error(msg) raise TypeError(msg) @@ -308,23 +357,27 @@ def stdevs(self): def output_stdevs(self): """Get the data normalization standard deviation values for only the output features + Returns ------- np.ndarray """ - indices = [self.training_features.index(f) - for f in self.output_features] + indices = [ + self.training_features.index(f) for f in self.output_features + ] return self._stdevs[indices] @property def output_means(self): """Get the data normalization mean values for only the output features + Returns ------- np.ndarray """ - indices = [self.training_features.index(f) - for f in self.output_features] + indices = [ + self.training_features.index(f) for f in self.output_features + ] return self._means[indices] def set_norm_stats(self, new_means, new_stdevs): @@ -341,10 +394,10 @@ def set_norm_stats(self, new_means, new_stdevs): if self._means is not None: logger.info('Setting new normalization statistics...') - logger.info("Model's previous data mean values: {}" - .format(self._means)) - logger.info("Model's previous data stdev values: {}" - .format(self._stdevs)) + logger.info("Model's previous data mean values: {}".format( + self._means)) + logger.info("Model's previous data stdev values: {}".format( + self._stdevs)) self._means = new_means self._stdevs = new_stdevs @@ -354,10 +407,10 @@ def set_norm_stats(self, new_means, new_stdevs): if not isinstance(self._stdevs, np.ndarray): self._stdevs = np.array(self._stdevs) - logger.info('Set data normalization mean values: {}' - .format(self._means)) - logger.info('Set data normalization stdev values: {}' - .format(self._stdevs)) + logger.info('Set data normalization mean values: {}'.format( + self._means)) + logger.info('Set data normalization stdev values: {}'.format( + self._stdevs)) def norm_input(self, low_res): """Normalize low resolution data being input to the generator. @@ -459,25 +512,6 @@ def generator_weights(self): """ return self.generator.weights - def _needs_lr_exo(self, low_res): - """Determine whether or not the sup3r model needs low-res exogenous - data - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - - Returns - ------- - needs_lr_exo : bool - True if the model requires low-resolution exogenous data. - """ - - return low_res.shape[-1] < len(self.training_features) - @staticmethod def init_optimizer(optimizer, learning_rate): """Initialize keras optimizer object. @@ -500,8 +534,10 @@ def init_optimizer(optimizer, learning_rate): class_name = optimizer['name'] OptimizerClass = getattr(optimizers, class_name) sig = signature(OptimizerClass) - optimizer_kwargs = {k: v for k, v in optimizer.items() - if k in sig.parameters} + optimizer_kwargs = { + k: v + for k, v in optimizer.items() if k in sig.parameters + } optimizer = OptimizerClass.from_config(optimizer_kwargs) elif optimizer is None: optimizer = optimizers.Adam(learning_rate=learning_rate) @@ -544,8 +580,8 @@ def load_saved_params(out_dir, verbose=True): if verbose: logger.info('Loading model from disk ' 'that was created with the ' - 'following package versions: \n{}' - .format(pprint.pformat(version_record, indent=2))) + 'following package versions: \n{}'.format( + pprint.pformat(version_record, indent=2))) return params @@ -575,8 +611,8 @@ def get_loss_fun(loss): if out is None: msg = ('Could not find requested loss function "{}" in ' - 'sup3r.utilities.loss_metrics or tf.keras.losses.' - .format(loss)) + 'sup3r.utilities.loss_metrics or tf.keras.losses.'.format( + loss)) logger.error(msg) raise KeyError(msg) @@ -635,8 +671,8 @@ def update_loss_details(loss_details, new_data, batch_len, prefix=None): for key, new_value in new_data.items(): key = key if prefix is None else prefix + key - new_value = (new_value if not isinstance(new_value, tf.Tensor) - else new_value.numpy()) + new_value = (new_value if not isinstance(new_value, tf.Tensor) else + new_value.numpy()) if key in loss_details: saved_value = loss_details[key] @@ -710,8 +746,8 @@ def early_stop(history, column, threshold=0.005, n_epoch=5): stop = True logger.info('Found early stop condition, loss values "{}" ' 'have absolute relative differences less than ' - 'threshold {}: {}' - .format(column, threshold, diffs[-n_epoch:])) + 'threshold {}: {}'.format(column, threshold, + diffs[-n_epoch:])) return stop @@ -726,10 +762,17 @@ def save(self, out_dir): if it does not already exist. """ - def finish_epoch(self, epoch, epochs, t0, loss_details, - checkpoint_int, out_dir, - early_stop_on, early_stop_threshold, - early_stop_n_epoch, extras=None): + def finish_epoch(self, + epoch, + epochs, + t0, + loss_details, + checkpoint_int, + out_dir, + early_stop_on, + early_stop_threshold, + early_stop_n_epoch, + extras=None): """Perform finishing checks after an epoch is done training Parameters @@ -790,7 +833,8 @@ def finish_epoch(self, epoch, epochs, t0, loss_details, stop = False if early_stop_on is not None and early_stop_on in self._history: - stop = self.early_stop(self._history, early_stop_on, + stop = self.early_stop(self._history, + early_stop_on, threshold=early_stop_threshold, n_epoch=early_stop_n_epoch) if stop: @@ -803,8 +847,12 @@ def finish_epoch(self, epoch, epochs, t0, loss_details, return stop @tf.function() - def get_single_grad(self, low_res, hi_res_true, training_weights, - device_name=None, **calc_loss_kwargs): + def get_single_grad(self, + low_res, + hi_res_true, + training_weights, + device_name=None, + **calc_loss_kwargs): """Run gradient descent for one mini-batch of (low_res, hi_res_true), do not update weights, just return gradient details. @@ -851,8 +899,12 @@ def get_single_grad(self, low_res, hi_res_true, training_weights, return grad, loss_details - def run_gradient_descent(self, low_res, hi_res_true, training_weights, - optimizer=None, multi_gpu=False, + def run_gradient_descent(self, + low_res, + hi_res_true, + training_weights, + optimizer=None, + multi_gpu=False, **calc_loss_kwargs): # pylint: disable=E0602 """Run gradient descent for one mini-batch of (low_res, hi_res_true) @@ -918,12 +970,13 @@ def run_gradient_descent(self, low_res, hi_res_true, training_weights, for i in range(len(self.gpu_list)): if split_mask: calc_loss_kwargs['mask'] = mask_chunks[i] - futures.append(exe.submit(self.get_single_grad, - lr_chunks[i], - hr_true_chunks[i], - training_weights, - device_name=f'/gpu:{i}', - **calc_loss_kwargs)) + futures.append( + exe.submit(self.get_single_grad, + lr_chunks[i], + hr_true_chunks[i], + training_weights, + device_name=f'/gpu:{i}', + **calc_loss_kwargs)) for i, future in enumerate(futures): grad, loss_details = future.result() optimizer.apply_gradients(zip(grad, training_weights)) @@ -965,8 +1018,8 @@ def set_model_params(**kwargs): """ output_features = kwargs['output_features'] msg = ('Last output feature from the data handler must be topography ' - 'to train the WindCC model, but received output features: {}' - .format(output_features)) + 'to train the WindCC model, but received output features: {}'. + format(output_features)) assert output_features[-1] == 'topography', msg output_features.remove('topography') kwargs['output_features'] = output_features @@ -1035,7 +1088,10 @@ def _reshape_norm_topo(self, hi_res, hi_res_topo, norm_in=True): return hi_res_topo - def generate(self, low_res, norm_in=True, un_norm_out=True, + def generate(self, + low_res, + norm_in=True, + un_norm_out=True, exogenous_data=None): """Use the generator model to generate high res data from low res input. This is the public generate function. @@ -1079,8 +1135,8 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, hi_res_topo = exogenous_data[1] exo_check = (low_res is None or not self._needs_lr_exo(low_res)) - low_res = (low_res if exo_check - else np.concatenate((low_res, low_res_topo), axis=-1)) + low_res = (low_res if exo_check else np.concatenate( + (low_res, low_res_topo), axis=-1)) if norm_in and self._means is not None: low_res = self.norm_input(low_res) @@ -1090,14 +1146,15 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, try: if (isinstance(layer, (Sup3rAdder, Sup3rConcat)) and hi_res_topo is not None): - hi_res_topo = self._reshape_norm_topo(hi_res, hi_res_topo, + hi_res_topo = self._reshape_norm_topo(hi_res, + hi_res_topo, norm_in=norm_in) hi_res = layer(hi_res, hi_res_topo) else: hi_res = layer(hi_res) except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}' - .format(i + 1, layer, hi_res.shape)) + msg = ('Could not run layer #{} "{}" on tensor of shape {}'. + format(i + 1, layer, hi_res.shape)) logger.error(msg) raise RuntimeError(msg) from e @@ -1141,16 +1198,20 @@ def _tf_generate(self, low_res, hi_res_topo): else: hi_res = layer(hi_res) except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}' - .format(i + 1, layer, hi_res.shape)) + msg = ('Could not run layer #{} "{}" on tensor of shape {}'. + format(i + 1, layer, hi_res.shape)) logger.error(msg) raise RuntimeError(msg) from e return hi_res @tf.function() - def get_single_grad(self, low_res, hi_res_true, training_weights, - device_name=None, **calc_loss_kwargs): + def get_single_grad(self, + low_res, + hi_res_true, + training_weights, + device_name=None, + **calc_loss_kwargs): """Run gradient descent for one mini-batch of (low_res, hi_res_true), do not update weights, just return gradient details. diff --git a/sup3r/models/multi_step.py b/sup3r/models/multi_step.py index 43d94d848..3a0d82ff5 100644 --- a/sup3r/models/multi_step.py +++ b/sup3r/models/multi_step.py @@ -1,17 +1,16 @@ # -*- coding: utf-8 -*- """Sup3r multi step model frameworks""" -import os import json import logging +import os + import numpy as np -from phygnn.layers.custom_layers import Sup3rAdder, Sup3rConcat # pylint: disable=cyclic-import import sup3r.models from sup3r.models.abstract import AbstractInterface from sup3r.models.base import Sup3rGan - logger = logging.getLogger(__name__) @@ -33,25 +32,6 @@ def __len__(self): """Get number of model steps""" return len(self._models) - @staticmethod - def _needs_hr_exo(model): - """Determine whether or not the sup3r model needs hi-res exogenous data - - Parameters - ---------- - model : Sup3rGan | WindGan - Sup3r GAN model based on Sup3rGan with a .generator attribute - - Returns - ------- - needs_hr_exo : bool - True if the model requires high-resolution exogenous data, - typically because of the use of Sup3rAdder or Sup3rConcat layers. - """ - return (hasattr(model, 'generator') - and any(isinstance(layer, (Sup3rAdder, Sup3rConcat)) - for layer in model.generator.layers)) - @classmethod def load(cls, model_dirs, verbose=True): """Load the GANs with its sub-networks from a previously saved-to @@ -80,7 +60,7 @@ def load(cls, model_dirs, verbose=True): for model_dir in model_dirs: fp_params = os.path.join(model_dir, 'model_params.json') assert os.path.exists(fp_params), f'Could not find: {fp_params}' - with open(fp_params, 'r') as f: + with open(fp_params) as f: params = json.load(f) meta = params.get('meta', {'class': 'Sup3rGan'}) @@ -182,7 +162,7 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, else True) i_exo_data = exo_data[i] - if self._needs_hr_exo(model): + if model.needs_hr_exo: i_exo_data = [exo_data[i], exo_data[i + 1]] try: diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index bcf86fc7e..7731f0f95 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -20,7 +20,8 @@ import sup3r.bias.bias_transforms import sup3r.models from sup3r.postprocessing.file_handling import (OutputHandler, OutputHandlerH5, - OutputHandlerNC) + OutputHandlerNC, + ) from sup3r.preprocessing.data_handling import ExogenousDataHandler from sup3r.preprocessing.data_handling.base import InputMixIn from sup3r.utilities import ModuleName @@ -28,7 +29,8 @@ from sup3r.utilities.execution import DistributedProcess from sup3r.utilities.utilities import (get_chunk_slices, get_input_handler_class, - get_source_type) + get_source_type, + ) np.random.seed(42) @@ -38,17 +40,9 @@ class ForwardPassSlicer: """Get slices for sending data chunks through model.""" - def __init__( - self, - coarse_shape, - time_steps, - temporal_slice, - chunk_shape, - s_enhancements, - t_enhancements, - spatial_pad, - temporal_pad, - ): + def __init__(self, coarse_shape, time_steps, temporal_slice, chunk_shape, + s_enhancements, t_enhancements, spatial_pad, temporal_pad, + ): """ Parameters ---------- @@ -218,10 +212,7 @@ def t_lr_pad_slices(self): """ if self._t_lr_pad_slices is None: self._t_lr_pad_slices = self.get_padded_slices( - self.t_lr_slices, - self.time_steps, - 1, - self.temporal_pad, + self.t_lr_slices, self.time_steps, 1, self.temporal_pad, self.temporal_slice.step, ) return self._t_lr_pad_slices @@ -321,12 +312,9 @@ def s_lr_crop_slices(self): self.s2_lr_pad_slices, 1) for i, _ in enumerate(self.s1_lr_slices): for j, _ in enumerate(self.s2_lr_slices): - lr_crop_slice = ( - s1_crop_slices[i], - s2_crop_slices[j], - slice(None), - slice(None), - ) + lr_crop_slice = (s1_crop_slices[i], s2_crop_slices[j], + slice(None), slice(None), + ) self._s_lr_crop_slices.append(lr_crop_slice) return self._s_lr_crop_slices @@ -594,24 +582,23 @@ class ForwardPassStrategy(InputMixIn, DistributedProcess): crop generator output to stich the chunks back togerther. """ - def __init__( - self, - file_paths, - model_kwargs, - fwp_chunk_shape, - spatial_pad, - temporal_pad, - model_class='Sup3rGan', - out_pattern=None, - input_handler=None, - input_handler_kwargs=None, - incremental=True, - worker_kwargs=None, - exo_kwargs=None, - bias_correct_method=None, - bias_correct_kwargs=None, - max_nodes=None, - ): + def __init__(self, + file_paths, + model_kwargs, + fwp_chunk_shape, + spatial_pad, + temporal_pad, + model_class='Sup3rGan', + out_pattern=None, + input_handler=None, + input_handler_kwargs=None, + incremental=True, + worker_kwargs=None, + exo_kwargs=None, + bias_correct_method=None, + bias_correct_kwargs=None, + max_nodes=None, + ): """Use these inputs to initialize data handlers on different nodes and to define the size of the data chunks that will be passed through the generator. @@ -718,14 +705,13 @@ def __init__( raster_index = self._input_handler_kwargs.get('raster_index', None) temporal_slice = self._input_handler_kwargs.get( 'temporal_slice', slice(None, None, 1)) - InputMixIn.__init__( - self, - target=target, - shape=grid_shape, - raster_file=raster_file, - raster_index=raster_index, - temporal_slice=temporal_slice, - ) + InputMixIn.__init__(self, + target=target, + shape=grid_shape, + raster_file=raster_file, + raster_index=raster_index, + temporal_slice=temporal_slice, + ) self.file_paths = file_paths self.model_kwargs = model_kwargs @@ -778,22 +764,16 @@ def __init__( self.output_features = model.output_features self.fwp_slicer = ForwardPassSlicer( - self.grid_shape, - self.raw_tsteps, - self.temporal_slice, - self.fwp_chunk_shape, - self.s_enhancements, - self.t_enhancements, - self.spatial_pad, - self.temporal_pad, + self.grid_shape, self.raw_tsteps, self.temporal_slice, + self.fwp_chunk_shape, self.s_enhancements, self.t_enhancements, + self.spatial_pad, self.temporal_pad, ) - DistributedProcess.__init__( - self, - max_nodes=max_nodes, - max_chunks=self.fwp_slicer.n_chunks, - incremental=self.incremental, - ) + DistributedProcess.__init__(self, + max_nodes=max_nodes, + max_chunks=self.fwp_slicer.n_chunks, + incremental=self.incremental, + ) self.preflight() @@ -821,10 +801,9 @@ def preflight(self): logger.warning(msg) warnings.warn(msg) - hr_data_shape = ( - self.grid_shape[0] * self.s_enhance, - self.grid_shape[1] * self.s_enhance, - ) + hr_data_shape = (self.grid_shape[0] * self.s_enhance, + self.grid_shape[1] * self.s_enhance, + ) self.gids = np.arange(np.product(hr_data_shape)) self.gids = self.gids.reshape(hr_data_shape) @@ -845,13 +824,11 @@ def init_handler(self): """Get initial input handler used for extracting handler features and low res grid""" if self._init_handler is None: - out = self.input_handler_class( - self.file_paths[0], - [], - target=self.target, - shape=self.grid_shape, - worker_kwargs=dict(ti_workers=1), - ) + out = self.input_handler_class(self.file_paths[0], [], + target=self.target, + shape=self.grid_shape, + worker_kwargs=dict(ti_workers=1), + ) self._init_handler = out return self._init_handler @@ -1093,24 +1070,9 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.pass_workers = strategy.pass_workers self.output_workers = strategy.output_workers self.exo_kwargs = strategy.exo_kwargs - - self.exogenous_handler = None - self.exogenous_data = None - if self.exo_kwargs: - exo_features = self.exo_kwargs.get('features', []) - exo_kwargs = copy.deepcopy(self.exo_kwargs) - exo_kwargs['target'] = self.target - exo_kwargs['shape'] = self.shape - self.features = [f for f in self.features if f not in exo_features] - self.exogenous_handler = ExogenousDataHandler(**exo_kwargs) - self.exogenous_data = self.exogenous_handler.data - shapes = [ - None if d is None else d.shape for d in self.exogenous_data - ] - logger.info( - 'Got exogenous_data of length {} with shapes: {}'.format( - len(self.exogenous_data), shapes)) - + self.exo_features = ([] + if not self.exo_kwargs else list(self.exo_kwargs)) + self.exogenous_data = self.load_exo_data() self.input_handler_class = strategy.input_handler_class if strategy.output_type == 'nc': @@ -1129,12 +1091,40 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.input_data, self.strategy.lr_lat_lon) exo_s_en = self.exo_kwargs.get('s_enhancements', None) + self.exo_kwargs.get('s_enhancements', None) out = self.pad_source_data(self.input_data, self.pad_width, self.exogenous_data, exo_s_en) self.input_data, self.exogenous_data = out self.unpadded_input_data = self.data_handler.data[self.lr_slice[0], self.lr_slice[1]] + def load_exo_data(self): + """Extract exogenous data for each exo feature and store data in + dictionary with key for each exo feature + + Returns + ------- + exo_data : dict + Dictionary of data arrays with keys for each exogeneous feature + """ + exo_data = {} + if self.exo_kwargs: + self.features = [ + f for f in self.features if f not in self.exo_features + ] + for feature in self.exo_features: + exo_kwargs = copy.deepcopy(self.exo_kwargs[feature]) + exo_kwargs['target'] = self.target + exo_kwargs['shape'] = self.shape + exo_data[feature] = ExogenousDataHandler(**exo_kwargs).data + shapes = [ + None if d is None else d.shape for d in exo_data[feature] + ] + logger.info( + 'Got exogenous_data of length {} with shapes: {}'.format( + len(exo_data[feature]), shapes)) + return exo_data + def update_input_handler_kwargs(self, strategy): """Update the kwargs for the input handler for the current forward pass chunk @@ -1425,6 +1415,7 @@ def pad_source_data(input_data, pad_width, exo_data, exo_s_enhancements, + exo_t_enhancements, mode='reflect'): """Pad the edges of the source data from the data handler. @@ -1449,6 +1440,9 @@ def pad_source_data(input_data, exo_s_enhancements : list List of spatial enhancement factors for each step of the sup3r resolution model corresponding to the exo_data order. + exo_t_enhancements : list + List of temporal enhancement factors for each step of the sup3r + resolution model corresponding to the exo_data order. mode : str Padding mode for np.pad(). Reflect is a good default for the convolutional sup3r work. @@ -1476,11 +1470,17 @@ def pad_source_data(input_data, s for s in total_s_enhance if s is not None ] total_s_enhance = np.product(total_s_enhance) + total_t_enhance = exo_t_enhancements[:i + 1] + total_t_enhance = [ + t for t in total_s_enhance if t is not None + ] + total_t_enhance = np.product(total_t_enhance) exo_pad_width = ((total_s_enhance * pad_width[0][0], total_s_enhance * pad_width[0][1]), (total_s_enhance * pad_width[1][0], - total_s_enhance * pad_width[1][1]), (0, - 0)) + total_s_enhance * pad_width[1][1]), + (total_t_enhance * pad_width[2][0], + total_t_enhance * pad_width[2][1])) exo_data[i] = np.pad(i_exo_data, exo_pad_width, mode=mode) return out, exo_data @@ -1537,47 +1537,48 @@ def _prep_exogenous_input(self, chunk_shape): Returns ------- - exo_data : list - List of arrays of exogenous data. If there are 2 spatial - enhancement steps this is a list of 3 arrays each with the - appropriate shape based on the enhancement factor + exo_data : dict + Dictionary of list of arrays with keys for each exogenous feature + and arrays of data for each feature with array length depending on + model steps. If there are 2 spatial enhancement steps this is a + list of 3 arrays each with the appropriate shape based on the + enhancement factor """ exo_data = [] - if self.exogenous_data is not None: - for arr in self.exogenous_data: + for i in range(len(self.exogenous_data[self.exo_features[0]])): + exo_data_f = [] + for feature in self.exo_features: + arr = self.exogenous_data[feature][i] if arr is not None: og_shape = arr.shape arr = np.expand_dims(arr, axis=2) arr = np.repeat(arr, chunk_shape[2], axis=2) - target_shape = ( - arr.shape[0], - arr.shape[1], - chunk_shape[2], - arr.shape[-1], - ) - msg = ('Target shape for exogenous data in forward pass ' - 'chunk was {}, but something went wrong and i ' - 'resized original data shape from {} to {}'.format( - target_shape, og_shape, arr.shape)) + target_shape = (arr.shape[0], arr.shape[1], chunk_shape[2], + arr.shape[-1], + ) + msg = ('Target shape for exogenous data in forward ' + 'pass chunk was {}, but something went wrong ' + 'and i resized original data shape from {} to ' + '{}'.format(target_shape, og_shape, arr.shape)) assert arr.shape == target_shape, msg - exo_data.append(arr) - + exo_data_f.append(arr) + exo_data.append( + np.vstack([arr for arr in exo_data_f if arr is not None])) return exo_data @classmethod - def _run_generator( - cls, - data_chunk, - hr_crop_slices, - model=None, - model_kwargs=None, - model_class=None, - s_enhance=None, - t_enhance=None, - exo_data=None, - ): + def _run_generator(cls, + data_chunk, + hr_crop_slices, + model=None, + model_kwargs=None, + model_class=None, + s_enhance=None, + t_enhance=None, + exo_data=None, + ): """Run forward pass of the generator on smallest data chunk. Each chunk has a maximum shape given by self.strategy.fwp_chunk_shape. @@ -1857,11 +1858,10 @@ def _run_serial(cls, strategy, node_index): 'serial.') for i, chunk_index in enumerate(strategy.node_chunks[node_index]): now = dt.now() - cls._single_proc_run( - strategy=strategy, - node_index=node_index, - chunk_index=chunk_index, - ) + cls._single_proc_run(strategy=strategy, + node_index=node_index, + chunk_index=chunk_index, + ) mem = psutil.virtual_memory() logger.info('Finished forward pass on chunk_index=' f'{chunk_index} in {dt.now() - now}. {i + 1} of ' @@ -1898,12 +1898,11 @@ def _run_parallel(cls, strategy, node_index): with SpawnProcessPool(**pool_kws) as exe: now = dt.now() for _i, chunk_index in enumerate(strategy.node_chunks[node_index]): - fut = exe.submit( - cls._single_proc_run, - strategy=strategy, - node_index=node_index, - chunk_index=chunk_index, - ) + fut = exe.submit(cls._single_proc_run, + strategy=strategy, + node_index=node_index, + chunk_index=chunk_index, + ) futures[fut] = { 'chunk_index': chunk_index, 'start_time': dt.now(), diff --git a/sup3r/utilities/topo.py b/sup3r/preprocessing/data_handling/exo_extraction.py similarity index 52% rename from sup3r/utilities/topo.py rename to sup3r/preprocessing/data_handling/exo_extraction.py index ac01cecc1..80ccd915b 100644 --- a/sup3r/utilities/topo.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -16,26 +16,26 @@ logger = logging.getLogger(__name__) -class TopoExtract(ABC): - """Class to extract high-res (4km+) topography rasters for new +class ExoExtract(ABC): + """Class to extract high-res (4km+) data rasters for new spatially-enhanced datasets (e.g. GCM files after spatial enhancement) using nearest neighbor mapping and aggregation from NREL datasets (e.g. WTK or NSRDB) """ - def __init__( - self, - file_paths, - topo_source, - s_enhance, - agg_factor, - target=None, - shape=None, - raster_file=None, - max_delta=20, - input_handler=None, - ti_workers=1, - ): + def __init__(self, + file_paths, + exo_source, + s_enhance, + agg_factor, + target=None, + shape=None, + temporal_slice=None, + raster_file=None, + max_delta=20, + input_handler=None, + ti_workers=1, + t_enhance=1): """ Parameters ---------- @@ -45,34 +45,43 @@ def __init__( file path which will be passed through glob.glob. This is typically low-res WRF output or GCM netcdf data files that is source low-resolution data intended to be sup3r resolved. - topo_source : str + exo_source : str Filepath to source wtk or nsrdb file to get hi-res (2km or 4km) elevation data from which will be mapped to the enhanced grid of the file_paths input s_enhance : int Factor by which the Sup3rGan model will enhance the spatial dimensions of low resolution data from file_paths input. For - example, if file_paths has 100km data and s_enhance is 4, this - class will output a topography raster corresponding to the - file_paths grid enhanced 4x to ~25km + example, if getting topography data, file_paths has 100km data, and + s_enhance is 4, this class will output a topography raster + corresponding to the file_paths grid enhanced 4x to ~25km agg_factor : int Factor by which to aggregate the topo_source_h5 elevation data to the resolution of the file_paths input enhanced by - s_enhance. For example, if file_paths has 100km data and s_enhance - is 4 resulting in a desired resolution of ~25km and topo_source_h5 - has a resolution of 4km, the agg_factor should be 36 so that 6x6 - 4km cells are averaged to the ~25km enhanced grid. + s_enhance. For example, if getting topography data, file_paths has + 100km data, and s_enhance is 4 resulting in a desired resolution of + ~25km and topo_source_h5 has a resolution of 4km, the agg_factor + should be 36 so that 6x6 4km cells are averaged to the ~25km + enhanced grid. target : tuple (lat, lon) lower left corner of raster. Either need target+shape or raster_file. shape : tuple (rows, cols) grid size. Either need target+shape or raster_file. + temporal_slice : slice | None + slice used to extract interval from temporal dimension for input + data and source data raster_file : str | None File for raster_index array for the corresponding target and shape. If specified the raster_index will be loaded from the file if it exists or written to the file if it does not yet exist. If None raster_index will be calculated directly. Either need target+shape or raster_file. + max_delta : int, optional + Optional maximum limit on the raster shape that is retrieved at + once. If shape is (20, 20) and max_delta=10, the full raster will + be retrieved in four chunks of (10, 10). This helps adapt to + non-regular grids that curve over large distances, by default 20 input_handler : str data handler class to use for input data. Provide a string name to match a class in data_handling.py. If None the correct handler will @@ -84,13 +93,19 @@ class will output a topography raster corresponding to the parallel and then concatenated to get the full time index. If input files do not all have time indices or if there are few input files this should be set to one. - + t_enhance : int + Factor by which the Sup3rGan model will enhance the temporal + dimension of low resolution data from file_paths input. For + example, if getting sza data, file_paths has hourly data, and + t_enhance is 4, this class will output a sza raster + corresponding to the file_paths temporally enhanced 4x to 15 min """ - logger.info('Initializing TopoExtract utility.') + logger.info(f'Initializing {self.__class__.__name__} utility.') - self._topo_source = topo_source + self._exo_source = exo_source self._s_enhance = s_enhance + self._t_enhance = t_enhance self._agg_factor = agg_factor self._tree = None self.ti_workers = ti_workers @@ -103,29 +118,25 @@ class will output a topography raster corresponding to the elif in_type == 'h5': input_handler = DataHandlerH5 else: - msg = 'Did not recognize input type "{}" for file paths: {}'.format( - in_type, file_paths - ) + msg = (f'Did not recognize input type "{in_type}" for file ' + f'paths: {file_paths}') logger.error(msg) raise RuntimeError(msg) elif isinstance(input_handler, str): - input_handler = getattr( - sup3r.preprocessing.data_handling, input_handler, None - ) + input_handler = getattr(sup3r.preprocessing.data_handling, + input_handler, None) if input_handler is None: - msg = ( - 'Could not find requested data handler class ' - f'"{input_handler}" in ' - 'sup3r.preprocessing.data_handling.' - ) + msg = ('Could not find requested data handler class ' + f'"{input_handler}" in ' + 'sup3r.preprocessing.data_handling.') logger.error(msg) raise KeyError(msg) self.input_handler = input_handler( - file_paths, - [], + file_paths, [], target=target, shape=shape, + temporal_slice=temporal_slice, raster_file=raster_file, max_delta=max_delta, worker_kwargs=dict(ti_workers=ti_workers), @@ -133,30 +144,30 @@ class will output a topography raster corresponding to the @property @abstractmethod - def source_elevation(self): - """Get the 1D array of elevation data from the topo_source_h5""" + def source_data(self): + """Get the 1D array of source data from the exo_source_h5""" @property @abstractmethod def source_lat_lon(self): - """Get the 2D array (n, 2) of lat, lon data from the topo_source_h5""" + """Get the 2D array (n, 2) of lat, lon data from the exo_source_h5""" @property def lr_shape(self): """Get the low-resolution spatial shape tuple""" - return (self.lr_lat_lon.shape[0], self.lr_lat_lon.shape[1]) + return (self.lr_lat_lon.shape[0], self.lr_lat_lon.shape[1], + len(self.input_handler.time_index)) @property def hr_shape(self): """Get the high-resolution spatial shape tuple""" - return ( - self._s_enhance * self.lr_lat_lon.shape[0], - self._s_enhance * self.lr_lat_lon.shape[1], - ) + return (self._s_enhance * self.lr_lat_lon.shape[0], + self._s_enhance * self.lr_lat_lon.shape[1], + self._t_enhance * len(self.input_handler.time_index)) @property def lr_lat_lon(self): - """lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon + """Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon array with same ordering in last dimension. This corresponds to the raw low-resolution meta data from the file_paths input. @@ -168,7 +179,7 @@ def lr_lat_lon(self): @property def hr_lat_lon(self): - """lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon + """Lat lon grid for data in format (spatial_1, spatial_2, 2) Lat/Lon array with same ordering in last dimension. This corresponds to the enhanced high-res meta data from the file_paths input * s_enhance. @@ -179,8 +190,7 @@ def hr_lat_lon(self): if self._hr_lat_lon is None: if self._s_enhance > 1: self._hr_lat_lon = OutputHandler.get_lat_lon( - self.lr_lat_lon, self.hr_shape - ) + self.lr_lat_lon, self.hr_shape) else: self._hr_lat_lon = self.lr_lat_lon return self._hr_lat_lon @@ -196,57 +206,52 @@ def tree(self): def nn(self): """Get the nearest neighbor indices""" ll2 = np.vstack( - ( - self.hr_lat_lon[:, :, 0].flatten(), - self.hr_lat_lon[:, :, 1].flatten(), - ) - ).T + (self.hr_lat_lon[:, :, 0].flatten(), self.hr_lat_lon[:, :, + 1].flatten(), + )).T _, nn = self.tree.query(ll2, k=self._agg_factor) if len(nn.shape) == 1: nn = np.expand_dims(nn, 1) return nn @property - def hr_elev(self): - """Get a raster of elevation values corresponding to the + def data(self): + """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance). The shape is (rows, cols) """ nn = self.nn - hr_elev = [] + hr_data = [] for j in range(self._agg_factor): - elev = self.source_elevation[nn[:, j]] - elev = elev.reshape(self.hr_shape) - hr_elev.append(elev) - hr_elev = np.dstack(hr_elev).mean(axis=-1) - logger.info( - 'Finished mapping topo raster from {}'.format(self._topo_source) - ) - return hr_elev + out = self.source_data[nn[:, j]] + out = out.reshape(self.hr_shape) + hr_data.append(out) + hr_data = np.dstack(hr_data).mean(axis=-1) + logger.info('Finished mapping raster from {}'.format(self._exo_source)) + return hr_data @classmethod - def get_topo_raster( - cls, - file_paths, - topo_source, - s_enhance, - agg_factor, - target=None, - shape=None, - raster_file=None, - max_delta=20, - input_handler=None, - ): - """Get the topography raster corresponding to the spatially enhanced + def get_exo_raster(cls, + file_paths, + exo_source, + s_enhance, + agg_factor, + target=None, + shape=None, + raster_file=None, + max_delta=20, + input_handler=None, + t_enhance=1): + """Get the exo feature raster corresponding to the spatially enhanced grid from the file_paths input Parameters ---------- file_paths : str | list - A single source h5 wind file to extract raster data from or a list + A single source h5 file to extract raster data from or a list of netcdf files with identical grid. The string can be a unix-style file path which will be passed through glob.glob - topo_source : str + exo_source : str Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or 4km) data from which will be mapped to the enhanced grid of the file_paths input @@ -274,54 +279,65 @@ class will output a topography raster corresponding to the exists or written to the file if it does not yet exist. If None raster_index will be calculated directly. Either need target+shape or raster_file. + max_delta : int, optional + Optional maximum limit on the raster shape that is retrieved at + once. If shape is (20, 20) and max_delta=10, the full raster will + be retrieved in four chunks of (10, 10). This helps adapt to + non-regular grids that curve over large distances, by default 20 input_handler : str data handler class to use for input data. Provide a string name to match a class in data_handling.py. If None the correct handler will be guessed based on file type and time series properties. + t_enhance : int + Factor by which the Sup3rGan model will enhance the temporal + dimension of low resolution data from file_paths input. For + example, if getting sza data, file_paths has hourly data, and + t_enhance is 4, this class will output a sza raster + corresponding to the file_paths temporally enhanced 4x to 15 min Returns ------- - topo_raster : np.ndarray - Elevation raster with shape (hr_rows, hr_cols) corresponding to the - shape of the spatially enhanced grid from file_paths * s_enhance. - The elevation units correspond to the source units in - topo_source_h5, usually meters. + exo_raster : np.ndarray + Exo feature raster with shape (hr_rows, hr_cols, h_temporal) + corresponding to the shape of the spatiotemporally enhanced data + from file_paths * s_enhance * t_enhance. The data units correspond + to the source units in exo_source_h5. This is usually meters when + feature='topography' """ - te = cls( - file_paths, - topo_source, - s_enhance, - agg_factor, - target=target, - shape=shape, - raster_file=raster_file, - max_delta=max_delta, - input_handler=input_handler, - ) + exo = cls(file_paths, + exo_source, + s_enhance, + agg_factor, + target=target, + shape=shape, + raster_file=raster_file, + max_delta=max_delta, + input_handler=input_handler, + t_enhance=t_enhance) - return te.hr_elev + return exo.data -class TopoExtractH5(TopoExtract): +class TopoExtractH5(ExoExtract): """TopoExtract for H5 files""" @property - def source_elevation(self): - """Get the 1D array of elevation data from the topo_source_h5""" - with Resource(self._topo_source) as res: + def source_data(self): + """Get the 1D array of elevation data from the exo_source_h5""" + with Resource(self._exo_source) as res: elev = res.get_meta_arr('elevation') return elev @property def source_lat_lon(self): - """Get the 2D array (n, 2) of lat, lon data from the topo_source_h5""" - with Resource(self._topo_source) as res: + """Get the 2D array (n, 2) of lat, lon data from the exo_source_h5""" + with Resource(self._exo_source) as res: source_lat_lon = res.lat_lon return source_lat_lon -class TopoExtractNC(TopoExtract): +class TopoExtractNC(ExoExtract): """TopoExtract for netCDF files""" def __init__(self, *args, **kwargs): @@ -335,24 +351,77 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) - logger.info( - 'Getting topography for full domain from ' f'{self._topo_source}' + logger.info('Getting topography for full domain from ' + f'{self._exo_source}') + self.source_handler = DataHandlerNC( + self._exo_source, + features=['topography'], + worker_kwargs=dict(ti_workers=self.ti_workers), + val_split=0.0, ) + + @property + def source_data(self): + """Get the 1D array of elevation data from the exo_source_h5""" + elev = self.source_handler.data.reshape(-1) + return elev + + @property + def source_lat_lon(self): + """Get the 2D array (n, 2) of lat, lon data from the exo_source_h5""" + source_lat_lon = self.source_handler.lat_lon.reshape((-1, 2)) + return source_lat_lon + + +class SzaExtractH5(ExoExtract): + """SzaExtract for H5 files""" + + @property + def source_data(self): + """Get the 1D array of sza data from the exo_source_h5""" + with Resource(self._exo_source) as res: + elev = res.get_meta_arr('elevation') + return elev + + @property + def source_lat_lon(self): + """Get the 2D array (n, 2) of lat, lon data from the exo_source_h5""" + with Resource(self._exo_source) as res: + source_lat_lon = res.lat_lon + return source_lat_lon + + +class SzaExtractNC(ExoExtract): + """TopoExtract for netCDF files""" + + def __init__(self, *args, **kwargs): + """ + Parameters + ---------- + args : list + Same positional arguments as TopoExtract + kwargs : dict + Same keyword arguments as TopoExtract + """ + + super().__init__(*args, **kwargs) + logger.info('Getting topography for full domain from ' + f'{self._exo_source}') self.source_handler = DataHandlerNC( - self._topo_source, + self._exo_source, features=['topography'], worker_kwargs=dict(ti_workers=self.ti_workers), val_split=0.0, ) @property - def source_elevation(self): - """Get the 1D array of elevation data from the topo_source_h5""" - elev = self.source_handler.data.reshape((-1)) + def source_data(self): + """Get the 1D array of elevation data from the exo_source_h5""" + elev = self.source_handler.data.reshape(-1) return elev @property def source_lat_lon(self): - """Get the 2D array (n, 2) of lat, lon data from the topo_source_h5""" + """Get the 2D array (n, 2) of lat, lon data from the exo_source_h5""" source_lat_lon = self.source_handler.lat_lon.reshape((-1, 2)) return source_lat_lon diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index c54b44acb..68baa591b 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -1,14 +1,18 @@ """Sup3r exogenous data handling""" -import os -import shutil import logging -import numpy as np +import os import pickle +import shutil +from typing import ClassVar from warnings import warn -from sup3r.utilities.topo import TopoExtractH5, TopoExtractNC -import sup3r.preprocessing.data_handling -import sup3r.utilities.topo +import numpy as np + +from sup3r.preprocessing.data_handling.exo_extraction import (SzaExtractH5, + SzaExtractNC, + TopoExtractH5, + TopoExtractNC, + ) from sup3r.utilities.utilities import get_source_type logger = logging.getLogger(__name__) @@ -19,10 +23,32 @@ class ExogenousDataHandler: Multiple topography arrays at different resolutions for multiple spatial enhancement steps.""" - def __init__(self, file_paths, features, source_file, s_enhancements, - agg_factors, target=None, shape=None, raster_file=None, - max_delta=20, input_handler=None, topo_handler=None, - exo_steps=None, cache_data=True): + AVAILABLE_HANDLERS: ClassVar[dict] = { + 'topography': { + 'h5': TopoExtractH5, + 'nc': TopoExtractNC + }, + 'sza': { + 'h5': SzaExtractH5, + 'nc': SzaExtractNC + } + } + + def __init__(self, + file_paths, + feature, + source_file, + s_enhancements, + agg_factors, + target=None, + shape=None, + raster_file=None, + max_delta=20, + input_handler=None, + exo_handler=None, + exo_steps=None, + cache_data=True, + t_enhancements=None): """ Parameters ---------- @@ -32,8 +58,8 @@ def __init__(self, file_paths, features, source_file, s_enhancements, through glob.glob. This is typically low-res WRF output or GCM netcdf data that is source low-resolution data intended to be sup3r resolved. - features : list - List of exogenous features to extract from source_h5 + feature : str + Exogenous feature to extract from source_h5 source_file : str Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or 4km) data from which will be mapped to the enhanced grid of the @@ -68,14 +94,20 @@ def __init__(self, file_paths, features, source_file, s_enhancements, exists or written to the file if it does not yet exist. If None raster_index will be calculated directly. Either need target+shape or raster_file. + max_delta : int, optional + Optional maximum limit on the raster shape that is retrieved at + once. If shape is (20, 20) and max_delta=10, the full raster will + be retrieved in four chunks of (10, 10). This helps adapt to + non-regular grids that curve over large distances, by default 20 input_handler : str data handler class to use for input data. Provide a string name to match a class in data_handling.py. If None the correct handler will be guessed based on file type and time series properties. - topo_handler : str - topo extract class to use for source data. Provide a string name to - match a class in topo.py. If None the correct handler will - be guessed based on file type and time series properties. + exo_handler : str + Feature extract class to use for source data. For example, if + feature='topography' this should be either TopoExtractH5 or + TopoExtractNC. If None the correct handler will be guessed based on + file type and time series properties. exo_steps : list List of model step indices for which exogenous data is required. e.g. If we have two model steps which take exo data and one which @@ -84,14 +116,19 @@ def __init__(self, file_paths, features, source_file, s_enhancements, cache_data : bool Flag to cache exogeneous data in ./exo_cache/ this can speed up forward passes with large temporal extents + t_enhancements : list + List of factors by which the Sup3rGan model will enhance the + temporal dimension of low resolution data from file_paths input + where the total temporal enhancement is the product of these + factors. """ - self.features = features + self.feature = feature self.s_enhancements = s_enhancements self.agg_factors = agg_factors self.source_file = source_file self.file_paths = file_paths - self.topo_handler = topo_handler + self.exo_handler = exo_handler self.target = target self.shape = shape self.raster_file = raster_file @@ -99,6 +136,8 @@ def __init__(self, file_paths, features, source_file, s_enhancements, self.input_handler = input_handler self.cache_data = cache_data self.data = [] + self.t_enhancements = (t_enhancements if t_enhancements is not None + else [1] * len(self.s_enhancements)) exo_steps = exo_steps or np.arange(len(self.s_enhancements)) if self.s_enhancements[0] != 1: @@ -108,11 +147,20 @@ def __init__(self, file_paths, features, source_file, s_enhancements, 's_enhancements: {}'.format(self.s_enhancements)) logger.warning(msg) warn(msg) + if self.t_enhancements[0] != 1: + msg = ('t_enhancements typically starts with 1 so the first ' + 'exogenous data input matches the temporal resolution of ' + 'the source low-res input data, but received ' + 't_enhancements: {}'.format(self.t_enhancements)) + logger.warning(msg) + warn(msg) msg = ('Need to provide the same number of enhancement factors and ' - f'agg factors. Received s_enhancements={s_enhancements} and ' + f'agg factors. Received s_enhancements={self.s_enhancements}, ' + f't_enhancements={self.t_enhancements}, and ' f'agg_factors={agg_factors}.') assert len(self.s_enhancements) == len(self.agg_factors), msg + assert len(self.t_enhancements) == len(self.agg_factors), msg msg = ('Need to provide an integer enhancement factor for each model' 'step. If the step is temporal enhancement then s_enhance=1') @@ -120,28 +168,37 @@ def __init__(self, file_paths, features, source_file, s_enhancements, for i in range(len(self.s_enhancements)): s_enhance = np.product(self.s_enhancements[:i + 1]) + t_enhance = np.product(self.t_enhancements[:i + 1]) agg_factor = self.agg_factors[i] fdata = [] if i in exo_steps: - for f in features: - if f == 'topography': - data = self.get_topo_data(s_enhance, agg_factor) - fdata.append(data) - else: - msg = (f"Can only extract topography. Recived {f}.") - raise NotImplementedError(msg) + if feature in list(self.AVAILABLE_HANDLERS): + data = self.get_exo_data(feature=feature, + s_enhance=s_enhance, + t_enhance=t_enhance, + agg_factor=agg_factor) + fdata.append(data) + else: + msg = (f"Can only extract {list(self.AVAILABLE_HANDLERS)}." + f" Received {feature}.") + raise NotImplementedError(msg) self.data.append(np.stack(fdata, axis=-1)) else: self.data.append(None) - def get_topo_data(self, s_enhance, agg_factor): - """Get the exogenous topography data + def get_cache_file(self, feature, s_enhance, t_enhance, agg_factor): + """Get cache file name Parameters ---------- + feature : str + Name of feature to get cache file for s_enhance : int Spatial enhancement for this exogeneous data step (cumulative for all model steps up to the current step). + t_enhance : int + Temporal enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). agg_factor : int Factor by which to aggregate the topo_source_h5 elevation data to the resolution of the file_paths input enhanced by @@ -149,78 +206,108 @@ def get_topo_data(self, s_enhance, agg_factor): Returns ------- - data : np.ndarray - 2D array of elevation data with shape (lat, lon) + cache_fp : str + Name of cache file """ - cache_dir = './exo_cache/' - fn = f'exo_{self.target}_{self.shape}_agg{agg_factor}_{s_enhance}x.pkl' + fn = f'exo_{feature}_{self.target}_{self.shape}_agg{agg_factor}_' + fn += f'{s_enhance}x_{t_enhance}x.pkl' fn = fn.replace('(', '').replace(')', '') fn = fn.replace('[', '').replace(']', '') fn = fn.replace(',', 'x').replace(' ', '') cache_fp = os.path.join(cache_dir, fn) - temp_fp = cache_fp + '.tmp' + if self.cache_data: + os.makedirs(cache_dir, exist_ok=True) + return cache_fp + + def get_exo_data(self, feature, s_enhance, t_enhance, agg_factor): + """Get the exogenous topography data + + Parameters + ---------- + feature : str + Name of feature to get exo data for + s_enhance : int + Spatial enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + t_enhance : int + Temporal enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + agg_factor : int + Factor by which to aggregate the topo_source_h5 elevation + data to the resolution of the file_paths input enhanced by + s_enhance. + + Returns + ------- + data : np.ndarray + 2D or 3D array of exo data with shape (lat, lon) or (lat, + lon, temporal) + """ + + cache_fp = self.get_cache_file(feature=feature) + tmp_fp = cache_fp + '.tmp' if os.path.exists(cache_fp): with open(cache_fp, 'rb') as f: data = pickle.load(f) else: - topo_handler = self.get_topo_handler(self.source_file, - self.topo_handler) - data = topo_handler(self.file_paths, self.source_file, s_enhance, - agg_factor, target=self.target, - shape=self.shape, - raster_file=self.raster_file, - max_delta=self.max_delta, - input_handler=self.input_handler) - data = data.hr_elev + exo_handler = self.get_exo_handler(feature, self.source_file, + self.exo_handler) + dh = exo_handler(self.file_paths, + self.source_file, + s_enhance=s_enhance, + t_enhance=t_enhance, + agg_factor=agg_factor, + target=self.target, + shape=self.shape, + raster_file=self.raster_file, + max_delta=self.max_delta, + input_handler=self.input_handler) if self.cache_data: - os.makedirs(cache_dir, exist_ok=True) - with open(temp_fp, 'wb') as f: - pickle.dump(data, f) - shutil.move(temp_fp, cache_fp) - + with open(tmp_fp, 'wb') as f: + pickle.dump(dh.data, f) + shutil.move(tmp_fp, cache_fp) return data - @staticmethod - def get_topo_handler(source_file, topo_handler): - """Get topo extraction class for source file + @classmethod + def get_exo_handler(cls, feature, source_file, exo_handler): + """Get exogenous feature extraction class for source file Parameters ---------- + feature : str + Name of feature to get exo handler for source_file : str Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or 4km) data from which will be mapped to the enhanced grid of the file_paths input - topo_handler : str - topo extract class to use for source data. Provide a string name to - match a class in topo.py. If None the correct handler will - be guessed based on file type and time series properties. + exo_handler : str + Feature extract class to use for source data. For example, if + feature='topography' this should be either TopoExtractH5 or + TopoExtractNC. If None the correct handler will be guessed based on + file type and time series properties. Returns ------- - topo_handler : str - topo extract class to use for source data. + exo_handler : str + Exogenous feature extraction class to use for source data. """ - if topo_handler is None: + if exo_handler is None: in_type = get_source_type(source_file) - if in_type == 'nc': - topo_handler = TopoExtractNC - elif in_type == 'h5': - topo_handler = TopoExtractH5 - else: - msg = ('Did not recognize input type "{}" for file paths: {}' - .format(in_type, source_file)) + if in_type not in ('h5', 'nc'): + msg = ('Did not recognize input type "{}" for file paths: {}'. + format(in_type, source_file)) logger.error(msg) raise RuntimeError(msg) - elif isinstance(topo_handler, str): - topo_handler = getattr(sup3r.utilities.topo, topo_handler, None) - if topo_handler is None: - msg = ('Could not find requested topo handler class ' - f'"{topo_handler}" in ' - 'sup3r.utilities.topo.') + check = (feature in cls.AVAILABLE_HANDLERS + and in_type in cls.AVAILABLE_HANDLERS[feature]) + if check: + exo_handler = cls.AVAILABLE_HANDLERS[feature][in_type] + else: + msg = ('Could not find exo handler class for ' + f'feature={feature} and input_type={in_type}.') logger.error(msg) raise KeyError(msg) - - return topo_handler + return exo_handler diff --git a/sup3r/preprocessing/data_handling/nc_data_handling.py b/sup3r/preprocessing/data_handling/nc_data_handling.py index afa2f79ac..92ae0dc0e 100644 --- a/sup3r/preprocessing/data_handling/nc_data_handling.py +++ b/sup3r/preprocessing/data_handling/nc_data_handling.py @@ -19,34 +19,13 @@ from sup3r.preprocessing.data_handling.base import DataHandler, DataHandlerDC from sup3r.preprocessing.feature_handling import ( - BVFreqMon, - BVFreqSquaredNC, - ClearSkyRatioCC, - Feature, - InverseMonNC, - LatLonNC, - PotentialTempNC, - PressureNC, - Rews, - Shear, - Tas, - TasMax, - TasMin, - TempNC, - TempNCforCC, - UWind, - UWindPowerLaw, - VWind, - VWindPowerLaw, - WinddirectionNC, - WindspeedNC, -) + BVFreqMon, BVFreqSquaredNC, ClearSkyRatioCC, Feature, InverseMonNC, + LatLonNC, PotentialTempNC, PressureNC, Rews, Shear, Tas, TasMax, TasMin, + TempNC, TempNCforCC, UWind, UWindPowerLaw, VWind, VWindPowerLaw, + WinddirectionNC, WindspeedNC) from sup3r.utilities.interpolation import Interpolator -from sup3r.utilities.utilities import ( - estimate_max_workers, - get_time_dim_name, - np_to_pd_times, -) +from sup3r.utilities.utilities import (estimate_max_workers, get_time_dim_name, + np_to_pd_times) np.random.seed(42) @@ -86,26 +65,6 @@ class DataHandlerNC(DataHandler): Chunk sizes that approximately match the data volume being extracted typically results in the most efficient IO.""" - def __init__(self, *args, xr_chunks=None, **kwargs): - """Initialize NETCDF data handler. - - Parameters - ---------- - *args : list - Same ordered required arguments as DataHandler parent class. - xr_chunks : int | "auto" | tuple | dict | None - kwarg that goes to xr.DataArray.chunk(chunks=xr_chunks). Chunk - sizes that approximately match the data volume being extracted - typically results in the most efficient IO. If not provided, this - defaults to the class CHUNKS attribute. - **kwargs : list - Same optional keyword arguments as DataHandler parent class. - """ - if xr_chunks is not None: - self.CHUNKS = xr_chunks - - super().__init__(*args, **kwargs) - @property def extract_workers(self): """Get upper bound for extract workers based on memory limits. Used to @@ -288,9 +247,9 @@ def extract_feature(cls, fdata = cls.direct_extract(handle, feat_key, raster_index, time_slice) - elif interp_height is not None and (cls.has_multilevel_feature( - feature, handle) or cls.has_surrounding_features( - feature, handle)): + elif interp_height is not None and ( + cls.has_multilevel_feature(feature, handle) + or cls.has_surrounding_features(feature, handle)): fdata = Interpolator.interp_var_to_height( handle, feature, raster_index, np.float32(interp_height), time_slice) diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index d6c54612f..2c7d55a95 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -15,8 +15,9 @@ from sup3r.models import LinearInterp, Sup3rGan, WindGan from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.preprocessing.data_handling import DataHandlerNC -from sup3r.utilities.pytest import (make_fake_nc_files, - make_fake_multi_time_nc_files) +from sup3r.utilities.pytest import (make_fake_multi_time_nc_files, + make_fake_nc_files, + ) FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') TARGET_COORD = (39.01, -105.15) @@ -43,10 +44,12 @@ def test_fwp_nc_cc(log=False): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - input_files = [os.path.join(TEST_DATA_DIR, 'ua_test.nc'), - os.path.join(TEST_DATA_DIR, 'va_test.nc'), - os.path.join(TEST_DATA_DIR, 'orog_test.nc'), - os.path.join(TEST_DATA_DIR, 'zg_test.nc')] + input_files = [ + os.path.join(TEST_DATA_DIR, 'ua_test.nc'), + os.path.join(TEST_DATA_DIR, 'va_test.nc'), + os.path.join(TEST_DATA_DIR, 'orog_test.nc'), + os.path.join(TEST_DATA_DIR, 'zg_test.nc') + ] features = ['U_100m', 'V_100m'] target = (13.67, 125.0) _ = model.generate(np.ones((4, 10, 10, 6, len(features)))) @@ -63,16 +66,20 @@ def test_fwp_nc_cc(log=False): # 1st forward pass max_workers = 1 input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, cache_pattern=cache_pattern, overwrite_cache=True, worker_kwargs=dict(max_workers=max_workers)) handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers), input_handler='DataHandlerNCforCC') forward_pass = ForwardPass(handler) @@ -84,14 +91,14 @@ def test_fwp_nc_cc(log=False): forward_pass.run(handler, node_index=0) with xr.open_dataset(handler.out_files[0]) as fh: - assert fh[FEATURES[0]].shape == ( - t_enhance * len(handler.time_index), - s_enhance * fwp_chunk_shape[0], - s_enhance * fwp_chunk_shape[1]) - assert fh[FEATURES[1]].shape == ( - t_enhance * len(handler.time_index), - s_enhance * fwp_chunk_shape[0], - s_enhance * fwp_chunk_shape[1]) + assert fh[FEATURES[0]].shape == (t_enhance + * len(handler.time_index), + s_enhance * fwp_chunk_shape[0], + s_enhance * fwp_chunk_shape[1]) + assert fh[FEATURES[1]].shape == (t_enhance + * len(handler.time_index), + s_enhance * fwp_chunk_shape[0], + s_enhance * fwp_chunk_shape[1]) def test_fwp_single_ts_vs_multi_ts_input_files(): @@ -117,16 +124,20 @@ def test_fwp_single_ts_vs_multi_ts_input_files(): max_workers = 1 input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), cache_pattern=cache_pattern, overwrite_cache=True) single_ts_handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers)) single_ts_forward_pass = ForwardPass(single_ts_handler) single_ts_forward_pass.run(single_ts_handler, node_index=0) @@ -138,16 +149,20 @@ def test_fwp_single_ts_vs_multi_ts_input_files(): max_workers = 1 input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), cache_pattern=cache_pattern, overwrite_cache=True) multi_ts_handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers)) multi_ts_forward_pass = ForwardPass(multi_ts_handler) multi_ts_forward_pass.run(multi_ts_handler, node_index=0) @@ -185,16 +200,20 @@ def test_fwp_spatial_only(): max_workers = 1 input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), cache_pattern=cache_pattern, overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers)) forward_pass = ForwardPass(handler) assert forward_pass.output_workers == max_workers @@ -205,14 +224,12 @@ def test_fwp_spatial_only(): forward_pass.run(handler, node_index=0) with xr.open_dataset(handler.out_files[0]) as fh: - assert fh[FEATURES[0]].shape == ( - len(handler.time_index), - 2 * fwp_chunk_shape[0], - 2 * fwp_chunk_shape[1]) - assert fh[FEATURES[1]].shape == ( - len(handler.time_index), - 2 * fwp_chunk_shape[0], - 2 * fwp_chunk_shape[1]) + assert fh[FEATURES[0]].shape == (len(handler.time_index), + 2 * fwp_chunk_shape[0], + 2 * fwp_chunk_shape[1]) + assert fh[FEATURES[1]].shape == (len(handler.time_index), + 2 * fwp_chunk_shape[0], + 2 * fwp_chunk_shape[1]) def test_fwp_nc(): @@ -238,16 +255,20 @@ def test_fwp_nc(): max_workers = 1 input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), cache_pattern=cache_pattern, overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers)) forward_pass = ForwardPass(handler) assert forward_pass.output_workers == max_workers @@ -258,14 +279,14 @@ def test_fwp_nc(): forward_pass.run(handler, node_index=0) with xr.open_dataset(handler.out_files[0]) as fh: - assert fh[FEATURES[0]].shape == ( - t_enhance * len(handler.time_index), - s_enhance * fwp_chunk_shape[0], - s_enhance * fwp_chunk_shape[1]) - assert fh[FEATURES[1]].shape == ( - t_enhance * len(handler.time_index), - s_enhance * fwp_chunk_shape[0], - s_enhance * fwp_chunk_shape[1]) + assert fh[FEATURES[0]].shape == (t_enhance + * len(handler.time_index), + s_enhance * fwp_chunk_shape[0], + s_enhance * fwp_chunk_shape[1]) + assert fh[FEATURES[1]].shape == (t_enhance + * len(handler.time_index), + s_enhance * fwp_chunk_shape[0], + s_enhance * fwp_chunk_shape[1]) def test_fwp_temporal_slice(): @@ -295,16 +316,20 @@ def test_fwp_temporal_slice(): raw_time_index = np.arange(20) n_tsteps = len(raw_time_index[temporal_slice]) input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), cache_pattern=cache_pattern, overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers)) forward_pass = ForwardPass(handler) assert forward_pass.output_workers == max_workers @@ -315,11 +340,10 @@ def test_fwp_temporal_slice(): forward_pass.run(handler, node_index=0) with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == ( - t_enhance * n_tsteps, - s_enhance**2 * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs for f in ('windspeed_100m', - 'winddirection_100m')) + assert fh.shape == (t_enhance * n_tsteps, s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) assert fh.global_attrs['package'] == 'sup3r' assert fh.global_attrs['version'] == __version__ @@ -355,15 +379,18 @@ def test_fwp_handler(): max_workers = 1 cache_pattern = os.path.join(td, 'cache') input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), cache_pattern=cache_pattern, overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, + spatial_pad=1, + temporal_pad=1, input_handler_kwargs=input_handler_kwargs, worker_kwargs=dict(max_workers=max_workers)) forward_pass = ForwardPass(handler) @@ -405,32 +432,39 @@ def test_fwp_chunking(log=False, plot=False): temporal_pad = 20 cache_pattern = os.path.join(td, 'cache') fwp_shape = (4, 4, len(input_files) // 2) - handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, - fwp_chunk_shape=fwp_shape, - worker_kwargs=dict(max_workers=1), - spatial_pad=spatial_pad, temporal_pad=temporal_pad, - input_handler_kwargs=dict(target=target, shape=shape, - temporal_slice=temporal_slice, - cache_pattern=cache_pattern, - overwrite_cache=True, - worker_kwargs=dict(max_workers=1))) - data_chunked = np.zeros((shape[0] * s_enhance, shape[1] * s_enhance, - len(input_files) * t_enhance, - len(model.output_features))) - handlerNC = DataHandlerNC(input_files, FEATURES, target=target, - val_split=0.0, shape=shape, + handler = ForwardPassStrategy(input_files, + model_kwargs={'model_dir': out_dir}, + fwp_chunk_shape=fwp_shape, + worker_kwargs=dict(max_workers=1), + spatial_pad=spatial_pad, + temporal_pad=temporal_pad, + input_handler_kwargs=dict( + target=target, + shape=shape, + temporal_slice=temporal_slice, + cache_pattern=cache_pattern, + overwrite_cache=True, + worker_kwargs=dict(max_workers=1))) + data_chunked = np.zeros( + (shape[0] * s_enhance, shape[1] * s_enhance, + len(input_files) * t_enhance, len(model.output_features))) + handlerNC = DataHandlerNC(input_files, + FEATURES, + target=target, + val_split=0.0, + shape=shape, worker_kwargs=dict(ti_workers=1)) pad_width = ((spatial_pad, spatial_pad), (spatial_pad, spatial_pad), (temporal_pad, temporal_pad), (0, 0)) hr_crop = (slice(s_enhance * spatial_pad, -s_enhance * spatial_pad), slice(s_enhance * spatial_pad, -s_enhance * spatial_pad), - slice(t_enhance * temporal_pad, -t_enhance * temporal_pad), - slice(None)) - input_data = np.pad(handlerNC.data, pad_width=pad_width, + slice(t_enhance * temporal_pad, + -t_enhance * temporal_pad), slice(None)) + input_data = np.pad(handlerNC.data, + pad_width=pad_width, mode='reflect') - data_nochunk = model.generate( - np.expand_dims(input_data, axis=0))[0][hr_crop] + data_nochunk = model.generate(np.expand_dims(input_data, + axis=0))[0][hr_crop] for i in range(handler.chunks): fwp = ForwardPass(handler, chunk_index=i) out = fwp.run_chunk() @@ -448,17 +482,23 @@ def test_fwp_chunking(log=False, plot=False): ax3 = fig.add_subplot(133) vmin = np.min(data_nochunk) vmax = np.max(data_nochunk) - nc = ax1.imshow(data_nochunk[..., 0, ifeature], vmin=vmin, + nc = ax1.imshow(data_nochunk[..., 0, ifeature], + vmin=vmin, vmax=vmax) - ch = ax2.imshow(data_chunked[..., 0, ifeature], vmin=vmin, + ch = ax2.imshow(data_chunked[..., 0, ifeature], + vmin=vmin, vmax=vmax) diff = ax3.imshow(err[..., 0, ifeature]) ax1.set_title('Non chunked output') ax2.set_title('Chunked output') ax3.set_title('Difference') - fig.colorbar(nc, ax=ax1, shrink=0.6, + fig.colorbar(nc, + ax=ax1, + shrink=0.6, label=f'{model.output_features[ifeature]}') - fig.colorbar(ch, ax=ax2, shrink=0.6, + fig.colorbar(ch, + ax=ax2, + shrink=0.6, label=f'{model.output_features[ifeature]}') fig.colorbar(diff, ax=ax3, shrink=0.6, label='Difference') plt.savefig(f'./chunk_vs_nochunk_{ifeature}.png') @@ -489,23 +529,27 @@ def test_fwp_nochunking(): model.save(out_dir) cache_pattern = os.path.join(td, 'cache') - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - cache_pattern=cache_pattern, - overwrite_cache=True) + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + cache_pattern=cache_pattern, + overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': out_dir}, + input_files, + model_kwargs={'model_dir': out_dir}, fwp_chunk_shape=(shape[0], shape[1], list_chunk_size), - spatial_pad=0, temporal_pad=0, + spatial_pad=0, + temporal_pad=0, input_handler_kwargs=input_handler_kwargs, worker_kwargs=dict(max_workers=1)) forward_pass = ForwardPass(handler) data_chunked = forward_pass.run_chunk() - handlerNC = DataHandlerNC(input_files, FEATURES, - target=target, shape=shape, + handlerNC = DataHandlerNC(input_files, + FEATURES, + target=target, + shape=shape, temporal_slice=temporal_slice, cache_pattern=None, time_chunk_size=100, @@ -513,8 +557,8 @@ def test_fwp_nochunking(): val_split=0.0, worker_kwargs=dict(max_workers=1)) - data_nochunk = model.generate( - np.expand_dims(handlerNC.data, axis=0))[0] + data_nochunk = model.generate(np.expand_dims(handlerNC.data, + axis=0))[0] assert np.array_equal(data_chunked, data_nochunk) @@ -567,31 +611,38 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): s_enhance = 12 t_enhance = 4 - exo_kwargs = {'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2], - 'agg_factors': [2, 4, 16], - 'exo_steps': [0, 1] - } - - model_kwargs = {'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir} + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 's_enhancements': [1, 2, 2], + 'agg_factors': [2, 4, 16], + 'exo_steps': [0, 1] + } + } + + model_kwargs = { + 'spatial_model_dirs': [s1_out_dir, s2_out_dir], + 'temporal_model_dirs': st_out_dir + } out_files = os.path.join(td, 'out_{file_id}.h5') input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, + input_files, + model_kwargs=model_kwargs, model_class='SpatialThenTemporalGan', fwp_chunk_shape=fwp_chunk_shape, input_handler_kwargs=input_handler_kwargs, - spatial_pad=0, temporal_pad=0, + spatial_pad=0, + temporal_pad=0, out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers), exo_kwargs=exo_kwargs, @@ -611,11 +662,10 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): forward_pass.run(handler, node_index=0) with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == ( - t_enhance * len(input_files), - s_enhance**2 * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs for f in ('windspeed_100m', - 'winddirection_100m')) + assert fh.shape == (t_enhance * len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) assert fh.global_attrs['package'] == 'sup3r' assert fh.global_attrs['version'] == __version__ @@ -625,8 +675,9 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): assert 'gan_meta' in fh.global_attrs gan_meta = json.loads(fh.global_attrs['gan_meta']) assert len(gan_meta) == 3 # three step model - assert gan_meta[0]['training_features'] == ['U_100m', 'V_100m', - 'topography'] + assert gan_meta[0]['training_features'] == [ + 'U_100m', 'V_100m', 'topography' + ] def test_fwp_multi_step_spatial_model_topo_noskip(): @@ -662,28 +713,33 @@ def test_fwp_multi_step_spatial_model_topo_noskip(): s_enhancements = [2, 2, 1] s_enhance = np.product(s_enhancements) - exo_kwargs = {'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2], - 'agg_factors': [12, 4, 2] - } + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 's_enhancements': [1, 2, 2], + 'agg_factors': [12, 4, 2] + } + } model_kwargs = {'model_dirs': [s1_out_dir, s2_out_dir]} out_files = os.path.join(td, 'out_{file_id}.h5') input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, + input_files, + model_kwargs=model_kwargs, model_class='MultiStepGan', fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, + spatial_pad=1, + temporal_pad=1, input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers), @@ -694,11 +750,10 @@ def test_fwp_multi_step_spatial_model_topo_noskip(): forward_pass.run(handler, node_index=0) with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == ( - len(input_files), - s_enhance**2 * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs for f in ('windspeed_100m', - 'winddirection_100m')) + assert fh.shape == (len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) assert fh.global_attrs['package'] == 'sup3r' assert fh.global_attrs['version'] == __version__ @@ -708,8 +763,133 @@ def test_fwp_multi_step_spatial_model_topo_noskip(): assert 'gan_meta' in fh.global_attrs gan_meta = json.loads(fh.global_attrs['gan_meta']) assert len(gan_meta) == 2 # two step model - assert gan_meta[0]['training_features'] == ['U_100m', 'V_100m', - 'topography'] + assert gan_meta[0]['training_features'] == [ + 'U_100m', 'V_100m', 'topography' + ] + + +def test_fwp_multi_step_model_multi_exo(): + """Test the forward pass with a multi step model class using 2 exogenous + data features""" + Sup3rGan.seed() + fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography', 'sza' + ] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + _ = s1_model.generate(np.ones((4, 10, 10, 3))) + + s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography', 'sza' + ] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + _ = s2_model.generate(np.ones((4, 10, 10, 3))) + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + st_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography', 'sza' + ] + st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['s_enhance'] = 3 + st_model.meta['t_enhance'] = 4 + _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + st_model.save(st_out_dir) + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + max_workers = 1 + fwp_chunk_shape = (4, 4, 8) + s_enhancements = [2, 2, 3] + s_enhance = np.product(s_enhancements) + t_enhance = 4 + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 's_enhancements': [1, 2, 2], + 'agg_factors': [2, 4, 12] + }, + 'sza': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 's_enhancements': [1, 2, 2], + 'agg_factors': [2, 4, 12] + } + } + + model_kwargs = { + 'spatial_model_dirs': [s1_out_dir, s2_out_dir], + 'temporal_model_dirs': st_out_dir + } + + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict( + target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=max_workers), + overwrite_cache=True) + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='SpatialThenTemporalGan', + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=max_workers), + exo_kwargs=exo_kwargs, + max_nodes=1) + + forward_pass = ForwardPass(handler) + + assert forward_pass.output_workers == max_workers + assert forward_pass.data_handler.compute_workers == max_workers + assert forward_pass.data_handler.load_workers == max_workers + assert forward_pass.data_handler.norm_workers == max_workers + assert forward_pass.data_handler.extract_workers == max_workers + + forward_pass.run(handler, node_index=0) + + with ResourceX(handler.out_files[0]) as fh: + assert fh.shape == (t_enhance * len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) + + assert fh.global_attrs['package'] == 'sup3r' + assert fh.global_attrs['version'] == __version__ + assert 'full_version_record' in fh.global_attrs + version_record = json.loads(fh.global_attrs['full_version_record']) + assert version_record['tensorflow'] == tf.__version__ + assert 'gan_meta' in fh.global_attrs + gan_meta = json.loads(fh.global_attrs['gan_meta']) + assert len(gan_meta) == 3 # three step model + assert gan_meta[0]['training_features'] == [ + 'U_100m', 'V_100m', 'topography' + ] def test_fwp_multi_step_model_topo_noskip(): @@ -757,29 +937,36 @@ def test_fwp_multi_step_model_topo_noskip(): s_enhance = np.product(s_enhancements) t_enhance = 4 - exo_kwargs = {'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2], - 'agg_factors': [2, 4, 12] - } - - model_kwargs = {'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir} + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 's_enhancements': [1, 2, 2], + 'agg_factors': [2, 4, 12] + } + } + + model_kwargs = { + 'spatial_model_dirs': [s1_out_dir, s2_out_dir], + 'temporal_model_dirs': st_out_dir + } out_files = os.path.join(td, 'out_{file_id}.h5') input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, + input_files, + model_kwargs=model_kwargs, model_class='SpatialThenTemporalGan', fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, temporal_pad=1, + spatial_pad=1, + temporal_pad=1, input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers), @@ -797,11 +984,10 @@ def test_fwp_multi_step_model_topo_noskip(): forward_pass.run(handler, node_index=0) with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == ( - t_enhance * len(input_files), - s_enhance**2 * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs for f in ('windspeed_100m', - 'winddirection_100m')) + assert fh.shape == (t_enhance * len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) assert fh.global_attrs['package'] == 'sup3r' assert fh.global_attrs['version'] == __version__ @@ -811,8 +997,9 @@ def test_fwp_multi_step_model_topo_noskip(): assert 'gan_meta' in fh.global_attrs gan_meta = json.loads(fh.global_attrs['gan_meta']) assert len(gan_meta) == 3 # three step model - assert gan_meta[0]['training_features'] == ['U_100m', 'V_100m', - 'topography'] + assert gan_meta[0]['training_features'] == [ + 'U_100m', 'V_100m', 'topography' + ] def test_fwp_multi_step_model(): @@ -851,27 +1038,32 @@ def test_fwp_multi_step_model(): s_enhance = 6 t_enhance = 4 - model_kwargs = {'spatial_model_dirs': s_out_dir, - 'temporal_model_dirs': st_out_dir} + model_kwargs = { + 'spatial_model_dirs': s_out_dir, + 'temporal_model_dirs': st_out_dir + } input_handler_kwargs = dict( - target=target, shape=shape, + target=target, + shape=shape, temporal_slice=temporal_slice, worker_kwargs=dict(max_workers=max_workers), overwrite_cache=True) handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, + input_files, + model_kwargs=model_kwargs, model_class='SpatialThenTemporalGan', fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=0, temporal_pad=0, + spatial_pad=0, + temporal_pad=0, input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, worker_kwargs=dict(max_workers=max_workers), max_nodes=1) forward_pass = ForwardPass(handler) - ones = np.ones((fwp_chunk_shape[2], fwp_chunk_shape[0], - fwp_chunk_shape[1], 2)) + ones = np.ones( + (fwp_chunk_shape[2], fwp_chunk_shape[0], fwp_chunk_shape[1], 2)) out = forward_pass.model.generate(ones) assert out.shape == (1, 24, 24, 32, 2) @@ -884,11 +1076,10 @@ def test_fwp_multi_step_model(): forward_pass.run(handler, node_index=0) with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == ( - t_enhance * len(input_files), - s_enhance**2 * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs for f in ('windspeed_100m', - 'winddirection_100m')) + assert fh.shape == (t_enhance * len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) assert fh.global_attrs['package'] == 'sup3r' assert fh.global_attrs['version'] == __version__ @@ -928,22 +1119,26 @@ def test_slicing_no_pad(log=False): st_out_dir = os.path.join(td, 'st_gan') st_model.save(st_out_dir) - handler = DataHandlerNC(input_files, features, - target=target, shape=shape, + handler = DataHandlerNC(input_files, + features, + target=target, + shape=shape, sample_shape=(1, 1, 1), val_split=0.0, worker_kwargs=dict(max_workers=1)) - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) strategy = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': st_out_dir}, + input_files, + model_kwargs={'model_dir': st_out_dir}, model_class='Sup3rGan', fwp_chunk_shape=(3, 2, 4), - spatial_pad=0, temporal_pad=0, + spatial_pad=0, + temporal_pad=0, input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, worker_kwargs=dict(max_workers=1), @@ -953,8 +1148,7 @@ def test_slicing_no_pad(log=False): forward_pass = ForwardPass(strategy, chunk_index=ichunk) s_slices = strategy.lr_pad_slices[forward_pass.spatial_chunk_index] lr_data_slice = (s_slices[0], s_slices[1], - forward_pass.ti_pad_slice, - slice(None)) + forward_pass.ti_pad_slice, slice(None)) truth = handler.data[lr_data_slice] assert np.allclose(forward_pass.input_data, truth) @@ -987,23 +1181,27 @@ def test_slicing_pad(log=False): st_out_dir = os.path.join(td, 'st_gan') st_model.save(st_out_dir) - handler = DataHandlerNC(input_files, features, - target=target, shape=shape, + handler = DataHandlerNC(input_files, + features, + target=target, + shape=shape, sample_shape=(1, 1, 1), val_split=0.0, worker_kwargs=dict(max_workers=1)) - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) strategy = ForwardPassStrategy( - input_files, model_kwargs={'model_dir': st_out_dir}, + input_files, + model_kwargs={'model_dir': st_out_dir}, model_class='Sup3rGan', fwp_chunk_shape=(2, 1, 4), input_handler_kwargs=input_handler_kwargs, - spatial_pad=2, temporal_pad=2, + spatial_pad=2, + temporal_pad=2, out_pattern=out_files, worker_kwargs=dict(max_workers=1), max_nodes=1) @@ -1027,8 +1225,7 @@ def test_slicing_pad(log=False): s_slices = strategy.lr_pad_slices[forward_pass.spatial_chunk_index] lr_data_slice = (s_slices[0], s_slices[1], - forward_pass.ti_pad_slice, - slice(None)) + forward_pass.ti_pad_slice, slice(None)) # do a manual calculation of what the padding should be. # s1 and t axes should have padding of 2 and the borders and @@ -1053,8 +1250,8 @@ def test_slicing_pad(log=False): pad_t_end = end_t_pad_lookup.get(idt, 0) pad_width = ((pad_s1_start, pad_s1_end), - (pad_s2_start, pad_s2_end), - (pad_t_start, pad_t_end), (0, 0)) + (pad_s2_start, pad_s2_end), (pad_t_start, + pad_t_end), (0, 0)) truth = handler.data[lr_data_slice] padded_truth = np.pad(truth, pad_width, mode='reflect') @@ -1068,39 +1265,67 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): requiring high-resolution topography input from the exogenous_data feature.""" Sup3rGan.seed() - gen_model = [{"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv3D", "filters": 64, "kernel_size": 3, - "strides": 1}, - {"class": "Cropping3D", "cropping": 2}, - {"class": "SpatioTemporalExpansion", "temporal_mult": 2, - "temporal_method": "nearest"}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv3D", "filters": 64, - "kernel_size": 3, "strides": 1}, - {"class": "Cropping3D", "cropping": 2}, - {"class": "SpatioTemporalExpansion", "spatial_mult": 2}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv3D", "filters": 64, - "kernel_size": 3, "strides": 1}, - {"class": "Cropping3D", "cropping": 2}, - {"alpha": 0.2, "class": "LeakyReLU"}, - - {"class": "Sup3rConcat"}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv3D", "filters": 2, - "kernel_size": 3, "strides": 1}, - {"class": "Cropping3D", "cropping": 2}] + gen_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping3D", + "cropping": 2 + }, { + "class": "SpatioTemporalExpansion", + "temporal_mult": 2, + "temporal_method": "nearest" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping3D", + "cropping": 2 + }, { + "class": "SpatioTemporalExpansion", + "spatial_mult": 2 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping3D", + "cropping": 2 + }, { + "alpha": 0.2, + "class": "LeakyReLU" + }, { + "class": "Sup3rConcat" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", + "filters": 2, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping3D", + "cropping": 2 + }] fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') model = WindGan(gen_model, fp_disc, learning_rate=1e-4) @@ -1117,31 +1342,34 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): st_out_dir = os.path.join(td, 'st_gan') model.save(st_out_dir) - exo_kwargs = {'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2], - 'agg_factors': [2, 4], - } + exo_kwargs = { + 'file_paths': input_files, + 'features': ['topography'], + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 's_enhancements': [1, 2], + 'agg_factors': [2, 4], + } model_kwargs = {'model_dir': st_out_dir} out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) # should get an error on a bad tensorflow concatenation with pytest.raises(RuntimeError): exo_kwargs['s_enhancements'] = [1, 1] handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, + input_files, + model_kwargs=model_kwargs, model_class='WindGan', fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, temporal_pad=1, + spatial_pad=1, + temporal_pad=1, input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, worker_kwargs=dict(max_workers=1), @@ -1152,10 +1380,12 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): exo_kwargs['s_enhancements'] = [1, 2] handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, + input_files, + model_kwargs=model_kwargs, model_class='WindGan', fwp_chunk_shape=(8, 8, 8), - spatial_pad=4, temporal_pad=4, + spatial_pad=4, + temporal_pad=4, input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, worker_kwargs=dict(max_workers=1), @@ -1170,7 +1400,8 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): vmin = np.min(forward_pass.input_data[..., ifeature]) vmax = np.max(forward_pass.input_data[..., ifeature]) nc = ax1.imshow(forward_pass.input_data[..., 0, ifeature], - vmin=vmin, vmax=vmax) + vmin=vmin, + vmax=vmax) fig.colorbar(nc, ax=ax1, shrink=0.6, label=f'{feature}') plt.savefig(f'./input_{feature}.png') plt.close() @@ -1181,41 +1412,249 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): assert os.path.exists(fp) +def test_fwp_multi_step_exo_hi_res_topo_and_sza(): + """Test the forward pass with multiple ExoGan models requiring + high-resolution topography and sza input from the exogenous_data feature.""" + Sup3rGan.seed() + gen_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "SpatialExpansion", + "spatial_mult": 2 + }, { + "class": "Activation", + "activation": "relu" + }, { + "class": "Sup3rConcat" + }, { + "class": "Sup3rConcat" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 2, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }] + + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = WindGan(gen_model, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography', 'sza' + ] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + _ = s1_model.generate(np.ones((4, 10, 10, 3)), + exogenous_data=(None, np.ones((4, 20, 20, 1)))) + + s2_model = WindGan(gen_model, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography', 'sza' + ] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + _ = s2_model.generate(np.ones((4, 10, 10, 3)), + exogenous_data=(None, np.ones((4, 20, 20, 1)))) + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + st_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography', 'sza' + ] + st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['s_enhance'] = 3 + st_model.meta['t_enhance'] = 4 + _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + st_model.save(st_out_dir) + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 's_enhancements': [1, 2, 2], + 'agg_factors': [2, 4, 12] + }, + 'sza': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 's_enhancements': [1, 2, 2], + 'agg_factors': [2, 4, 12] + } + } + + model_kwargs = { + 'spatial_model_dirs': [s1_out_dir, s2_out_dir], + 'temporal_model_dirs': st_out_dir + } + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) + + # should get an error on a bad tensorflow concatenation + with pytest.raises(RuntimeError): + exo_kwargs['s_enhancements'] = [1, 1, 1] + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='SpatialThenTemporalGan', + fwp_chunk_shape=(4, 4, 8), + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=1), + exo_kwargs=exo_kwargs, + max_nodes=1) + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + exo_kwargs['s_enhancements'] = [1, 2, 2] + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='SpatialThenTemporalGan', + fwp_chunk_shape=(4, 4, 8), + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=1), + exo_kwargs=exo_kwargs, + max_nodes=1) + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + for fp in handler.out_files: + assert os.path.exists(fp) + + def test_fwp_multi_step_wind_hi_res_topo(): """Test the forward pass with multiple WindGan models requiring high-resolution topograph input from the exogenous_data feature.""" Sup3rGan.seed() - gen_model = [{"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 64, "kernel_size": 3, - "strides": 1, "activation": "relu"}, - {"class": "Cropping2D", "cropping": 4}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 64, - "kernel_size": 3, "strides": 1, "activation": "relu"}, - {"class": "Cropping2D", "cropping": 4}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 64, - "kernel_size": 3, "strides": 1, "activation": "relu"}, - {"class": "Cropping2D", "cropping": 4}, - {"class": "SpatialExpansion", "spatial_mult": 2}, - {"class": "Activation", "activation": "relu"}, - - {"class": "Sup3rConcat"}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 2, - "kernel_size": 3, "strides": 1, "activation": "relu"}, - {"class": "Cropping2D", "cropping": 4}] + gen_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "SpatialExpansion", + "spatial_mult": 2 + }, { + "class": "Activation", + "activation": "relu" + }, { + "class": "Sup3rConcat" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 2, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }] fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s1_model = WindGan(gen_model, fp_disc, learning_rate=1e-4) @@ -1253,32 +1692,37 @@ def test_fwp_multi_step_wind_hi_res_topo(): s1_model.save(s1_out_dir) s2_model.save(s2_out_dir) - exo_kwargs = {'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2], - 'agg_factors': [2, 4, 12], - } - - model_kwargs = {'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir} + exo_kwargs = { + 'file_paths': input_files, + 'features': ['topography'], + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 's_enhancements': [1, 2, 2], + 'agg_factors': [2, 4, 12], + } + + model_kwargs = { + 'spatial_model_dirs': [s1_out_dir, s2_out_dir], + 'temporal_model_dirs': st_out_dir + } out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) # should get an error on a bad tensorflow concatenation with pytest.raises(RuntimeError): exo_kwargs['s_enhancements'] = [1, 1, 1] handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, + input_files, + model_kwargs=model_kwargs, model_class='SpatialThenTemporalGan', fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, temporal_pad=1, + spatial_pad=1, + temporal_pad=1, input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, worker_kwargs=dict(max_workers=1), @@ -1289,10 +1733,12 @@ def test_fwp_multi_step_wind_hi_res_topo(): exo_kwargs['s_enhancements'] = [1, 2, 2] handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, + input_files, + model_kwargs=model_kwargs, model_class='SpatialThenTemporalGan', fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, temporal_pad=1, + spatial_pad=1, + temporal_pad=1, input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, worker_kwargs=dict(max_workers=1), @@ -1311,37 +1757,63 @@ def test_fwp_wind_hi_res_topo_plus_linear(): temporal enhancement.""" Sup3rGan.seed() - gen_model = [{"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 64, "kernel_size": 3, - "strides": 1}, - {"class": "Cropping2D", "cropping": 4}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 64, - "kernel_size": 3, "strides": 1}, - {"class": "Cropping2D", "cropping": 4}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 64, - "kernel_size": 3, "strides": 1}, - {"class": "Cropping2D", "cropping": 4}, - {"class": "SpatialExpansion", "spatial_mult": 2}, - {"alpha": 0.2, "class": "LeakyReLU"}, - - {"class": "Sup3rConcat"}, - - {"class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, - {"class": "Conv2DTranspose", "filters": 2, - "kernel_size": 3, "strides": 1}, - {"class": "Cropping2D", "cropping": 4}] + gen_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "SpatialExpansion", + "spatial_mult": 2 + }, { + "alpha": 0.2, + "class": "LeakyReLU" + }, { + "class": "Sup3rConcat" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 2, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping2D", + "cropping": 4 + }] fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') s_model = WindGan(gen_model, fp_disc, learning_rate=1e-4) @@ -1352,7 +1824,8 @@ def test_fwp_wind_hi_res_topo_plus_linear(): _ = s_model.generate(np.ones((4, 10, 10, 3)), exogenous_data=(None, np.ones((4, 20, 20, 1)))) - t_model = LinearInterp(features=['U_100m', 'V_100m'], s_enhance=1, + t_model = LinearInterp(features=['U_100m', 'V_100m'], + s_enhance=1, t_enhance=4) with tempfile.TemporaryDirectory() as td: @@ -1363,30 +1836,35 @@ def test_fwp_wind_hi_res_topo_plus_linear(): s_model.save(s_out_dir) t_model.save(t_out_dir) - exo_kwargs = {'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2], - 'agg_factors': [2, 4], - } - - model_kwargs = {'spatial_model_dirs': s_out_dir, - 'temporal_model_dirs': t_out_dir} + exo_kwargs = { + 'file_paths': input_files, + 'features': ['topography'], + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 's_enhancements': [1, 2], + 'agg_factors': [2, 4], + } + + model_kwargs = { + 'spatial_model_dirs': s_out_dir, + 'temporal_model_dirs': t_out_dir + } out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict( - target=target, shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) exo_kwargs['s_enhancements'] = [1, 2] handler = ForwardPassStrategy( - input_files, model_kwargs=model_kwargs, + input_files, + model_kwargs=model_kwargs, model_class='SpatialThenTemporalGan', fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, temporal_pad=1, + spatial_pad=1, + temporal_pad=1, input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, worker_kwargs=dict(max_workers=1), From dd9c48226d9489b2a86cd430dab498b146578713 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 12 Sep 2023 15:29:36 -0600 Subject: [PATCH 02/15] linter fix --- tests/forward_pass/test_forward_pass.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 2c7d55a95..043043aab 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -16,8 +16,7 @@ from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.preprocessing.data_handling import DataHandlerNC from sup3r.utilities.pytest import (make_fake_multi_time_nc_files, - make_fake_nc_files, - ) + make_fake_nc_files) FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') TARGET_COORD = (39.01, -105.15) @@ -472,7 +471,7 @@ def test_fwp_chunking(log=False, plot=False): fwp.ti_slice.stop * t_enhance) data_chunked[fwp.hr_slice][..., t_hr_slice, :] = out - err = (data_chunked - data_nochunk) + err = data_chunked - data_nochunk err /= data_nochunk if plot: for ifeature in range(data_nochunk.shape[-1]): @@ -1414,13 +1413,13 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): def test_fwp_multi_step_exo_hi_res_topo_and_sza(): """Test the forward pass with multiple ExoGan models requiring - high-resolution topography and sza input from the exogenous_data feature.""" + high-resolution topography and sza input from the exogenous_data + feature.""" Sup3rGan.seed() gen_model = [{ "class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { + "mode": "REFLECT"}, { "class": "Conv2DTranspose", "filters": 64, "kernel_size": 3, From 4a295c4b027f99f23c6a23b3c1c6a9d999e8d945 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 13 Sep 2023 13:48:29 -0600 Subject: [PATCH 03/15] Edited exo data methods in forward_pass to accomodate temporal dependence. t_agg_factor added to ExoExtract. Prelim tests working for single exo feature. --- sup3r/models/abstract.py | 5 +- sup3r/pipeline/forward_pass.py | 154 ++++++++---------- .../data_handling/exo_extraction.py | 140 ++++++++++------ .../data_handling/exogenous_data_handling.py | 123 ++++++++------ sup3r/preprocessing/data_handling/mixin.py | 12 +- tests/forward_pass/test_forward_pass.py | 57 +++++-- 6 files changed, 286 insertions(+), 205 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 70a23ea78..86275044e 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -437,7 +437,7 @@ def norm_input(self, low_res): if any(self._stdevs == 0): stdevs = np.where(self._stdevs == 0, 1, self._stdevs) - msg = ('Some standard deviations are zero.') + msg = 'Some standard deviations are zero.' logger.warning(msg) warn(msg) else: @@ -961,6 +961,7 @@ def run_gradient_descent(self, lr_chunks = np.array_split(low_res, len(self.gpu_list)) hr_true_chunks = np.array_split(hi_res_true, len(self.gpu_list)) split_mask = False + mask_chunks = None if 'mask' in calc_loss_kwargs: split_mask = True mask_chunks = np.array_split(calc_loss_kwargs['mask'], @@ -977,7 +978,7 @@ def run_gradient_descent(self, training_weights, device_name=f'/gpu:{i}', **calc_loss_kwargs)) - for i, future in enumerate(futures): + for _, future in enumerate(futures): grad, loss_details = future.result() optimizer.apply_gradients(zip(grad, training_weights)) diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 7731f0f95..02be38f0c 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -19,18 +19,21 @@ import sup3r.bias.bias_transforms import sup3r.models -from sup3r.postprocessing.file_handling import (OutputHandler, OutputHandlerH5, - OutputHandlerNC, - ) +from sup3r.postprocessing.file_handling import ( + OutputHandler, + OutputHandlerH5, + OutputHandlerNC, +) from sup3r.preprocessing.data_handling import ExogenousDataHandler from sup3r.preprocessing.data_handling.base import InputMixIn from sup3r.utilities import ModuleName from sup3r.utilities.cli import BaseCLI from sup3r.utilities.execution import DistributedProcess -from sup3r.utilities.utilities import (get_chunk_slices, - get_input_handler_class, - get_source_type, - ) +from sup3r.utilities.utilities import ( + get_chunk_slices, + get_input_handler_class, + get_source_type, +) np.random.seed(42) @@ -40,8 +43,15 @@ class ForwardPassSlicer: """Get slices for sending data chunks through model.""" - def __init__(self, coarse_shape, time_steps, temporal_slice, chunk_shape, - s_enhancements, t_enhancements, spatial_pad, temporal_pad, + def __init__(self, + coarse_shape, + time_steps, + temporal_slice, + chunk_shape, + s_enhancements, + t_enhancements, + spatial_pad, + temporal_pad, ): """ Parameters @@ -212,7 +222,10 @@ def t_lr_pad_slices(self): """ if self._t_lr_pad_slices is None: self._t_lr_pad_slices = self.get_padded_slices( - self.t_lr_slices, self.time_steps, 1, self.temporal_pad, + self.t_lr_slices, + self.time_steps, + 1, + self.temporal_pad, self.temporal_slice.step, ) return self._t_lr_pad_slices @@ -307,13 +320,17 @@ def s_lr_crop_slices(self): if self._s_lr_crop_slices is None: self._s_lr_crop_slices = [] s1_crop_slices = self.get_cropped_slices(self.s1_lr_slices, - self.s1_lr_pad_slices, 1) + self.s1_lr_pad_slices, + 1) s2_crop_slices = self.get_cropped_slices(self.s2_lr_slices, - self.s2_lr_pad_slices, 1) + self.s2_lr_pad_slices, + 1) for i, _ in enumerate(self.s1_lr_slices): for j, _ in enumerate(self.s2_lr_slices): - lr_crop_slice = (s1_crop_slices[i], s2_crop_slices[j], - slice(None), slice(None), + lr_crop_slice = (s1_crop_slices[i], + s2_crop_slices[j], + slice(None), + slice(None), ) self._s_lr_crop_slices.append(lr_crop_slice) return self._s_lr_crop_slices @@ -763,11 +780,15 @@ def __init__(self, self.t_enhance = np.product(self.t_enhancements) self.output_features = model.output_features - self.fwp_slicer = ForwardPassSlicer( - self.grid_shape, self.raw_tsteps, self.temporal_slice, - self.fwp_chunk_shape, self.s_enhancements, self.t_enhancements, - self.spatial_pad, self.temporal_pad, - ) + self.fwp_slicer = ForwardPassSlicer(self.grid_shape, + self.raw_tsteps, + self.temporal_slice, + self.fwp_chunk_shape, + self.s_enhancements, + self.t_enhancements, + self.spatial_pad, + self.temporal_pad, + ) DistributedProcess.__init__(self, max_nodes=max_nodes, @@ -1090,10 +1111,13 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.input_data = self.bias_correct_source_data( self.input_data, self.strategy.lr_lat_lon) - exo_s_en = self.exo_kwargs.get('s_enhancements', None) - self.exo_kwargs.get('s_enhancements', None) - out = self.pad_source_data(self.input_data, self.pad_width, - self.exogenous_data, exo_s_en) + exo_s_en = [1, *self.strategy.s_enhancements] + exo_t_en = [1, *self.strategy.t_enhancements] + out = self.pad_source_data(self.input_data, + self.pad_width, + self.exogenous_data, + exo_s_en, + exo_t_en) self.input_data, self.exogenous_data = out self.unpadded_input_data = self.data_handler.data[self.lr_slice[0], self.lr_slice[1]] @@ -1107,22 +1131,33 @@ def load_exo_data(self): exo_data : dict Dictionary of data arrays with keys for each exogeneous feature """ - exo_data = {} + exo_data_dict = {} if self.exo_kwargs: self.features = [ f for f in self.features if f not in self.exo_features ] for feature in self.exo_features: exo_kwargs = copy.deepcopy(self.exo_kwargs[feature]) + exo_kwargs['feature'] = feature exo_kwargs['target'] = self.target exo_kwargs['shape'] = self.shape - exo_data[feature] = ExogenousDataHandler(**exo_kwargs).data + exo_kwargs['temporal_slice'] = self.ti_pad_slice + exo_data_dict[feature] = ExogenousDataHandler( + **exo_kwargs).data shapes = [ - None if d is None else d.shape for d in exo_data[feature] + None if d is None else d.shape + for d in exo_data_dict[feature] ] logger.info( 'Got exogenous_data of length {} with shapes: {}'.format( - len(exo_data[feature]), shapes)) + len(exo_data_dict[feature]), shapes)) + exo_data = [] + for i in range(len(exo_data_dict[self.exo_features[0]])): + exo_data_f = [] + for feature in self.exo_features: + exo_data_f.append(exo_data_dict[feature][i]) + exo_data.append( + np.vstack([arr for arr in exo_data_f if arr is not None])) return exo_data def update_input_handler_kwargs(self, strategy): @@ -1459,7 +1494,8 @@ def pad_source_data(input_data, logger.info('Padded input data shape from {} to {} using mode "{}" ' 'with padding argument: {}'.format(input_data.shape, - out.shape, mode, + out.shape, + mode, pad_width)) if exo_data is not None: @@ -1472,7 +1508,7 @@ def pad_source_data(input_data, total_s_enhance = np.product(total_s_enhance) total_t_enhance = exo_t_enhancements[:i + 1] total_t_enhance = [ - t for t in total_s_enhance if t is not None + t for t in total_t_enhance if t is not None ] total_t_enhance = np.product(total_t_enhance) exo_pad_width = ((total_s_enhance * pad_width[0][0], @@ -1480,9 +1516,9 @@ def pad_source_data(input_data, (total_s_enhance * pad_width[1][0], total_s_enhance * pad_width[1][1]), (total_t_enhance * pad_width[2][0], - total_t_enhance * pad_width[2][1])) + total_t_enhance * pad_width[2][1]), + (0, 0)) exo_data[i] = np.pad(i_exo_data, exo_pad_width, mode=mode) - return out, exo_data def bias_correct_source_data(self, data, lat_lon): @@ -1522,52 +1558,12 @@ def bias_correct_source_data(self, data, lat_lon): 'using function: {} with kwargs: {}'.format( feature, idf, method, feature_kwargs)) - data[..., idf] = method(data[..., idf], lat_lon, + data[..., idf] = method(data[..., idf], + lat_lon, **feature_kwargs) return data - def _prep_exogenous_input(self, chunk_shape): - """Shape exogenous data according to model type and model steps - - Parameters - ---------- - chunk_shape : tuple - Shape of data chunk going through forward pass - - Returns - ------- - exo_data : dict - Dictionary of list of arrays with keys for each exogenous feature - and arrays of data for each feature with array length depending on - model steps. If there are 2 spatial enhancement steps this is a - list of 3 arrays each with the appropriate shape based on the - enhancement factor - """ - exo_data = [] - for i in range(len(self.exogenous_data[self.exo_features[0]])): - exo_data_f = [] - for feature in self.exo_features: - arr = self.exogenous_data[feature][i] - if arr is not None: - og_shape = arr.shape - arr = np.expand_dims(arr, axis=2) - arr = np.repeat(arr, chunk_shape[2], axis=2) - - target_shape = (arr.shape[0], arr.shape[1], chunk_shape[2], - arr.shape[-1], - ) - msg = ('Target shape for exogenous data in forward ' - 'pass chunk was {}, but something went wrong ' - 'and i resized original data shape from {} to ' - '{}'.format(target_shape, og_shape, arr.shape)) - assert arr.shape == target_shape, msg - - exo_data_f.append(arr) - exo_data.append( - np.vstack([arr for arr in exo_data_f if arr is not None])) - return exo_data - @classmethod def _run_generator(cls, data_chunk, @@ -1904,8 +1900,7 @@ def _run_parallel(cls, strategy, node_index): chunk_index=chunk_index, ) futures[fut] = { - 'chunk_index': chunk_index, - 'start_time': dt.now(), + 'chunk_index': chunk_index, 'start_time': dt.now(), } logger.info(f'Started {len(futures)} forward pass runs in ' @@ -1942,20 +1937,15 @@ def run_chunk(self): f'{self.strategy.temporal_pad}.') logger.info(msg) - data_chunk = self.input_data - exo_data = None - if self.exogenous_data is not None: - exo_data = self._prep_exogenous_input(data_chunk.shape) - self.output_data = self._run_generator( - data_chunk, + self.input_data, hr_crop_slices=self.hr_crop_slice, model=self.model, model_kwargs=self.model_kwargs, model_class=self.model_class, s_enhance=self.s_enhance, t_enhance=self.t_enhance, - exo_data=exo_data, + exo_data=self.exogenous_data, ) self._constant_output_check(self.output_data) diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 80ccd915b..45b4baadf 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -27,15 +27,16 @@ def __init__(self, file_paths, exo_source, s_enhance, - agg_factor, + t_enhance, + s_agg_factor, + t_agg_factor, target=None, shape=None, temporal_slice=None, raster_file=None, max_delta=20, input_handler=None, - ti_workers=1, - t_enhance=1): + ti_workers=1): """ Parameters ---------- @@ -55,14 +56,26 @@ def __init__(self, example, if getting topography data, file_paths has 100km data, and s_enhance is 4, this class will output a topography raster corresponding to the file_paths grid enhanced 4x to ~25km - agg_factor : int - Factor by which to aggregate the topo_source_h5 elevation - data to the resolution of the file_paths input enhanced by - s_enhance. For example, if getting topography data, file_paths has - 100km data, and s_enhance is 4 resulting in a desired resolution of - ~25km and topo_source_h5 has a resolution of 4km, the agg_factor - should be 36 so that 6x6 4km cells are averaged to the ~25km - enhanced grid. + t_enhance : int + Factor by which the Sup3rGan model will enhance the temporal + dimension of low resolution data from file_paths input. For + example, if getting sza data, file_paths has hourly data, and + t_enhance is 4, this class will output a sza raster + corresponding to the file_paths temporally enhanced 4x to 15 min + s_agg_factor : int + Factor by which to aggregate the exo_source data to the resolution + of the file_paths input enhanced by s_enhance. For example, if + getting topography data, file_paths have 100km data, and s_enhance + is 4 resulting in a desired resolution of ~25km and topo_source_h5 + has a resolution of 4km, the s_agg_factor should be 36 so that 6x6 + 4km cells are averaged to the ~25km enhanced grid. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the resolution + of the file_paths input enhanced by t_enhance. For example, if + getting sza data, file_paths have hourly data, and t_enhance + is 4 resulting in a desired resolution of 5 min and exo_source + has a resolution of 5 min, the t_agg_factor should be 4 so that + every fourth timestep in the exo_source data is skipped. target : tuple (lat, lon) lower left corner of raster. Either need target+shape or raster_file. @@ -93,12 +106,6 @@ def __init__(self, parallel and then concatenated to get the full time index. If input files do not all have time indices or if there are few input files this should be set to one. - t_enhance : int - Factor by which the Sup3rGan model will enhance the temporal - dimension of low resolution data from file_paths input. For - example, if getting sza data, file_paths has hourly data, and - t_enhance is 4, this class will output a sza raster - corresponding to the file_paths temporally enhanced 4x to 15 min """ logger.info(f'Initializing {self.__class__.__name__} utility.') @@ -106,7 +113,8 @@ def __init__(self, self._exo_source = exo_source self._s_enhance = s_enhance self._t_enhance = t_enhance - self._agg_factor = agg_factor + self._s_agg_factor = s_agg_factor + self._t_agg_factor = t_agg_factor self._tree = None self.ti_workers = ti_workers self._hr_lat_lon = None @@ -147,6 +155,21 @@ def __init__(self, def source_data(self): """Get the 1D array of source data from the exo_source_h5""" + @property + @abstractmethod + def hr_time_index(self): + """Get the full time index of the exo_source data""" + + @property + def hr_temporal_slice(self): + """Get the temporal slice fr the exo_source data corresponding to the + input file temporal slice""" + start_index = self.hr_time_index.get_loc( + self.input_handler.time_index[0], method='nearest') + end_index = self.hr_time_index.get_loc( + self.input_handler.time_index[-1], method='nearest') + return slice(start_index, end_index + 1) + @property @abstractmethod def source_lat_lon(self): @@ -190,7 +213,7 @@ def hr_lat_lon(self): if self._hr_lat_lon is None: if self._s_enhance > 1: self._hr_lat_lon = OutputHandler.get_lat_lon( - self.lr_lat_lon, self.hr_shape) + self.lr_lat_lon, self.hr_shape[:-1]) else: self._hr_lat_lon = self.lr_lat_lon return self._hr_lat_lon @@ -206,10 +229,9 @@ def tree(self): def nn(self): """Get the nearest neighbor indices""" ll2 = np.vstack( - (self.hr_lat_lon[:, :, 0].flatten(), self.hr_lat_lon[:, :, - 1].flatten(), - )).T - _, nn = self.tree.query(ll2, k=self._agg_factor) + (self.hr_lat_lon[:, :, 0].flatten(), + self.hr_lat_lon[:, :, 1].flatten())).T + _, nn = self.tree.query(ll2, k=self._s_agg_factor) if len(nn.shape) == 1: nn = np.expand_dims(nn, 1) return nn @@ -217,16 +239,16 @@ def nn(self): @property def data(self): """Get a raster of source values corresponding to the - high-resolution grid (the file_paths input grid * s_enhance). The shape - is (rows, cols) + high-resolution grid (the file_paths input grid * s_enhance * + t_enhance). The shape is (rows, cols, temporal) """ nn = self.nn hr_data = [] - for j in range(self._agg_factor): - out = self.source_data[nn[:, j]] + for j in range(self._s_agg_factor): + out = self.source_data[nn[:, j], ::self._t_agg_factor] out = out.reshape(self.hr_shape) - hr_data.append(out) - hr_data = np.dstack(hr_data).mean(axis=-1) + hr_data.append(out[..., np.newaxis]) + hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) logger.info('Finished mapping raster from {}'.format(self._exo_source)) return hr_data @@ -235,13 +257,15 @@ def get_exo_raster(cls, file_paths, exo_source, s_enhance, - agg_factor, + t_enhance, + s_agg_factor, + t_agg_factor, target=None, shape=None, + temporal_slice=None, raster_file=None, max_delta=20, - input_handler=None, - t_enhance=1): + input_handler=None): """Get the exo feature raster corresponding to the spatially enhanced grid from the file_paths input @@ -261,18 +285,34 @@ def get_exo_raster(cls, example, if file_paths has 100km data and s_enhance is 4, this class will output a topography raster corresponding to the file_paths grid enhanced 4x to ~25km - agg_factor : int - Factor by which to aggregate the topo_source_h5 elevation data to - the resolution of the file_paths input enhanced by s_enhance. For - example, if file_paths has 100km data and s_enhance is 4 resulting - in a desired resolution of ~25km and topo_source_h5 has a - resolution of 4km, the agg_factor should be 36 so that 6x6 4km - cells are averaged to the ~25km enhanced grid. + t_enhance : int + Factor by which the Sup3rGan model will enhance the temporal + dimension of low resolution data from file_paths input. For + example, if getting sza data, file_paths has hourly data, and + t_enhance is 4, this class will output a sza raster + corresponding to the file_paths temporally enhanced 4x to 15 min + s_agg_factor : int + Factor by which to aggregate the exo_source data to the resolution + of the file_paths input enhanced by s_enhance. For example, if + getting topography data, file_paths have 100km data, and s_enhance + is 4 resulting in a desired resolution of ~25km and topo_source_h5 + has a resolution of 4km, the s_agg_factor should be 36 so that 6x6 + 4km cells are averaged to the ~25km enhanced grid. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the resolution + of the file_paths input enhanced by t_enhance. For example, if + getting sza data, file_paths have hourly data, and t_enhance + is 4 resulting in a desired resolution of 5 min and exo_source + has a resolution of 5 min, the t_agg_factor should be 4 so that + every fourth timestep in the exo_source data is skipped. target : tuple (lat, lon) lower left corner of raster. Either need target+shape or raster_file. shape : tuple (rows, cols) grid size. Either need target+shape or raster_file. + temporal_slice : slice | None + slice used to extract interval from temporal dimension for input + data and source data raster_file : str | None File for raster_index array for the corresponding target and shape. If specified the raster_index will be loaded from the file if it @@ -288,12 +328,6 @@ class will output a topography raster corresponding to the data handler class to use for input data. Provide a string name to match a class in data_handling.py. If None the correct handler will be guessed based on file type and time series properties. - t_enhance : int - Factor by which the Sup3rGan model will enhance the temporal - dimension of low resolution data from file_paths input. For - example, if getting sza data, file_paths has hourly data, and - t_enhance is 4, this class will output a sza raster - corresponding to the file_paths temporally enhanced 4x to 15 min Returns ------- @@ -304,18 +338,18 @@ class will output a topography raster corresponding to the to the source units in exo_source_h5. This is usually meters when feature='topography' """ - exo = cls(file_paths, exo_source, s_enhance, - agg_factor, + t_enhance, + s_agg_factor, + t_agg_factor, target=target, shape=shape, + temporal_slice=temporal_slice, raster_file=raster_file, max_delta=max_delta, - input_handler=input_handler, - t_enhance=t_enhance) - + input_handler=input_handler) return exo.data @@ -327,6 +361,7 @@ def source_data(self): """Get the 1D array of elevation data from the exo_source_h5""" with Resource(self._exo_source) as res: elev = res.get_meta_arr('elevation') + elev = np.repeat(elev[:, np.newaxis], self.hr_shape[-1], axis=-1) return elev @property @@ -362,13 +397,13 @@ def __init__(self, *args, **kwargs): @property def source_data(self): - """Get the 1D array of elevation data from the exo_source_h5""" + """Get the 1D array of elevation data from the exo_source_nc""" elev = self.source_handler.data.reshape(-1) return elev @property def source_lat_lon(self): - """Get the 2D array (n, 2) of lat, lon data from the exo_source_h5""" + """Get the 2D array (n, 2) of lat, lon data from the exo_source_nc""" source_lat_lon = self.source_handler.lat_lon.reshape((-1, 2)) return source_lat_lon @@ -381,6 +416,7 @@ def source_data(self): """Get the 1D array of sza data from the exo_source_h5""" with Resource(self._exo_source) as res: elev = res.get_meta_arr('elevation') + elev = np.repeat(elev[:, np.newaxis], self.hr_shape[-1], axis=-1) return elev @property diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index 68baa591b..7962c3c8e 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -11,8 +11,7 @@ from sup3r.preprocessing.data_handling.exo_extraction import (SzaExtractH5, SzaExtractNC, TopoExtractH5, - TopoExtractNC, - ) + TopoExtractNC) from sup3r.utilities.utilities import get_source_type logger = logging.getLogger(__name__) @@ -39,16 +38,18 @@ def __init__(self, feature, source_file, s_enhancements, - agg_factors, + t_enhancements, + s_agg_factors, + t_agg_factors, target=None, shape=None, + temporal_slice=None, raster_file=None, max_delta=20, input_handler=None, exo_handler=None, exo_steps=None, - cache_data=True, - t_enhancements=None): + cache_data=True): """ Parameters ---------- @@ -78,16 +79,29 @@ def __init__(self, receive a high-res input feature (e.g. WindGan). The length of this list should be equal to the number of agg_factors and the number of exo_steps - agg_factors : list - List of factors by which to aggregate the topo_source_h5 elevation - data to the resolution of the file_paths input enhanced by + t_enhancements : list + List of factors by which the Sup3rGan model will enhance the + temporal dimension of low resolution data from file_paths input + where the total temporal enhancement is the product of these + factors. + s_agg_factors : list + List of factors by which to aggregate the exo_source + data to the spatial resolution of the file_paths input enhanced by s_enhance. The length of this list should be equal to the number of s_enhancements and the number of exo_steps + t_agg_factors : list + List of factors by which to aggregate the exo_source + data to the temporal resolution of the file_paths input enhanced by + t_enhance. The length of this list should be equal to the number of + t_enhancements and the number of exo_steps target : tuple (lat, lon) lower left corner of raster. Either need target+shape or raster_file. shape : tuple (rows, cols) grid size. Either need target+shape or raster_file. + temporal_slice : slice | None + slice used to extract interval from temporal dimension for input + data and source data raster_file : str | None File for raster_index array for the corresponding target and shape. If specified the raster_index will be loaded from the file if it @@ -112,23 +126,22 @@ def __init__(self, List of model step indices for which exogenous data is required. e.g. If we have two model steps which take exo data and one which does not exo_steps = [0, 1]. The length of this list should be - equal to the number of s_enhancements and the number of agg_factors + equal to the number of s/t_enhancements and the number of + s/t_agg_factors cache_data : bool Flag to cache exogeneous data in ./exo_cache/ this can speed up forward passes with large temporal extents - t_enhancements : list - List of factors by which the Sup3rGan model will enhance the - temporal dimension of low resolution data from file_paths input - where the total temporal enhancement is the product of these - factors. """ self.feature = feature self.s_enhancements = s_enhancements - self.agg_factors = agg_factors + self.t_enhancements = t_enhancements + self.s_agg_factors = s_agg_factors + self.t_agg_factors = t_agg_factors self.source_file = source_file self.file_paths = file_paths self.exo_handler = exo_handler + self.temporal_slice = temporal_slice self.target = target self.shape = shape self.raster_file = raster_file @@ -136,8 +149,6 @@ def __init__(self, self.input_handler = input_handler self.cache_data = cache_data self.data = [] - self.t_enhancements = (t_enhancements if t_enhancements is not None - else [1] * len(self.s_enhancements)) exo_steps = exo_steps or np.arange(len(self.s_enhancements)) if self.s_enhancements[0] != 1: @@ -157,10 +168,12 @@ def __init__(self, msg = ('Need to provide the same number of enhancement factors and ' f'agg factors. Received s_enhancements={self.s_enhancements}, ' - f't_enhancements={self.t_enhancements}, and ' - f'agg_factors={agg_factors}.') - assert len(self.s_enhancements) == len(self.agg_factors), msg - assert len(self.t_enhancements) == len(self.agg_factors), msg + f'and s_agg_factors={self.s_agg_factors}.') + assert len(self.s_enhancements) == len(self.s_agg_factors), msg + msg = ('Need to provide the same number of enhancement factors and ' + f'agg factors. Received t_enhancements={self.t_enhancements}, ' + f'and t_agg_factors={self.t_agg_factors}.') + assert len(self.t_enhancements) == len(self.t_agg_factors), msg msg = ('Need to provide an integer enhancement factor for each model' 'step. If the step is temporal enhancement then s_enhance=1') @@ -169,14 +182,16 @@ def __init__(self, for i in range(len(self.s_enhancements)): s_enhance = np.product(self.s_enhancements[:i + 1]) t_enhance = np.product(self.t_enhancements[:i + 1]) - agg_factor = self.agg_factors[i] + s_agg_factor = self.s_agg_factors[i] + t_agg_factor = self.t_agg_factors[i] fdata = [] if i in exo_steps: if feature in list(self.AVAILABLE_HANDLERS): data = self.get_exo_data(feature=feature, s_enhance=s_enhance, t_enhance=t_enhance, - agg_factor=agg_factor) + s_agg_factor=s_agg_factor, + t_agg_factor=t_agg_factor) fdata.append(data) else: msg = (f"Can only extract {list(self.AVAILABLE_HANDLERS)}." @@ -186,7 +201,8 @@ def __init__(self, else: self.data.append(None) - def get_cache_file(self, feature, s_enhance, t_enhance, agg_factor): + def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, + t_agg_factor): """Get cache file name Parameters @@ -199,10 +215,12 @@ def get_cache_file(self, feature, s_enhance, t_enhance, agg_factor): t_enhance : int Temporal enhancement for this exogeneous data step (cumulative for all model steps up to the current step). - agg_factor : int - Factor by which to aggregate the topo_source_h5 elevation - data to the resolution of the file_paths input enhanced by - s_enhance. + s_agg_factor : int + Factor by which to aggregate the exo_source data to the spatial + resolution of the file_paths input enhanced by s_enhance. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the temporal + resolution of the file_paths input enhanced by t_enhance. Returns ------- @@ -210,8 +228,8 @@ def get_cache_file(self, feature, s_enhance, t_enhance, agg_factor): Name of cache file """ cache_dir = './exo_cache/' - fn = f'exo_{feature}_{self.target}_{self.shape}_agg{agg_factor}_' - fn += f'{s_enhance}x_{t_enhance}x.pkl' + fn = f'exo_{feature}_{self.target}_{self.shape}_sagg{s_agg_factor}_' + fn += f'tagg_{t_agg_factor}_{s_enhance}x_{t_enhance}x.pkl' fn = fn.replace('(', '').replace(')', '') fn = fn.replace('[', '').replace(']', '') fn = fn.replace(',', 'x').replace(' ', '') @@ -220,7 +238,8 @@ def get_cache_file(self, feature, s_enhance, t_enhance, agg_factor): os.makedirs(cache_dir, exist_ok=True) return cache_fp - def get_exo_data(self, feature, s_enhance, t_enhance, agg_factor): + def get_exo_data(self, feature, s_enhance, t_enhance, s_agg_factor, + t_agg_factor): """Get the exogenous topography data Parameters @@ -233,10 +252,12 @@ def get_exo_data(self, feature, s_enhance, t_enhance, agg_factor): t_enhance : int Temporal enhancement for this exogeneous data step (cumulative for all model steps up to the current step). - agg_factor : int - Factor by which to aggregate the topo_source_h5 elevation - data to the resolution of the file_paths input enhanced by - s_enhance. + s_agg_factor : int + Factor by which to aggregate the exo_source data to the spatial + resolution of the file_paths input enhanced by s_enhance. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the temporal + resolution of the file_paths input enhanced by t_enhance. Returns ------- @@ -245,9 +266,13 @@ def get_exo_data(self, feature, s_enhance, t_enhance, agg_factor): lon, temporal) """ - cache_fp = self.get_cache_file(feature=feature) + cache_fp = self.get_cache_file(feature=feature, + s_enhance=s_enhance, + t_enhance=t_enhance, + s_agg_factor=s_agg_factor, + t_agg_factor=t_agg_factor) tmp_fp = cache_fp + '.tmp' - + print(cache_fp) if os.path.exists(cache_fp): with open(cache_fp, 'rb') as f: data = pickle.load(f) @@ -255,19 +280,21 @@ def get_exo_data(self, feature, s_enhance, t_enhance, agg_factor): else: exo_handler = self.get_exo_handler(feature, self.source_file, self.exo_handler) - dh = exo_handler(self.file_paths, - self.source_file, - s_enhance=s_enhance, - t_enhance=t_enhance, - agg_factor=agg_factor, - target=self.target, - shape=self.shape, - raster_file=self.raster_file, - max_delta=self.max_delta, - input_handler=self.input_handler) + data = exo_handler(self.file_paths, + self.source_file, + s_enhance=s_enhance, + t_enhance=t_enhance, + s_agg_factor=s_agg_factor, + t_agg_factor=t_agg_factor, + target=self.target, + shape=self.shape, + temporal_slice=self.temporal_slice, + raster_file=self.raster_file, + max_delta=self.max_delta, + input_handler=self.input_handler).data if self.cache_data: with open(tmp_fp, 'wb') as f: - pickle.dump(dh.data, f) + pickle.dump(data, f) shutil.move(tmp_fp, cache_fp) return data diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index 6f29d178d..c52c3640d 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -15,10 +15,12 @@ import pandas as pd from scipy.stats import mode -from sup3r.utilities.utilities import (get_source_type, ignore_case_path_fetch, - uniform_box_sampler, - uniform_time_sampler, - ) +from sup3r.utilities.utilities import ( + get_source_type, + ignore_case_path_fetch, + uniform_box_sampler, + uniform_time_sampler, +) np.random.seed(42) @@ -601,6 +603,8 @@ def temporal_slice(self, temporal_slice): elements and no more than three, corresponding to the inputs of slice() """ + if temporal_slice is None: + temporal_slice = slice(None) msg = 'temporal_slice must be tuple, list, or slice' assert isinstance(temporal_slice, (tuple, list, slice)), msg if isinstance(temporal_slice, slice): diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 043043aab..56c1c054d 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -15,8 +15,10 @@ from sup3r.models import LinearInterp, Sup3rGan, WindGan from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.preprocessing.data_handling import DataHandlerNC -from sup3r.utilities.pytest import (make_fake_multi_time_nc_files, - make_fake_nc_files) +from sup3r.utilities.pytest import ( + make_fake_multi_time_nc_files, + make_fake_nc_files, +) FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') TARGET_COORD = (39.01, -105.15) @@ -617,7 +619,9 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): 'target': target, 'shape': shape, 's_enhancements': [1, 2, 2], - 'agg_factors': [2, 4, 16], + 't_enhancements': [1, 1, 1], + 's_agg_factors': [16, 4, 2], + 't_agg_factors': [1, 1, 1], 'exo_steps': [0, 1] } } @@ -719,7 +723,9 @@ def test_fwp_multi_step_spatial_model_topo_noskip(): 'target': target, 'shape': shape, 's_enhancements': [1, 2, 2], - 'agg_factors': [12, 4, 2] + 't_enhancements': [1, 1, 1], + 's_agg_factors': [12, 4, 2], + 't_agg_factors': [1, 1, 1] } } @@ -824,16 +830,20 @@ def test_fwp_multi_step_model_multi_exo(): 'source_file': FP_WTK, 'target': target, 'shape': shape, - 's_enhancements': [1, 2, 2], - 'agg_factors': [2, 4, 12] + 's_enhancements': [1, 2, 2, 3], + 't_enhancements': [1, 1, 1, 4], + 's_agg_factors': [12, 4, 2, 1], + 't_agg_factors': [4, 4, 4, 1] }, 'sza': { 'file_paths': input_files, 'source_file': FP_WTK, 'target': target, 'shape': shape, - 's_enhancements': [1, 2, 2], - 'agg_factors': [2, 4, 12] + 's_enhancements': [1, 2, 2, 1], + 't_enhancements': [1, 1, 1, 4], + 's_agg_factors': [12, 4, 2, 1], + 't_agg_factors': [4, 4, 4, 1] } } @@ -942,8 +952,11 @@ def test_fwp_multi_step_model_topo_noskip(): 'source_file': FP_WTK, 'target': target, 'shape': shape, - 's_enhancements': [1, 2, 2], - 'agg_factors': [2, 4, 12] + 's_enhancements': [1, 2, 2, 3], + 't_enhancements': [1, 1, 1, 4], + 's_agg_factors': [12, 4, 2, 1], + 't_agg_factors': [4, 4, 4, 1], + 'cache_data': False } } @@ -1348,7 +1361,9 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): 'target': target, 'shape': shape, 's_enhancements': [1, 2], - 'agg_factors': [2, 4], + 't_enhancements': [1, 1], + 's_agg_factors': [4, 2], + 't_agg_factors': [1, 1], } model_kwargs = {'model_dir': st_out_dir} @@ -1527,16 +1542,20 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): 'source_file': FP_WTK, 'target': target, 'shape': shape, - 's_enhancements': [1, 2, 2], - 'agg_factors': [2, 4, 12] + 's_enhancements': [1, 2, 2, 3], + 't_enhancements': [1, 1, 1, 4], + 's_agg_factors': [12, 4, 2, 1], + 't_agg_factors': [4, 4, 4, 1] }, 'sza': { 'file_paths': input_files, 'source_file': FP_WTK, 'target': target, 'shape': shape, - 's_enhancements': [1, 2, 2], - 'agg_factors': [2, 4, 12] + 's_enhancements': [1, 2, 2, 3], + 't_enhancements': [1, 1, 1, 4], + 's_agg_factors': [12, 4, 2, 1], + 't_agg_factors': [4, 4, 4, 1] } } @@ -1698,7 +1717,9 @@ def test_fwp_multi_step_wind_hi_res_topo(): 'target': target, 'shape': shape, 's_enhancements': [1, 2, 2], - 'agg_factors': [2, 4, 12], + 't_enhancements': [1, 1, 1], + 's_agg_factors': [12, 4, 2], + 't_agg_factors': [1, 1, 1], } model_kwargs = { @@ -1842,7 +1863,9 @@ def test_fwp_wind_hi_res_topo_plus_linear(): 'target': target, 'shape': shape, 's_enhancements': [1, 2], - 'agg_factors': [2, 4], + 't_enhancements': [1, 1], + 's_agg_factors': [4, 2], + 't_agg_factors': [1, 1] } model_kwargs = { From e8fe0376ecd07438024f111958fcc3e49b15e5d9 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 19 Sep 2023 11:23:36 -0600 Subject: [PATCH 04/15] forward pass methods for parsing exo_kwargs dict, computing agg factors, and extracting exo data. --- sup3r/models/abstract.py | 36 +++- sup3r/models/base.py | 7 +- sup3r/pipeline/forward_pass.py | 187 ++++++++++++++---- .../data_handling/exogenous_data_handling.py | 51 ++--- 4 files changed, 208 insertions(+), 73 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 86275044e..71b351e1d 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -6,6 +6,7 @@ import logging import os import pprint +import re import time from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor @@ -118,6 +119,27 @@ def t_enhance(self): model training during high res coarsening""" return self.meta.get('t_enhance', None) + @property + def input_resolution(self): + """Resolution of input data. Given as a dictionary {'spatial':..., + 'temporal':...}""" + return self.meta.get('input_resolution', None) + + @property + def output_resolution(self): + """Resolution of output data. Given as a dictionary {'spatial':..., + 'temporal':...}""" + input_res = self.input_resolution + output_res = {} if input_res is None else input_res.copy() + if input_res is not None: + input_temporal = re.search(r'\d+', input_res['temporal']).group(0) + input_spatial = re.search(r'\d+', input_res['spatial']).group(0) + output_temporal = int(self.t_enhance * input_temporal) + output_spatial = int(self.s_enhance * input_spatial) + output_res['temporal'].replace(input_temporal, output_temporal) + output_res['spatial'].replace(input_spatial, output_spatial) + return output_res + @property def needs_hr_exo(self): """Determine whether or not the sup3r model needs hi-res exogenous data @@ -229,8 +251,8 @@ def set_model_params(self, **kwargs): 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' """ - keys = ('training_features', 'output_features', 'smoothed_features', - 's_enhance', 't_enhance', 'smoothing') + keys = ('input_resolution', 'training_features', 'output_features', + 'smoothed_features', 's_enhance', 't_enhance', 'smoothing') keys = [k for k in keys if k in kwargs] for var in keys: @@ -1004,10 +1026,10 @@ def set_model_params(**kwargs): Parameters ---------- kwargs : dict - Keyword arguments including 'training_features', 'output_features', - 'smoothed_features', 's_enhance', 't_enhance', 'smoothing'. For the - Wind classes, the last entry in "output_features" must be - "topography" + Keyword arguments including 'input_resolution', + 'training_features', 'output_features', 'smoothed_features', + 's_enhance', 't_enhance', 'smoothing'. For the Wind classes, the + last entry in "output_features" must be "topography" Returns ------- @@ -1019,7 +1041,7 @@ def set_model_params(**kwargs): """ output_features = kwargs['output_features'] msg = ('Last output feature from the data handler must be topography ' - 'to train the WindCC model, but received output features: {}'. + 'to train the Wind model, but received output features: {}'. format(output_features)) assert output_features[-1] == 'topography', msg output_features.remove('topography') diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 9e97b18ec..b5b664358 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -639,7 +639,7 @@ def calc_loss(self, loss_gen_content = self.calc_loss_gen_content(hi_res_true, hi_res_gen) loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen) - loss_gen = (loss_gen_content + weight_gen_advers * loss_gen_advers) + loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen) @@ -851,6 +851,7 @@ def update_adversarial_weights(self, history, adaptive_update_fraction, def train(self, batch_handler, + input_resolution, n_epoch, weight_gen_advers=0.001, train_gen=True, @@ -870,6 +871,9 @@ def train(self, ---------- batch_handler : sup3r.data_handling.preprocessing.BatchHandler BatchHandler object to iterate through + input_resolution : dict + Dictionary specifying spatiotemporal input resolution. e.g. + {'temporal': '60min', 'spatial': '30km'} n_epoch : int Number of epochs to train on weight_gen_advers : float @@ -925,6 +929,7 @@ def train(self, self.set_norm_stats(batch_handler.means, batch_handler.stds) self.set_model_params( + input_resolution=input_resolution, s_enhance=batch_handler.s_enhance, t_enhance=batch_handler.t_enhance, smoothing=batch_handler.smoothing, diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 02be38f0c..c9571c664 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -7,6 +7,7 @@ import copy import logging import os +import re import warnings from concurrent.futures import as_completed from datetime import datetime as dt @@ -92,12 +93,6 @@ def __init__(self, passes for subsequent temporal stitching. This overlap will pad both sides of the fwp_chunk_shape. Note that the first and last chunks in the temporal dimension will not be padded. - exo_s_enhancements : list - List of spatial enhancement steps specific to the exogenous_data - inputs. This differs from s_enhancements in that s_enhancements[0] - will be the spatial enhancement of the first model, but - exo_s_enhancements[0] may be 1 to signify exo data is required for - the first non-enhanced spatial input resolution. """ self.grid_shape = coarse_shape self.time_steps = time_steps @@ -697,7 +692,14 @@ def __init__(self, exo_kwargs : dict | None Dictionary of args to pass to ExogenousDataHandler for extracting exogenous features such as topography for future multistep foward - pass + pass. This should be a nested dictionary with keys for each + exogeneous feature. The dictionaries corresponding to the feature + names should include the path to exogenous data source, the + resolution of the exogenous data, and how the exogenous data should + be used in the model. e.g. {'topography': {'file_paths': 'path to + input files', 'source_file': 'path to exo data', 'exo_resolution': + {'spatial': '1km', 'temporal': None}, 'steps': [{'model': 0, + 'concat_type': 'input'}, {'model': 0, 'concat_type': 'layer'}]} bias_correct_method : str | None Optional bias correction function name that can be imported from the :mod:`sup3r.bias.bias_transforms` module. This will transform @@ -727,8 +729,7 @@ def __init__(self, shape=grid_shape, raster_file=raster_file, raster_index=raster_index, - temporal_slice=temporal_slice, - ) + temporal_slice=temporal_slice) self.file_paths = file_paths self.model_kwargs = model_kwargs @@ -787,14 +788,12 @@ def __init__(self, self.s_enhancements, self.t_enhancements, self.spatial_pad, - self.temporal_pad, - ) + self.temporal_pad) DistributedProcess.__init__(self, max_nodes=max_nodes, max_chunks=self.fwp_slicer.n_chunks, - incremental=self.incremental, - ) + incremental=self.incremental) self.preflight() @@ -1090,7 +1089,7 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.max_workers = strategy.max_workers self.pass_workers = strategy.pass_workers self.output_workers = strategy.output_workers - self.exo_kwargs = strategy.exo_kwargs + self.exo_kwargs = self._prep_exo_extract_kwargs(strategy.exo_kwargs) self.exo_features = ([] if not self.exo_kwargs else list(self.exo_kwargs)) self.exogenous_data = self.load_exo_data() @@ -1122,6 +1121,127 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.unpadded_input_data = self.data_handler.data[self.lr_slice[0], self.lr_slice[1]] + def get_agg_factor(self, input_res, exo_res): + """Compute agg factor for exo data given input and output resolution + + Parameters + ---------- + input_res : str | None + Input resolution. e.g. 30km or 60min + exo_res : str | None + Exogenous data resolution. e.g. 1km or 5min + + Returns + ------- + agg_factor : int + Aggregation factor for exogenous data extraction. + """ + ires_num = (None if input_res is None + else re.search(r'\d+', input_res).group(0)) + eres_num = (None if exo_res is None + else re.search(r'\d+', exo_res).group(0)) + i_units = (None if input_res is None + else input_res.replace(ires_num, '')) + e_units = None if exo_res is None else exo_res.replace(eres_num, '') + msg = 'Received conflicting units for input and exo resolution' + if e_units is not None: + assert i_units == e_units, msg + if ires_num is not None and eres_num is not None: + agg_factor = (int(ires_num) / int(eres_num)) ** 2 + else: + agg_factor = None + return agg_factor + + def _prep_exo_extract_single_step(self, step_dict, exo_resolution): + """Compute agg and enhancement factors for exogenous data extraction + using exo_kwargs single model step. These factors are computed using + exo_resolution and the input/output resolution of each model step. + + Parameters + ---------- + step_dict : dict + Model step dictionary. e.g. {'model': 0, 'concat_type': 'input'} + exo_resolution : dict + Resolution of exogenous data. e.g. {'temporal': 15min, 'spatial': + '1km'} + + Returns + ------- + updated_step_dict : dict + Same as input dictionary with s_agg_factor, t_agg_factor, + s_enhance, t_enhance added + """ + model_step = step_dict['model'] + exo_res_t = exo_resolution['temporal'] + exo_res_s = exo_resolution['spatial'] + s_enhance = None + t_enhance = None + s_agg_factor = None + t_agg_factor = None + models = getattr(self.model, 'models', [self.model]) + msg = (f'Model index from exo_kwargs ({model_step} exceeds number ' + f'of model steps ({len(self.model.models)})') + assert len(models) > model_step, msg + model = models[model_step] + input_res_t = model.input_resolution['temporal'] + input_res_s = model.input_resolution['spatial'] + output_res_t = model.output_resolution['temporal'] + output_res_s = model.output_resolution['spatial'] + concat_type = step_dict.get('concat_type', None) + + if concat_type.lower() == 'input': + if model_step == 0: + s_enhance = 1 + t_enhance = 1 + else: + s_enhance = self.strategy.s_enhancements[model_step - 1] + t_enhance = self.strategy.t_enhancements[model_step - 1] + s_agg_factor = self.get_agg_factor(input_res_s, exo_res_s) + t_agg_factor = self.get_agg_factor(input_res_t, exo_res_t) + + elif concat_type.lower() in ('output', 'layer'): + s_enhance = self.strategy.s_enhancements[model_step] + t_enhance = self.strategy.t_enhancements[model_step] + s_agg_factor = self.get_agg_factor(output_res_s, exo_res_s) + t_agg_factor = self.get_agg_factor(output_res_t, exo_res_t) + + else: + msg = 'Received exo_kwargs entry without valid concat_type' + raise OSError(msg) + + updated_dict = step_dict.copy() + updated_dict.update({'s_enhance': s_enhance, 't_enhance': t_enhance, + 's_agg_factor': s_agg_factor, + 't_agg_factor': t_agg_factor}) + return updated_dict + + def _prep_exo_extract_kwargs(self, exo_kwargs): + """Compute agg and enhancement factors for all model steps for all + features. + + Parameters + ---------- + exo_kwargs: dict + Full exo_kwargs dictionary with all feature entries. + e.g. {'topography': {'exo_resolution': {'spatial': '1km', + 'temporal': None}, 'steps': [{'model': 0, 'concat_type': 'input'}, + {'model': 0, 'concat_type': 'layer'}]}} + + Returns + ------- + updated_dict : dict + Same as input dictionary with s_agg_factor, t_agg_factor, + s_enhance, t_enhance added to each step entry for all features + """ + if exo_kwargs: + for feature, v in exo_kwargs.items(): + exo_resolution = v['exo_resolution'] + for i, step in enumerate(v['steps']): + out = self._prep_exo_extract_single_step( + step, exo_resolution) + exo_kwargs[feature]['steps'][i] = out + return exo_kwargs + def load_exo_data(self): """Extract exogenous data for each exo feature and store data in dictionary with key for each exo feature @@ -1129,35 +1249,36 @@ def load_exo_data(self): Returns ------- exo_data : dict - Dictionary of data arrays with keys for each exogeneous feature + Same as exo_kwargs dictionary with data arrays added to a 'data' + key for each feature """ - exo_data_dict = {} + exo_data = None if self.exo_kwargs: - self.features = [ - f for f in self.features if f not in self.exo_features - ] + exo_data = self.exo_kwargs.copy() + self.features = [f for f in self.features + if f not in self.exo_features] for feature in self.exo_features: exo_kwargs = copy.deepcopy(self.exo_kwargs[feature]) exo_kwargs['feature'] = feature exo_kwargs['target'] = self.target exo_kwargs['shape'] = self.shape exo_kwargs['temporal_slice'] = self.ti_pad_slice - exo_data_dict[feature] = ExogenousDataHandler( - **exo_kwargs).data - shapes = [ - None if d is None else d.shape - for d in exo_data_dict[feature] - ] + steps = exo_kwargs.pop('steps') + exo_kwargs['s_agg_factors'] = [step['s_agg_factor'] + for step in steps] + exo_kwargs['t_agg_factors'] = [step['t_agg_factor'] + for step in steps] + exo_kwargs['s_enhancements'] = [step['s_enhance'] + for step in steps] + exo_kwargs['t_enhancements'] = [step['t_enhance'] + for step in steps] + data = ExogenousDataHandler(**exo_kwargs).data + for i, _ in enumerate(steps): + exo_data[feature]['steps']['data'] = data[i] + shapes = [None if d is None else d.shape for d in data] logger.info( 'Got exogenous_data of length {} with shapes: {}'.format( - len(exo_data_dict[feature]), shapes)) - exo_data = [] - for i in range(len(exo_data_dict[self.exo_features[0]])): - exo_data_f = [] - for feature in self.exo_features: - exo_data_f.append(exo_data_dict[feature][i]) - exo_data.append( - np.vstack([arr for arr in exo_data_f if arr is not None])) + len(data), shapes)) return exo_data def update_input_handler_kwargs(self, strategy): diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index 7962c3c8e..aaf6878a7 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -6,12 +6,12 @@ from typing import ClassVar from warnings import warn -import numpy as np - -from sup3r.preprocessing.data_handling.exo_extraction import (SzaExtractH5, - SzaExtractNC, - TopoExtractH5, - TopoExtractNC) +from sup3r.preprocessing.data_handling.exo_extraction import ( + SzaExtractH5, + SzaExtractNC, + TopoExtractH5, + TopoExtractNC, +) from sup3r.utilities.utilities import get_source_type logger = logging.getLogger(__name__) @@ -48,7 +48,6 @@ def __init__(self, max_delta=20, input_handler=None, exo_handler=None, - exo_steps=None, cache_data=True): """ Parameters @@ -122,12 +121,6 @@ def __init__(self, feature='topography' this should be either TopoExtractH5 or TopoExtractNC. If None the correct handler will be guessed based on file type and time series properties. - exo_steps : list - List of model step indices for which exogenous data is required. - e.g. If we have two model steps which take exo data and one which - does not exo_steps = [0, 1]. The length of this list should be - equal to the number of s/t_enhancements and the number of - s/t_agg_factors cache_data : bool Flag to cache exogeneous data in ./exo_cache/ this can speed up forward passes with large temporal extents @@ -149,7 +142,6 @@ def __init__(self, self.input_handler = input_handler self.cache_data = cache_data self.data = [] - exo_steps = exo_steps or np.arange(len(self.s_enhancements)) if self.s_enhancements[0] != 1: msg = ('s_enhancements typically starts with 1 so the first ' @@ -179,27 +171,22 @@ def __init__(self, 'step. If the step is temporal enhancement then s_enhance=1') assert not any(s is None for s in self.s_enhancements), msg - for i in range(len(self.s_enhancements)): - s_enhance = np.product(self.s_enhancements[:i + 1]) - t_enhance = np.product(self.t_enhancements[:i + 1]) + for i, _ in enumerate(self.s_enhancements): + s_enhance = self.s_enhancements[i] + t_enhance = self.t_enhancements[i] s_agg_factor = self.s_agg_factors[i] t_agg_factor = self.t_agg_factors[i] - fdata = [] - if i in exo_steps: - if feature in list(self.AVAILABLE_HANDLERS): - data = self.get_exo_data(feature=feature, - s_enhance=s_enhance, - t_enhance=t_enhance, - s_agg_factor=s_agg_factor, - t_agg_factor=t_agg_factor) - fdata.append(data) - else: - msg = (f"Can only extract {list(self.AVAILABLE_HANDLERS)}." - f" Received {feature}.") - raise NotImplementedError(msg) - self.data.append(np.stack(fdata, axis=-1)) + if feature in list(self.AVAILABLE_HANDLERS): + data = self.get_exo_data(feature=feature, + s_enhance=s_enhance, + t_enhance=t_enhance, + s_agg_factor=s_agg_factor, + t_agg_factor=t_agg_factor) + self.data.append(data) else: - self.data.append(None) + msg = (f"Can only extract {list(self.AVAILABLE_HANDLERS)}." + f" Received {feature}.") + raise NotImplementedError(msg) def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, t_agg_factor): From 21ed39137479d9754580c32bad31bb5448911cd3 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 20 Sep 2023 11:35:33 -0600 Subject: [PATCH 05/15] _reshape_data_chunk in forward_pass.py modified to use new exo_data dictionary --- sup3r/pipeline/forward_pass.py | 63 ++++++++++++++++++++-------------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index c9571c664..6c19b1b43 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -1198,12 +1198,14 @@ def _prep_exo_extract_single_step(self, step_dict, exo_resolution): t_enhance = self.strategy.t_enhancements[model_step - 1] s_agg_factor = self.get_agg_factor(input_res_s, exo_res_s) t_agg_factor = self.get_agg_factor(input_res_t, exo_res_t) + resolution = {'spatial': input_res_s, 'temporal': input_res_t} elif concat_type.lower() in ('output', 'layer'): s_enhance = self.strategy.s_enhancements[model_step] t_enhance = self.strategy.t_enhancements[model_step] s_agg_factor = self.get_agg_factor(output_res_s, exo_res_s) t_agg_factor = self.get_agg_factor(output_res_t, exo_res_t) + resolution = {'spatial': output_res_s, 'temporal': output_res_t} else: msg = 'Received exo_kwargs entry without valid concat_type' @@ -1212,7 +1214,8 @@ def _prep_exo_extract_single_step(self, step_dict, exo_resolution): updated_dict = step_dict.copy() updated_dict.update({'s_enhance': s_enhance, 't_enhance': t_enhance, 's_agg_factor': s_agg_factor, - 't_agg_factor': t_agg_factor}) + 't_agg_factor': t_agg_factor, + 'resolution': resolution}) return updated_dict def _prep_exo_extract_kwargs(self, exo_kwargs): @@ -1694,8 +1697,7 @@ def _run_generator(cls, model_class=None, s_enhance=None, t_enhance=None, - exo_data=None, - ): + exo_data=None): """Run forward pass of the generator on smallest data chunk. Each chunk has a maximum shape given by self.strategy.fwp_chunk_shape. @@ -1729,10 +1731,15 @@ def _run_generator(cls, Factor by which to enhance temporal resolution s_enhance : int Factor by which to enhance spatial resolution - exo_data : list | None - List of arrays of exogenous data for each model step. - If there are two spatial enhancement steps this is a list of length - 3 with arrays for each intermediate spatial resolution. + exo_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. e.g. + {'topography': {'steps': [ + {'concat_type': 'input', 'model': 0, 'data': ..., + 'resolution': ...}, + {'concat_type': 'layer', 'model': 0, 'data': ..., + 'resolution': ...}]}} Returns ------- @@ -1790,9 +1797,15 @@ def _reshape_data_chunk(model, data_chunk, exo_data): data_chunk : np.ndarray Low resolution data for a single spatiotemporal chunk that is going to be passed to the model generate function. - exo_data : list | None - Optional exogenous data which can be a list of arrays of exogenous - inputs to complement data_chunk + exo_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. e.g. + {'topography': {'steps': [ + {'concat_type': 'input', 'model': 0, 'data': ..., + 'resolution': ...}, + {'concat_type': 'layer', 'model': 0, 'data': ..., + 'resolution': ...}]}} Returns ------- @@ -1801,27 +1814,27 @@ def _reshape_data_chunk(model, data_chunk, exo_data): features) if the model is a spatial-first model or (n_obs, spatial_1, spatial_2, temporal, features) if the model is spatiotemporal - exo_data : list | None - Same reshaping procedure as for data_chunk + exo_data : dict | None + Same reshaping procedure as for data_chunk applied to + exo_data[feature]['steps'][...]['data'] i_lr_t : int Axis index for the low-resolution temporal dimension i_lr_s : int Axis index for the low-resolution spatial_1 dimension """ - current_model = None if exo_data is not None: - for i, arr in enumerate(exo_data): - if arr is not None: - if not hasattr(model, 'models'): - current_model = model - elif i < len(model.models): - current_model = model.models[i] - - if current_model is not None: - if current_model.input_dims == 4: - exo_data[i] = np.transpose(arr, axes=(2, 0, 1, 3)) - else: - exo_data[i] = np.expand_dims(arr, axis=0) + for feature in exo_data: + for i, entry in enumerate(exo_data[feature]['steps']): + models = getattr(model, 'models', [model]) + msg = (f'model index ({entry["model"]}) for exo step {i} ' + 'exceeds the number of model steps') + assert entry['model'] < len(models), msg + current_model = models[entry['model']] + if current_model.input_dims == 4: + out = np.transpose(entry['data'], axes=(2, 0, 1, 3)) + else: + out = np.expand_dims(entry['data'], axis=0) + exo_data[feature]['steps'][i]['data'] = out if model.input_dims == 4: i_lr_t = 0 From 509b1c12d9236fdbc6a8e8a2be89eec7e4c1fb0e Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 20 Sep 2023 15:28:15 -0600 Subject: [PATCH 06/15] WindGan -> MultiExoGan. AbstractWindInterface -> AbstractExoInterface. Both generalized for multiple exogenous features --- sup3r/models/abstract.py | 233 +++++++++++++++-------- sup3r/models/{wind.py => multi_exo.py} | 58 +++--- sup3r/models/multi_step.py | 178 ++++++++++++++--- sup3r/models/wind_conditional_moments.py | 8 +- sup3r/pipeline/forward_pass.py | 24 +-- 5 files changed, 349 insertions(+), 152 deletions(-) rename sup3r/models/{wind.py => multi_exo.py} (67%) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 71b351e1d..99470c8b3 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -1012,15 +1012,14 @@ def run_gradient_descent(self, # pylint: disable=E1101,W0201,E0203 -class AbstractWindInterface(ABC): +class AbstractExoInterface(ABC): """ Abstract class to define the required training interface - for Sup3r wind model subclasses + for Sup3r model subclasses with exogenous features """ # pylint: disable=E0211 - @staticmethod - def set_model_params(**kwargs): + def set_model_params(self, **kwargs): """Set parameters used for training the model Parameters @@ -1034,21 +1033,25 @@ def set_model_params(**kwargs): Returns ------- kwargs : dict - Same as input but with topography removed from "output_features", - this is because topography is concatenated mid-network in the - WindGan generators and is not an output feature but is required in - the hi-res training set. + Same as input but with exogenous features removed from + "output_features", this is because the exo features are + concatenated mid-network in the ExoGan generators and are not + output features but are required in the hi-res training set. """ output_features = kwargs['output_features'] - msg = ('Last output feature from the data handler must be topography ' - 'to train the Wind model, but received output features: {}'. + msg = (f'Last {len(self.exogenous_features)} output features from the ' + f'data handler must be {self.exogenous_features} ' + 'to train the Exo model, but received output features: {}'. format(output_features)) - assert output_features[-1] == 'topography', msg - output_features.remove('topography') + check = (output_features[-len(self.exogenous_features)] + == self.exogenous_features) + assert check, msg + for f in self.exogenous_features: + output_features.remove(f) kwargs['output_features'] = output_features return kwargs - def _reshape_norm_topo(self, hi_res, hi_res_topo, norm_in=True): + def _reshape_norm_exo(self, hi_res, hi_res_exo, exo_name, norm_in=True): """Reshape the hi_res_topo to match the hi_res tensor (if necessary) and normalize (if requested). @@ -1059,16 +1062,18 @@ def _reshape_norm_topo(self, hi_res, hi_res_topo, norm_in=True): array with shape: (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) - hi_res_topo : np.ndarray + hi_res_exo : np.ndarray This should be a 4D array for spatial enhancement model or 5D array for a spatiotemporal enhancement model (obs, spatial_1, spatial_2, (temporal), features) corresponding to the high-resolution - spatial_1 and spatial_2. This data will be input to the custom - phygnn Sup3rAdder or Sup3rConcat layer if found in the generative - network. This differs from the exogenous_data input in that - exogenous_data always matches the low-res input. For this function, - hi_res_topo can also be a 2D array (spatial_1, spatial_2). Note - that this input gets normalized if norm_in=True. + spatial_1, spatial_2, temporal. This data will be input to the + custom phygnn Sup3rAdder or Sup3rConcat layer if found in the + generative network. This differs from the exogenous_data input in + that exogenous_data always matches the low-res input. For this + function, hi_res_exo can also be a 3D array (spatial_1, spatial_2, + 1). Note that this input gets normalized if norm_in=True. + exo_name : str + Name of feature corresponding to hi_res_exo data. norm_in : bool Flag to normalize low_res input data if the self._means, self._stdevs attributes are available. The generator should always @@ -1081,35 +1086,96 @@ def _reshape_norm_topo(self, hi_res, hi_res_topo, norm_in=True): Same as input but reshaped to match hi_res (if necessary) and normalized (if requested) """ - if hi_res_topo is None: - return hi_res_topo + if hi_res_exo is None: + return hi_res_exo if norm_in and self._means is not None: - idf = self.training_features.index('topography') - hi_res_topo = ((hi_res_topo.copy() - self._means[idf]) - / self._stdevs[idf]) - - if len(hi_res_topo.shape) > 2: - slicer = [0] * len(hi_res_topo.shape) - slicer[1] = slice(None) - slicer[2] = slice(None) - hi_res_topo = hi_res_topo[tuple(slicer)] - - if len(hi_res.shape) == 4: - hi_res_topo = np.expand_dims(hi_res_topo, axis=(0, 3)) - hi_res_topo = np.repeat(hi_res_topo, hi_res.shape[0], axis=0) - elif len(hi_res.shape) == 5: - hi_res_topo = np.expand_dims(hi_res_topo, axis=(0, 3, 4)) - hi_res_topo = np.repeat(hi_res_topo, hi_res.shape[0], axis=0) - hi_res_topo = np.repeat(hi_res_topo, hi_res.shape[3], axis=3) - - if len(hi_res_topo.shape) != len(hi_res.shape): - msg = ('hi_res and hi_res_topo arrays are not of the same rank: ' - '{} and {}'.format(hi_res.shape, hi_res_topo.shape)) + idf = self.training_features.index(exo_name) + hi_res_exo = ((hi_res_exo.copy() - self._means[idf]) + / self._stdevs[idf]) + + if len(hi_res_exo.shape) == 3: + hi_res_exo = np.expand_dims(hi_res_exo, axis=0) + hi_res_exo = np.repeat(hi_res_exo, hi_res.shape[0], axis=0) + if len(hi_res_exo.shape) == 4 and len(hi_res.shape) == 5: + hi_res_exo = np.expand_dims(hi_res_exo, axis=3) + hi_res_exo = np.repeat(hi_res_exo, hi_res.shape[3], axis=3) + + if len(hi_res_exo.shape) != len(hi_res.shape): + msg = ('hi_res and hi_res_exo arrays are not of the same rank: ' + '{} and {}'.format(hi_res.shape, hi_res_exo.shape)) logger.error(msg) raise RuntimeError(msg) - return hi_res_topo + return hi_res_exo + + def _combine_input(self, low_res, exogenous_data=None): + """Combine exogenous_data at input resolution with low_res data + + Parameters + ---------- + low_res : np.ndarray + Low-resolution input data, usually a 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. This doesn't have to include the 'model' key since + this data is for a single step model. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'data': ..., 'resolution': ...}, + {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}} + + Returns + ------- + low_res : np.ndarray + Low-resolution input data combined with exogenous_data, usually a + 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + """ + for feature in self.exogenous_features: + msg = f'Did not find {feature} in exogenous_data' + assert feature in exogenous_data, msg + for step in exogenous_data['steps']: + if step['combine_type'] == 'input': + low_res = np.concatenate(low_res, step['data'], axis=-1) + return low_res + + def _combine_output(self, hi_res, exogenous_data=None): + """Combine exogenous_data at input resolution with low_res data + + Parameters + ---------- + hi_res : np.ndarray + High-resolution output data, usually a 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. This doesn't have to include the 'model' key since + this data is for a single step model. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'data': ..., 'resolution': ...}, + {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}} + + Returns + ------- + hi_res : np.ndarray + High-resolution output data combined with exogenous_data, usually a + 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + """ + for feature in self.exogenous_features: + msg = f'Did not find {feature} in exogenous_data' + assert feature in exogenous_data, msg + for step in exogenous_data['steps']: + if step['combine_type'] == 'output': + hi_res = np.concatenate(hi_res, step['data'], axis=-1) + return hi_res def generate(self, low_res, @@ -1133,12 +1199,14 @@ def generate(self, un_norm_out : bool Flag to un-normalize synthetically generated output data to physical units - exogenous_data : ndarray | list | None - Exogenous data for topography inputs. The first entry in this list - (or only entry) is a low-resolution topography array that can be - concatenated to the low_res input array. The second entry is - high-resolution topography (either 2D or 4D/5D depending on if - spatial or spatiotemporal super res). + exogenous_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. This doesn't have to include the 'model' key since + this data is for a single step model. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'data': ..., 'resolution': ...}, + {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}} Returns ------- @@ -1148,18 +1216,7 @@ def generate(self, (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - low_res_topo = None - hi_res_topo = None - if isinstance(exogenous_data, np.ndarray): - low_res_topo = exogenous_data - elif isinstance(exogenous_data, (list, tuple)): - low_res_topo = exogenous_data[0] - if len(exogenous_data) > 1: - hi_res_topo = exogenous_data[1] - - exo_check = (low_res is None or not self._needs_lr_exo(low_res)) - low_res = (low_res if exo_check else np.concatenate( - (low_res, low_res_topo), axis=-1)) + low_res = self._combine_input(low_res, exogenous_data) if norm_in and self._means is not None: low_res = self.norm_input(low_res) @@ -1167,12 +1224,15 @@ def generate(self, hi_res = self.generator.layers[0](low_res) for i, layer in enumerate(self.generator.layers[1:]): try: - if (isinstance(layer, (Sup3rAdder, Sup3rConcat)) - and hi_res_topo is not None): - hi_res_topo = self._reshape_norm_topo(hi_res, - hi_res_topo, - norm_in=norm_in) - hi_res = layer(hi_res, hi_res_topo) + if isinstance(layer, (Sup3rAdder, Sup3rConcat)): + exo_name = layer.name + steps = exogenous_data[exo_name]['steps'] + hi_res_exo = [step['data'] for step in steps + if step['combine_type'] == 'layer'] + hi_res_exo = self._reshape_norm_exo(hi_res, + hi_res_exo, + norm_in=norm_in) + hi_res = layer(hi_res, hi_res_exo) else: hi_res = layer(hi_res) except Exception as e: @@ -1186,10 +1246,12 @@ def generate(self, if un_norm_out and self._means is not None: hi_res = self.un_norm_output(hi_res) + hi_res = self._combine_output(hi_res, exogenous_data) + return hi_res @tf.function - def _tf_generate(self, low_res, hi_res_topo): + def _tf_generate(self, low_res, hi_res_exo): """Use the generator model to generate high res data from los res input Parameters @@ -1197,14 +1259,17 @@ def _tf_generate(self, low_res, hi_res_topo): low_res : np.ndarray Real low-resolution data. The generator should always received normalized data with mean=0 stdev=1. - hi_res_topo : np.ndarray - This should be a 4D array for spatial enhancement model or 5D array - for a spatiotemporal enhancement model (obs, spatial_1, spatial_2, - (temporal), features) corresponding to the high-resolution - spatial_1 and spatial_2. This data will be input to the custom - phygnn Sup3rAdder or Sup3rConcat layer if found in the generative - network. This differs from the exogenous_data input in that - exogenous_data always matches the low-res input. + hi_res_exo : dict + Dictionary of exogenous_data with same resolution as high_res data + e.g. {'topography': np.array} + The arrays in this dictionary should be a 4D array for spatial + enhancement model or 5D array for a spatiotemporal enhancement + model (obs, spatial_1, spatial_2, (temporal), features) + corresponding to the high-resolution spatial_1 and spatial_2. This + data will be input to the custom phygnn Sup3rAdder or Sup3rConcat + layer if found in the generative network. This differs from the + exogenous_data input in that exogenous_data always matches the + low-res input. Returns ------- @@ -1214,10 +1279,9 @@ def _tf_generate(self, low_res, hi_res_topo): hi_res = self.generator.layers[0](low_res) for i, layer in enumerate(self.generator.layers[1:]): try: - if (isinstance(layer, (Sup3rAdder, Sup3rConcat)) - and hi_res_topo is not None): - hi_res = layer(hi_res, hi_res_topo) - + if isinstance(layer, (Sup3rAdder, Sup3rConcat)): + hr_exo = hi_res_exo[layer.name] + hi_res = layer(hi_res, hr_exo) else: hi_res = layer(hi_res) except Exception as e: @@ -1268,13 +1332,16 @@ def get_single_grad(self, Namespace of the breakdown of loss components """ - hi_res_topo = hi_res_true[..., -1:] + hi_res_exo = {} + for feature in self.exogenous_features: + f_idx = self.training_features.index(feature) + hi_res_exo[feature] = hi_res_true[..., f_idx: f_idx + 1] with tf.device(device_name): with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(training_weights) - hi_res_gen = self._tf_generate(low_res, hi_res_topo) + hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss_out = self.calc_loss(hi_res_true, hi_res_gen, **calc_loss_kwargs) loss, loss_details = loss_out diff --git a/sup3r/models/wind.py b/sup3r/models/multi_exo.py similarity index 67% rename from sup3r/models/wind.py rename to sup3r/models/multi_exo.py index 9aa0c2f9a..a393521c7 100644 --- a/sup3r/models/wind.py +++ b/sup3r/models/multi_exo.py @@ -1,30 +1,31 @@ # -*- coding: utf-8 -*- """Wind super resolution GAN with handling of low and high res topography inputs.""" -import numpy as np import logging + +import numpy as np import tensorflow as tf +from sup3r.models.abstract import AbstractExoInterface from sup3r.models.base import Sup3rGan -from sup3r.models.abstract import AbstractWindInterface - logger = logging.getLogger(__name__) -class WindGan(AbstractWindInterface, Sup3rGan): - """Wind super resolution GAN with handling of low and high res topography - inputs. +class MultiExoGan(AbstractExoInterface, Sup3rGan): + """Super resolution GAN with handling of low and high res exogenous feature + inputs. This exogenous data is commonly just topography. Modifications to standard Sup3rGan: - - Hi res topography is expected as the last feature channel in the true - data in the true batch observation. This topo channel is appended to - the generated output so the discriminator can look at the wind fields - compared to the associated hi res topo. + - Hi res exogenous features are expected as the last + len(self.exogenous_features) channels in the true data in the true + batch observation. These channels are appended to the generated + output so the discriminator can look at the super resolved fields + compared to the associated hi res exogenous feature data. - If a custom Sup3rAdder or Sup3rConcat layer (from phygnn) is present - in the network, the hi-res topography will be added or concatenated - to the data at that point in the network during either training or - the forward pass. + in the network, the hi-res exogenous feature matching layer.name will + be added or concatenated to the data at that point in the network + during either training or the forward pass. """ def init_weights(self, lr_shape, hr_shape, device=None): @@ -53,11 +54,14 @@ def init_weights(self, lr_shape, hr_shape, device=None): low_res = np.ones(lr_shape).astype(np.float32) hi_res = np.ones(hr_shape).astype(np.float32) - hr_topo_shape = hr_shape[:-1] + (1,) - hr_topo = np.ones(hr_topo_shape).astype(np.float32) + hr_exo_shape = hr_shape[:-1] + (1,) + hr_exo = np.ones(hr_exo_shape).astype(np.float32) with tf.device(device): - _ = self._tf_generate(low_res, hr_topo) + hr_exo_data = {} + for feature in self.exogenous_features: + hr_exo_data[feature] = hr_exo + _ = self._tf_generate(low_res, hr_exo_data) _ = self._tf_discriminate(hi_res) def set_model_params(self, **kwargs): @@ -66,10 +70,11 @@ def set_model_params(self, **kwargs): Parameters ---------- kwargs : dict - Keyword arguments including 'training_features', 'output_features', - 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' + Keyword arguments including 'input_resolution', + 'training_features', 'output_features', 'smoothed_features', + 's_enhance', 't_enhance', 'smoothing' """ - AbstractWindInterface.set_model_params(**kwargs) + AbstractExoInterface.set_model_params(self, **kwargs) Sup3rGan.set_model_params(self, **kwargs) @tf.function @@ -96,9 +101,10 @@ def calc_loss(self, hi_res_true, hi_res_gen, **kwargs): loss_details : dict Namespace of the breakdown of loss components """ - - # append the true topography to the generated synthetic wind data - hi_res_gen = tf.concat((hi_res_gen, hi_res_true[..., -1:]), axis=-1) + for feature in self.exogenous_features: + f_idx = self.training_features.index(feature) + exo_data = hi_res_true[..., f_idx: f_idx + 1] + hi_res_gen = tf.concat((hi_res_gen, exo_data), axis=-1) return super().calc_loss(hi_res_true, hi_res_gen, **kwargs) @@ -123,8 +129,12 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - high_res_gen = self._tf_generate(val_batch.low_res, - val_batch.high_res[..., -1:]) + val_exo_data = {} + for feature in self.exogenous_features: + f_idx = self.training_features.index(feature) + exo_data = val_batch.high_res[..., f_idx: f_idx + 1] + val_exo_data[feature] = exo_data + high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data) _, v_loss_details = self.calc_loss( val_batch.high_res, high_res_gen, weight_gen_advers=weight_gen_advers, diff --git a/sup3r/models/multi_step.py b/sup3r/models/multi_step.py index 3a0d82ff5..e9d235761 100644 --- a/sup3r/models/multi_step.py +++ b/sup3r/models/multi_step.py @@ -111,6 +111,43 @@ def seed(s=0): """ Sup3rGan.seed(s=s) + def _get_model_step_exo(self, model_step, exogenous_data=None): + """Get the exogenous data for the given model_step from the full + exogenous data dictionary + + Parameters + ---------- + model_step : int + Index of the model to get exogenous data for. + exogenous_data : dict + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'model': 0, 'data': ..., + 'resolution': ...}, + {'combine_type': 'layer', 'model': 0, 'data': ..., + 'resolution': ...}]}} + Each array in in 'data' key has 3D or 4D shape: + (spatial_1, spatial_2, 1) + (spatial_1, spatial_2, n_temporal, 1) + + Returns + ------- + exogenous_data : dict + Same as input dictionary but with only entries with 'model': + model_step + """ + model_step_exo = None + if exogenous_data is not None: + model_step_exo = {} + for feature in exogenous_data: + steps = [step for step in exogenous_data[feature]['steps'] + if step['model'] == model_step] + if steps: + model_step_exo[feature] = {'steps': steps} + return model_step_exo + def generate(self, low_res, norm_in=True, un_norm_out=True, exogenous_data=None): """Use the generator model to generate high res data from low res @@ -129,16 +166,18 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, un_norm_out : bool Flag to un-normalize synthetically generated output data to physical units - exogenous_data : list - List of arrays of exogenous_data with length equal to the - number of model steps. e.g. If we want to include topography as - an exogenous feature in a spatial + temporal multistep model then - we need to provide a list of length=2 with topography at the low - spatial resolution and at the high resolution. If we include more - than one exogenous feature the ordering must be consistent. - Each array in the list has 3D or 4D shape: - (spatial_1, spatial_2, n_features) - (spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'model': 0, 'data': ..., + 'resolution': ...}, + {'combine_type': 'layer', 'model': 0, 'data': ..., + 'resolution': ...}]}} + Each array in in 'data' key has 3D or 4D shape: + (spatial_1, spatial_2, 1) + (spatial_1, spatial_2, n_temporal, 1) Returns ------- @@ -148,10 +187,6 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - - exo_data = ([None] * len(self.models) if not exogenous_data - else exogenous_data) - hi_res = low_res.copy() for i, model in enumerate(self.models): @@ -161,9 +196,7 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, if (i + 1 == len(self.models) and not un_norm_out) else True) - i_exo_data = exo_data[i] - if model.needs_hr_exo: - i_exo_data = [exo_data[i], exo_data[i + 1]] + i_exo_data = self._get_model_step_exo(i, exogenous_data) try: logger.debug('Data input to model #{} of {} has shape {}' @@ -361,6 +394,52 @@ def output_features(self): interpolation model in this SpatialThenTemporalGan model outputs.""" return self.temporal_models.output_features + def _split_exo_spatial_temporal(self, exogenous_data=None): + """Split exogenous_data into spatial_exo and temporal_exo eacho of + which are then passed through MultiStepGan models + + Parameters + ---------- + exogenous_data : dict + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'model': 0, 'data': ..., + 'resolution': ...}, + {'combine_type': 'layer', 'model': 0, 'data': ..., + 'resolution': ...}]}} + Each array in in 'data' key has 3D or 4D shape: + (spatial_1, spatial_2, 1) + (spatial_1, spatial_2, n_temporal, 1) + + Returns + ------- + spatial_exo : dict + Same as input dictionary but with only entries with 'model': + model_step where model_step corresponds to a spatial model step + temporal_exo : dict + Same as input dictionary but with only entries with 'model': + model_step where model_step corresponds to a temporal model step + """ + spatial_exo = None + temporal_exo = None + if exogenous_data is not None: + spatial_exo = {} + for feature in exogenous_data: + steps = [step for step in exogenous_data[feature]['steps'] + if step['model'] < len(self.spatial_models)] + if steps: + spatial_exo[feature] = {'steps': steps} + steps = [step for step in exogenous_data[feature]['steps'] + if step['model'] >= len(self.spatial_models)] + t_shift = len(self.spatial_models) + steps = [step.update({'model': step['model'] - t_shift}) + for step in steps] + if steps: + temporal_exo[feature] = {'steps': steps} + return spatial_exo, temporal_exo + def generate(self, low_res, norm_in=True, un_norm_out=True, exogenous_data=None): """Use the generator model to generate high res data from low res @@ -398,14 +477,11 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, """ logger.debug('Data input to the 1st step spatial-only ' 'enhancement has shape {}'.format(low_res.shape)) - t_exogenous = None - if exogenous_data is not None: - t_exogenous = exogenous_data[len(self.spatial_models):] - + s_exo, t_exo = self._split_exo_spatial_temporal(exogenous_data) try: hi_res = self.spatial_models.generate( low_res, norm_in=norm_in, un_norm_out=True, - exogenous_data=exogenous_data) + exogenous_data=s_exo) except Exception as e: msg = ('Could not run the 1st step spatial-only GAN on input ' 'shape {}'.format(low_res.shape)) @@ -422,7 +498,7 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, try: hi_res = self.temporal_models.generate( hi_res, norm_in=True, un_norm_out=un_norm_out, - exogenous_data=t_exogenous) + exogenous_data=t_exo) except Exception as e: msg = ('Could not run the 2nd step (spatio)temporal GAN on input ' 'shape {}'.format(low_res.shape)) @@ -488,6 +564,52 @@ def output_features(self): interpolation model in this TemporalThenSpatialGan model outputs.""" return self.spatial_models.output_features + def _split_exo_temporal_spatial(self, exogenous_data=None): + """Split exogenous_data into spatial_exo and temporal_exo eacho of + which are then passed through MultiStepGan models + + Parameters + ---------- + exogenous_data : dict + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'model': 0, 'data': ..., + 'resolution': ...}, + {'combine_type': 'layer', 'model': 0, 'data': ..., + 'resolution': ...}]}} + Each array in in 'data' key has 3D or 4D shape: + (spatial_1, spatial_2, 1) + (spatial_1, spatial_2, n_temporal, 1) + + Returns + ------- + temporal_exo : dict + Same as input dictionary but with only entries with 'model': + model_step where model_step corresponds to a temporal model step + spatial_exo : dict + Same as input dictionary but with only entries with 'model': + model_step where model_step corresponds to a spatial model step + """ + spatial_exo = None + temporal_exo = None + if exogenous_data is not None: + temporal_exo = {} + for feature in exogenous_data: + steps = [step for step in exogenous_data[feature]['steps'] + if step['model'] < len(self.temporal_models)] + if steps: + temporal_exo[feature] = {'steps': steps} + steps = [step for step in exogenous_data[feature]['steps'] + if step['model'] >= len(self.temporal_models)] + s_shift = len(self.temporal_models) + steps = [step.update({'model': step['model'] - s_shift}) + for step in steps] + if steps: + spatial_exo[feature] = {'steps': steps} + return temporal_exo, spatial_exo + def generate(self, low_res, norm_in=True, un_norm_out=True, exogenous_data=None): """Use the generator model to generate high res data from low res @@ -525,16 +647,14 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, """ logger.debug('Data input to the 1st step (spatio)temporal ' 'enhancement has shape {}'.format(low_res.shape)) - s_exogenous = None - if exogenous_data is not None: - s_exogenous = exogenous_data[len(self.temporal_models):] + t_exo, s_exo = self._split_exo_temporal_spatial(exogenous_data) assert low_res.shape[0] == 1, 'Low res input can only have 1 obs!' try: hi_res = self.temporal_models.generate( low_res, norm_in=norm_in, un_norm_out=True, - exogenous_data=exogenous_data) + exogenous_data=t_exo) except Exception as e: msg = ('Could not run the 1st step (spatio)temporal GAN on input ' 'shape {}'.format(low_res.shape)) @@ -550,7 +670,7 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, try: hi_res = self.spatial_models.generate( hi_res, norm_in=True, un_norm_out=un_norm_out, - exogenous_data=s_exogenous) + exogenous_data=s_exo) except Exception as e: msg = ('Could not run the 2nd step spatial GAN on input ' 'shape {}'.format(low_res.shape)) @@ -622,7 +742,7 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, logger.debug('Data input to the 1st step spatial-only ' 'enhancement has shape {}'.format(low_res.shape)) - msg = ('MultiStepSurfaceMetGan needs exogenous_data with two ' + msg = ('MultiStepSurfaceMetGan needs exogenous_data with two ' 'entries for low and high res topography inputs.') assert exogenous_data is not None, msg assert isinstance(exogenous_data, (list, tuple)), msg diff --git a/sup3r/models/wind_conditional_moments.py b/sup3r/models/wind_conditional_moments.py index cf3a3dc2c..0cbb99b75 100644 --- a/sup3r/models/wind_conditional_moments.py +++ b/sup3r/models/wind_conditional_moments.py @@ -2,16 +2,16 @@ """Wind conditional moment estimator with handling of low and high res topography inputs.""" import logging + import tensorflow as tf -from sup3r.models.abstract import AbstractWindInterface +from sup3r.models.abstract import AbstractExoInterface from sup3r.models.conditional_moments import Sup3rCondMom - logger = logging.getLogger(__name__) -class WindCondMom(AbstractWindInterface, Sup3rCondMom): +class WindCondMom(AbstractExoInterface, Sup3rCondMom): """Wind conditional moment estimator with handling of low and high res topography inputs. @@ -33,7 +33,7 @@ def set_model_params(self, **kwargs): Keyword arguments including 'training_features', 'output_features', 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' """ - AbstractWindInterface.set_model_params(**kwargs) + AbstractExoInterface.set_model_params(self, **kwargs) Sup3rCondMom.set_model_params(self, **kwargs) @tf.function diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 6c19b1b43..1d37ecedc 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -699,7 +699,7 @@ def __init__(self, be used in the model. e.g. {'topography': {'file_paths': 'path to input files', 'source_file': 'path to exo data', 'exo_resolution': {'spatial': '1km', 'temporal': None}, 'steps': [{'model': 0, - 'concat_type': 'input'}, {'model': 0, 'concat_type': 'layer'}]} + 'combine_type': 'input'}, {'model': 0, 'combine_type': 'layer'}]} bias_correct_method : str | None Optional bias correction function name that can be imported from the :mod:`sup3r.bias.bias_transforms` module. This will transform @@ -1160,7 +1160,7 @@ def _prep_exo_extract_single_step(self, step_dict, exo_resolution): Parameters ---------- step_dict : dict - Model step dictionary. e.g. {'model': 0, 'concat_type': 'input'} + Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} exo_resolution : dict Resolution of exogenous data. e.g. {'temporal': 15min, 'spatial': '1km'} @@ -1187,9 +1187,9 @@ def _prep_exo_extract_single_step(self, step_dict, exo_resolution): input_res_s = model.input_resolution['spatial'] output_res_t = model.output_resolution['temporal'] output_res_s = model.output_resolution['spatial'] - concat_type = step_dict.get('concat_type', None) + combine_type = step_dict.get('combine_type', None) - if concat_type.lower() == 'input': + if combine_type.lower() == 'input': if model_step == 0: s_enhance = 1 t_enhance = 1 @@ -1200,7 +1200,7 @@ def _prep_exo_extract_single_step(self, step_dict, exo_resolution): t_agg_factor = self.get_agg_factor(input_res_t, exo_res_t) resolution = {'spatial': input_res_s, 'temporal': input_res_t} - elif concat_type.lower() in ('output', 'layer'): + elif combine_type.lower() in ('output', 'layer'): s_enhance = self.strategy.s_enhancements[model_step] t_enhance = self.strategy.t_enhancements[model_step] s_agg_factor = self.get_agg_factor(output_res_s, exo_res_s) @@ -1208,7 +1208,7 @@ def _prep_exo_extract_single_step(self, step_dict, exo_resolution): resolution = {'spatial': output_res_s, 'temporal': output_res_t} else: - msg = 'Received exo_kwargs entry without valid concat_type' + msg = 'Received exo_kwargs entry without valid combine_type' raise OSError(msg) updated_dict = step_dict.copy() @@ -1227,8 +1227,8 @@ def _prep_exo_extract_kwargs(self, exo_kwargs): exo_kwargs: dict Full exo_kwargs dictionary with all feature entries. e.g. {'topography': {'exo_resolution': {'spatial': '1km', - 'temporal': None}, 'steps': [{'model': 0, 'concat_type': 'input'}, - {'model': 0, 'concat_type': 'layer'}]}} + 'temporal': None}, 'steps': [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}]}} Returns ------- @@ -1736,9 +1736,9 @@ def _run_generator(cls, whether features should be combined at input, a mid network layer, or with output. e.g. {'topography': {'steps': [ - {'concat_type': 'input', 'model': 0, 'data': ..., + {'combine_type': 'input', 'model': 0, 'data': ..., 'resolution': ...}, - {'concat_type': 'layer', 'model': 0, 'data': ..., + {'combine_type': 'layer', 'model': 0, 'data': ..., 'resolution': ...}]}} Returns @@ -1802,9 +1802,9 @@ def _reshape_data_chunk(model, data_chunk, exo_data): whether features should be combined at input, a mid network layer, or with output. e.g. {'topography': {'steps': [ - {'concat_type': 'input', 'model': 0, 'data': ..., + {'combine_type': 'input', 'model': 0, 'data': ..., 'resolution': ...}, - {'concat_type': 'layer', 'model': 0, 'data': ..., + {'combine_type': 'layer', 'model': 0, 'data': ..., 'resolution': ...}]}} Returns From 9e6b384500d0e1d137eac8a63f7f7fbb7704b18a Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 21 Sep 2023 14:57:44 -0600 Subject: [PATCH 07/15] general tf_generate and generate methods which can be used with or without exo data defined in abstract single model and removed from base. forward pass tests working. --- sup3r/models/__init__.py | 18 +- sup3r/models/abstract.py | 330 +++++++----------- sup3r/models/base.py | 89 +---- sup3r/models/conditional_moments.py | 12 +- sup3r/models/data_centric.py | 10 +- sup3r/pipeline/forward_pass.py | 92 ++--- .../data_handling/exo_extraction.py | 13 +- .../data_handling/exogenous_data_handling.py | 27 +- tests/forward_pass/test_forward_pass.py | 107 +++--- ..._train_wind.py => test_train_multi_exo.py} | 34 +- 10 files changed, 276 insertions(+), 456 deletions(-) rename tests/training/{test_train_wind.py => test_train_multi_exo.py} (95%) diff --git a/sup3r/models/__init__.py b/sup3r/models/__init__.py index 930440439..1c3e01a3f 100644 --- a/sup3r/models/__init__.py +++ b/sup3r/models/__init__.py @@ -1,15 +1,19 @@ # -*- coding: utf-8 -*- """Sup3r Model Software""" from .base import Sup3rGan -from .wind import WindGan -from .solar_cc import SolarCC +from .conditional_moments import Sup3rCondMom from .data_centric import Sup3rGanDC -from .multi_step import (MultiStepGan, - SpatialThenTemporalGan, TemporalThenSpatialGan, - MultiStepSurfaceMetGan, SolarMultiStepGan) -from .surface import SurfaceSpatialMetModel from .linear import LinearInterp -from .conditional_moments import Sup3rCondMom +from .multi_exo import MultiExoGan +from .multi_step import ( + MultiStepGan, + MultiStepSurfaceMetGan, + SolarMultiStepGan, + SpatialThenTemporalGan, + TemporalThenSpatialGan, +) +from .solar_cc import SolarCC +from .surface import SurfaceSpatialMetModel from .wind_conditional_moments import WindCondMom SPATIAL_FIRST_MODELS = (SpatialThenTemporalGan, diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 99470c8b3..974b6ec2c 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -67,27 +67,6 @@ def seed(s=0): """ CustomNetwork.seed(s=s) - @abstractmethod - def generate(self, low_res): - """Use the generator model to generate high res data from low res - input. This is the public generate function. - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - - Returns - ------- - hi_res : ndarray - Synthetically generated high-resolution data, usually a 4D or 5D - array with shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - @property def input_dims(self): """Get dimension of model generator input. This is usually 4D for @@ -136,42 +115,99 @@ def output_resolution(self): input_spatial = re.search(r'\d+', input_res['spatial']).group(0) output_temporal = int(self.t_enhance * input_temporal) output_spatial = int(self.s_enhance * input_spatial) - output_res['temporal'].replace(input_temporal, output_temporal) - output_res['spatial'].replace(input_spatial, output_spatial) + output_res['temporal'].replace(input_temporal, + str(output_temporal)) + output_res['spatial'].replace(input_spatial, + str(output_spatial)) return output_res - @property - def needs_hr_exo(self): - """Determine whether or not the sup3r model needs hi-res exogenous data + def _combine_input(self, low_res, exogenous_data=None): + """Combine exogenous_data at input resolution with low_res data + + Parameters + ---------- + low_res : np.ndarray + Low-resolution input data, usually a 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. This doesn't have to include the 'model' key since + this data is for a single step model. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'data': ..., 'resolution': ...}, + {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}} Returns ------- - needs_hr_exo : bool - True if the model requires high-resolution exogenous data, - typically because of the use of Sup3rAdder or Sup3rConcat layers. + low_res : np.ndarray + Low-resolution input data combined with exogenous_data, usually a + 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - # pylint: disable=E1101 - return (hasattr(self, '_gen') and any( - isinstance(layer, (Sup3rAdder, Sup3rConcat)) - for layer in self._gen.layers)) + check = (exogenous_data is not None + and low_res.shape[-1] < len(self.training_features)) + if check: + for i, (feature, entry) in enumerate(exogenous_data.items()): + f_idx = low_res.shape[-1] + i + training_feature = self.training_features[f_idx] + msg = ('The ordering of features in exogenous_data conflicts ' + 'with the ordering of training features. Received ' + f'{feature} instead of {training_feature}.') + assert feature == training_feature, msg + combine_types = [step['combine_type'] + for step in entry['steps']] + if 'input' in combine_types: + idx = combine_types.index('input') + low_res = np.concatenate((low_res, + entry['steps'][idx]['data']), + axis=-1) + return low_res - def _needs_lr_exo(self, low_res): - """Determine whether or not the sup3r model needs low-res exogenous - data + def _combine_output(self, hi_res, exogenous_data=None): + """Combine exogenous_data at input resolution with low_res data Parameters ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: + hi_res : np.ndarray + High-resolution output data, usually a 4D or 5D array of shape: (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) + exogenous_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. This doesn't have to include the 'model' key since + this data is for a single step model. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'data': ..., 'resolution': ...}, + {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}} Returns ------- - needs_lr_exo : bool - True if the model requires low-resolution exogenous data. + hi_res : np.ndarray + High-resolution output data combined with exogenous_data, usually a + 4D or 5D array of shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - return low_res.shape[-1] < len(self.training_features) + if exogenous_data is not None: + for i, (feature, entry) in enumerate(exogenous_data.items()): + f_idx = hi_res.shape[-1] + i + training_feature = self.training_features[f_idx] + msg = ('The ordering of features in exogenous_data conflicts ' + 'with the ordering of training features. Received ' + f'{feature} instead of {training_feature}.') + assert feature == training_feature, msg + combine_types = [step['combine_type'] + for step in entry['steps']] + if 'output' in combine_types: + idx = combine_types.index('output') + hi_res = np.concatenate((hi_res, + entry['steps'][idx]['data']), + axis=-1) + return hi_res @property def exogenous_features(self): @@ -868,59 +904,6 @@ def finish_epoch(self, return stop - @tf.function() - def get_single_grad(self, - low_res, - hi_res_true, - training_weights, - device_name=None, - **calc_loss_kwargs): - """Run gradient descent for one mini-batch of (low_res, hi_res_true), - do not update weights, just return gradient details. - - Parameters - ---------- - low_res : np.ndarray - Real low-resolution data in a 4D or 5D array: - (n_observations, spatial_1, spatial_2, features) - (n_observations, spatial_1, spatial_2, temporal, features) - hi_res_true : np.ndarray - Real high-resolution data in a 4D or 5D array: - (n_observations, spatial_1, spatial_2, features) - (n_observations, spatial_1, spatial_2, temporal, features) - training_weights : list - A list of layer weights that are to-be-trained based on the - current loss weight values. - device_name : None | str - Optional tensorflow device name for GPU placement. Note that if a - GPU is available, variables will be placed on that GPU even if - device_name=None. - calc_loss_kwargs : dict - Kwargs to pass to the self.calc_loss() method - - Returns - ------- - grad : list - a list or nested structure of Tensors (or IndexedSlices, or None, - or CompositeTensor) representing the gradients for the - training_weights - loss_details : dict - Namespace of the breakdown of loss components - """ - - with tf.device(device_name): - with tf.GradientTape(watch_accessed_variables=False) as tape: - tape.watch(training_weights) - - hi_res_gen = self._tf_generate(low_res) - loss_out = self.calc_loss(hi_res_true, hi_res_gen, - **calc_loss_kwargs) - loss, loss_details = loss_out - - grad = tape.gradient(loss, training_weights) - - return grad, loss_details - def run_gradient_descent(self, low_res, hi_res_true, @@ -1010,47 +993,6 @@ def run_gradient_descent(self, return loss_details - -# pylint: disable=E1101,W0201,E0203 -class AbstractExoInterface(ABC): - """ - Abstract class to define the required training interface - for Sup3r model subclasses with exogenous features - """ - - # pylint: disable=E0211 - def set_model_params(self, **kwargs): - """Set parameters used for training the model - - Parameters - ---------- - kwargs : dict - Keyword arguments including 'input_resolution', - 'training_features', 'output_features', 'smoothed_features', - 's_enhance', 't_enhance', 'smoothing'. For the Wind classes, the - last entry in "output_features" must be "topography" - - Returns - ------- - kwargs : dict - Same as input but with exogenous features removed from - "output_features", this is because the exo features are - concatenated mid-network in the ExoGan generators and are not - output features but are required in the hi-res training set. - """ - output_features = kwargs['output_features'] - msg = (f'Last {len(self.exogenous_features)} output features from the ' - f'data handler must be {self.exogenous_features} ' - 'to train the Exo model, but received output features: {}'. - format(output_features)) - check = (output_features[-len(self.exogenous_features)] - == self.exogenous_features) - assert check, msg - for f in self.exogenous_features: - output_features.remove(f) - kwargs['output_features'] = output_features - return kwargs - def _reshape_norm_exo(self, hi_res, hi_res_exo, exo_name, norm_in=True): """Reshape the hi_res_topo to match the hi_res tensor (if necessary) and normalize (if requested). @@ -1109,74 +1051,6 @@ def _reshape_norm_exo(self, hi_res, hi_res_exo, exo_name, norm_in=True): return hi_res_exo - def _combine_input(self, low_res, exogenous_data=None): - """Combine exogenous_data at input resolution with low_res data - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - exogenous_data : dict | None - Dictionary of exogenous feature data with entries describing - whether features should be combined at input, a mid network layer, - or with output. This doesn't have to include the 'model' key since - this data is for a single step model. e.g. - {'topography': {'steps': [ - {'combine_type': 'input', 'data': ..., 'resolution': ...}, - {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}} - - Returns - ------- - low_res : np.ndarray - Low-resolution input data combined with exogenous_data, usually a - 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - for feature in self.exogenous_features: - msg = f'Did not find {feature} in exogenous_data' - assert feature in exogenous_data, msg - for step in exogenous_data['steps']: - if step['combine_type'] == 'input': - low_res = np.concatenate(low_res, step['data'], axis=-1) - return low_res - - def _combine_output(self, hi_res, exogenous_data=None): - """Combine exogenous_data at input resolution with low_res data - - Parameters - ---------- - hi_res : np.ndarray - High-resolution output data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - exogenous_data : dict | None - Dictionary of exogenous feature data with entries describing - whether features should be combined at input, a mid network layer, - or with output. This doesn't have to include the 'model' key since - this data is for a single step model. e.g. - {'topography': {'steps': [ - {'combine_type': 'input', 'data': ..., 'resolution': ...}, - {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}} - - Returns - ------- - hi_res : np.ndarray - High-resolution output data combined with exogenous_data, usually a - 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - for feature in self.exogenous_features: - msg = f'Did not find {feature} in exogenous_data' - assert feature in exogenous_data, msg - for step in exogenous_data['steps']: - if step['combine_type'] == 'output': - hi_res = np.concatenate(hi_res, step['data'], axis=-1) - return hi_res - def generate(self, low_res, norm_in=True, @@ -1227,8 +1101,8 @@ def generate(self, if isinstance(layer, (Sup3rAdder, Sup3rConcat)): exo_name = layer.name steps = exogenous_data[exo_name]['steps'] - hi_res_exo = [step['data'] for step in steps - if step['combine_type'] == 'layer'] + idx = [step['combine_type'] for step in steps] + hi_res_exo = steps[idx]['data'] hi_res_exo = self._reshape_norm_exo(hi_res, hi_res_exo, norm_in=norm_in) @@ -1251,8 +1125,8 @@ def generate(self, return hi_res @tf.function - def _tf_generate(self, low_res, hi_res_exo): - """Use the generator model to generate high res data from los res input + def _tf_generate(self, low_res, hi_res_exo=None): + """Use the generator model to generate high res data from low res input Parameters ---------- @@ -1331,7 +1205,6 @@ def get_single_grad(self, loss_details : dict Namespace of the breakdown of loss components """ - hi_res_exo = {} for feature in self.exogenous_features: f_idx = self.training_features.index(feature) @@ -1349,3 +1222,44 @@ def get_single_grad(self, grad = tape.gradient(loss, training_weights) return grad, loss_details + + +# pylint: disable=E1101,W0201,E0203 +class AbstractExoInterface(AbstractInterface): + """ + Abstract class to define the required training interface + for Sup3r model subclasses with exogenous features + """ + + # pylint: disable=E0211 + def set_model_params(self, **kwargs): + """Set parameters used for training the model + + Parameters + ---------- + kwargs : dict + Keyword arguments including 'input_resolution', + 'training_features', 'output_features', 'smoothed_features', + 's_enhance', 't_enhance', 'smoothing'. For the Wind classes, the + last entry in "output_features" must be "topography" + + Returns + ------- + kwargs : dict + Same as input but with exogenous features removed from + "output_features", this is because the exo features are + concatenated mid-network in the ExoGan generators and are not + output features but are required in the hi-res training set. + """ + output_features = kwargs['output_features'] + msg = (f'Last {len(self.exogenous_features)} output features from the ' + f'data handler must be {self.exogenous_features} ' + 'to train the Exo model, but received output features: {}'. + format(output_features)) + check = (output_features[-len(self.exogenous_features)] + == self.exogenous_features) + assert check, msg + for f in self.exogenous_features: + output_features.remove(f) + kwargs['output_features'] = output_features + return kwargs diff --git a/sup3r/models/base.py b/sup3r/models/base.py index b5b664358..ac481af31 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -179,91 +179,6 @@ def load(cls, model_dir, verbose=True): return cls(fp_gen, fp_disc, **params) - def generate(self, - low_res, - norm_in=True, - un_norm_out=True, - exogenous_data=None): - """Use the generator model to generate high res data from low res - input. This is the public generate function. - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - norm_in : bool - Flag to normalize low_res input data if the self._means, - self._stdevs attributes are available. The generator should always - received normalized data with mean=0 stdev=1. - un_norm_out : bool - Flag to un-normalize synthetically generated output data to physical - units - exogenous_data : ndarray | None - Exogenous data array, usually a 4D or 5D array with shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - - Returns - ------- - hi_res : ndarray - Synthetically generated high-resolution data, usually a 4D or 5D - array with shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - exo_check = (exogenous_data is None or not self._needs_lr_exo(low_res)) - low_res = (low_res if exo_check else np.concatenate( - (low_res, exogenous_data), axis=-1)) - - if norm_in and self._means is not None: - low_res = self.norm_input(low_res) - - hi_res = self.generator.layers[0](low_res) - for i, layer in enumerate(self.generator.layers[1:]): - try: - hi_res = layer(hi_res) - except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}'. - format(i + 1, layer, hi_res.shape)) - logger.error(msg) - raise RuntimeError(msg) from e - - hi_res = hi_res.numpy() - - if un_norm_out and self._means is not None: - hi_res = self.un_norm_output(hi_res) - - return hi_res - - @tf.function - def _tf_generate(self, low_res): - """Use the generator model to generate high res data from los res input - - Parameters - ---------- - low_res : np.ndarray - Real low-resolution data. The generator should always - received normalized data with mean=0 stdev=1. - - Returns - ------- - hi_res : tf.Tensor - Synthetically generated high-resolution data - """ - hi_res = self.generator.layers[0](low_res) - for i, layer in enumerate(self.generator.layers[1:]): - try: - hi_res = layer(hi_res) - except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}'. - format(i + 1, layer, hi_res.shape)) - logger.error(msg) - raise RuntimeError(msg) from e - - return hi_res - @property def discriminator(self): """Get the discriminator model. @@ -679,10 +594,10 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - output_gen = self._tf_generate(val_batch.low_res) + high_res_gen = self._tf_generate(val_batch.low_res) _, v_loss_details = self.calc_loss( val_batch.high_res, - output_gen, + high_res_gen, weight_gen_advers=weight_gen_advers, train_gen=False, train_disc=False) diff --git a/sup3r/models/conditional_moments.py b/sup3r/models/conditional_moments.py index 4a3a2fc46..347aaeeee 100644 --- a/sup3r/models/conditional_moments.py +++ b/sup3r/models/conditional_moments.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- """Sup3r conditional moment model software""" +import logging import os +import pprint import time -import logging + import numpy as np -import pprint import pandas as pd import tensorflow as tf from tensorflow.keras import optimizers @@ -12,7 +13,6 @@ from sup3r.models.abstract import AbstractInterface, AbstractSingleModel from sup3r.utilities import VERSION_RECORD - logger = logging.getLogger(__name__) @@ -173,9 +173,7 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - exo_check = (exogenous_data is None or not self._needs_lr_exo(low_res)) - low_res = (low_res if exo_check - else np.concatenate((low_res, exogenous_data), axis=-1)) + low_res = self._combine_input(low_res, exogenous_data) if norm_in and self._means is not None: low_res = self.norm_input(low_res) @@ -195,6 +193,8 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, if un_norm_out and self._means is not None: output = self.un_norm_output(output) + output = self._combine_output(output, exogenous_data) + return output @tf.function diff --git a/sup3r/models/data_centric.py b/sup3r/models/data_centric.py index a91b95785..b41b7611d 100644 --- a/sup3r/models/data_centric.py +++ b/sup3r/models/data_centric.py @@ -6,7 +6,7 @@ import numpy as np from sup3r.models.base import Sup3rGan -from sup3r.models.wind import WindGan +from sup3r.models.multi_exo import MultiExoGan from sup3r.utilities.utilities import round_array logger = logging.getLogger(__name__) @@ -138,7 +138,7 @@ def calc_temporal_losses(total_losses, content_losses, batch_handler): f'{round_array(new_temporal_weights)}') -class WindGanDC(WindGan, Sup3rGanDC): +class MultiExoGanDC(MultiExoGan, Sup3rGanDC): """Data-centric model using loss across time bins to select training observations with handling of low and high res topography inputs.""" @@ -149,8 +149,8 @@ def calc_val_loss_gen(self, batch_handler, weight_gen_advers): validation set has 10 bins then this will get a list of losses across step 0 to 10, 10 to 20, etc. Use this to determine performance within bins and to update how observations are selected from these - bins. Use the _tf_generate function from WindGan to include the high - resolution topography. + bins. Use the _tf_generate function from MultiExoGan to include the + high resolution topography. Parameters ---------- @@ -181,7 +181,7 @@ def calc_val_loss_gen_content(self, batch_handler): validation set has 10 bins then this will get a list of losses across step 0 to 10, 10 to 20, etc. Use this to determine performance within bins and to update how observations are selected from these - bins. Use the _tf_generate function from WindGan to include high + bins. Use the _tf_generate function from MultiExoGan to include high resolution topography. Parameters diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 1d37ecedc..84d1fdf35 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -1110,13 +1110,9 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.input_data = self.bias_correct_source_data( self.input_data, self.strategy.lr_lat_lon) - exo_s_en = [1, *self.strategy.s_enhancements] - exo_t_en = [1, *self.strategy.t_enhancements] out = self.pad_source_data(self.input_data, self.pad_width, - self.exogenous_data, - exo_s_en, - exo_t_en) + self.exogenous_data) self.input_data, self.exogenous_data = out self.unpadded_input_data = self.data_handler.data[self.lr_slice[0], self.lr_slice[1]] @@ -1137,17 +1133,18 @@ def get_agg_factor(self, input_res, exo_res): Aggregation factor for exogenous data extraction. """ ires_num = (None if input_res is None - else re.search(r'\d+', input_res).group(0)) + else int(re.search(r'\d+', input_res).group(0))) eres_num = (None if exo_res is None - else re.search(r'\d+', exo_res).group(0)) + else int(re.search(r'\d+', exo_res).group(0))) i_units = (None if input_res is None - else input_res.replace(ires_num, '')) - e_units = None if exo_res is None else exo_res.replace(eres_num, '') + else input_res.replace(str(ires_num), '')) + e_units = (None if exo_res is None + else exo_res.replace(str(eres_num), '')) msg = 'Received conflicting units for input and exo resolution' if e_units is not None: assert i_units == e_units, msg if ires_num is not None and eres_num is not None: - agg_factor = (int(ires_num) / int(eres_num)) ** 2 + agg_factor = int((ires_num / eres_num) ** 2) else: agg_factor = None return agg_factor @@ -1180,7 +1177,7 @@ def _prep_exo_extract_single_step(self, step_dict, exo_resolution): t_agg_factor = None models = getattr(self.model, 'models', [self.model]) msg = (f'Model index from exo_kwargs ({model_step} exceeds number ' - f'of model steps ({len(self.model.models)})') + f'of model steps ({len(models)})') assert len(models) > model_step, msg model = models[model_step] input_res_t = model.input_resolution['temporal'] @@ -1237,9 +1234,9 @@ def _prep_exo_extract_kwargs(self, exo_kwargs): s_enhance, t_enhance added to each step entry for all features """ if exo_kwargs: - for feature, v in exo_kwargs.items(): - exo_resolution = v['exo_resolution'] - for i, step in enumerate(v['steps']): + for feature in exo_kwargs: + exo_resolution = exo_kwargs[feature]['exo_resolution'] + for i, step in enumerate(exo_kwargs[feature]['steps']): out = self._prep_exo_extract_single_step( step, exo_resolution) exo_kwargs[feature]['steps'][i] = out @@ -1265,8 +1262,8 @@ def load_exo_data(self): exo_kwargs['feature'] = feature exo_kwargs['target'] = self.target exo_kwargs['shape'] = self.shape + steps = exo_kwargs['steps'] exo_kwargs['temporal_slice'] = self.ti_pad_slice - steps = exo_kwargs.pop('steps') exo_kwargs['s_agg_factors'] = [step['s_agg_factor'] for step in steps] exo_kwargs['t_agg_factors'] = [step['t_agg_factor'] @@ -1275,9 +1272,12 @@ def load_exo_data(self): for step in steps] exo_kwargs['t_enhancements'] = [step['t_enhance'] for step in steps] + sig = signature(ExogenousDataHandler) + exo_kwargs = {k: v for k, v in exo_kwargs.items() + if k in sig.parameters} data = ExogenousDataHandler(**exo_kwargs).data for i, _ in enumerate(steps): - exo_data[feature]['steps']['data'] = data[i] + exo_data[feature]['steps'][i]['data'] = data[i] shapes = [None if d is None else d.shape for d in data] logger.info( 'Got exogenous_data of length {} with shapes: {}'.format( @@ -1573,8 +1573,6 @@ def pad_width(self): def pad_source_data(input_data, pad_width, exo_data, - exo_s_enhancements, - exo_t_enhancements, mode='reflect'): """Pad the edges of the source data from the data handler. @@ -1583,36 +1581,27 @@ def pad_source_data(input_data, input_data : np.ndarray Source input data from data handler class, shape is: (spatial_1, spatial_2, temporal, features) - spatial_pad : int - Size of spatial overlap between coarse chunks passed to forward - passes for subsequent spatial stitching. This overlap will pad both - sides of the fwp_chunk_shape. Note that the first and last chunks - in any of the spatial dimension will not be padded. pad_width : tuple Tuple of tuples with padding width for spatial and temporal dimensions. Each tuple includes the start and end of padding for that dimension. Ordering is spatial_1, spatial_2, temporal. - exo_data : None | list - List of exogenous data arrays for each step of the sup3r resolution - model. List entries can be None if not exo data is requested for a - given model step. - exo_s_enhancements : list - List of spatial enhancement factors for each step of the sup3r - resolution model corresponding to the exo_data order. - exo_t_enhancements : list - List of temporal enhancement factors for each step of the sup3r - resolution model corresponding to the exo_data order. + exo_data: dict + Full exo_kwargs dictionary with all feature entries. + e.g. {'topography': {'exo_resolution': {'spatial': '1km', + 'temporal': None}, 'steps': [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}]}} mode : str - Padding mode for np.pad(). Reflect is a good default for the - convolutional sup3r work. + Mode to use for padding. e.g. 'reflect'. Returns ------- out : np.ndarray Padded copy of source input data from data handler class, shape is: (spatial_1, spatial_2, temporal, features) - exo_data : list | None - Padded copy of exo_data input. + exo_data : dict + Same as input dictionary with s_agg_factor, t_agg_factor, + s_enhance, t_enhance added to each step entry for all features + """ out = np.pad(input_data, (*pad_width, (0, 0)), mode=mode) @@ -1623,26 +1612,17 @@ def pad_source_data(input_data, pad_width)) if exo_data is not None: - for i, i_exo_data in enumerate(exo_data): - if i_exo_data is not None: - total_s_enhance = exo_s_enhancements[:i + 1] - total_s_enhance = [ - s for s in total_s_enhance if s is not None - ] - total_s_enhance = np.product(total_s_enhance) - total_t_enhance = exo_t_enhancements[:i + 1] - total_t_enhance = [ - t for t in total_t_enhance if t is not None - ] - total_t_enhance = np.product(total_t_enhance) - exo_pad_width = ((total_s_enhance * pad_width[0][0], - total_s_enhance * pad_width[0][1]), - (total_s_enhance * pad_width[1][0], - total_s_enhance * pad_width[1][1]), - (total_t_enhance * pad_width[2][0], - total_t_enhance * pad_width[2][1]), + for feature in exo_data: + for i, step in enumerate(exo_data[feature]['steps']): + exo_pad_width = ((step['s_enhance'] * pad_width[0][0], + step['s_enhance'] * pad_width[0][1]), + (step['s_enhance'] * pad_width[1][0], + step['s_enhance'] * pad_width[1][1]), + (step['t_enhance'] * pad_width[2][0], + step['t_enhance'] * pad_width[2][1]), (0, 0)) - exo_data[i] = np.pad(i_exo_data, exo_pad_width, mode=mode) + new_exo = np.pad(step['data'], exo_pad_width, mode=mode) + exo_data[feature]['steps'][i]['data'] = new_exo return out, exo_data def bias_correct_source_data(self, data, lat_lon): diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 45b4baadf..5d1442b61 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -118,6 +118,7 @@ def __init__(self, self._tree = None self.ti_workers = ti_workers self._hr_lat_lon = None + self._hr_time_index = None if input_handler is None: in_type = get_source_type(file_paths) @@ -240,7 +241,7 @@ def nn(self): def data(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance * - t_enhance). The shape is (rows, cols, temporal) + t_enhance). The shape is (lats, lons, temporal, 1) """ nn = self.nn hr_data = [] @@ -250,7 +251,7 @@ def data(self): hr_data.append(out[..., np.newaxis]) hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) logger.info('Finished mapping raster from {}'.format(self._exo_source)) - return hr_data + return hr_data[..., np.newaxis] @classmethod def get_exo_raster(cls, @@ -371,6 +372,14 @@ def source_lat_lon(self): source_lat_lon = res.lat_lon return source_lat_lon + @property + def hr_time_index(self): + """Time index of the high-resolution exo data""" + if self._hr_time_index is None: + with Resource(self._exo_source) as res: + self._hr_time_index = res.time_index + return self._hr_time_index + class TopoExtractNC(ExoExtract): """TopoExtract for netCDF files""" diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index aaf6878a7..3111ffff6 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -48,7 +48,8 @@ def __init__(self, max_delta=20, input_handler=None, exo_handler=None, - cache_data=True): + cache_data=True, + cache_dir='./exo_cache'): """ Parameters ---------- @@ -75,9 +76,8 @@ def __init__(self, at 100km (s_enhance=1, exo_step=0), the input to the 5x model gets exogenous data at 25km (s_enhance=4, exo_step=1), and there is a 20x (1*4*5) exogeneous data layer available if the second model can - receive a high-res input feature (e.g. WindGan). The length of this - list should be equal to the number of agg_factors and the number of - exo_steps + receive a high-res input feature (e.g. MultiExoGan). The length of + this list should be equal to the number of s_agg_factors t_enhancements : list List of factors by which the Sup3rGan model will enhance the temporal dimension of low resolution data from file_paths input @@ -87,12 +87,12 @@ def __init__(self, List of factors by which to aggregate the exo_source data to the spatial resolution of the file_paths input enhanced by s_enhance. The length of this list should be equal to the number of - s_enhancements and the number of exo_steps + s_enhancements t_agg_factors : list List of factors by which to aggregate the exo_source data to the temporal resolution of the file_paths input enhanced by t_enhance. The length of this list should be equal to the number of - t_enhancements and the number of exo_steps + t_enhancements target : tuple (lat, lon) lower left corner of raster. Either need target+shape or raster_file. @@ -122,8 +122,10 @@ def __init__(self, TopoExtractNC. If None the correct handler will be guessed based on file type and time series properties. cache_data : bool - Flag to cache exogeneous data in ./exo_cache/ this can speed up - forward passes with large temporal extents + Flag to cache exogeneous data in /exo_cache/ this can + speed up forward passes with large temporal extents + cache_dir : str + Directory for storing cache data. Default is './exo_cache' """ self.feature = feature @@ -141,6 +143,7 @@ def __init__(self, self.max_delta = max_delta self.input_handler = input_handler self.cache_data = cache_data + self.cache_dir = cache_dir self.data = [] if self.s_enhancements[0] != 1: @@ -214,15 +217,14 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, cache_fp : str Name of cache file """ - cache_dir = './exo_cache/' fn = f'exo_{feature}_{self.target}_{self.shape}_sagg{s_agg_factor}_' - fn += f'tagg_{t_agg_factor}_{s_enhance}x_{t_enhance}x.pkl' + fn += f'tagg{t_agg_factor}_{s_enhance}x_{t_enhance}x.pkl' fn = fn.replace('(', '').replace(')', '') fn = fn.replace('[', '').replace(']', '') fn = fn.replace(',', 'x').replace(' ', '') - cache_fp = os.path.join(cache_dir, fn) + cache_fp = os.path.join(self.cache_dir, fn) if self.cache_data: - os.makedirs(cache_dir, exist_ok=True) + os.makedirs(self.cache_dir, exist_ok=True) return cache_fp def get_exo_data(self, feature, s_enhance, t_enhance, s_agg_factor, @@ -259,7 +261,6 @@ def get_exo_data(self, feature, s_enhance, t_enhance, s_agg_factor, s_agg_factor=s_agg_factor, t_agg_factor=t_agg_factor) tmp_fp = cache_fp + '.tmp' - print(cache_fp) if os.path.exists(cache_fp): with open(cache_fp, 'rb') as f: data = pickle.load(f) diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 56c1c054d..8c4d5db25 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -12,7 +12,7 @@ from rex import ResourceX, init_logger from sup3r import CONFIG_DIR, TEST_DATA_DIR, __version__ -from sup3r.models import LinearInterp, Sup3rGan, WindGan +from sup3r.models import LinearInterp, MultiExoGan, Sup3rGan from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.preprocessing.data_handling import DataHandlerNC from sup3r.utilities.pytest import ( @@ -579,6 +579,8 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): s1_model.meta['output_features'] = ['U_100m', 'V_100m'] s1_model.meta['s_enhance'] = 2 s1_model.meta['t_enhance'] = 1 + s1_model.meta['input_resolution'] = {'spatial': '48km', + 'temporal': '60min'} _ = s1_model.generate(np.ones((4, 10, 10, 3))) s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) @@ -586,6 +588,8 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): s2_model.meta['output_features'] = ['U_100m', 'V_100m'] s2_model.meta['s_enhance'] = 2 s2_model.meta['t_enhance'] = 1 + s2_model.meta['input_resolution'] = {'spatial': '24km', + 'temporal': '60min'} _ = s2_model.generate(np.ones((4, 10, 10, 3))) fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') @@ -595,6 +599,8 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): st_model.meta['output_features'] = ['U_100m', 'V_100m'] st_model.meta['s_enhance'] = 3 st_model.meta['t_enhance'] = 4 + st_model.meta['input_resolution'] = {'spatial': '12km', + 'temporal': '60min'} _ = st_model.generate(np.ones((4, 10, 10, 6, 2))) with tempfile.TemporaryDirectory() as td: @@ -618,11 +624,12 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): 'source_file': FP_WTK, 'target': target, 'shape': shape, - 's_enhancements': [1, 2, 2], - 't_enhancements': [1, 1, 1], - 's_agg_factors': [16, 4, 2], - 't_agg_factors': [1, 1, 1], - 'exo_steps': [0, 1] + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + 'steps': [ + {'model': 0, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'input'} + ] } } @@ -689,14 +696,14 @@ def test_fwp_multi_step_spatial_model_topo_noskip(): Sup3rGan.seed() fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = WindGan(fp_gen, fp_disc, learning_rate=1e-4) + s1_model = MultiExoGan(fp_gen, fp_disc, learning_rate=1e-4) s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] s1_model.meta['output_features'] = ['U_100m', 'V_100m'] s1_model.meta['s_enhance'] = 2 s1_model.meta['t_enhance'] = 1 _ = s1_model.generate(np.ones((4, 10, 10, 3))) - s2_model = WindGan(fp_gen, fp_disc, learning_rate=1e-4) + s2_model = MultiExoGan(fp_gen, fp_disc, learning_rate=1e-4) s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] s2_model.meta['output_features'] = ['U_100m', 'V_100m'] s2_model.meta['s_enhance'] = 2 @@ -1273,7 +1280,7 @@ def test_slicing_pad(log=False): def test_fwp_single_step_wind_hi_res_topo(plot=False): - """Test the forward pass with a single spatiotemporal WindGan model + """Test the forward pass with a single spatiotemporal MultiExoGan model requiring high-resolution topography input from the exogenous_data feature.""" Sup3rGan.seed() @@ -1324,7 +1331,8 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): "alpha": 0.2, "class": "LeakyReLU" }, { - "class": "Sup3rConcat" + "class": "Sup3rConcat", + "name": "topography" }, { "class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], @@ -1340,13 +1348,17 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): }] fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - model = WindGan(gen_model, fp_disc, learning_rate=1e-4) + model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] model.meta['output_features'] = ['U_100m', 'V_100m'] model.meta['s_enhance'] = 2 model.meta['t_enhance'] = 2 - _ = model.generate(np.random.rand(4, 10, 10, 6, 3), - exogenous_data=(None, np.random.rand(4, 20, 20, 6, 1))) + model.meta['input_resolution'] = {'spatial': '8km', + 'temporal': '60min'} + exo_tmp = {'topography': { + 'steps': [{'model': 0, 'combine_type': 'layer', + 'data': np.random.rand(4, 20, 20, 12, 1)}]}} + _ = model.generate(np.random.rand(4, 10, 10, 6, 3), exogenous_data=exo_tmp) with tempfile.TemporaryDirectory() as td: input_files = make_fake_nc_files(td, INPUT_FILE, 8) @@ -1355,16 +1367,17 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): model.save(st_out_dir) exo_kwargs = { - 'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2], - 't_enhancements': [1, 1], - 's_agg_factors': [4, 2], - 't_agg_factors': [1, 1], - } + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': None}, + 'steps': [ + {'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'} + ]}} model_kwargs = {'model_dir': st_out_dir} out_files = os.path.join(td, 'out_{file_id}.h5') @@ -1374,29 +1387,10 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): worker_kwargs=dict(max_workers=1), overwrite_cache=True) - # should get an error on a bad tensorflow concatenation - with pytest.raises(RuntimeError): - exo_kwargs['s_enhancements'] = [1, 1] - handler = ForwardPassStrategy( - input_files, - model_kwargs=model_kwargs, - model_class='WindGan', - fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, - temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=1), - exo_kwargs=exo_kwargs, - max_nodes=1) - forward_pass = ForwardPass(handler) - forward_pass.run(handler, node_index=0) - - exo_kwargs['s_enhancements'] = [1, 2] handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, - model_class='WindGan', + model_class='MultiExoGan', fwp_chunk_shape=(8, 8, 8), spatial_pad=4, temporal_pad=4, @@ -1476,9 +1470,11 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): "class": "Activation", "activation": "relu" }, { - "class": "Sup3rConcat" + "class": "Sup3rConcat", + "name": "topography" }, { - "class": "Sup3rConcat" + "class": "Sup3rConcat", + "name": "sza" }, { "class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], @@ -1495,7 +1491,7 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): }] fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = WindGan(gen_model, fp_disc, learning_rate=1e-4) + s1_model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) s1_model.meta['training_features'] = [ 'U_100m', 'V_100m', 'topography', 'sza' ] @@ -1505,7 +1501,7 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): _ = s1_model.generate(np.ones((4, 10, 10, 3)), exogenous_data=(None, np.ones((4, 20, 20, 1)))) - s2_model = WindGan(gen_model, fp_disc, learning_rate=1e-4) + s2_model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) s2_model.meta['training_features'] = [ 'U_100m', 'V_100m', 'topography', 'sza' ] @@ -1609,7 +1605,7 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): def test_fwp_multi_step_wind_hi_res_topo(): - """Test the forward pass with multiple WindGan models requiring + """Test the forward pass with multiple MultiExoGan models requiring high-resolution topograph input from the exogenous_data feature.""" Sup3rGan.seed() gen_model = [{ @@ -1658,7 +1654,8 @@ def test_fwp_multi_step_wind_hi_res_topo(): "class": "Activation", "activation": "relu" }, { - "class": "Sup3rConcat" + "class": "Sup3rConcat", + "name": "topography" }, { "class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], @@ -1675,7 +1672,7 @@ def test_fwp_multi_step_wind_hi_res_topo(): }] fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = WindGan(gen_model, fp_disc, learning_rate=1e-4) + s1_model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] s1_model.meta['output_features'] = ['U_100m', 'V_100m'] s1_model.meta['s_enhance'] = 2 @@ -1683,7 +1680,7 @@ def test_fwp_multi_step_wind_hi_res_topo(): _ = s1_model.generate(np.ones((4, 10, 10, 3)), exogenous_data=(None, np.ones((4, 20, 20, 1)))) - s2_model = WindGan(gen_model, fp_disc, learning_rate=1e-4) + s2_model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] s2_model.meta['output_features'] = ['U_100m', 'V_100m'] s2_model.meta['s_enhance'] = 2 @@ -1772,9 +1769,9 @@ def test_fwp_multi_step_wind_hi_res_topo(): def test_fwp_wind_hi_res_topo_plus_linear(): - """Test the forward pass with a WindGan model requiring high-res topo input - from exo data for spatial enhancement and a linear interpolation model for - temporal enhancement.""" + """Test the forward pass with a MultiExoGan model requiring high-res topo + input from exo data for spatial enhancement and a linear interpolation + model for temporal enhancement.""" Sup3rGan.seed() gen_model = [{ @@ -1836,7 +1833,7 @@ def test_fwp_wind_hi_res_topo_plus_linear(): }] fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s_model = WindGan(gen_model, fp_disc, learning_rate=1e-4) + s_model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) s_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] s_model.meta['output_features'] = ['U_100m', 'V_100m'] s_model.meta['s_enhance'] = 2 diff --git a/tests/training/test_train_wind.py b/tests/training/test_train_multi_exo.py similarity index 95% rename from tests/training/test_train_wind.py rename to tests/training/test_train_multi_exo.py index 0eff72943..f40e671e3 100644 --- a/tests/training/test_train_wind.py +++ b/tests/training/test_train_multi_exo.py @@ -8,8 +8,8 @@ from rex import init_logger from sup3r import CONFIG_DIR, TEST_DATA_DIR -from sup3r.models import WindGan -from sup3r.models.data_centric import WindGanDC +from sup3r.models import MultiExoGan +from sup3r.models.data_centric import MultiExoGanDC from sup3r.preprocessing.batch_handling import ( BatchHandlerCC, BatchHandlerDC, @@ -60,8 +60,8 @@ def test_wind_cc_model(log=False): fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_4x_24x_3f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - WindGan.seed() - model = WindGan(fp_gen, fp_disc, learning_rate=1e-4) + MultiExoGan.seed() + model = MultiExoGan(fp_gen, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: model.train(batcher, n_epoch=1, @@ -71,7 +71,7 @@ def test_wind_cc_model(log=False): out_dir=os.path.join(td, 'test_{epoch}')) assert 'test_0' in os.listdir(td) - assert model.meta['class'] == 'WindGan' + assert model.meta['class'] == 'MultiExoGan' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features assert len(model.output_features) == len(FEATURES_W) - 1 @@ -108,8 +108,8 @@ def test_wind_cc_model_spatial(log=False): fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - WindGan.seed() - model = WindGan(fp_gen, fp_disc, learning_rate=1e-4) + MultiExoGan.seed() + model = MultiExoGan(fp_gen, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: model.train(batcher, n_epoch=1, @@ -120,7 +120,7 @@ def test_wind_cc_model_spatial(log=False): assert 'test_0' in os.listdir(td) assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'WindGan' + assert model.meta['class'] == 'MultiExoGan' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features @@ -188,8 +188,8 @@ def test_wind_hi_res_topo(custom_layer, log=False): fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - WindGan.seed() - model = WindGan(gen_model, fp_disc, learning_rate=1e-4) + MultiExoGan.seed() + model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: model.train(batcher, n_epoch=1, @@ -200,7 +200,7 @@ def test_wind_hi_res_topo(custom_layer, log=False): assert 'test_0' in os.listdir(td) assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'WindGan' + assert model.meta['class'] == 'MultiExoGan' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features @@ -273,8 +273,8 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - WindGan.seed() - model = WindGan(gen_model, fp_disc, learning_rate=1e-4) + MultiExoGan.seed() + model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: model.train(batcher, n_epoch=1, @@ -285,7 +285,7 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): assert 'test_0' in os.listdir(td) assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'WindGan' + assert model.meta['class'] == 'MultiExoGan' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features @@ -358,8 +358,8 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - WindGanDC.seed() - model = WindGanDC(gen_model, fp_disc, learning_rate=1e-4) + MultiExoGanDC.seed() + model = MultiExoGanDC(gen_model, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: model.train(batcher, n_epoch=1, @@ -370,7 +370,7 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): assert 'test_0' in os.listdir(td) assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'WindGanDC' + assert model.meta['class'] == 'MultiExoGanDC' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features From 6630d309e4bfcd08cecf32a7e5b6f15eb2515775 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 22 Sep 2023 14:11:13 -0600 Subject: [PATCH 08/15] Combined Sup3rGan and WindGan to remove lots of duplicate code. Sup3rGan/Sup3rCondMom/Sup3rGanDC now works with/without mid network exo injection. Transpose method added to MultiStepGan to match input data shape with model.input_dims. This should accomplish the same as SpatialThenTemporal/TemporalThenSpatial models in addition to a st model + spatial + st model, for example. --- sup3r/models/__init__.py | 2 - sup3r/models/abstract.py | 143 +- sup3r/models/base.py | 20 +- sup3r/models/conditional_moments.py | 87 +- sup3r/models/data_centric.py | 73 +- sup3r/models/multi_exo.py | 147 --- sup3r/models/multi_step.py | 113 +- sup3r/models/surface.py | 33 +- sup3r/models/wind_conditional_moments.py | 98 -- sup3r/pipeline/forward_pass.py | 83 +- .../data_handling/exo_extraction.py | 4 +- .../data_handling/exogenous_data_handling.py | 4 +- tests/forward_pass/test_forward_pass.py | 1077 +-------------- tests/forward_pass/test_forward_pass_exo.py | 1158 +++++++++++++++++ tests/forward_pass/test_multi_step.py | 15 +- tests/training/test_train_multi_exo.py | 34 +- 16 files changed, 1450 insertions(+), 1641 deletions(-) delete mode 100644 sup3r/models/multi_exo.py delete mode 100644 sup3r/models/wind_conditional_moments.py create mode 100644 tests/forward_pass/test_forward_pass_exo.py diff --git a/sup3r/models/__init__.py b/sup3r/models/__init__.py index 1c3e01a3f..779b65cf0 100644 --- a/sup3r/models/__init__.py +++ b/sup3r/models/__init__.py @@ -4,7 +4,6 @@ from .conditional_moments import Sup3rCondMom from .data_centric import Sup3rGanDC from .linear import LinearInterp -from .multi_exo import MultiExoGan from .multi_step import ( MultiStepGan, MultiStepSurfaceMetGan, @@ -14,7 +13,6 @@ ) from .solar_cc import SolarCC from .surface import SurfaceSpatialMetModel -from .wind_conditional_moments import WindCondMom SPATIAL_FIRST_MODELS = (SpatialThenTemporalGan, MultiStepSurfaceMetGan, diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 974b6ec2c..3b898c3d2 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -121,8 +121,9 @@ def output_resolution(self): str(output_spatial)) return output_res - def _combine_input(self, low_res, exogenous_data=None): - """Combine exogenous_data at input resolution with low_res data + def _combine_fwp_input(self, low_res, exogenous_data=None): + """Combine exogenous_data at input resolution with low_res data prior + to forward pass through generator Parameters ---------- @@ -166,8 +167,9 @@ def _combine_input(self, low_res, exogenous_data=None): axis=-1) return low_res - def _combine_output(self, hi_res, exogenous_data=None): - """Combine exogenous_data at input resolution with low_res data + def _combine_fwp_output(self, hi_res, exogenous_data=None): + """Combine exogenous_data at output resolution with generated hi_res + data following forward pass output. Parameters ---------- @@ -209,6 +211,50 @@ def _combine_output(self, hi_res, exogenous_data=None): axis=-1) return hi_res + def _combine_loss_input(self, high_res_true, high_res_gen): + """Combine exogenous feature data from high_res_true with high_res_gen + for loss calculation + + Parameters + ---------- + high_res_true : tf.Tensor + Ground truth high resolution spatiotemporal data. + high_res_gen : tf.Tensor + Superresolved high resolution spatiotemporal data generated by the + generative model. + + Returns + ------- + high_res_gen : tf.Tensor + Same as input with exogenous data combined with high_res input + """ + for feature in self.exogenous_features: + f_idx = self.training_features.index(feature) + exo_data = high_res_true[..., f_idx: f_idx + 1] + high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1) + return high_res_gen + + def _get_exo_val_loss_input(self, high_res): + """Get exogenous feature data from high_res + + Parameters + ---------- + high_res : tf.Tensor + Ground truth high resolution spatiotemporal data. + + Returns + ------- + exo_data : dict + Dictionary of exogenous feature data used as input to tf_generate. + e.g. {'topography': np.ndarray(...)} + """ + exo_data = {} + for feature in self.exogenous_features: + f_idx = self.training_features.index(feature) + exo_fdata = high_res[..., f_idx: f_idx + 1] + exo_data[feature] = exo_fdata + return exo_data + @property def exogenous_features(self): """Get list of exogenous filter names the model uses. If the model has @@ -219,12 +265,12 @@ def exogenous_features(self): [topo, sza]. Topo will then be used in the first concat layer and sza will be used in the second""" # pylint: disable=E1101 - layer_count = 0 + features = [] if hasattr(self, '_gen'): - layer_count = sum( - isinstance(layer, (Sup3rAdder, Sup3rConcat)) - for layer in self._gen.layers) - return self.training_features[-layer_count:] + for layer in self._gen.layers: + if isinstance(layer, (Sup3rAdder, Sup3rConcat)): + features.append(layer.name) + return features @property @abstractmethod @@ -277,6 +323,35 @@ def version_record(self): """ return VERSION_RECORD + def _check_exo_features(self, **kwargs): + """Make sure exogenous features have the correct ordering and are + included in training_features + + Parameters + ---------- + kwargs : dict + Keyword arguments including 'training_features', 'output_features', + 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' + + Returns + ------- + kwargs : dict + Same as input but with exogenous_features removed from output + features + """ + output_features = kwargs['output_features'] + msg = (f'Last {len(self.exogenous_features)} output features from the ' + f'data handler must be {self.exogenous_features} ' + 'to train the Exo model, but received output features: {}'. + format(output_features)) + check = (output_features[-len(self.exogenous_features)] + == self.exogenous_features) + assert check, msg + for f in self.exogenous_features: + output_features.remove(f) + kwargs['output_features'] = output_features + return kwargs + def set_model_params(self, **kwargs): """Set parameters used for training the model @@ -286,6 +361,7 @@ def set_model_params(self, **kwargs): Keyword arguments including 'training_features', 'output_features', 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' """ + kwargs = self._check_exo_features(**kwargs) keys = ('input_resolution', 'training_features', 'output_features', 'smoothed_features', 's_enhance', 't_enhance', 'smoothing') @@ -1090,7 +1166,7 @@ def generate(self, (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - low_res = self._combine_input(low_res, exogenous_data) + low_res = self._combine_fwp_input(low_res, exogenous_data) if norm_in and self._means is not None: low_res = self.norm_input(low_res) @@ -1101,10 +1177,12 @@ def generate(self, if isinstance(layer, (Sup3rAdder, Sup3rConcat)): exo_name = layer.name steps = exogenous_data[exo_name]['steps'] - idx = [step['combine_type'] for step in steps] + combine_types = [step['combine_type'] for step in steps] + idx = combine_types.index('layer') hi_res_exo = steps[idx]['data'] hi_res_exo = self._reshape_norm_exo(hi_res, hi_res_exo, + exo_name, norm_in=norm_in) hi_res = layer(hi_res, hi_res_exo) else: @@ -1120,7 +1198,7 @@ def generate(self, if un_norm_out and self._means is not None: hi_res = self.un_norm_output(hi_res) - hi_res = self._combine_output(hi_res, exogenous_data) + hi_res = self._combine_fwp_output(hi_res, exogenous_data) return hi_res @@ -1222,44 +1300,3 @@ def get_single_grad(self, grad = tape.gradient(loss, training_weights) return grad, loss_details - - -# pylint: disable=E1101,W0201,E0203 -class AbstractExoInterface(AbstractInterface): - """ - Abstract class to define the required training interface - for Sup3r model subclasses with exogenous features - """ - - # pylint: disable=E0211 - def set_model_params(self, **kwargs): - """Set parameters used for training the model - - Parameters - ---------- - kwargs : dict - Keyword arguments including 'input_resolution', - 'training_features', 'output_features', 'smoothed_features', - 's_enhance', 't_enhance', 'smoothing'. For the Wind classes, the - last entry in "output_features" must be "topography" - - Returns - ------- - kwargs : dict - Same as input but with exogenous features removed from - "output_features", this is because the exo features are - concatenated mid-network in the ExoGan generators and are not - output features but are required in the hi-res training set. - """ - output_features = kwargs['output_features'] - msg = (f'Last {len(self.exogenous_features)} output features from the ' - f'data handler must be {self.exogenous_features} ' - 'to train the Exo model, but received output features: {}'. - format(output_features)) - check = (output_features[-len(self.exogenous_features)] - == self.exogenous_features) - assert check, msg - for f in self.exogenous_features: - output_features.remove(f) - kwargs['output_features'] = output_features - return kwargs diff --git a/sup3r/models/base.py b/sup3r/models/base.py index ac481af31..7c1b04233 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -377,8 +377,15 @@ def init_weights(self, lr_shape, hr_shape, device=None): logger.info('Initializing model weights on device "{}"'.format(device)) low_res = np.ones(lr_shape).astype(np.float32) hi_res = np.ones(hr_shape).astype(np.float32) + + hr_exo_shape = hr_shape[:-1] + (1,) + hr_exo = np.ones(hr_exo_shape).astype(np.float32) + with tf.device(device): - _ = self._tf_generate(low_res) + hr_exo_data = {} + for feature in self.exogenous_features: + hr_exo_data[feature] = hr_exo + _ = self._tf_generate(low_res, hr_exo_data) _ = self._tf_discriminate(hi_res) @staticmethod @@ -539,6 +546,7 @@ def calc_loss(self, loss_details : dict Namespace of the breakdown of loss components """ + hi_res_gen = self._combine_loss_input(hi_res_true, hi_res_gen) if hi_res_gen.shape != hi_res_true.shape: msg = ('The tensor shapes of the synthetic output {} and ' @@ -594,19 +602,17 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - high_res_gen = self._tf_generate(val_batch.low_res) + val_exo_data = self._get_exo_val_loss_input(val_batch.high_res) + high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data) _, v_loss_details = self.calc_loss( - val_batch.high_res, - high_res_gen, + val_batch.high_res, high_res_gen, weight_gen_advers=weight_gen_advers, - train_gen=False, - train_disc=False) + train_gen=False, train_disc=False) loss_details = self.update_loss_details(loss_details, v_loss_details, len(val_batch), prefix='val_') - return loss_details def train_epoch(self, diff --git a/sup3r/models/conditional_moments.py b/sup3r/models/conditional_moments.py index 347aaeeee..b0d1cd4a3 100644 --- a/sup3r/models/conditional_moments.py +++ b/sup3r/models/conditional_moments.py @@ -142,89 +142,6 @@ def load(cls, model_dir, verbose=True): params = cls.load_saved_params(model_dir, verbose=verbose) return cls(fp_gen, **params) - def generate(self, low_res, norm_in=True, un_norm_out=True, - exogenous_data=None): - """Use the generator model to generate high res data from low res - input. This is the public generate function. - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, usually a 4D or 5D array of shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - norm_in : bool - Flag to normalize low_res input data if the self._means, - self._stdevs attributes are available. The generator should always - received normalized data with mean=0 stdev=1. - un_norm_out : bool - Flag to un-normalize synthetically generated output data to physical - units - exogenous_data : ndarray | None - Exogenous data array, usually a 4D or 5D array with shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - - Returns - ------- - output : ndarray - Synthetically generated high-resolution data, usually a 4D or 5D - array with shape: - (n_obs, spatial_1, spatial_2, n_features) - (n_obs, spatial_1, spatial_2, n_temporal, n_features) - """ - low_res = self._combine_input(low_res, exogenous_data) - - if norm_in and self._means is not None: - low_res = self.norm_input(low_res) - - output = self.generator.layers[0](low_res) - for i, layer in enumerate(self.generator.layers[1:]): - try: - output = layer(output) - except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}' - .format(i + 1, layer, output.shape)) - logger.error(msg) - raise RuntimeError(msg) from e - - output = output.numpy() - - if un_norm_out and self._means is not None: - output = self.un_norm_output(output) - - output = self._combine_output(output, exogenous_data) - - return output - - @tf.function - def _tf_generate(self, low_res): - """Use the generator model to generate high res data from los res input - - Parameters - ---------- - low_res : np.ndarray - Real low-resolution data. The generator should always - received normalized data with mean=0 stdev=1. - - Returns - ------- - output : tf.Tensor - Synthetically generated high-resolution data - """ - - output = self.generator.layers[0](low_res) - for i, layer in enumerate(self.generator.layers[1:]): - try: - output = layer(output) - except Exception as e: - msg = ('Could not run layer #{} "{}" on tensor of shape {}' - .format(i + 1, layer, output.shape)) - logger.error(msg) - raise RuntimeError(msg) from e - - return output - def update_optimizer(self, **kwargs): """Update optimizer by changing current configuration @@ -331,6 +248,7 @@ def calc_loss(self, output_true, output_gen, mask): loss_details : dict Namespace of the breakdown of loss components """ + output_gen = self._combine_loss_input(output_true, output_gen) if output_gen.shape != output_true.shape: msg = ('The tensor shapes of the synthetic output {} and ' @@ -365,7 +283,8 @@ def calc_val_loss(self, batch_handler, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - output_gen = self._tf_generate(val_batch.low_res) + val_exo_data = self._get_exo_val_loss_input(val_batch.high_res) + output_gen = self._tf_generate(val_batch.low_res, val_exo_data) _, v_loss_details = self.calc_loss( val_batch.output, output_gen, val_batch.mask) diff --git a/sup3r/models/data_centric.py b/sup3r/models/data_centric.py index b41b7611d..2afc75896 100644 --- a/sup3r/models/data_centric.py +++ b/sup3r/models/data_centric.py @@ -6,7 +6,6 @@ import numpy as np from sup3r.models.base import Sup3rGan -from sup3r.models.multi_exo import MultiExoGan from sup3r.utilities.utilities import round_array logger = logging.getLogger(__name__) @@ -39,7 +38,8 @@ def calc_val_loss_gen(self, batch_handler, weight_gen_advers): """ losses = [] for obs in batch_handler.val_data: - gen = self._tf_generate(obs.low_res) + exo_data = self._get_exo_val_loss_input(obs.high_res) + gen = self._tf_generate(obs.low_res, exo_data) loss, _ = self.calc_loss(obs.high_res, gen, weight_gen_advers=weight_gen_advers, train_gen=True, train_disc=True) @@ -66,7 +66,8 @@ def calc_val_loss_gen_content(self, batch_handler): """ losses = [] for obs in batch_handler.val_data: - gen = self._tf_generate(obs.low_res) + exo_data = self._get_exo_val_loss_input(obs.high_res) + gen = self._tf_generate(obs.low_res, exo_data) loss = self.calc_loss_gen_content(obs.high_res, gen) losses.append(float(loss)) return losses @@ -93,7 +94,6 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): Updated loss_details with mean validation loss calculated using the validation samples across the time bins """ - total_losses = self.calc_val_loss_gen(batch_handler, weight_gen_advers) content_losses = self.calc_val_loss_gen_content(batch_handler) @@ -138,71 +138,6 @@ def calc_temporal_losses(total_losses, content_losses, batch_handler): f'{round_array(new_temporal_weights)}') -class MultiExoGanDC(MultiExoGan, Sup3rGanDC): - """Data-centric model using loss across time bins to select training - observations with handling of low and high res topography - inputs.""" - - def calc_val_loss_gen(self, batch_handler, weight_gen_advers): - """Calculate the validation total loss across the validation - samples. e.g. If the sample domain has 100 steps and the - validation set has 10 bins then this will get a list of losses across - step 0 to 10, 10 to 20, etc. Use this to determine performance - within bins and to update how observations are selected from these - bins. Use the _tf_generate function from MultiExoGan to include the - high resolution topography. - - Parameters - ---------- - batch_handler : sup3r.data_handling.preprocessing.BatchHandlerDC - BatchHandler object to iterate through - weight_gen_advers : float - Weight factor for the adversarial loss component of the generator - vs. the discriminator. - - Returns - ------- - list - List of total losses for all sample bins - """ - losses = [] - for obs in batch_handler.val_data: - gen = self._tf_generate(obs.low_res, - obs.high_res[..., -1:]) - loss, _ = self.calc_loss(obs.high_res, gen, - weight_gen_advers=weight_gen_advers, - train_gen=True, train_disc=True) - losses.append(float(loss)) - return losses - - def calc_val_loss_gen_content(self, batch_handler): - """Calculate the validation content loss across the validation - samples. e.g. If the sample domain has 100 steps and the - validation set has 10 bins then this will get a list of losses across - step 0 to 10, 10 to 20, etc. Use this to determine performance - within bins and to update how observations are selected from these - bins. Use the _tf_generate function from MultiExoGan to include high - resolution topography. - - Parameters - ---------- - batch_handler : sup3r.data_handling.preprocessing.BatchHandlerDC - BatchHandler object to iterate through - - Returns - ------- - list - List of content losses for all sample bins - """ - losses = [] - for obs in batch_handler.val_data: - gen = self._tf_generate(obs.low_res, - obs.high_res[..., -1:]) - loss = self.calc_loss_gen_content(obs.high_res, gen) - losses.append(float(loss)) - return losses - - class Sup3rGanSpatialDC(Sup3rGanDC): """Data-centric model using loss across time bins to select training observations""" diff --git a/sup3r/models/multi_exo.py b/sup3r/models/multi_exo.py deleted file mode 100644 index a393521c7..000000000 --- a/sup3r/models/multi_exo.py +++ /dev/null @@ -1,147 +0,0 @@ -# -*- coding: utf-8 -*- -"""Wind super resolution GAN with handling of low and high res topography -inputs.""" -import logging - -import numpy as np -import tensorflow as tf - -from sup3r.models.abstract import AbstractExoInterface -from sup3r.models.base import Sup3rGan - -logger = logging.getLogger(__name__) - - -class MultiExoGan(AbstractExoInterface, Sup3rGan): - """Super resolution GAN with handling of low and high res exogenous feature - inputs. This exogenous data is commonly just topography. - - Modifications to standard Sup3rGan: - - Hi res exogenous features are expected as the last - len(self.exogenous_features) channels in the true data in the true - batch observation. These channels are appended to the generated - output so the discriminator can look at the super resolved fields - compared to the associated hi res exogenous feature data. - - If a custom Sup3rAdder or Sup3rConcat layer (from phygnn) is present - in the network, the hi-res exogenous feature matching layer.name will - be added or concatenated to the data at that point in the network - during either training or the forward pass. - """ - - def init_weights(self, lr_shape, hr_shape, device=None): - """Initialize the generator and discriminator weights with device - placement. - - Parameters - ---------- - lr_shape : tuple - Shape of one batch of low res input data for sup3r resolution. Note - that the batch size (axis=0) must be included, but the actual batch - size doesnt really matter. - hr_shape : tuple - Shape of one batch of high res input data for sup3r resolution. - Note that the batch size (axis=0) must be included, but the actual - batch size doesnt really matter. - device : str | None - Option to place model weights on a device. If None, - self.default_device will be used. - """ - - if device is None: - device = self.default_device - - logger.info('Initializing model weights on device "{}"'.format(device)) - low_res = np.ones(lr_shape).astype(np.float32) - hi_res = np.ones(hr_shape).astype(np.float32) - - hr_exo_shape = hr_shape[:-1] + (1,) - hr_exo = np.ones(hr_exo_shape).astype(np.float32) - - with tf.device(device): - hr_exo_data = {} - for feature in self.exogenous_features: - hr_exo_data[feature] = hr_exo - _ = self._tf_generate(low_res, hr_exo_data) - _ = self._tf_discriminate(hi_res) - - def set_model_params(self, **kwargs): - """Set parameters used for training the model - - Parameters - ---------- - kwargs : dict - Keyword arguments including 'input_resolution', - 'training_features', 'output_features', 'smoothed_features', - 's_enhance', 't_enhance', 'smoothing' - """ - AbstractExoInterface.set_model_params(self, **kwargs) - Sup3rGan.set_model_params(self, **kwargs) - - @tf.function - def calc_loss(self, hi_res_true, hi_res_gen, **kwargs): - """Calculate the GAN loss function using generated and true high - resolution data. - - Parameters - ---------- - hi_res_true : tf.Tensor - Ground truth high resolution spatiotemporal data. - hi_res_gen : tf.Tensor - Superresolved high resolution spatiotemporal data generated by the - generative model. - kwargs : dict - Key word arguments for: - Sup3rGan.calc_loss(hi_res_true, hi_res_gen, **kwargs) - - Returns - ------- - loss : tf.Tensor - 0D tensor representing the loss value for the network being trained - (either generator or one of the discriminators) - loss_details : dict - Namespace of the breakdown of loss components - """ - for feature in self.exogenous_features: - f_idx = self.training_features.index(feature) - exo_data = hi_res_true[..., f_idx: f_idx + 1] - hi_res_gen = tf.concat((hi_res_gen, exo_data), axis=-1) - - return super().calc_loss(hi_res_true, hi_res_gen, **kwargs) - - def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): - """Calculate the validation loss at the current state of model training - - Parameters - ---------- - batch_handler : sup3r.data_handling.preprocessing.BatchHandler - BatchHandler object to iterate through - weight_gen_advers : float - Weight factor for the adversarial loss component of the generator - vs. the discriminator. - loss_details : dict - Namespace of the breakdown of loss components - - Returns - ------- - loss_details : dict - Same as input but now includes val_* loss info - """ - logger.debug('Starting end-of-epoch validation loss calculation...') - loss_details['n_obs'] = 0 - for val_batch in batch_handler.val_data: - val_exo_data = {} - for feature in self.exogenous_features: - f_idx = self.training_features.index(feature) - exo_data = val_batch.high_res[..., f_idx: f_idx + 1] - val_exo_data[feature] = exo_data - high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data) - _, v_loss_details = self.calc_loss( - val_batch.high_res, high_res_gen, - weight_gen_advers=weight_gen_advers, - train_gen=False, train_disc=False) - - loss_details = self.update_loss_details(loss_details, - v_loss_details, - len(val_batch), - prefix='val_') - return loss_details diff --git a/sup3r/models/multi_step.py b/sup3r/models/multi_step.py index e9d235761..3cc7557be 100644 --- a/sup3r/models/multi_step.py +++ b/sup3r/models/multi_step.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- """Sup3r multi step model frameworks""" +import copy import json import logging import os @@ -148,6 +149,45 @@ def _get_model_step_exo(self, model_step, exogenous_data=None): model_step_exo[feature] = {'steps': steps} return model_step_exo + def _transpose_model_input(self, model, hi_res): + """Transpose input data according to mdel input dimensions. + + NOTE: If hi_res.shape == 4, it is assumed that the dimensions have the + ordering (n_obs, spatial_1, spatial_2, features) + + If hi_res.shape == 5, it is assumed that the dimensions have the + ordering (1, spatial_1, spatial_2, temporal, features) + + Parameters + ---------- + model : Sup3rGan + A single step model with the attribute model.input_dims + hi_res : ndarray + Synthetically generated high-resolution data, usually a 4D or 5D + array with shape: + (n_obs, spatial_1, spatial_2, n_features) + (n_obs, spatial_1, spatial_2, n_temporal, n_features) + + Returns + ------- + hi_res : ndarray + Synthetically generated high-resolution data transposed according + to the number of model input dimensions + """ + if model.input_dims == 5 and len(hi_res.shape) == 4: + hi_res = np.transpose( + hi_res, axes=(1, 2, 0, 3))[np.newaxis] + elif model.input_dims == 4 and len(hi_res.shape) == 5: + msg = ('Recieved 5D input data with shape ' + f'({hi_res.shape}) to a 4D model.') + assert hi_res.shape[0] == 1, msg + hi_res = np.transpose(hi_res[0], axes=(2, 0, 1, 3)) + else: + msg = ('Recieved input data with shape ' + f'{hi_res.shape} to a {model.input_dims}D model.') + assert model.input_dims == len(hi_res.shape), msg + return hi_res + def generate(self, low_res, norm_in=True, un_norm_out=True, exogenous_data=None): """Use the generator model to generate high res data from low res @@ -199,6 +239,7 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, i_exo_data = self._get_model_step_exo(i, exogenous_data) try: + hi_res = self._transpose_model_input(model, hi_res) logger.debug('Data input to model #{} of {} has shape {}' .format(i + 1, len(self.models), hi_res.shape)) hi_res = model.generate(hi_res, norm_in=i_norm_in, @@ -422,20 +463,20 @@ def _split_exo_spatial_temporal(self, exogenous_data=None): Same as input dictionary but with only entries with 'model': model_step where model_step corresponds to a temporal model step """ - spatial_exo = None - temporal_exo = None + spatial_exo = {} + temporal_exo = {} if exogenous_data is not None: - spatial_exo = {} - for feature in exogenous_data: - steps = [step for step in exogenous_data[feature]['steps'] + exo_data = copy.deepcopy(exogenous_data) + for feature in exo_data: + steps = [step for step in exo_data[feature]['steps'] if step['model'] < len(self.spatial_models)] if steps: spatial_exo[feature] = {'steps': steps} - steps = [step for step in exogenous_data[feature]['steps'] + steps = [step for step in exo_data[feature]['steps'] if step['model'] >= len(self.spatial_models)] t_shift = len(self.spatial_models) - steps = [step.update({'model': step['model'] - t_shift}) - for step in steps] + for step in steps: + step.update({'model': step['model'] - t_shift}) if steps: temporal_exo[feature] = {'steps': steps} return spatial_exo, temporal_exo @@ -592,20 +633,20 @@ def _split_exo_temporal_spatial(self, exogenous_data=None): Same as input dictionary but with only entries with 'model': model_step where model_step corresponds to a spatial model step """ - spatial_exo = None - temporal_exo = None + spatial_exo = {} + temporal_exo = {} if exogenous_data is not None: - temporal_exo = {} - for feature in exogenous_data: - steps = [step for step in exogenous_data[feature]['steps'] + exo_data = copy.deepcopy(exogenous_data) + for feature in exo_data: + steps = [step for step in exo_data[feature]['steps'] if step['model'] < len(self.temporal_models)] if steps: temporal_exo[feature] = {'steps': steps} - steps = [step for step in exogenous_data[feature]['steps'] + steps = [step for step in exo_data[feature]['steps'] if step['model'] >= len(self.temporal_models)] s_shift = len(self.temporal_models) - steps = [step.update({'model': step['model'] - s_shift}) - for step in steps] + for step in steps: + step.update({'model': step['model'] - s_shift}) if steps: spatial_exo[feature] = {'steps': steps} return temporal_exo, spatial_exo @@ -723,12 +764,15 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, un_norm_out : bool Flag to un-normalize synthetically generated output data to physical units - exogenous_data : list - For the MultiStepSurfaceMetGan model, this must be a 2-entry list - where the first entry is a 2D (lat, lon) array of low-resolution - surface elevation data in meters (must match spatial_1, spatial_2 - from low_res), and the second entry is a 2D (lat, lon) array of - high-resolution surface elevation data in meters. + exogenous_data : dict + For the MultiStepSurfaceMetGan, this must be a nested dictionary + with a main 'topography' key and two entries for + exogenous_data['topography']['steps']. The first entry includes a + 2D (lat, lon) array of low-resolution surface elevation data in + meters (must match spatial_1, spatial_2 from low_res), and the + second entry includes a 2D (lat, lon) array of high-resolution + surface elevation data in meters. e.g. + {'topography': {'steps': [{'data': lr_topo}, {'data': hr_topo'}]}} Returns ------- @@ -745,18 +789,10 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, msg = ('MultiStepSurfaceMetGan needs exogenous_data with two ' 'entries for low and high res topography inputs.') assert exogenous_data is not None, msg - assert isinstance(exogenous_data, (list, tuple)), msg - exogenous_data = [d for d in exogenous_data if d is not None] - assert len(exogenous_data) == 2, msg - - # SurfaceSpatialMetModel needs a 2D array for exo topography input - for i, i_exo in enumerate(exogenous_data): - if len(i_exo.shape) == 3: - exogenous_data[i] = i_exo[:, :, 0] - elif len(i_exo.shape) == 4: - exogenous_data[i] = i_exo[0, :, :, 0] - elif len(i_exo.shape) == 5: - exogenous_data[i] = i_exo[0, :, :, 0, 0] + exo_data = [step['data'] + for step in exogenous_data['topography']['steps']] + assert isinstance(exo_data, (list, tuple)), msg + assert len(exo_data) == 2, msg try: hi_res = self.spatial_models.generate( @@ -1062,15 +1098,12 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, logger.debug('Data input to the SolarMultiStepGan has shape {} which ' 'will be split up for solar- and wind-only features.' .format(low_res.shape)) - t_exogenous = None - if exogenous_data is not None: - t_exogenous = exogenous_data[len(self.spatial_wind_models):] - + s_exo, t_exo = self._split_exo_spatial_temporal(exogenous_data) try: hi_res_wind = self.spatial_wind_models.generate( low_res[..., self.idf_wind], norm_in=norm_in, un_norm_out=True, - exogenous_data=exogenous_data) + exogenous_data=s_exo) except Exception as e: msg = ('Could not run the 1st step spatial-wind-only GAN on ' 'input shape {}'.format(low_res.shape)) @@ -1106,7 +1139,7 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, try: hi_res = self.temporal_solar_models.generate( hi_res, norm_in=True, un_norm_out=un_norm_out, - exogenous_data=t_exogenous) + exogenous_data=t_exo) except Exception as e: msg = ('Could not run the 2nd step (spatio)temporal solar GAN on ' 'input shape {}'.format(low_res.shape)) diff --git a/sup3r/models/surface.py b/sup3r/models/surface.py index 0d9881b0c..a11ce0ea9 100644 --- a/sup3r/models/surface.py +++ b/sup3r/models/surface.py @@ -2,10 +2,11 @@ """Special models for surface meteorological data.""" import logging from fnmatch import fnmatch +from warnings import warn + import numpy as np from PIL import Image from sklearn import linear_model -from warnings import warn from sup3r.models.linear import LinearInterp from sup3r.utilities.utilities import spatial_coarsening @@ -470,12 +471,15 @@ def generate(self, low_res, norm_in=False, un_norm_out=False, un_norm_out : bool This doesnt do anything for this SurfaceSpatialMetModel, but is kept to keep the same interface as Sup3rGan - exogenous_data : list - For the SurfaceSpatialMetModel, this must be a 2-entry list where - the first entry is a 2D (lat, lon) array of low-resolution surface - elevation data in meters (must match spatial_1, spatial_2 from - low_res), and the second entry is a 2D (lat, lon) array of - high-resolution surface elevation data in meters. + exogenous_data : dict + For the SurfaceSpatialMetModel, this must be a nested dictionary + with a main 'topography' key and two entries for + exogenous_data['topography']['steps']. The first entry includes a + 2D (lat, lon) array of low-resolution surface elevation data in + meters (must match spatial_1, spatial_2 from low_res), and the + second entry includes a 2D (lat, lon) array of high-resolution + surface elevation data in meters. e.g. + {'topography': {'steps': [{'data': lr_topo}, {'data': hr_topo'}]}} Returns ------- @@ -485,16 +489,17 @@ def generate(self, low_res, norm_in=False, un_norm_out=False, channel can include temperature_*m, relativehumidity_*m, and/or pressure_*m """ - + exo_data = [step['data'] + for step in exogenous_data['topography']['steps']] msg = ('exogenous_data is of a bad type {}!' - .format(type(exogenous_data))) - assert isinstance(exogenous_data, (list, tuple)), msg + .format(type(exo_data))) + assert isinstance(exo_data, (list, tuple)), msg msg = ('exogenous_data is of a bad length {}!' - .format(len(exogenous_data))) - assert len(exogenous_data) == 2, msg + .format(len(exo_data))) + assert len(exo_data) == 2, msg - topo_lr = exogenous_data[0] - topo_hr = exogenous_data[1] + topo_lr = exo_data[0] + topo_hr = exo_data[1] logger.debug('SurfaceSpatialMetModel received low/high res topo ' 'shapes of {} and {}' .format(topo_lr.shape, topo_hr.shape)) diff --git a/sup3r/models/wind_conditional_moments.py b/sup3r/models/wind_conditional_moments.py deleted file mode 100644 index 0cbb99b75..000000000 --- a/sup3r/models/wind_conditional_moments.py +++ /dev/null @@ -1,98 +0,0 @@ -# -*- coding: utf-8 -*- -"""Wind conditional moment estimator with handling of low and -high res topography inputs.""" -import logging - -import tensorflow as tf - -from sup3r.models.abstract import AbstractExoInterface -from sup3r.models.conditional_moments import Sup3rCondMom - -logger = logging.getLogger(__name__) - - -class WindCondMom(AbstractExoInterface, Sup3rCondMom): - """Wind conditional moment estimator with handling of low and - high res topography inputs. - - Modifications to standard Sup3rCondMom: - - Hi res topography is expected as the last feature channel in the true - data in the true batch observation. - - If a custom Sup3rAdder or Sup3rConcat layer (from phygnn) is present - in the network, the hi-res topography will be added or concatenated - to the data at that point in the network during either training or - the forward pass. - """ - - def set_model_params(self, **kwargs): - """Set parameters used for training the model - - Parameters - ---------- - kwargs : dict - Keyword arguments including 'training_features', 'output_features', - 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' - """ - AbstractExoInterface.set_model_params(self, **kwargs) - Sup3rCondMom.set_model_params(self, **kwargs) - - @tf.function - def calc_loss(self, hi_res_true, hi_res_gen, mask, **kwargs): - """Calculate the loss function using generated and true high - resolution data. - - Parameters - ---------- - hi_res_true : tf.Tensor - Ground truth high resolution spatiotemporal data. - hi_res_gen : tf.Tensor - Superresolved high resolution spatiotemporal data generated by the - generative model. - mask : tf.Tensor - Mask to apply - kwargs : dict - Key word arguments for: - Sup3rGan.calc_loss(hi_res_true, hi_res_gen, **kwargs) - - Returns - ------- - loss : tf.Tensor - 0D tensor representing the loss value for the network being trained - (either generator or one of the discriminators) - loss_details : dict - Namespace of the breakdown of loss components - """ - - # append the true topography to the generated synthetic wind data - hi_res_gen = tf.concat((hi_res_gen, hi_res_true[..., -1:]), axis=-1) - - return super().calc_loss(hi_res_true, hi_res_gen, mask, **kwargs) - - def calc_val_loss(self, batch_handler, loss_details): - """Calculate the validation loss at the current state of model training - - Parameters - ---------- - batch_handler : sup3r.data_handling.preprocessing.BatchHandler - BatchHandler object to iterate through - loss_details : dict - Namespace of the breakdown of loss components - - Returns - ------- - loss_details : dict - Same as input but now includes val_* loss info - """ - logger.debug('Starting end-of-epoch validation loss calculation...') - loss_details['n_obs'] = 0 - for val_batch in batch_handler.val_data: - high_res_gen = self._tf_generate(val_batch.low_res, - val_batch.high_res[..., -1:]) - _, v_loss_details = self.calc_loss( - val_batch.output, high_res_gen, val_batch.mask) - - loss_details = self.update_loss_details(loss_details, - v_loss_details, - len(val_batch), - prefix='val_') - return loss_details diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 84d1fdf35..83d2583ab 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -1117,20 +1117,20 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.unpadded_input_data = self.data_handler.data[self.lr_slice[0], self.lr_slice[1]] - def get_agg_factor(self, input_res, exo_res): - """Compute agg factor for exo data given input and output resolution + def _get_res_ratio(self, input_res, exo_res): + """Compute resolution ratio given input and output resolution Parameters ---------- input_res : str | None - Input resolution. e.g. 30km or 60min + Input resolution. e.g. '30km' or '60min' exo_res : str | None - Exogenous data resolution. e.g. 1km or 5min + Exo resolution. e.g. '1km' or '5min' Returns ------- - agg_factor : int - Aggregation factor for exogenous data extraction. + res_ratio : int | None + Ratio of input / exo resolution """ ires_num = (None if input_res is None else int(re.search(r'\d+', input_res).group(0))) @@ -1144,10 +1144,38 @@ def get_agg_factor(self, input_res, exo_res): if e_units is not None: assert i_units == e_units, msg if ires_num is not None and eres_num is not None: - agg_factor = int((ires_num / eres_num) ** 2) + res_ratio = int(ires_num / eres_num) else: - agg_factor = None - return agg_factor + res_ratio = None + return res_ratio + + def get_agg_factors(self, input_res, exo_res): + """Compute aggregation ratio for exo data given input and output + resolution + + Parameters + ---------- + input_res : dict | None + Input resolution. e.g. {'spatial': '30km', 'temporal': '60min'} + exo_res : dict | None + Exogenous data resolution. e.g. + {'spatial': '1km', 'temporal': '5min'} + + Returns + ------- + s_agg_factor : int + Spatial aggregation factor for exogenous data extraction. + t_agg_factor : int + Temporal aggregation factor for exogenous data extraction. + """ + input_s_res = None if input_res is None else input_res['spatial'] + exo_s_res = None if exo_res is None else exo_res['spatial'] + s_res_ratio = self._get_res_ratio(input_s_res, exo_s_res) + s_agg_factor = None if s_res_ratio is None else int(s_res_ratio)**2 + input_t_res = None if input_res is None else input_res['temporal'] + exo_t_res = None if exo_res is None else exo_res['temporal'] + t_agg_factor = self._get_res_ratio(input_t_res, exo_t_res) + return s_agg_factor, t_agg_factor def _prep_exo_extract_single_step(self, step_dict, exo_resolution): """Compute agg and enhancement factors for exogenous data extraction @@ -1169,8 +1197,6 @@ def _prep_exo_extract_single_step(self, step_dict, exo_resolution): s_enhance, t_enhance added """ model_step = step_dict['model'] - exo_res_t = exo_resolution['temporal'] - exo_res_s = exo_resolution['spatial'] s_enhance = None t_enhance = None s_agg_factor = None @@ -1180,10 +1206,8 @@ def _prep_exo_extract_single_step(self, step_dict, exo_resolution): f'of model steps ({len(models)})') assert len(models) > model_step, msg model = models[model_step] - input_res_t = model.input_resolution['temporal'] - input_res_s = model.input_resolution['spatial'] - output_res_t = model.output_resolution['temporal'] - output_res_s = model.output_resolution['spatial'] + input_res = model.input_resolution + output_res = model.output_resolution combine_type = step_dict.get('combine_type', None) if combine_type.lower() == 'input': @@ -1191,21 +1215,26 @@ def _prep_exo_extract_single_step(self, step_dict, exo_resolution): s_enhance = 1 t_enhance = 1 else: - s_enhance = self.strategy.s_enhancements[model_step - 1] - t_enhance = self.strategy.t_enhancements[model_step - 1] - s_agg_factor = self.get_agg_factor(input_res_s, exo_res_s) - t_agg_factor = self.get_agg_factor(input_res_t, exo_res_t) - resolution = {'spatial': input_res_s, 'temporal': input_res_t} + s_enhance = np.product( + self.strategy.s_enhancements[:model_step]) + t_enhance = np.product( + self.strategy.t_enhancements[:model_step]) + s_agg_factor, t_agg_factor = self.get_agg_factors( + input_res, exo_resolution) + resolution = input_res elif combine_type.lower() in ('output', 'layer'): - s_enhance = self.strategy.s_enhancements[model_step] - t_enhance = self.strategy.t_enhancements[model_step] - s_agg_factor = self.get_agg_factor(output_res_s, exo_res_s) - t_agg_factor = self.get_agg_factor(output_res_t, exo_res_t) - resolution = {'spatial': output_res_s, 'temporal': output_res_t} + s_enhance = np.product( + self.strategy.s_enhancements[:model_step + 1]) + t_enhance = np.product( + self.strategy.t_enhancements[:model_step + 1]) + s_agg_factor, t_agg_factor = self.get_agg_factors( + output_res, exo_resolution) + resolution = output_res else: - msg = 'Received exo_kwargs entry without valid combine_type' + msg = ('Received exo_kwargs entry without valid combine_type ' + '(input/layer/output)') raise OSError(msg) updated_dict = step_dict.copy() @@ -1233,7 +1262,7 @@ def _prep_exo_extract_kwargs(self, exo_kwargs): Same as input dictionary with s_agg_factor, t_agg_factor, s_enhance, t_enhance added to each step entry for all features """ - if exo_kwargs: + if exo_kwargs is not None: for feature in exo_kwargs: exo_resolution = exo_kwargs[feature]['exo_resolution'] for i, step in enumerate(exo_kwargs[feature]['steps']): diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 5d1442b61..97cf843b2 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -362,7 +362,9 @@ def source_data(self): """Get the 1D array of elevation data from the exo_source_h5""" with Resource(self._exo_source) as res: elev = res.get_meta_arr('elevation') - elev = np.repeat(elev[:, np.newaxis], self.hr_shape[-1], axis=-1) + elev = np.repeat(elev[:, np.newaxis], + self.hr_shape[-1] * self._t_agg_factor, + axis=-1) return elev @property diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index 3111ffff6..9b08b5aaf 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -76,8 +76,8 @@ def __init__(self, at 100km (s_enhance=1, exo_step=0), the input to the 5x model gets exogenous data at 25km (s_enhance=4, exo_step=1), and there is a 20x (1*4*5) exogeneous data layer available if the second model can - receive a high-res input feature (e.g. MultiExoGan). The length of - this list should be equal to the number of s_agg_factors + receive a high-res input feature. The length of this list should be + equal to the number of s_agg_factors t_enhancements : list List of factors by which the Sup3rGan model will enhance the temporal dimension of low resolution data from file_paths input diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 8c4d5db25..31ce6bdf2 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -6,13 +6,12 @@ import matplotlib.pyplot as plt import numpy as np -import pytest import tensorflow as tf import xarray as xr from rex import ResourceX, init_logger from sup3r import CONFIG_DIR, TEST_DATA_DIR, __version__ -from sup3r.models import LinearInterp, MultiExoGan, Sup3rGan +from sup3r.models import Sup3rGan from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.preprocessing.data_handling import DataHandlerNC from sup3r.utilities.pytest import ( @@ -564,463 +563,6 @@ def test_fwp_nochunking(): assert np.array_equal(data_chunked, data_nochunk) -def test_fwp_multi_step_model_topo_exoskip(log=False): - """Test the forward pass with a multi step model class using exogenous data - for the first two steps and not the last""" - - if log: - init_logger('sup3r', log_level='DEBUG') - - Sup3rGan.seed() - fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] - s1_model.meta['s_enhance'] = 2 - s1_model.meta['t_enhance'] = 1 - s1_model.meta['input_resolution'] = {'spatial': '48km', - 'temporal': '60min'} - _ = s1_model.generate(np.ones((4, 10, 10, 3))) - - s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] - s2_model.meta['s_enhance'] = 2 - s2_model.meta['t_enhance'] = 1 - s2_model.meta['input_resolution'] = {'spatial': '24km', - 'temporal': '60min'} - _ = s2_model.generate(np.ones((4, 10, 10, 3))) - - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] - st_model.meta['s_enhance'] = 3 - st_model.meta['t_enhance'] = 4 - st_model.meta['input_resolution'] = {'spatial': '12km', - 'temporal': '60min'} - _ = st_model.generate(np.ones((4, 10, 10, 6, 2))) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - st_out_dir = os.path.join(td, 'st_gan') - s1_out_dir = os.path.join(td, 's1_gan') - s2_out_dir = os.path.join(td, 's2_gan') - st_model.save(st_out_dir) - s1_model.save(s1_out_dir) - s2_model.save(s2_out_dir) - - max_workers = 1 - fwp_chunk_shape = (4, 4, 8) - s_enhance = 12 - t_enhance = 4 - - exo_kwargs = { - 'topography': { - 'file_paths': input_files, - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 'cache_dir': td, - 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, - 'steps': [ - {'model': 0, 'combine_type': 'input'}, - {'model': 1, 'combine_type': 'input'} - ] - } - } - - model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir - } - - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict( - target=target, - shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=max_workers), - overwrite_cache=True) - handler = ForwardPassStrategy( - input_files, - model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=fwp_chunk_shape, - input_handler_kwargs=input_handler_kwargs, - spatial_pad=0, - temporal_pad=0, - out_pattern=out_files, - worker_kwargs=dict(max_workers=max_workers), - exo_kwargs=exo_kwargs, - max_nodes=1) - - forward_pass = ForwardPass(handler) - - assert forward_pass.output_workers == max_workers - assert forward_pass.pass_workers == max_workers - assert forward_pass.max_workers == max_workers - assert forward_pass.data_handler.max_workers == max_workers - assert forward_pass.data_handler.compute_workers == max_workers - assert forward_pass.data_handler.load_workers == max_workers - assert forward_pass.data_handler.norm_workers == max_workers - assert forward_pass.data_handler.extract_workers == max_workers - - forward_pass.run(handler, node_index=0) - - with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == (t_enhance * len(input_files), s_enhance**2 - * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs - for f in ('windspeed_100m', 'winddirection_100m')) - - assert fh.global_attrs['package'] == 'sup3r' - assert fh.global_attrs['version'] == __version__ - assert 'full_version_record' in fh.global_attrs - version_record = json.loads(fh.global_attrs['full_version_record']) - assert version_record['tensorflow'] == tf.__version__ - assert 'gan_meta' in fh.global_attrs - gan_meta = json.loads(fh.global_attrs['gan_meta']) - assert len(gan_meta) == 3 # three step model - assert gan_meta[0]['training_features'] == [ - 'U_100m', 'V_100m', 'topography' - ] - - -def test_fwp_multi_step_spatial_model_topo_noskip(): - """Test the forward pass with a multi step spatial only model class using - exogenous data for all model steps""" - Sup3rGan.seed() - fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = MultiExoGan(fp_gen, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] - s1_model.meta['s_enhance'] = 2 - s1_model.meta['t_enhance'] = 1 - _ = s1_model.generate(np.ones((4, 10, 10, 3))) - - s2_model = MultiExoGan(fp_gen, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] - s2_model.meta['s_enhance'] = 2 - s2_model.meta['t_enhance'] = 1 - _ = s2_model.generate(np.ones((4, 10, 10, 3))) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - s1_out_dir = os.path.join(td, 's1_gan') - s2_out_dir = os.path.join(td, 's2_gan') - s1_model.save(s1_out_dir) - s2_model.save(s2_out_dir) - - max_workers = 1 - fwp_chunk_shape = (4, 4, 8) - s_enhancements = [2, 2, 1] - s_enhance = np.product(s_enhancements) - - exo_kwargs = { - 'topography': { - 'file_paths': input_files, - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2], - 't_enhancements': [1, 1, 1], - 's_agg_factors': [12, 4, 2], - 't_agg_factors': [1, 1, 1] - } - } - - model_kwargs = {'model_dirs': [s1_out_dir, s2_out_dir]} - - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict( - target=target, - shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=max_workers), - overwrite_cache=True) - handler = ForwardPassStrategy( - input_files, - model_kwargs=model_kwargs, - model_class='MultiStepGan', - fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, - temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=max_workers), - exo_kwargs=exo_kwargs, - max_nodes=1) - - forward_pass = ForwardPass(handler) - forward_pass.run(handler, node_index=0) - - with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == (len(input_files), s_enhance**2 - * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs - for f in ('windspeed_100m', 'winddirection_100m')) - - assert fh.global_attrs['package'] == 'sup3r' - assert fh.global_attrs['version'] == __version__ - assert 'full_version_record' in fh.global_attrs - version_record = json.loads(fh.global_attrs['full_version_record']) - assert version_record['tensorflow'] == tf.__version__ - assert 'gan_meta' in fh.global_attrs - gan_meta = json.loads(fh.global_attrs['gan_meta']) - assert len(gan_meta) == 2 # two step model - assert gan_meta[0]['training_features'] == [ - 'U_100m', 'V_100m', 'topography' - ] - - -def test_fwp_multi_step_model_multi_exo(): - """Test the forward pass with a multi step model class using 2 exogenous - data features""" - Sup3rGan.seed() - fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = [ - 'U_100m', 'V_100m', 'topography', 'sza' - ] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] - s1_model.meta['s_enhance'] = 2 - s1_model.meta['t_enhance'] = 1 - _ = s1_model.generate(np.ones((4, 10, 10, 3))) - - s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = [ - 'U_100m', 'V_100m', 'topography', 'sza' - ] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] - s2_model.meta['s_enhance'] = 2 - s2_model.meta['t_enhance'] = 1 - _ = s2_model.generate(np.ones((4, 10, 10, 3))) - - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = [ - 'U_100m', 'V_100m', 'topography', 'sza' - ] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] - st_model.meta['s_enhance'] = 3 - st_model.meta['t_enhance'] = 4 - _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - st_out_dir = os.path.join(td, 'st_gan') - s1_out_dir = os.path.join(td, 's1_gan') - s2_out_dir = os.path.join(td, 's2_gan') - st_model.save(st_out_dir) - s1_model.save(s1_out_dir) - s2_model.save(s2_out_dir) - - max_workers = 1 - fwp_chunk_shape = (4, 4, 8) - s_enhancements = [2, 2, 3] - s_enhance = np.product(s_enhancements) - t_enhance = 4 - - exo_kwargs = { - 'topography': { - 'file_paths': input_files, - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2, 3], - 't_enhancements': [1, 1, 1, 4], - 's_agg_factors': [12, 4, 2, 1], - 't_agg_factors': [4, 4, 4, 1] - }, - 'sza': { - 'file_paths': input_files, - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2, 1], - 't_enhancements': [1, 1, 1, 4], - 's_agg_factors': [12, 4, 2, 1], - 't_agg_factors': [4, 4, 4, 1] - } - } - - model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir - } - - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict( - target=target, - shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=max_workers), - overwrite_cache=True) - handler = ForwardPassStrategy( - input_files, - model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, - temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=max_workers), - exo_kwargs=exo_kwargs, - max_nodes=1) - - forward_pass = ForwardPass(handler) - - assert forward_pass.output_workers == max_workers - assert forward_pass.data_handler.compute_workers == max_workers - assert forward_pass.data_handler.load_workers == max_workers - assert forward_pass.data_handler.norm_workers == max_workers - assert forward_pass.data_handler.extract_workers == max_workers - - forward_pass.run(handler, node_index=0) - - with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == (t_enhance * len(input_files), s_enhance**2 - * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs - for f in ('windspeed_100m', 'winddirection_100m')) - - assert fh.global_attrs['package'] == 'sup3r' - assert fh.global_attrs['version'] == __version__ - assert 'full_version_record' in fh.global_attrs - version_record = json.loads(fh.global_attrs['full_version_record']) - assert version_record['tensorflow'] == tf.__version__ - assert 'gan_meta' in fh.global_attrs - gan_meta = json.loads(fh.global_attrs['gan_meta']) - assert len(gan_meta) == 3 # three step model - assert gan_meta[0]['training_features'] == [ - 'U_100m', 'V_100m', 'topography' - ] - - -def test_fwp_multi_step_model_topo_noskip(): - """Test the forward pass with a multi step model class using exogenous data - for all model steps""" - Sup3rGan.seed() - fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] - s1_model.meta['s_enhance'] = 2 - s1_model.meta['t_enhance'] = 1 - _ = s1_model.generate(np.ones((4, 10, 10, 3))) - - s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] - s2_model.meta['s_enhance'] = 2 - s2_model.meta['t_enhance'] = 1 - _ = s2_model.generate(np.ones((4, 10, 10, 3))) - - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] - st_model.meta['s_enhance'] = 3 - st_model.meta['t_enhance'] = 4 - _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - st_out_dir = os.path.join(td, 'st_gan') - s1_out_dir = os.path.join(td, 's1_gan') - s2_out_dir = os.path.join(td, 's2_gan') - st_model.save(st_out_dir) - s1_model.save(s1_out_dir) - s2_model.save(s2_out_dir) - - max_workers = 1 - fwp_chunk_shape = (4, 4, 8) - s_enhancements = [2, 2, 3] - s_enhance = np.product(s_enhancements) - t_enhance = 4 - - exo_kwargs = { - 'topography': { - 'file_paths': input_files, - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2, 3], - 't_enhancements': [1, 1, 1, 4], - 's_agg_factors': [12, 4, 2, 1], - 't_agg_factors': [4, 4, 4, 1], - 'cache_data': False - } - } - - model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir - } - - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict( - target=target, - shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=max_workers), - overwrite_cache=True) - handler = ForwardPassStrategy( - input_files, - model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=fwp_chunk_shape, - spatial_pad=1, - temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=max_workers), - exo_kwargs=exo_kwargs, - max_nodes=1) - - forward_pass = ForwardPass(handler) - - assert forward_pass.output_workers == max_workers - assert forward_pass.data_handler.compute_workers == max_workers - assert forward_pass.data_handler.load_workers == max_workers - assert forward_pass.data_handler.norm_workers == max_workers - assert forward_pass.data_handler.extract_workers == max_workers - - forward_pass.run(handler, node_index=0) - - with ResourceX(handler.out_files[0]) as fh: - assert fh.shape == (t_enhance * len(input_files), s_enhance**2 - * fwp_chunk_shape[0] * fwp_chunk_shape[1]) - assert all(f in fh.attrs - for f in ('windspeed_100m', 'winddirection_100m')) - - assert fh.global_attrs['package'] == 'sup3r' - assert fh.global_attrs['version'] == __version__ - assert 'full_version_record' in fh.global_attrs - version_record = json.loads(fh.global_attrs['full_version_record']) - assert version_record['tensorflow'] == tf.__version__ - assert 'gan_meta' in fh.global_attrs - gan_meta = json.loads(fh.global_attrs['gan_meta']) - assert len(gan_meta) == 3 # three step model - assert gan_meta[0]['training_features'] == [ - 'U_100m', 'V_100m', 'topography' - ] - - def test_fwp_multi_step_model(): """Test the forward pass with a multi step model class""" Sup3rGan.seed() @@ -1277,620 +819,3 @@ def test_slicing_pad(log=False): assert forward_pass.input_data.shape == padded_truth.shape assert np.allclose(forward_pass.input_data, padded_truth) - - -def test_fwp_single_step_wind_hi_res_topo(plot=False): - """Test the forward pass with a single spatiotemporal MultiExoGan model - requiring high-resolution topography input from the exogenous_data - feature.""" - Sup3rGan.seed() - gen_model = [{ - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv3D", - "filters": 64, - "kernel_size": 3, - "strides": 1 - }, { - "class": "Cropping3D", - "cropping": 2 - }, { - "class": "SpatioTemporalExpansion", - "temporal_mult": 2, - "temporal_method": "nearest" - }, { - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv3D", - "filters": 64, - "kernel_size": 3, - "strides": 1 - }, { - "class": "Cropping3D", - "cropping": 2 - }, { - "class": "SpatioTemporalExpansion", - "spatial_mult": 2 - }, { - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv3D", - "filters": 64, - "kernel_size": 3, - "strides": 1 - }, { - "class": "Cropping3D", - "cropping": 2 - }, { - "alpha": 0.2, - "class": "LeakyReLU" - }, { - "class": "Sup3rConcat", - "name": "topography" - }, { - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv3D", - "filters": 2, - "kernel_size": 3, - "strides": 1 - }, { - "class": "Cropping3D", - "cropping": 2 - }] - - fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) - model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - model.meta['output_features'] = ['U_100m', 'V_100m'] - model.meta['s_enhance'] = 2 - model.meta['t_enhance'] = 2 - model.meta['input_resolution'] = {'spatial': '8km', - 'temporal': '60min'} - exo_tmp = {'topography': { - 'steps': [{'model': 0, 'combine_type': 'layer', - 'data': np.random.rand(4, 20, 20, 12, 1)}]}} - _ = model.generate(np.random.rand(4, 10, 10, 6, 3), exogenous_data=exo_tmp) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - st_out_dir = os.path.join(td, 'st_gan') - model.save(st_out_dir) - - exo_kwargs = { - 'topography': { - 'file_paths': input_files, - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 'cache_dir': td, - 'exo_resolution': {'spatial': '4km', 'temporal': None}, - 'steps': [ - {'model': 0, 'combine_type': 'input'}, - {'model': 0, 'combine_type': 'layer'} - ]}} - - model_kwargs = {'model_dir': st_out_dir} - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict(target=target, - shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) - - handler = ForwardPassStrategy( - input_files, - model_kwargs=model_kwargs, - model_class='MultiExoGan', - fwp_chunk_shape=(8, 8, 8), - spatial_pad=4, - temporal_pad=4, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=1), - exo_kwargs=exo_kwargs, - max_nodes=1) - forward_pass = ForwardPass(handler) - - if plot: - for ifeature, feature in enumerate(forward_pass.output_features): - fig = plt.figure(figsize=(15, 5)) - ax1 = fig.add_subplot(111) - vmin = np.min(forward_pass.input_data[..., ifeature]) - vmax = np.max(forward_pass.input_data[..., ifeature]) - nc = ax1.imshow(forward_pass.input_data[..., 0, ifeature], - vmin=vmin, - vmax=vmax) - fig.colorbar(nc, ax=ax1, shrink=0.6, label=f'{feature}') - plt.savefig(f'./input_{feature}.png') - plt.close() - - forward_pass.run(handler, node_index=0) - - for fp in handler.out_files: - assert os.path.exists(fp) - - -def test_fwp_multi_step_exo_hi_res_topo_and_sza(): - """Test the forward pass with multiple ExoGan models requiring - high-resolution topography and sza input from the exogenous_data - feature.""" - Sup3rGan.seed() - gen_model = [{ - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, { - "class": "Conv2DTranspose", - "filters": 64, - "kernel_size": 3, - "strides": 1, - "activation": "relu" - }, { - "class": "Cropping2D", - "cropping": 4 - }, { - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv2DTranspose", - "filters": 64, - "kernel_size": 3, - "strides": 1, - "activation": "relu" - }, { - "class": "Cropping2D", - "cropping": 4 - }, { - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv2DTranspose", - "filters": 64, - "kernel_size": 3, - "strides": 1, - "activation": "relu" - }, { - "class": "Cropping2D", - "cropping": 4 - }, { - "class": "SpatialExpansion", - "spatial_mult": 2 - }, { - "class": "Activation", - "activation": "relu" - }, { - "class": "Sup3rConcat", - "name": "topography" - }, { - "class": "Sup3rConcat", - "name": "sza" - }, { - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv2DTranspose", - "filters": 2, - "kernel_size": 3, - "strides": 1, - "activation": "relu" - }, { - "class": "Cropping2D", - "cropping": 4 - }] - - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = [ - 'U_100m', 'V_100m', 'topography', 'sza' - ] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] - s1_model.meta['s_enhance'] = 2 - s1_model.meta['t_enhance'] = 1 - _ = s1_model.generate(np.ones((4, 10, 10, 3)), - exogenous_data=(None, np.ones((4, 20, 20, 1)))) - - s2_model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = [ - 'U_100m', 'V_100m', 'topography', 'sza' - ] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] - s2_model.meta['s_enhance'] = 2 - s2_model.meta['t_enhance'] = 1 - _ = s2_model.generate(np.ones((4, 10, 10, 3)), - exogenous_data=(None, np.ones((4, 20, 20, 1)))) - - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = [ - 'U_100m', 'V_100m', 'topography', 'sza' - ] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] - st_model.meta['s_enhance'] = 3 - st_model.meta['t_enhance'] = 4 - _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - st_out_dir = os.path.join(td, 'st_gan') - s1_out_dir = os.path.join(td, 's1_gan') - s2_out_dir = os.path.join(td, 's2_gan') - st_model.save(st_out_dir) - s1_model.save(s1_out_dir) - s2_model.save(s2_out_dir) - - exo_kwargs = { - 'topography': { - 'file_paths': input_files, - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2, 3], - 't_enhancements': [1, 1, 1, 4], - 's_agg_factors': [12, 4, 2, 1], - 't_agg_factors': [4, 4, 4, 1] - }, - 'sza': { - 'file_paths': input_files, - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2, 3], - 't_enhancements': [1, 1, 1, 4], - 's_agg_factors': [12, 4, 2, 1], - 't_agg_factors': [4, 4, 4, 1] - } - } - - model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir - } - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict(target=target, - shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) - - # should get an error on a bad tensorflow concatenation - with pytest.raises(RuntimeError): - exo_kwargs['s_enhancements'] = [1, 1, 1] - handler = ForwardPassStrategy( - input_files, - model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, - temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=1), - exo_kwargs=exo_kwargs, - max_nodes=1) - forward_pass = ForwardPass(handler) - forward_pass.run(handler, node_index=0) - - exo_kwargs['s_enhancements'] = [1, 2, 2] - handler = ForwardPassStrategy( - input_files, - model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, - temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=1), - exo_kwargs=exo_kwargs, - max_nodes=1) - forward_pass = ForwardPass(handler) - forward_pass.run(handler, node_index=0) - - for fp in handler.out_files: - assert os.path.exists(fp) - - -def test_fwp_multi_step_wind_hi_res_topo(): - """Test the forward pass with multiple MultiExoGan models requiring - high-resolution topograph input from the exogenous_data feature.""" - Sup3rGan.seed() - gen_model = [{ - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv2DTranspose", - "filters": 64, - "kernel_size": 3, - "strides": 1, - "activation": "relu" - }, { - "class": "Cropping2D", - "cropping": 4 - }, { - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv2DTranspose", - "filters": 64, - "kernel_size": 3, - "strides": 1, - "activation": "relu" - }, { - "class": "Cropping2D", - "cropping": 4 - }, { - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv2DTranspose", - "filters": 64, - "kernel_size": 3, - "strides": 1, - "activation": "relu" - }, { - "class": "Cropping2D", - "cropping": 4 - }, { - "class": "SpatialExpansion", - "spatial_mult": 2 - }, { - "class": "Activation", - "activation": "relu" - }, { - "class": "Sup3rConcat", - "name": "topography" - }, { - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv2DTranspose", - "filters": 2, - "kernel_size": 3, - "strides": 1, - "activation": "relu" - }, { - "class": "Cropping2D", - "cropping": 4 - }] - - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) - s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s1_model.meta['output_features'] = ['U_100m', 'V_100m'] - s1_model.meta['s_enhance'] = 2 - s1_model.meta['t_enhance'] = 1 - _ = s1_model.generate(np.ones((4, 10, 10, 3)), - exogenous_data=(None, np.ones((4, 20, 20, 1)))) - - s2_model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) - s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s2_model.meta['output_features'] = ['U_100m', 'V_100m'] - s2_model.meta['s_enhance'] = 2 - s2_model.meta['t_enhance'] = 1 - _ = s2_model.generate(np.ones((4, 10, 10, 3)), - exogenous_data=(None, np.ones((4, 20, 20, 1)))) - - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - st_model.meta['output_features'] = ['U_100m', 'V_100m'] - st_model.meta['s_enhance'] = 3 - st_model.meta['t_enhance'] = 4 - _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - st_out_dir = os.path.join(td, 'st_gan') - s1_out_dir = os.path.join(td, 's1_gan') - s2_out_dir = os.path.join(td, 's2_gan') - st_model.save(st_out_dir) - s1_model.save(s1_out_dir) - s2_model.save(s2_out_dir) - - exo_kwargs = { - 'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2, 2], - 't_enhancements': [1, 1, 1], - 's_agg_factors': [12, 4, 2], - 't_agg_factors': [1, 1, 1], - } - - model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir - } - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict(target=target, - shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) - - # should get an error on a bad tensorflow concatenation - with pytest.raises(RuntimeError): - exo_kwargs['s_enhancements'] = [1, 1, 1] - handler = ForwardPassStrategy( - input_files, - model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, - temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=1), - exo_kwargs=exo_kwargs, - max_nodes=1) - forward_pass = ForwardPass(handler) - forward_pass.run(handler, node_index=0) - - exo_kwargs['s_enhancements'] = [1, 2, 2] - handler = ForwardPassStrategy( - input_files, - model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, - temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=1), - exo_kwargs=exo_kwargs, - max_nodes=1) - forward_pass = ForwardPass(handler) - forward_pass.run(handler, node_index=0) - - for fp in handler.out_files: - assert os.path.exists(fp) - - -def test_fwp_wind_hi_res_topo_plus_linear(): - """Test the forward pass with a MultiExoGan model requiring high-res topo - input from exo data for spatial enhancement and a linear interpolation - model for temporal enhancement.""" - - Sup3rGan.seed() - gen_model = [{ - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv2DTranspose", - "filters": 64, - "kernel_size": 3, - "strides": 1 - }, { - "class": "Cropping2D", - "cropping": 4 - }, { - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv2DTranspose", - "filters": 64, - "kernel_size": 3, - "strides": 1 - }, { - "class": "Cropping2D", - "cropping": 4 - }, { - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv2DTranspose", - "filters": 64, - "kernel_size": 3, - "strides": 1 - }, { - "class": "Cropping2D", - "cropping": 4 - }, { - "class": "SpatialExpansion", - "spatial_mult": 2 - }, { - "alpha": 0.2, - "class": "LeakyReLU" - }, { - "class": "Sup3rConcat" - }, { - "class": "FlexiblePadding", - "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT" - }, { - "class": "Conv2DTranspose", - "filters": 2, - "kernel_size": 3, - "strides": 1 - }, { - "class": "Cropping2D", - "cropping": 4 - }] - - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s_model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) - s_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] - s_model.meta['output_features'] = ['U_100m', 'V_100m'] - s_model.meta['s_enhance'] = 2 - s_model.meta['t_enhance'] = 1 - _ = s_model.generate(np.ones((4, 10, 10, 3)), - exogenous_data=(None, np.ones((4, 20, 20, 1)))) - - t_model = LinearInterp(features=['U_100m', 'V_100m'], - s_enhance=1, - t_enhance=4) - - with tempfile.TemporaryDirectory() as td: - input_files = make_fake_nc_files(td, INPUT_FILE, 8) - - s_out_dir = os.path.join(td, 's_gan') - t_out_dir = os.path.join(td, 't_interp') - s_model.save(s_out_dir) - t_model.save(t_out_dir) - - exo_kwargs = { - 'file_paths': input_files, - 'features': ['topography'], - 'source_file': FP_WTK, - 'target': target, - 'shape': shape, - 's_enhancements': [1, 2], - 't_enhancements': [1, 1], - 's_agg_factors': [4, 2], - 't_agg_factors': [1, 1] - } - - model_kwargs = { - 'spatial_model_dirs': s_out_dir, - 'temporal_model_dirs': t_out_dir - } - out_files = os.path.join(td, 'out_{file_id}.h5') - input_handler_kwargs = dict(target=target, - shape=shape, - temporal_slice=temporal_slice, - worker_kwargs=dict(max_workers=1), - overwrite_cache=True) - - exo_kwargs['s_enhancements'] = [1, 2] - handler = ForwardPassStrategy( - input_files, - model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, - temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=1), - exo_kwargs=exo_kwargs, - max_nodes=1) - forward_pass = ForwardPass(handler) - forward_pass.run(handler, node_index=0) - - for fp in handler.out_files: - assert os.path.exists(fp) diff --git a/tests/forward_pass/test_forward_pass_exo.py b/tests/forward_pass/test_forward_pass_exo.py new file mode 100644 index 000000000..ecfd51fb4 --- /dev/null +++ b/tests/forward_pass/test_forward_pass_exo.py @@ -0,0 +1,1158 @@ +# -*- coding: utf-8 -*- +"""pytests for data handling""" +import json +import os +import tempfile + +import matplotlib.pyplot as plt +import numpy as np +import pytest +import tensorflow as tf +from rex import ResourceX, init_logger + +from sup3r import CONFIG_DIR, TEST_DATA_DIR, __version__ +from sup3r.models import LinearInterp, Sup3rGan +from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy +from sup3r.utilities.pytest import make_fake_nc_files + +FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') +TARGET_COORD = (39.01, -105.15) +FEATURES = ['U_100m', 'V_100m', 'BVF2_200m'] +INPUT_FILE = os.path.join(TEST_DATA_DIR, 'test_wrf_2014-10-01_00_00_00') +target = (19.3, -123.5) +shape = (8, 8) +sample_shape = (8, 8, 6) +temporal_slice = slice(None, None, 1) +list_chunk_size = 10 +fwp_chunk_shape = (4, 4, 150) +s_enhance = 3 +t_enhance = 4 + + +def test_fwp_multi_step_model_topo_exoskip(log=False): + """Test the forward pass with a multi step model class using exogenous data + for the first two steps and not the last""" + + if log: + init_logger('sup3r', log_level='DEBUG') + + Sup3rGan.seed() + fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + s1_model.meta['input_resolution'] = {'spatial': '48km', + 'temporal': '60min'} + _ = s1_model.generate(np.ones((4, 10, 10, 3))) + + s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + s2_model.meta['input_resolution'] = {'spatial': '24km', + 'temporal': '60min'} + _ = s2_model.generate(np.ones((4, 10, 10, 3))) + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['s_enhance'] = 3 + st_model.meta['t_enhance'] = 4 + st_model.meta['input_resolution'] = {'spatial': '12km', + 'temporal': '60min'} + _ = st_model.generate(np.ones((4, 10, 10, 6, 2))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + st_model.save(st_out_dir) + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + max_workers = 1 + fwp_chunk_shape = (4, 4, 8) + s_enhance = 12 + t_enhance = 4 + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + 'steps': [ + {'model': 0, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'input'} + ] + } + } + + model_kwargs = { + 'spatial_model_dirs': [s1_out_dir, s2_out_dir], + 'temporal_model_dirs': st_out_dir + } + + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict( + target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=max_workers), + overwrite_cache=True) + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='SpatialThenTemporalGan', + fwp_chunk_shape=fwp_chunk_shape, + input_handler_kwargs=input_handler_kwargs, + spatial_pad=0, + temporal_pad=0, + out_pattern=out_files, + worker_kwargs=dict(max_workers=max_workers), + exo_kwargs=exo_kwargs, + max_nodes=1) + + forward_pass = ForwardPass(handler) + + assert forward_pass.output_workers == max_workers + assert forward_pass.pass_workers == max_workers + assert forward_pass.max_workers == max_workers + assert forward_pass.data_handler.max_workers == max_workers + assert forward_pass.data_handler.compute_workers == max_workers + assert forward_pass.data_handler.load_workers == max_workers + assert forward_pass.data_handler.norm_workers == max_workers + assert forward_pass.data_handler.extract_workers == max_workers + + forward_pass.run(handler, node_index=0) + + with ResourceX(handler.out_files[0]) as fh: + assert fh.shape == (t_enhance * len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) + + assert fh.global_attrs['package'] == 'sup3r' + assert fh.global_attrs['version'] == __version__ + assert 'full_version_record' in fh.global_attrs + version_record = json.loads(fh.global_attrs['full_version_record']) + assert version_record['tensorflow'] == tf.__version__ + assert 'gan_meta' in fh.global_attrs + gan_meta = json.loads(fh.global_attrs['gan_meta']) + assert len(gan_meta) == 3 # three step model + assert gan_meta[0]['training_features'] == [ + 'U_100m', 'V_100m', 'topography' + ] + + +def test_fwp_multi_step_spatial_model_topo_noskip(): + """Test the forward pass with a multi step spatial only model class using + exogenous data for all model steps""" + Sup3rGan.seed() + fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + s1_model.meta['input_resolution'] = {'spatial': '16km', + 'temporal': '60min'} + _ = s1_model.generate(np.ones((4, 10, 10, 3))) + + s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + s2_model.meta['input_resolution'] = {'spatial': '8km', + 'temporal': '60min'} + _ = s2_model.generate(np.ones((4, 10, 10, 3))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + max_workers = 1 + fwp_chunk_shape = (4, 4, 8) + s_enhancements = [2, 2, 1] + s_enhance = np.product(s_enhancements) + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + 'steps': [ + {'model': 0, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'input'}, + ] + } + } + + model_kwargs = {'model_dirs': [s1_out_dir, s2_out_dir]} + + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict( + target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=max_workers), + overwrite_cache=True) + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='MultiStepGan', + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=max_workers), + exo_kwargs=exo_kwargs, + max_nodes=1) + + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + with ResourceX(handler.out_files[0]) as fh: + assert fh.shape == (len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) + + assert fh.global_attrs['package'] == 'sup3r' + assert fh.global_attrs['version'] == __version__ + assert 'full_version_record' in fh.global_attrs + version_record = json.loads(fh.global_attrs['full_version_record']) + assert version_record['tensorflow'] == tf.__version__ + assert 'gan_meta' in fh.global_attrs + gan_meta = json.loads(fh.global_attrs['gan_meta']) + assert len(gan_meta) == 2 # two step model + assert gan_meta[0]['training_features'] == [ + 'U_100m', 'V_100m', 'topography' + ] + + +def test_fwp_multi_step_model_topo_noskip(): + """Test the forward pass with a multi step model class using exogenous data + for all model steps""" + Sup3rGan.seed() + fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + s1_model.meta['input_resolution'] = {'spatial': '48km', + 'temporal': '60min'} + _ = s1_model.generate(np.ones((4, 10, 10, 3))) + + s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + s2_model.meta['input_resolution'] = {'spatial': '24km', + 'temporal': '60min'} + _ = s2_model.generate(np.ones((4, 10, 10, 3))) + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['s_enhance'] = 3 + st_model.meta['t_enhance'] = 4 + st_model.meta['input_resolution'] = {'spatial': '12km', + 'temporal': '60min'} + _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + st_model.save(st_out_dir) + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + max_workers = 1 + fwp_chunk_shape = (4, 4, 8) + s_enhancements = [2, 2, 3] + s_enhance = np.product(s_enhancements) + t_enhance = 4 + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '15min'}, + 'steps': [{'model': 0, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'input'}, + {'model': 2, 'combine_type': 'input'}] + } + } + + model_kwargs = { + 'spatial_model_dirs': [s1_out_dir, s2_out_dir], + 'temporal_model_dirs': st_out_dir + } + + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict( + target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=max_workers), + overwrite_cache=True) + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='SpatialThenTemporalGan', + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=max_workers), + exo_kwargs=exo_kwargs, + max_nodes=1) + + forward_pass = ForwardPass(handler) + + assert forward_pass.output_workers == max_workers + assert forward_pass.data_handler.compute_workers == max_workers + assert forward_pass.data_handler.load_workers == max_workers + assert forward_pass.data_handler.norm_workers == max_workers + assert forward_pass.data_handler.extract_workers == max_workers + + forward_pass.run(handler, node_index=0) + + with ResourceX(handler.out_files[0]) as fh: + assert fh.shape == (t_enhance * len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) + + assert fh.global_attrs['package'] == 'sup3r' + assert fh.global_attrs['version'] == __version__ + assert 'full_version_record' in fh.global_attrs + version_record = json.loads(fh.global_attrs['full_version_record']) + assert version_record['tensorflow'] == tf.__version__ + assert 'gan_meta' in fh.global_attrs + gan_meta = json.loads(fh.global_attrs['gan_meta']) + assert len(gan_meta) == 3 # three step model + assert gan_meta[0]['training_features'] == [ + 'U_100m', 'V_100m', 'topography' + ] + + +def test_fwp_single_step_wind_hi_res_topo(plot=False): + """Test the forward pass with a single spatiotemporal Sup3rGan model + requiring high-resolution topography input from the exogenous_data + feature.""" + Sup3rGan.seed() + gen_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping3D", + "cropping": 2 + }, { + "class": "SpatioTemporalExpansion", + "temporal_mult": 2, + "temporal_method": "nearest" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping3D", + "cropping": 2 + }, { + "class": "SpatioTemporalExpansion", + "spatial_mult": 2 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping3D", + "cropping": 2 + }, { + "alpha": 0.2, + "class": "LeakyReLU" + }, { + "class": "Sup3rConcat", + "name": "topography" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", + "filters": 2, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping3D", + "cropping": 2 + }] + + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) + model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + model.meta['output_features'] = ['U_100m', 'V_100m'] + model.meta['s_enhance'] = 2 + model.meta['t_enhance'] = 2 + model.meta['input_resolution'] = {'spatial': '8km', + 'temporal': '60min'} + exo_tmp = { + 'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'layer', + 'data': np.random.rand(4, 20, 20, 12, 1)}]}} + _ = model.generate(np.random.rand(4, 10, 10, 6, 3), exogenous_data=exo_tmp) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + model.save(st_out_dir) + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '15min'}, + 'steps': [ + {'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'} + ]}} + + model_kwargs = {'model_dir': st_out_dir} + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) + + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='Sup3rGan', + fwp_chunk_shape=(8, 8, 8), + spatial_pad=4, + temporal_pad=4, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=1), + exo_kwargs=exo_kwargs, + max_nodes=1) + forward_pass = ForwardPass(handler) + + if plot: + for ifeature, feature in enumerate(forward_pass.output_features): + fig = plt.figure(figsize=(15, 5)) + ax1 = fig.add_subplot(111) + vmin = np.min(forward_pass.input_data[..., ifeature]) + vmax = np.max(forward_pass.input_data[..., ifeature]) + nc = ax1.imshow(forward_pass.input_data[..., 0, ifeature], + vmin=vmin, + vmax=vmax) + fig.colorbar(nc, ax=ax1, shrink=0.6, label=f'{feature}') + plt.savefig(f'./input_{feature}.png') + plt.close() + + forward_pass.run(handler, node_index=0) + + for fp in handler.out_files: + assert os.path.exists(fp) + + +def test_fwp_multi_step_wind_hi_res_topo(): + """Test the forward pass with multiple Sup3rGan models requiring + high-resolution topograph input from the exogenous_data feature.""" + Sup3rGan.seed() + gen_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "SpatialExpansion", + "spatial_mult": 2 + }, { + "class": "Activation", + "activation": "relu" + }, { + "class": "Sup3rConcat", + "name": "topography" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 2, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }] + + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + s1_model.meta['input_resolution'] = {'spatial': '48km', + 'temporal': '60min'} + + exo_tmp = { + 'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'layer', + 'data': np.random.rand(4, 20, 20, 1)}]}} + _ = s1_model.generate(np.ones((4, 10, 10, 3)), exogenous_data=exo_tmp) + + s2_model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + s2_model.meta['input_resolution'] = {'spatial': '24km', + 'temporal': '60min'} + _ = s2_model.generate(np.ones((4, 10, 10, 3)), exogenous_data=exo_tmp) + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['s_enhance'] = 3 + st_model.meta['t_enhance'] = 4 + st_model.meta['input_resolution'] = {'spatial': '12km', + 'temporal': '60min'} + _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + st_model.save(st_out_dir) + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '15min'}, + } + } + + model_kwargs = { + 'spatial_model_dirs': [s1_out_dir, s2_out_dir], + 'temporal_model_dirs': st_out_dir + } + model_kwargs = { + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] + } + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) + + with pytest.raises(RuntimeError): + # should raise error since steps doesn't include + # {'model': 2, 'combine_type': 'input'} + steps = [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}, + {'model': 1, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'layer'}] + exo_kwargs['topography']['steps'] = steps + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='MultiStepGan', + fwp_chunk_shape=(4, 4, 8), + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=1), + exo_kwargs=exo_kwargs, + max_nodes=1) + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + steps = [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}, + {'model': 1, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'layer'}, + {'model': 2, 'combine_type': 'input'}] + exo_kwargs['topography']['steps'] = steps + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='MultiStepGan', + fwp_chunk_shape=(4, 4, 8), + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=1), + exo_kwargs=exo_kwargs, + max_nodes=1) + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + for fp in handler.out_files: + assert os.path.exists(fp) + + +def test_fwp_wind_hi_res_topo_plus_linear(): + """Test the forward pass with a Sup3rGan model requiring high-res topo + input from exo data for spatial enhancement and a linear interpolation + model for temporal enhancement.""" + + Sup3rGan.seed() + gen_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "SpatialExpansion", + "spatial_mult": 2 + }, { + "alpha": 0.2, + "class": "LeakyReLU" + }, { + "class": "Sup3rConcat", + "name": "topography" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 2, + "kernel_size": 3, + "strides": 1 + }, { + "class": "Cropping2D", + "cropping": 4 + }] + + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s_model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) + s_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + s_model.meta['output_features'] = ['U_100m', 'V_100m'] + s_model.meta['s_enhance'] = 2 + s_model.meta['t_enhance'] = 1 + s_model.meta['input_resolution'] = {'spatial': '12km', + 'temporal': '60min'} + exo_tmp = { + 'topography': { + 'steps': [ + {'combine_type': 'layer', 'data': np.ones((4, 20, 20, 1))}]}} + _ = s_model.generate(np.ones((4, 10, 10, 3)), + exogenous_data=exo_tmp) + + t_model = LinearInterp(features=['U_100m', 'V_100m'], + s_enhance=1, + t_enhance=4) + t_model.meta['input_resolution'] = {'spatial': '4km', + 'temporal': '60min'} + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + s_out_dir = os.path.join(td, 's_gan') + t_out_dir = os.path.join(td, 't_interp') + s_model.save(s_out_dir) + t_model.save(t_out_dir) + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '15min'}, + 'steps': [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}] + } + } + + model_kwargs = { + 'spatial_model_dirs': s_out_dir, + 'temporal_model_dirs': t_out_dir + } + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) + + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='SpatialThenTemporalGan', + fwp_chunk_shape=(4, 4, 8), + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=1), + exo_kwargs=exo_kwargs, + max_nodes=1) + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + for fp in handler.out_files: + assert os.path.exists(fp) + + +def test_fwp_multi_step_model_multi_exo(): + """Test the forward pass with a multi step model class using 2 exogenous + data features""" + Sup3rGan.seed() + fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography' + ] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + s1_model.meta['input_resolution'] = {'spatial': '48km', + 'temporal': '60min'} + _ = s1_model.generate(np.ones((4, 10, 10, 3))) + + s2_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography' + ] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + s2_model.meta['input_resolution'] = {'spatial': '24km', + 'temporal': '60min'} + _ = s2_model.generate(np.ones((4, 10, 10, 3))) + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + st_model.meta['input_resolution'] = {'spatial': '12km', + 'temporal': '60min'} + st_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'sza' + ] + st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['s_enhance'] = 3 + st_model.meta['t_enhance'] = 4 + _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + st_model.save(st_out_dir) + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + max_workers = 1 + fwp_chunk_shape = (4, 4, 8) + s_enhancements = [2, 2, 3] + s_enhance = np.product(s_enhancements) + t_enhance = 4 + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': None}, + 'steps': [{'model': 0, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'input'}] + }, + 'sza': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': None}, + 'steps': [{'model': 2, 'combine_type': 'input'}] + } + } + + model_kwargs = { + 'spatial_model_dirs': [s1_out_dir, s2_out_dir], + 'temporal_model_dirs': st_out_dir + } + + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict( + target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=max_workers), + overwrite_cache=True) + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='SpatialThenTemporalGan', + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=max_workers), + exo_kwargs=exo_kwargs, + max_nodes=1) + + forward_pass = ForwardPass(handler) + + assert forward_pass.output_workers == max_workers + assert forward_pass.data_handler.compute_workers == max_workers + assert forward_pass.data_handler.load_workers == max_workers + assert forward_pass.data_handler.norm_workers == max_workers + assert forward_pass.data_handler.extract_workers == max_workers + + forward_pass.run(handler, node_index=0) + + with ResourceX(handler.out_files[0]) as fh: + assert fh.shape == (t_enhance * len(input_files), s_enhance**2 + * fwp_chunk_shape[0] * fwp_chunk_shape[1]) + assert all(f in fh.attrs + for f in ('windspeed_100m', 'winddirection_100m')) + + assert fh.global_attrs['package'] == 'sup3r' + assert fh.global_attrs['version'] == __version__ + assert 'full_version_record' in fh.global_attrs + version_record = json.loads(fh.global_attrs['full_version_record']) + assert version_record['tensorflow'] == tf.__version__ + assert 'gan_meta' in fh.global_attrs + gan_meta = json.loads(fh.global_attrs['gan_meta']) + assert len(gan_meta) == 3 # three step model + assert gan_meta[0]['training_features'] == [ + 'U_100m', 'V_100m', 'topography' + ] + + +def test_fwp_multi_step_exo_hi_res_topo_and_sza(): + """Test the forward pass with multiple ExoGan models requiring + high-resolution topography and sza input from the exogenous_data + feature.""" + Sup3rGan.seed() + gen_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT"}, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 64, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }, { + "class": "SpatialExpansion", + "spatial_mult": 2 + }, { + "class": "Activation", + "activation": "relu" + }, { + "class": "Sup3rConcat", + "name": "topography" + }, { + "class": "Sup3rConcat", + "name": "sza" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv2DTranspose", + "filters": 2, + "kernel_size": 3, + "strides": 1, + "activation": "relu" + }, { + "class": "Cropping2D", + "cropping": 4 + }] + + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + s1_model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) + s1_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography', 'sza' + ] + s1_model.meta['output_features'] = ['U_100m', 'V_100m'] + s1_model.meta['s_enhance'] = 2 + s1_model.meta['t_enhance'] = 1 + s1_model.meta['input_resolution'] = {'spatial': '48km', + 'temporal': '60min'} + _ = s1_model.generate(np.ones((4, 10, 10, 3)), + exogenous_data=(None, np.ones((4, 20, 20, 1)))) + + s2_model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) + s2_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography', 'sza' + ] + s2_model.meta['output_features'] = ['U_100m', 'V_100m'] + s2_model.meta['s_enhance'] = 2 + s2_model.meta['t_enhance'] = 1 + s2_model.meta['input_resolution'] = {'spatial': '24km', + 'temporal': '60min'} + _ = s2_model.generate(np.ones((4, 10, 10, 3)), + exogenous_data=(None, np.ones((4, 20, 20, 1)))) + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + st_model.meta['training_features'] = [ + 'U_100m', 'V_100m', 'topography', 'sza' + ] + st_model.meta['output_features'] = ['U_100m', 'V_100m'] + st_model.meta['s_enhance'] = 3 + st_model.meta['t_enhance'] = 4 + st_model.meta['input_resolution'] = {'spatial': '12km', + 'temporal': '60min'} + _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) + + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + + st_out_dir = os.path.join(td, 'st_gan') + s1_out_dir = os.path.join(td, 's1_gan') + s2_out_dir = os.path.join(td, 's2_gan') + st_model.save(st_out_dir) + s1_model.save(s1_out_dir) + s2_model.save(s2_out_dir) + + exo_kwargs = { + 'topography': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': '15min'}, + 'steps': [{'model': 0, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'input'}] + }, + 'sza': { + 'file_paths': input_files, + 'source_file': FP_WTK, + 'target': target, + 'shape': shape, + 'cache_dir': td, + 'exo_resolution': {'spatial': '4km', 'temporal': None}, + 'steps': [{'model': 2, 'combine_type': 'input'}] + } + } + + model_kwargs = { + 'spatial_model_dirs': [s1_out_dir, s2_out_dir], + 'temporal_model_dirs': st_out_dir + } + out_files = os.path.join(td, 'out_{file_id}.h5') + input_handler_kwargs = dict(target=target, + shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=1), + overwrite_cache=True) + + # should get an error on a bad tensorflow concatenation + with pytest.raises(RuntimeError): + exo_kwargs['s_enhancements'] = [1, 1, 1] + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='SpatialThenTemporalGan', + fwp_chunk_shape=(4, 4, 8), + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=1), + exo_kwargs=exo_kwargs, + max_nodes=1) + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + handler = ForwardPassStrategy( + input_files, + model_kwargs=model_kwargs, + model_class='SpatialThenTemporalGan', + fwp_chunk_shape=(4, 4, 8), + spatial_pad=1, + temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, + out_pattern=out_files, + worker_kwargs=dict(max_workers=1), + exo_kwargs=exo_kwargs, + max_nodes=1) + forward_pass = ForwardPass(handler) + forward_pass.run(handler, node_index=0) + + for fp in handler.out_files: + assert os.path.exists(fp) + diff --git a/tests/forward_pass/test_multi_step.py b/tests/forward_pass/test_multi_step.py index 6c9901e71..e55bebcd8 100644 --- a/tests/forward_pass/test_multi_step.py +++ b/tests/forward_pass/test_multi_step.py @@ -1,14 +1,20 @@ # -*- coding: utf-8 -*- """Test forward passes through multi-step GAN models""" import os +import tempfile + import numpy as np import pytest -import tempfile from sup3r import CONFIG_DIR -from sup3r.models import (Sup3rGan, MultiStepGan, - SpatialThenTemporalGan, TemporalThenSpatialGan, - SolarMultiStepGan, LinearInterp) +from sup3r.models import ( + LinearInterp, + MultiStepGan, + SolarMultiStepGan, + SpatialThenTemporalGan, + Sup3rGan, + TemporalThenSpatialGan, +) FEATURES = ['U_100m', 'V_100m'] @@ -237,3 +243,4 @@ def test_solar_multistep(): x = np.ones((3, 10, 10, len(features1 + features2))) out = ms_model.generate(x) assert out.shape == (1, 20, 20, 24, 1) + diff --git a/tests/training/test_train_multi_exo.py b/tests/training/test_train_multi_exo.py index f40e671e3..d5b922a27 100644 --- a/tests/training/test_train_multi_exo.py +++ b/tests/training/test_train_multi_exo.py @@ -8,8 +8,8 @@ from rex import init_logger from sup3r import CONFIG_DIR, TEST_DATA_DIR -from sup3r.models import MultiExoGan -from sup3r.models.data_centric import MultiExoGanDC +from sup3r.models import Sup3rGan +from sup3r.models.data_centric import Sup3rGanDC from sup3r.preprocessing.batch_handling import ( BatchHandlerCC, BatchHandlerDC, @@ -60,8 +60,8 @@ def test_wind_cc_model(log=False): fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_4x_24x_3f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - MultiExoGan.seed() - model = MultiExoGan(fp_gen, fp_disc, learning_rate=1e-4) + Sup3rGan.seed() + model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: model.train(batcher, n_epoch=1, @@ -71,7 +71,7 @@ def test_wind_cc_model(log=False): out_dir=os.path.join(td, 'test_{epoch}')) assert 'test_0' in os.listdir(td) - assert model.meta['class'] == 'MultiExoGan' + assert model.meta['class'] == 'Sup3rGan' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features assert len(model.output_features) == len(FEATURES_W) - 1 @@ -108,8 +108,8 @@ def test_wind_cc_model_spatial(log=False): fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - MultiExoGan.seed() - model = MultiExoGan(fp_gen, fp_disc, learning_rate=1e-4) + Sup3rGan.seed() + model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: model.train(batcher, n_epoch=1, @@ -120,7 +120,7 @@ def test_wind_cc_model_spatial(log=False): assert 'test_0' in os.listdir(td) assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'MultiExoGan' + assert model.meta['class'] == 'Sup3rGan' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features @@ -188,8 +188,8 @@ def test_wind_hi_res_topo(custom_layer, log=False): fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - MultiExoGan.seed() - model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) + Sup3rGan.seed() + model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: model.train(batcher, n_epoch=1, @@ -200,7 +200,7 @@ def test_wind_hi_res_topo(custom_layer, log=False): assert 'test_0' in os.listdir(td) assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'MultiExoGan' + assert model.meta['class'] == 'Sup3rGan' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features @@ -273,8 +273,8 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - MultiExoGan.seed() - model = MultiExoGan(gen_model, fp_disc, learning_rate=1e-4) + Sup3rGan.seed() + model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: model.train(batcher, n_epoch=1, @@ -285,7 +285,7 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): assert 'test_0' in os.listdir(td) assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'MultiExoGan' + assert model.meta['class'] == 'Sup3rGan' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features @@ -358,8 +358,8 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - MultiExoGanDC.seed() - model = MultiExoGanDC(gen_model, fp_disc, learning_rate=1e-4) + Sup3rGanDC.seed() + model = Sup3rGanDC(gen_model, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: model.train(batcher, n_epoch=1, @@ -370,7 +370,7 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): assert 'test_0' in os.listdir(td) assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'MultiExoGanDC' + assert model.meta['class'] == 'Sup3rGanDC' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features From 6afd315756fdb27f7edc4296055675cb637f5459 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 25 Sep 2023 16:34:57 -0600 Subject: [PATCH 09/15] SzaExtract using rex SolarPosition added to exo_extraction.py. all tests modified to handle new exogenous_data input format. --- sup3r/configs/sup3rcc/gen_wind_3x_4x_2f.json | 2 +- sup3r/configs/sup3rcc/gen_wind_5x_1x_6f.json | 2 +- sup3r/models/abstract.py | 34 ++-- sup3r/models/base.py | 2 +- sup3r/models/conditional_moments.py | 8 +- sup3r/models/linear.py | 12 +- sup3r/models/surface.py | 15 +- .../data_handling/exo_extraction.py | 160 +++++++++--------- .../data_handling/exogenous_data_handling.py | 36 ++-- .../wind_conditional_moment_batch_handling.py | 36 +++- tests/forward_pass/test_forward_pass_exo.py | 134 ++++++++++----- tests/forward_pass/test_multi_step.py | 1 - tests/forward_pass/test_surface_model.py | 23 ++- .../test_train_conditional_moments.py | 83 ++++++--- tests/training/test_train_gan.py | 25 ++- ...ain_multi_exo.py => test_train_gan_exo.py} | 64 ++++--- tests/training/test_train_gan_lr_era.py | 2 + tests/training/test_train_solar.py | 28 +-- .../test_train_wind_conditional_moments.py | 76 +++++---- 19 files changed, 459 insertions(+), 284 deletions(-) rename tests/training/{test_train_multi_exo.py => test_train_gan_exo.py} (88%) diff --git a/sup3r/configs/sup3rcc/gen_wind_3x_4x_2f.json b/sup3r/configs/sup3rcc/gen_wind_3x_4x_2f.json index 870c363f2..9785abf0e 100644 --- a/sup3r/configs/sup3rcc/gen_wind_3x_4x_2f.json +++ b/sup3r/configs/sup3rcc/gen_wind_3x_4x_2f.json @@ -34,7 +34,7 @@ {"class": "SpatioTemporalExpansion", "spatial_mult": 3}, {"alpha": 0.2, "class": "LeakyReLU"}, - {"class": "Sup3rConcat"}, + {"class": "Sup3rConcat", "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0,0], [3,3], [3,3], [3,3], [0,0]], "mode": "REFLECT"}, {"class": "Conv3D", "filters": 2, "kernel_size": 3, "strides": 1}, diff --git a/sup3r/configs/sup3rcc/gen_wind_5x_1x_6f.json b/sup3r/configs/sup3rcc/gen_wind_5x_1x_6f.json index 53ce739a0..b90b9f39c 100644 --- a/sup3r/configs/sup3rcc/gen_wind_5x_1x_6f.json +++ b/sup3r/configs/sup3rcc/gen_wind_5x_1x_6f.json @@ -31,7 +31,7 @@ {"class": "SpatialExpansion", "spatial_mult": 5}, {"alpha": 0.2, "class": "LeakyReLU"}, - {"class": "Sup3rConcat"}, + {"class": "Sup3rConcat", "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0,0], [3,3], [3,3], [0,0]], "mode": "REFLECT"}, {"class": "Conv2D", "filters": 64, "kernel_size": 3, "strides": 1}, {"class": "Cropping2D", "cropping": 2}, diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 3b898c3d2..5e205c860 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -148,11 +148,14 @@ def _combine_fwp_input(self, low_res, exogenous_data=None): (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ + low_res_shape = low_res.shape check = (exogenous_data is not None and low_res.shape[-1] < len(self.training_features)) + exo_data = {k: v for k, v in exogenous_data.items() + if k in self.training_features} if check: - for i, (feature, entry) in enumerate(exogenous_data.items()): - f_idx = low_res.shape[-1] + i + for i, (feature, entry) in enumerate(exo_data.items()): + f_idx = low_res_shape[-1] + i training_feature = self.training_features[f_idx] msg = ('The ordering of features in exogenous_data conflicts ' 'with the ordering of training features. Received ' @@ -194,9 +197,14 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None): (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - if exogenous_data is not None: - for i, (feature, entry) in enumerate(exogenous_data.items()): - f_idx = hi_res.shape[-1] + i + hi_res_shape = hi_res.shape[-1] + check = (exogenous_data is not None + and hi_res.shape[-1] < len(self.output_features)) + exo_data = {k: v for k, v in exogenous_data.items() + if k in self.training_features} + if check: + for i, (feature, entry) in enumerate(exo_data.items()): + f_idx = hi_res_shape[-1] + i training_feature = self.training_features[f_idx] msg = ('The ordering of features in exogenous_data conflicts ' 'with the ordering of training features. Received ' @@ -228,10 +236,11 @@ def _combine_loss_input(self, high_res_true, high_res_gen): high_res_gen : tf.Tensor Same as input with exogenous data combined with high_res input """ - for feature in self.exogenous_features: - f_idx = self.training_features.index(feature) - exo_data = high_res_true[..., f_idx: f_idx + 1] - high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1) + if high_res_true.shape[-1] > high_res_gen.shape[-1]: + for feature in self.exogenous_features: + f_idx = self.training_features.index(feature) + exo_data = high_res_true[..., f_idx: f_idx + 1] + high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1) return high_res_gen def _get_exo_val_loss_input(self, high_res): @@ -344,9 +353,9 @@ def _check_exo_features(self, **kwargs): f'data handler must be {self.exogenous_features} ' 'to train the Exo model, but received output features: {}'. format(output_features)) - check = (output_features[-len(self.exogenous_features)] - == self.exogenous_features) - assert check, msg + exo_features = ([] if len(self.exogenous_features) == 0 + else output_features[-len(self.exogenous_features):]) + assert exo_features == self.exogenous_features, msg for f in self.exogenous_features: output_features.remove(f) kwargs['output_features'] = output_features @@ -1166,6 +1175,7 @@ def generate(self, (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ + low_res = self._combine_fwp_input(low_res, exogenous_data) if norm_in and self._means is not None: diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 7c1b04233..30d6ab22d 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -447,7 +447,7 @@ def calc_loss_gen_content(self, hi_res_true, hi_res_gen): 0D tensor generator model loss for the content loss comparing the hi res ground truth to the hi res synthetically generated output. """ - + hi_res_gen = self._combine_loss_input(hi_res_true, hi_res_gen) loss_gen_content = self.loss_fun(hi_res_true, hi_res_gen) return loss_gen_content diff --git a/sup3r/models/conditional_moments.py b/sup3r/models/conditional_moments.py index b0d1cd4a3..af7835e9d 100644 --- a/sup3r/models/conditional_moments.py +++ b/sup3r/models/conditional_moments.py @@ -339,7 +339,9 @@ def train_epoch(self, batch_handler, multi_gpu=False): return loss_details - def train(self, batch_handler, n_epoch, + def train(self, batch_handler, + input_resolution, + n_epoch, checkpoint_int=None, out_dir='./condMom_{epoch}', early_stop_on=None, @@ -352,6 +354,9 @@ def train(self, batch_handler, n_epoch, ---------- batch_handler : sup3r.data_handling.preprocessing.BatchHandler BatchHandler object to iterate through + input_resolution : dict + Dictionary specifying spatiotemporal input resolution. e.g. + {'temporal': '60min', 'spatial': '30km'} n_epoch : int Number of epochs to train on checkpoint_int : int | None @@ -386,6 +391,7 @@ def train(self, batch_handler, n_epoch, """ self.set_norm_stats(batch_handler.means, batch_handler.stds) self.set_model_params( + input_resolution=input_resolution, s_enhance=batch_handler.s_enhance, t_enhance=batch_handler.t_enhance, smoothing=batch_handler.smoothing, diff --git a/sup3r/models/linear.py b/sup3r/models/linear.py index e049a3efd..378b068c3 100644 --- a/sup3r/models/linear.py +++ b/sup3r/models/linear.py @@ -1,12 +1,14 @@ # -*- coding: utf-8 -*- """Simple models for super resolution such as linear interp models.""" -import numpy as np +import json import logging -from inspect import signature import os -import json -from sup3r.utilities.utilities import st_interp +from inspect import signature + +import numpy as np + from sup3r.models.abstract import AbstractInterface +from sup3r.utilities.utilities import st_interp logger = logging.getLogger(__name__) @@ -59,7 +61,7 @@ class init args. """ fp_params = os.path.join(model_dir, 'model_params.json') assert os.path.exists(fp_params), f'Could not find: {fp_params}' - with open(fp_params, 'r') as f: + with open(fp_params) as f: params = json.load(f) meta = params['meta'] diff --git a/sup3r/models/surface.py b/sup3r/models/surface.py index a11ce0ea9..f0b93b8a5 100644 --- a/sup3r/models/surface.py +++ b/sup3r/models/surface.py @@ -107,6 +107,7 @@ def __init__(self, features, s_enhance, noise_adders=None, self._pres_div = pres_div or self.PRES_DIV self._pres_exp = pres_exp or self.PRES_EXP self._fix_bias = fix_bias + self._input_resolution = None self._interp_method = getattr(Image.Resampling, interp_method) if isinstance(self._noise_adders, (int, float)): @@ -561,6 +562,7 @@ def meta(self): 's_enhance': self._s_enhance, 't_enhance': 1, 'noise_adders': self._noise_adders, + 'input_resolution': self._input_resolution, 'weight_for_delta_temp': self._w_delta_temp, 'weight_for_delta_topo': self._w_delta_topo, 'pressure_divisor': self._pres_div, @@ -572,10 +574,10 @@ def meta(self): 'class': self.__class__.__name__, } - def train(self, true_hr_temp, true_hr_rh, true_hr_topo): - """This method trains the relative humidity linear model. The - temperature and surface lapse rate models are parameterizations taken - from the NSRDB and are not trained. + def train(self, true_hr_temp, true_hr_rh, true_hr_topo, input_resolution): + """Trains the relative humidity linear model. The temperature and + surface lapse rate models are parameterizations taken from the NSRDB + and are not trained. Parameters ---------- @@ -588,6 +590,9 @@ def train(self, true_hr_temp, true_hr_rh, true_hr_topo): true_hr_topo : np.ndarray High-resolution surface elevation data in meters with shape (lat, lon) + input_resolution : dict + Dictionary of spatial and temporal input resolution. e.g. + {'spatial': '20km': 'temporal': '60min'} Returns ------- @@ -598,7 +603,7 @@ def train(self, true_hr_temp, true_hr_rh, true_hr_topo): Weight for the delta-topography feature for the relative humidity linear regression model. """ - + self._input_resolution = input_resolution assert len(true_hr_temp.shape) == 3, 'Bad true_hr_temp shape' assert len(true_hr_rh.shape) == 3, 'Bad true_hr_rh shape' assert len(true_hr_topo.shape) == 2, 'Bad true_hr_topo shape' diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 97cf843b2..6d8c6bdfe 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -5,6 +5,7 @@ import numpy as np from rex import Resource +from rex.utilities.solar_position import SolarPosition from scipy.spatial import KDTree import sup3r.preprocessing.data_handling @@ -110,15 +111,17 @@ def __init__(self, logger.info(f'Initializing {self.__class__.__name__} utility.') + self.ti_workers = ti_workers self._exo_source = exo_source self._s_enhance = s_enhance self._t_enhance = t_enhance self._s_agg_factor = s_agg_factor self._t_agg_factor = t_agg_factor self._tree = None - self.ti_workers = ti_workers self._hr_lat_lon = None + self._source_lat_lon = None self._hr_time_index = None + self._src_time_index = None if input_handler is None: in_type = get_source_type(file_paths) @@ -157,24 +160,25 @@ def source_data(self): """Get the 1D array of source data from the exo_source_h5""" @property - @abstractmethod - def hr_time_index(self): - """Get the full time index of the exo_source data""" - - @property - def hr_temporal_slice(self): - """Get the temporal slice fr the exo_source data corresponding to the + def source_temporal_slice(self): + """Get the temporal slice for the exo_source data corresponding to the input file temporal slice""" - start_index = self.hr_time_index.get_loc( - self.input_handler.time_index[0], method='nearest') - end_index = self.hr_time_index.get_loc( - self.input_handler.time_index[-1], method='nearest') - return slice(start_index, end_index + 1) + start_index = self.source_time_index.get_indexer( + [self.input_handler.hr_time_index[0]], method='nearest')[0] + end_index = self.source_time_index.get_indexer( + [self.input_handler.hr_time_index[-1]], method='nearest')[0] + return slice(start_index, end_index + 1, self._t_agg_factor) @property - @abstractmethod def source_lat_lon(self): - """Get the 2D array (n, 2) of lat, lon data from the exo_source_h5""" + """Get the 2D array (n, 2) of lat, lon data for the exo source""" + if self._source_lat_lon is None: + src_enhance = int(np.sqrt(self._s_agg_factor)) + src_shape = (self.hr_shape[0] * src_enhance, + self.hr_shape[1] * src_enhance) + self._source_lat_lon = OutputHandler.get_lat_lon( + self.lr_lat_lon, src_shape).reshape((-1, 2)) + return self._source_lat_lon @property def lr_shape(self): @@ -219,6 +223,29 @@ def hr_lat_lon(self): self._hr_lat_lon = self.lr_lat_lon return self._hr_lat_lon + @property + def source_time_index(self): + """Get the full time index of the exo_source data""" + if self._src_time_index is None: + if self._t_agg_factor > 1: + self._src_time_index = OutputHandler.get_times( + self.input_handler.time_index, + self.hr_shape[-1] * self._t_agg_factor) + else: + self._src_time_index = self.hr_time_index + return self._src_time_index + + @property + def hr_time_index(self): + """Get the full time index for aggregated source data""" + if self._hr_time_index is None: + if self._t_enhance > 1: + self._hr_time_index = OutputHandler.get_times( + self.input_handler.time_index, self.hr_shape[-1]) + else: + self._hr_time_index = self.input_handler.time_index + return self._hr_time_index + @property def tree(self): """Get the KDTree built on the source lat lon data""" @@ -246,7 +273,7 @@ def data(self): nn = self.nn hr_data = [] for j in range(self._s_agg_factor): - out = self.source_data[nn[:, j], ::self._t_agg_factor] + out = self.source_data[nn[:, j], self.source_temporal_slice] out = out.reshape(self.hr_shape) hr_data.append(out[..., np.newaxis]) hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) @@ -256,11 +283,11 @@ def data(self): @classmethod def get_exo_raster(cls, file_paths, - exo_source, s_enhance, t_enhance, s_agg_factor, t_agg_factor, + exo_source=None, target=None, shape=None, temporal_slice=None, @@ -276,10 +303,6 @@ def get_exo_raster(cls, A single source h5 file to extract raster data from or a list of netcdf files with identical grid. The string can be a unix-style file path which will be passed through glob.glob - exo_source : str - Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or - 4km) data from which will be mapped to the enhanced grid of the - file_paths input s_enhance : int Factor by which the Sup3rGan model will enhance the spatial dimensions of low resolution data from file_paths input. For @@ -306,6 +329,10 @@ class will output a topography raster corresponding to the is 4 resulting in a desired resolution of 5 min and exo_source has a resolution of 5 min, the t_agg_factor should be 4 so that every fourth timestep in the exo_source data is skipped. + exo_source : str + Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or + 4km) data from which will be mapped to the enhanced grid of the + file_paths input target : tuple (lat, lon) lower left corner of raster. Either need target+shape or raster_file. @@ -340,11 +367,11 @@ class will output a topography raster corresponding to the feature='topography' """ exo = cls(file_paths, - exo_source, s_enhance, t_enhance, s_agg_factor, t_agg_factor, + exo_source=exo_source, target=target, shape=shape, temporal_slice=temporal_slice, @@ -362,9 +389,7 @@ def source_data(self): """Get the 1D array of elevation data from the exo_source_h5""" with Resource(self._exo_source) as res: elev = res.get_meta_arr('elevation') - elev = np.repeat(elev[:, np.newaxis], - self.hr_shape[-1] * self._t_agg_factor, - axis=-1) + elev = np.repeat(elev[:, np.newaxis], self.hr_shape[-1], axis=-1) return elev @property @@ -375,15 +400,31 @@ def source_lat_lon(self): return source_lat_lon @property - def hr_time_index(self): - """Time index of the high-resolution exo data""" - if self._hr_time_index is None: + def source_time_index(self): + """Time index of the source exo data""" + if self._src_time_index is None: with Resource(self._exo_source) as res: - self._hr_time_index = res.time_index - return self._hr_time_index + self._src_time_index = res.time_index + return self._src_time_index + @property + def data(self): + """Get a raster of source values corresponding to the + high-resolution grid (the file_paths input grid * s_enhance * + t_enhance). The shape is (lats, lons, temporal, 1) + """ + nn = self.nn + hr_data = [] + for j in range(self._s_agg_factor): + out = self.source_data[nn[:, j]] + out = out.reshape(self.hr_shape) + hr_data.append(out[..., np.newaxis]) + hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) + logger.info('Finished mapping raster from {}'.format(self._exo_source)) + return hr_data[..., np.newaxis] -class TopoExtractNC(ExoExtract): + +class TopoExtractNC(TopoExtractH5): """TopoExtract for netCDF files""" def __init__(self, *args, **kwargs): @@ -409,7 +450,7 @@ def __init__(self, *args, **kwargs): @property def source_data(self): """Get the 1D array of elevation data from the exo_source_nc""" - elev = self.source_handler.data.reshape(-1) + elev = self.source_handler.data.reshape((-1, self.lr_shape[-1])) return elev @property @@ -419,56 +460,21 @@ def source_lat_lon(self): return source_lat_lon -class SzaExtractH5(ExoExtract): +class SzaExtract(ExoExtract): """SzaExtract for H5 files""" @property def source_data(self): """Get the 1D array of sza data from the exo_source_h5""" - with Resource(self._exo_source) as res: - elev = res.get_meta_arr('elevation') - elev = np.repeat(elev[:, np.newaxis], self.hr_shape[-1], axis=-1) - return elev + return SolarPosition(self.hr_time_index, + self.hr_lat_lon.reshape((-1, 2))).zenith.T @property - def source_lat_lon(self): - """Get the 2D array (n, 2) of lat, lon data from the exo_source_h5""" - with Resource(self._exo_source) as res: - source_lat_lon = res.lat_lon - return source_lat_lon - - -class SzaExtractNC(ExoExtract): - """TopoExtract for netCDF files""" - - def __init__(self, *args, **kwargs): - """ - Parameters - ---------- - args : list - Same positional arguments as TopoExtract - kwargs : dict - Same keyword arguments as TopoExtract + def data(self): + """Get a raster of source values corresponding to the + high-resolution grid (the file_paths input grid * s_enhance * + t_enhance). The shape is (lats, lons, temporal, 1) """ - - super().__init__(*args, **kwargs) - logger.info('Getting topography for full domain from ' - f'{self._exo_source}') - self.source_handler = DataHandlerNC( - self._exo_source, - features=['topography'], - worker_kwargs=dict(ti_workers=self.ti_workers), - val_split=0.0, - ) - - @property - def source_data(self): - """Get the 1D array of elevation data from the exo_source_h5""" - elev = self.source_handler.data.reshape(-1) - return elev - - @property - def source_lat_lon(self): - """Get the 2D array (n, 2) of lat, lon data from the exo_source_h5""" - source_lat_lon = self.source_handler.lat_lon.reshape((-1, 2)) - return source_lat_lon + hr_data = self.source_data.reshape(self.hr_shape) + logger.info('Finished computing SZA data') + return hr_data[..., np.newaxis] diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index 9b08b5aaf..884d4a7c4 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -4,11 +4,10 @@ import pickle import shutil from typing import ClassVar -from warnings import warn +from sup3r.preprocessing.data_handling import exo_extraction from sup3r.preprocessing.data_handling.exo_extraction import ( - SzaExtractH5, - SzaExtractNC, + SzaExtract, TopoExtractH5, TopoExtractNC, ) @@ -28,19 +27,19 @@ class ExogenousDataHandler: 'nc': TopoExtractNC }, 'sza': { - 'h5': SzaExtractH5, - 'nc': SzaExtractNC + 'h5': SzaExtract, + 'nc': SzaExtract } } def __init__(self, file_paths, feature, - source_file, s_enhancements, t_enhancements, s_agg_factors, t_agg_factors, + source_file=None, target=None, shape=None, temporal_slice=None, @@ -61,10 +60,6 @@ def __init__(self, sup3r resolved. feature : str Exogenous feature to extract from source_h5 - source_file : str - Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or - 4km) data from which will be mapped to the enhanced grid of the - file_paths input s_enhancements : list List of factors by which the Sup3rGan model will enhance the spatial dimensions of low resolution data from file_paths input @@ -93,6 +88,10 @@ def __init__(self, data to the temporal resolution of the file_paths input enhanced by t_enhance. The length of this list should be equal to the number of t_enhancements + source_file : str + Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or + 4km) data from which will be mapped to the enhanced grid of the + file_paths input target : tuple (lat, lon) lower left corner of raster. Either need target+shape or raster_file. @@ -146,21 +145,6 @@ def __init__(self, self.cache_dir = cache_dir self.data = [] - if self.s_enhancements[0] != 1: - msg = ('s_enhancements typically starts with 1 so the first ' - 'exogenous data input matches the spatial resolution of ' - 'the source low-res input data, but received ' - 's_enhancements: {}'.format(self.s_enhancements)) - logger.warning(msg) - warn(msg) - if self.t_enhancements[0] != 1: - msg = ('t_enhancements typically starts with 1 so the first ' - 'exogenous data input matches the temporal resolution of ' - 'the source low-res input data, but received ' - 't_enhancements: {}'.format(self.t_enhancements)) - logger.warning(msg) - warn(msg) - msg = ('Need to provide the same number of enhancement factors and ' f'agg factors. Received s_enhancements={self.s_enhancements}, ' f'and s_agg_factors={self.s_agg_factors}.') @@ -325,4 +309,6 @@ def get_exo_handler(cls, feature, source_file, exo_handler): f'feature={feature} and input_type={in_type}.') logger.error(msg) raise KeyError(msg) + elif isinstance(exo_handler, str): + exo_handler = getattr(exo_extraction, exo_handler, None) return exo_handler diff --git a/sup3r/preprocessing/wind_conditional_moment_batch_handling.py b/sup3r/preprocessing/wind_conditional_moment_batch_handling.py index 537935d82..0b554cca8 100644 --- a/sup3r/preprocessing/wind_conditional_moment_batch_handling.py +++ b/sup3r/preprocessing/wind_conditional_moment_batch_handling.py @@ -3,17 +3,21 @@ Sup3r wind conditional moment batch_handling module. """ import logging -import tensorflow as tf + import numpy as np +import tensorflow as tf -from sup3r.utilities.utilities import (spatial_simple_enhancing, - temporal_simple_enhancing) from sup3r.preprocessing.batch_handling import Batch from sup3r.preprocessing.conditional_moment_batch_handling import ( - SpatialBatchHandlerMom1, + BatchHandlerMom1, BatchMom1, + SpatialBatchHandlerMom1, ValidationDataMom1, - BatchHandlerMom1) +) +from sup3r.utilities.utilities import ( + spatial_simple_enhancing, + temporal_simple_enhancing, +) np.random.seed(42) @@ -122,7 +126,8 @@ def make_output(low_res, high_res, HR is high-res and LR is low-res """ # Remove first moment from HR and square it - out = model_mom1._tf_generate(low_res, high_res[..., -1:]).numpy() + out = model_mom1._tf_generate( + low_res, {'topography': high_res[..., -1:]}).numpy() out = tf.concat((out, high_res[..., -1:]), axis=-1) return (high_res - out)**2 @@ -176,7 +181,8 @@ def make_output(low_res, high_res, SF = HR - LR """ # Remove LR and first moment from HR and square it - out = model_mom1._tf_generate(low_res, high_res[..., -1:]).numpy() + out = model_mom1._tf_generate( + low_res, {'topography': high_res[..., -1:]}).numpy() out = tf.concat((out, high_res[..., -1:]), axis=-1) enhanced_lr = spatial_simple_enhancing(low_res, s_enhance=s_enhance) @@ -251,6 +257,7 @@ class WindBatchHandlerMom1(BatchHandlerMom1): class WindSpatialBatchHandlerMom1(SpatialBatchHandlerMom1): """Sup3r spatial batch handling class""" + # Classes to use for handling an individual batch obj. VAL_CLASS = ValidationDataMom1 BATCH_CLASS = WindBatchMom1 @@ -260,36 +267,42 @@ class WindSpatialBatchHandlerMom1(SpatialBatchHandlerMom1): class ValidationDataWindMom1SF(ValidationDataMom1): """Iterator for validation wind data for first conditional moment of subfilter velocity""" + BATCH_CLASS = WindBatchMom1SF class ValidationDataWindMom2(ValidationDataMom1): """Iterator for subfilter validation wind data for second conditional moment""" + BATCH_CLASS = WindBatchMom2 class ValidationDataWindMom2Sep(ValidationDataMom1): """Iterator for subfilter validation wind data for second conditional moment separate from first moment""" + BATCH_CLASS = WindBatchMom2Sep class ValidationDataWindMom2SF(ValidationDataMom1): """Iterator for validation wind data for second conditional moment of subfilter velocity""" + BATCH_CLASS = WindBatchMom2SF class ValidationDataWindMom2SepSF(ValidationDataMom1): """Iterator for validation wind data for second conditional moment of subfilter velocity separate from first moment""" + BATCH_CLASS = WindBatchMom2SepSF class WindBatchHandlerMom1SF(WindBatchHandlerMom1): """Sup3r batch handling class for first conditional moment of subfilter velocity using topography as input""" + VAL_CLASS = ValidationDataWindMom1SF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -297,6 +310,7 @@ class WindBatchHandlerMom1SF(WindBatchHandlerMom1): class WindSpatialBatchHandlerMom1SF(WindSpatialBatchHandlerMom1): """Sup3r spatial batch handling class for first conditional moment of subfilter velocity using topography as input""" + VAL_CLASS = ValidationDataWindMom1SF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -304,6 +318,7 @@ class WindSpatialBatchHandlerMom1SF(WindSpatialBatchHandlerMom1): class WindBatchHandlerMom2(WindBatchHandlerMom1): """Sup3r batch handling class for second conditional moment using topography as input""" + VAL_CLASS = ValidationDataWindMom2 BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -311,6 +326,7 @@ class WindBatchHandlerMom2(WindBatchHandlerMom1): class WindBatchHandlerMom2Sep(WindBatchHandlerMom1): """Sup3r batch handling class for second conditional moment separate from first moment using topography as input""" + VAL_CLASS = ValidationDataWindMom2Sep BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -318,6 +334,7 @@ class WindBatchHandlerMom2Sep(WindBatchHandlerMom1): class WindSpatialBatchHandlerMom2(WindSpatialBatchHandlerMom1): """Sup3r spatial batch handling class for second conditional moment using topography as input""" + VAL_CLASS = ValidationDataWindMom2 BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -325,6 +342,7 @@ class WindSpatialBatchHandlerMom2(WindSpatialBatchHandlerMom1): class WindSpatialBatchHandlerMom2Sep(WindSpatialBatchHandlerMom1): """Sup3r spatial batch handling class for second conditional moment separate from first moment using topography as input""" + VAL_CLASS = ValidationDataWindMom2Sep BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -332,6 +350,7 @@ class WindSpatialBatchHandlerMom2Sep(WindSpatialBatchHandlerMom1): class WindBatchHandlerMom2SF(WindBatchHandlerMom1): """Sup3r batch handling class for second conditional moment of subfilter velocity""" + VAL_CLASS = ValidationDataWindMom2SF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -339,6 +358,7 @@ class WindBatchHandlerMom2SF(WindBatchHandlerMom1): class WindBatchHandlerMom2SepSF(WindBatchHandlerMom1): """Sup3r batch handling class for second conditional moment of subfilter velocity separate from first moment using topography as input""" + VAL_CLASS = ValidationDataWindMom2SepSF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -346,6 +366,7 @@ class WindBatchHandlerMom2SepSF(WindBatchHandlerMom1): class WindSpatialBatchHandlerMom2SF(WindSpatialBatchHandlerMom1): """Sup3r spatial batch handling class for second conditional moment of subfilter velocity using topography as input""" + VAL_CLASS = ValidationDataWindMom2SF BATCH_CLASS = VAL_CLASS.BATCH_CLASS @@ -353,5 +374,6 @@ class WindSpatialBatchHandlerMom2SF(WindSpatialBatchHandlerMom1): class WindSpatialBatchHandlerMom2SepSF(WindSpatialBatchHandlerMom1): """Sup3r spatial batch handling class for second conditional moment of subfilter velocity separate from first moment using topography as input""" + VAL_CLASS = ValidationDataWindMom2SepSF BATCH_CLASS = VAL_CLASS.BATCH_CLASS diff --git a/tests/forward_pass/test_forward_pass_exo.py b/tests/forward_pass/test_forward_pass_exo.py index ecfd51fb4..f30b83852 100644 --- a/tests/forward_pass/test_forward_pass_exo.py +++ b/tests/forward_pass/test_forward_pass_exo.py @@ -309,7 +309,7 @@ def test_fwp_multi_step_model_topo_noskip(): 'target': target, 'shape': shape, 'cache_dir': td, - 'exo_resolution': {'spatial': '4km', 'temporal': '15min'}, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, 'steps': [{'model': 0, 'combine_type': 'input'}, {'model': 1, 'combine_type': 'input'}, {'model': 2, 'combine_type': 'input'}] @@ -466,7 +466,7 @@ def test_fwp_single_step_wind_hi_res_topo(plot=False): 'target': target, 'shape': shape, 'cache_dir': td, - 'exo_resolution': {'spatial': '4km', 'temporal': '15min'}, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, 'steps': [ {'model': 0, 'combine_type': 'input'}, {'model': 0, 'combine_type': 'layer'} @@ -633,7 +633,7 @@ def test_fwp_multi_step_wind_hi_res_topo(): 'target': target, 'shape': shape, 'cache_dir': td, - 'exo_resolution': {'spatial': '4km', 'temporal': '15min'}, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, } } @@ -800,7 +800,7 @@ def test_fwp_wind_hi_res_topo_plus_linear(): 'target': target, 'shape': shape, 'cache_dir': td, - 'exo_resolution': {'spatial': '4km', 'temporal': '15min'}, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, 'steps': [{'model': 0, 'combine_type': 'input'}, {'model': 0, 'combine_type': 'layer'}] } @@ -900,17 +900,17 @@ def test_fwp_multi_step_model_multi_exo(): 'target': target, 'shape': shape, 'cache_dir': td, - 'exo_resolution': {'spatial': '4km', 'temporal': None}, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, 'steps': [{'model': 0, 'combine_type': 'input'}, {'model': 1, 'combine_type': 'input'}] }, 'sza': { 'file_paths': input_files, - 'source_file': FP_WTK, 'target': target, 'shape': shape, 'cache_dir': td, - 'exo_resolution': {'spatial': '4km', 'temporal': None}, + 'exo_handler': 'SzaExtract', + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, 'steps': [{'model': 2, 'combine_type': 'input'}] } } @@ -974,10 +974,11 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): high-resolution topography and sza input from the exogenous_data feature.""" Sup3rGan.seed() - gen_model = [{ + gen_s_model = [{ "class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], - "mode": "REFLECT"}, { + "mode": "REFLECT" + }, { "class": "Conv2DTranspose", "filters": 64, "kernel_size": 3, @@ -1039,8 +1040,53 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): "cropping": 4 }] + gen_t_model = [{ + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", "filters": 1, "kernel_size": 3, "strides": 1 + }, { + "class": "Cropping3D", "cropping": 2 + }, { + "alpha": 0.2, "class": "LeakyReLU" + }, { + "class": "SpatioTemporalExpansion", "temporal_mult": 2, + "temporal_method": "nearest" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", "filters": 1, "kernel_size": 3, "strides": 1 + }, { + "class": "Cropping3D", "cropping": 2 + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", "filters": 36, "kernel_size": 3, "strides": 1 + }, { + "class": "Cropping3D", "cropping": 2 + }, { + "class": "SpatioTemporalExpansion", "spatial_mult": 3 + }, { + "alpha": 0.2, "class": "LeakyReLU" + }, { + "class": "Sup3rConcat", "name": "sza" + }, { + "class": "FlexiblePadding", + "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], + "mode": "REFLECT" + }, { + "class": "Conv3D", "filters": 2, "kernel_size": 3, "strides": 1 + }, { + "class": "Cropping3D", "cropping": 2 + }] + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - s1_model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) + s1_model = Sup3rGan(gen_s_model, fp_disc, learning_rate=1e-4) s1_model.meta['training_features'] = [ 'U_100m', 'V_100m', 'topography', 'sza' ] @@ -1049,10 +1095,17 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): s1_model.meta['t_enhance'] = 1 s1_model.meta['input_resolution'] = {'spatial': '48km', 'temporal': '60min'} - _ = s1_model.generate(np.ones((4, 10, 10, 3)), - exogenous_data=(None, np.ones((4, 20, 20, 1)))) - - s2_model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) + exo_tmp = { + 'topography': { + 'steps': [{'model': 0, 'combine_type': 'layer', + 'data': np.ones((4, 20, 20, 1))}]}, + 'sza': { + 'steps': [{'model': 0, 'combine_type': 'layer', + 'data': np.ones((4, 20, 20, 1))}]} + } + _ = s1_model.generate(np.ones((4, 10, 10, 4)), exogenous_data=exo_tmp) + + s2_model = Sup3rGan(gen_s_model, fp_disc, learning_rate=1e-4) s2_model.meta['training_features'] = [ 'U_100m', 'V_100m', 'topography', 'sza' ] @@ -1061,21 +1114,24 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): s2_model.meta['t_enhance'] = 1 s2_model.meta['input_resolution'] = {'spatial': '24km', 'temporal': '60min'} - _ = s2_model.generate(np.ones((4, 10, 10, 3)), - exogenous_data=(None, np.ones((4, 20, 20, 1)))) + _ = s2_model.generate(np.ones((4, 10, 10, 4)), exogenous_data=exo_tmp) - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + st_model = Sup3rGan(gen_t_model, fp_disc, learning_rate=1e-4) st_model.meta['training_features'] = [ - 'U_100m', 'V_100m', 'topography', 'sza' + 'U_100m', 'V_100m', 'sza' ] st_model.meta['output_features'] = ['U_100m', 'V_100m'] st_model.meta['s_enhance'] = 3 - st_model.meta['t_enhance'] = 4 + st_model.meta['t_enhance'] = 2 st_model.meta['input_resolution'] = {'spatial': '12km', 'temporal': '60min'} - _ = st_model.generate(np.ones((4, 10, 10, 6, 3))) + exo_tmp = { + 'sza': { + 'steps': [{'model': 0, 'combine_type': 'layer', + 'data': np.ones((4, 30, 30, 12, 1))}]} + } + _ = st_model.generate(np.ones((4, 10, 10, 6, 3)), exogenous_data=exo_tmp) with tempfile.TemporaryDirectory() as td: input_files = make_fake_nc_files(td, INPUT_FILE, 8) @@ -1094,18 +1150,25 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): 'target': target, 'shape': shape, 'cache_dir': td, - 'exo_resolution': {'spatial': '4km', 'temporal': '15min'}, + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, 'steps': [{'model': 0, 'combine_type': 'input'}, - {'model': 1, 'combine_type': 'input'}] + {'model': 0, 'combine_type': 'layer'}, + {'model': 1, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'layer'}] }, 'sza': { 'file_paths': input_files, - 'source_file': FP_WTK, + 'exo_handler': 'SzaExtract', 'target': target, 'shape': shape, 'cache_dir': td, - 'exo_resolution': {'spatial': '4km', 'temporal': None}, - 'steps': [{'model': 2, 'combine_type': 'input'}] + 'exo_resolution': {'spatial': '4km', 'temporal': '60min'}, + 'steps': [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}, + {'model': 1, 'combine_type': 'input'}, + {'model': 1, 'combine_type': 'layer'}, + {'model': 2, 'combine_type': 'input'}, + {'model': 2, 'combine_type': 'layer'}] } } @@ -1120,24 +1183,6 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): worker_kwargs=dict(max_workers=1), overwrite_cache=True) - # should get an error on a bad tensorflow concatenation - with pytest.raises(RuntimeError): - exo_kwargs['s_enhancements'] = [1, 1, 1] - handler = ForwardPassStrategy( - input_files, - model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', - fwp_chunk_shape=(4, 4, 8), - spatial_pad=1, - temporal_pad=1, - input_handler_kwargs=input_handler_kwargs, - out_pattern=out_files, - worker_kwargs=dict(max_workers=1), - exo_kwargs=exo_kwargs, - max_nodes=1) - forward_pass = ForwardPass(handler) - forward_pass.run(handler, node_index=0) - handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, @@ -1155,4 +1200,3 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): for fp in handler.out_files: assert os.path.exists(fp) - diff --git a/tests/forward_pass/test_multi_step.py b/tests/forward_pass/test_multi_step.py index e55bebcd8..8db52c8f1 100644 --- a/tests/forward_pass/test_multi_step.py +++ b/tests/forward_pass/test_multi_step.py @@ -243,4 +243,3 @@ def test_solar_multistep(): x = np.ones((3, 10, 10, len(features1 + features2))) out = ms_model.generate(x) assert out.shape == (1, 20, 20, 24, 1) - diff --git a/tests/forward_pass/test_surface_model.py b/tests/forward_pass/test_surface_model.py index 0373394e2..2c72044d0 100644 --- a/tests/forward_pass/test_surface_model.py +++ b/tests/forward_pass/test_surface_model.py @@ -4,15 +4,15 @@ import json import os import tempfile -import pytest -import numpy as np +import numpy as np +import pytest from rex import Resource -from sup3r.models import Sup3rGan from sup3r import CONFIG_DIR, TEST_DATA_DIR -from sup3r.models.surface import SurfaceSpatialMetModel +from sup3r.models import Sup3rGan from sup3r.models.multi_step import MultiStepSurfaceMetGan +from sup3r.models.surface import SurfaceSpatialMetModel from sup3r.utilities.utilities import spatial_coarsening INPUT_FILE_W = os.path.join(TEST_DATA_DIR, 'test_wtk_surface_vars.h5') @@ -61,8 +61,8 @@ def test_surface_model(s_enhance=5): json.dump(kwargs, f) model = SurfaceSpatialMetModel.load(model_dir=td) - - hi_res = model.generate(low_res, exogenous_data=[topo_lr, topo_hr]) + exo_tmp = {'topography': {'steps': [{'data': topo_lr}, {'data': topo_hr}]}} + hi_res = model.generate(low_res, exogenous_data=exo_tmp) diff = true_hi_res - hi_res @@ -86,7 +86,9 @@ def test_train_rh_model(s_enhance=10): true_hr_rh = np.transpose(true_hi_res[..., 1], axes=(1, 2, 0)) model = SurfaceSpatialMetModel(FEATURES, s_enhance=s_enhance) - w_delta_temp, w_delta_topo = model.train(true_hr_temp, true_hr_rh, topo_hr) + w_delta_temp, w_delta_topo = model.train( + true_hr_temp, true_hr_rh, topo_hr, + input_resolution={'spatial': '3km', 'temporal': '60min'}) # pretty generous tolerances because the training dataset is so small assert np.allclose(w_delta_temp, SurfaceSpatialMetModel.W_DELTA_TEMP, @@ -120,6 +122,8 @@ def test_multi_step_surface(s_enhance=2, t_enhance=2): model.set_norm_stats([0.3, 0.9, 0.1], [0.02, 0.07, 0.03]) model.set_model_params(training_features=FEATURES, output_features=FEATURES, + input_resolution={'spatial': '30km', + 'temporal': '60min'}, s_enhance=1, t_enhance=t_enhance) @@ -155,7 +159,10 @@ def test_multi_step_surface(s_enhance=2, t_enhance=2): topo_lr = topo_lr[:4, :4] topo_hr = topo_hr[:8, :8] - hi_res = ms_model.generate(low_res, exogenous_data=[topo_lr, topo_hr]) + exo_tmp = { + 'topography': { + 'steps': [{'data': topo_lr}, {'data': topo_hr}]}} + hi_res = ms_model.generate(low_res, exogenous_data=exo_tmp) target_shape = (1, low_res.shape[1] * s_enhance, diff --git a/tests/training/test_train_conditional_moments.py b/tests/training/test_train_conditional_moments.py index 927da3753..71fbfb281 100644 --- a/tests/training/test_train_conditional_moments.py +++ b/tests/training/test_train_conditional_moments.py @@ -1,33 +1,32 @@ # -*- coding: utf-8 -*- """Test the basic training of super resolution GAN""" import os +import tempfile + # import json import numpy as np import pytest -import tempfile import tensorflow as tf -from tensorflow.python.framework.errors_impl import InvalidArgumentError - from rex import init_logger +from tensorflow.python.framework.errors_impl import InvalidArgumentError -from sup3r import TEST_DATA_DIR -from sup3r import CONFIG_DIR +from sup3r import CONFIG_DIR, TEST_DATA_DIR from sup3r.models import Sup3rCondMom -from sup3r.preprocessing.data_handling import DataHandlerH5 from sup3r.preprocessing.conditional_moment_batch_handling import ( - SpatialBatchHandlerMom1, - SpatialBatchHandlerMom1SF, - SpatialBatchHandlerMom2, - SpatialBatchHandlerMom2Sep, - SpatialBatchHandlerMom2SF, - SpatialBatchHandlerMom2SepSF, BatchHandlerMom1, BatchHandlerMom1SF, BatchHandlerMom2, BatchHandlerMom2Sep, + BatchHandlerMom2SepSF, BatchHandlerMom2SF, - BatchHandlerMom2SepSF) - + SpatialBatchHandlerMom1, + SpatialBatchHandlerMom1SF, + SpatialBatchHandlerMom2, + SpatialBatchHandlerMom2Sep, + SpatialBatchHandlerMom2SepSF, + SpatialBatchHandlerMom2SF, +) +from sup3r.preprocessing.data_handling import DataHandlerH5 FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') TARGET_COORD = (39.01, -105.15) @@ -79,7 +78,9 @@ def test_train_s_mom1(FEATURES, TRAIN_FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '8km', 'temporal': '30min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) @@ -167,7 +168,9 @@ def test_train_s_mom1_sf(FEATURES, TRAIN_FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '8km', 'temporal': '30min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) @@ -227,7 +230,10 @@ def test_train_s_mom2(FEATURES, TRAIN_FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '8km', + 'temporal': '30min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -288,7 +294,10 @@ def test_train_s_mom2_sf(FEATURES, TRAIN_FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '8km', + 'temporal': '30min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -340,7 +349,10 @@ def test_train_s_mom2_sep(FEATURES, TRAIN_FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '8km', + 'temporal': '30min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -392,7 +404,10 @@ def test_train_s_mom2_sep_sf(FEATURES, TRAIN_FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '8km', + 'temporal': '30min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -436,7 +451,9 @@ def test_train_st_mom1(FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) @@ -485,7 +502,9 @@ def test_train_st_mom1_sf(FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '12km', 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) @@ -541,7 +560,10 @@ def test_train_st_mom2(FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '12km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -599,7 +621,10 @@ def test_train_st_mom2_sf(FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '12km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -645,7 +670,10 @@ def test_train_st_mom2_sep(FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '12km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality @@ -691,7 +719,10 @@ def test_train_st_mom2_sep_sf(FEATURES, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batch_handler, n_epoch=n_epoch, + model_mom2.train(batch_handler, + input_resolution={'spatial': '12km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=2, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) # test save/load functionality diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 17bbfb75e..5aa4c96d2 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -52,8 +52,13 @@ def test_train_spatial(log=False, full_shape=(20, 20), with tempfile.TemporaryDirectory() as td: # test that training works and reduces loss - model.train(batch_handler, n_epoch=n_epoch, weight_gen_advers=0.0, - train_gen=True, train_disc=False, checkpoint_int=1, + model.train(batch_handler, + input_resolution={'spatial': '8km', 'temporal': '30min'}, + n_epoch=n_epoch, + weight_gen_advers=0.0, + train_gen=True, + train_disc=False, + checkpoint_int=1, out_dir=os.path.join(td, 'test_{epoch}')) assert len(model.history) == n_epoch @@ -122,7 +127,9 @@ def test_train_st_weight_update(n_epoch=2, log=False): adaptive_update_bounds = (0.9, 0.99) with tempfile.TemporaryDirectory() as td: - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '8km', 'temporal': '30min'}, + n_epoch=n_epoch, weight_gen_advers=1e-6, train_gen=True, train_disc=True, checkpoint_int=10, @@ -178,7 +185,9 @@ def test_train_spatial_dc(log=False, full_shape=(20, 20), with tempfile.TemporaryDirectory() as td: # test that the normalized number of samples from each bin is close # to the weight for that bin - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '8km', 'temporal': '30min'}, + n_epoch=n_epoch, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=2, @@ -226,7 +235,9 @@ def test_train_st_dc(n_epoch=2, log=False): with tempfile.TemporaryDirectory() as td: # test that the normalized number of samples from each bin is close # to the weight for that bin - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '8km', 'temporal': '30min'}, + n_epoch=n_epoch, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=2, @@ -274,7 +285,9 @@ def test_train_st(n_epoch=2, log=False): with tempfile.TemporaryDirectory() as td: # test that training works and reduces loss - model.train(batch_handler, n_epoch=n_epoch, + model.train(batch_handler, + input_resolution={'spatial': '8km', 'temporal': '30min'}, + n_epoch=n_epoch, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=1, diff --git a/tests/training/test_train_multi_exo.py b/tests/training/test_train_gan_exo.py similarity index 88% rename from tests/training/test_train_multi_exo.py rename to tests/training/test_train_gan_exo.py index d5b922a27..66cda197c 100644 --- a/tests/training/test_train_multi_exo.py +++ b/tests/training/test_train_gan_exo.py @@ -48,15 +48,12 @@ def test_wind_cc_model(log=False): time_roll=-7, sample_shape=(20, 20, 96), worker_kwargs=dict(max_workers=1), - train_only_features=tuple()) - + train_only_features=['topography']) batcher = BatchHandlerCC([handler], batch_size=2, n_batches=2, s_enhance=4, sub_daily_shape=None) - if log: init_logger('sup3r', log_level='DEBUG') - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_4x_24x_3f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') @@ -64,7 +61,10 @@ def test_wind_cc_model(log=False): model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '16km', + 'temporal': '3600min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, @@ -72,7 +72,6 @@ def test_wind_cc_model(log=False): assert 'test_0' in os.listdir(td) assert model.meta['class'] == 'Sup3rGan' - assert 'topography' in batcher.output_features assert 'topography' not in model.output_features assert len(model.output_features) == len(FEATURES_W) - 1 @@ -97,7 +96,7 @@ def test_wind_cc_model_spatial(log=False): val_split=0.1, sample_shape=(20, 20), worker_kwargs=dict(max_workers=1), - train_only_features=tuple()) + train_only_features=['topography']) batcher = SpatialBatchHandlerCC([handler], batch_size=2, n_batches=2, s_enhance=2) @@ -112,7 +111,10 @@ def test_wind_cc_model_spatial(log=False): model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '16km', + 'temporal': '3600min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, @@ -121,7 +123,6 @@ def test_wind_cc_model_spatial(log=False): assert 'test_0' in os.listdir(td) assert model.meta['output_features'] == ['U_100m', 'V_100m'] assert model.meta['class'] == 'Sup3rGan' - assert 'topography' in batcher.output_features assert 'topography' not in model.output_features x = np.random.uniform(0, 1, (4, 30, 30, 3)) @@ -146,7 +147,7 @@ def test_wind_hi_res_topo(custom_layer, log=False): val_split=0.1, sample_shape=(20, 20), worker_kwargs=dict(max_workers=1), - train_only_features=tuple()) + train_only_features=()) batcher = SpatialBatchHandlerCC([handler], batch_size=2, n_batches=2, s_enhance=2) @@ -177,7 +178,7 @@ def test_wind_hi_res_topo(custom_layer, log=False): {"class": "SpatialExpansion", "spatial_mult": 2}, {"class": "Activation", "activation": "relu"}, - {"class": custom_layer}, + {"class": custom_layer, "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], @@ -192,7 +193,10 @@ def test_wind_hi_res_topo(custom_layer, log=False): model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '16km', + 'temporal': '3600min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, @@ -205,12 +209,16 @@ def test_wind_hi_res_topo(custom_layer, log=False): assert 'topography' not in model.output_features x = np.random.uniform(0, 1, (4, 30, 30, 3)) - hi_res_topo = np.random.uniform(0, 1, (60, 60)) + hi_res_topo = np.random.uniform(0, 1, (4, 60, 60, 1)) with pytest.raises(RuntimeError): y = model.generate(x, exogenous_data=None) - y = model.generate(x, exogenous_data=(None, hi_res_topo)) + exo_tmp = { + 'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}} + y = model.generate(x, exogenous_data=exo_tmp) assert y.shape[0] == x.shape[0] assert y.shape[1] == x.shape[1] * 2 @@ -262,7 +270,7 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): {"class": "SpatialExpansion", "spatial_mult": 2}, {"class": "Activation", "activation": "relu"}, - {"class": custom_layer}, + {"class": custom_layer, "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], @@ -277,7 +285,10 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): model = Sup3rGan(gen_model, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '16km', + 'temporal': '3600min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, @@ -290,12 +301,16 @@ def test_wind_non_cc_hi_res_topo(custom_layer, log=False): assert 'topography' not in model.output_features x = np.random.uniform(0, 1, (4, 30, 30, 3)) - hi_res_topo = np.random.uniform(0, 1, (60, 60)) + hi_res_topo = np.random.uniform(0, 1, (4, 60, 60, 1)) with pytest.raises(RuntimeError): y = model.generate(x, exogenous_data=None) - y = model.generate(x, exogenous_data=(None, hi_res_topo)) + exo_tmp = { + 'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}} + y = model.generate(x, exogenous_data=exo_tmp) assert y.shape[0] == x.shape[0] assert y.shape[1] == x.shape[1] * 2 @@ -347,7 +362,7 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): {"class": "SpatioTemporalExpansion", "spatial_mult": 2}, {"class": "Activation", "activation": "relu"}, - {"class": custom_layer}, + {"class": custom_layer, "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [3, 3], [0, 0]], @@ -362,7 +377,10 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): model = Sup3rGanDC(gen_model, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '16km', + 'temporal': '3600min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, @@ -380,7 +398,11 @@ def test_wind_dc_hi_res_topo(custom_layer, log=False): with pytest.raises(RuntimeError): y = model.generate(x, exogenous_data=None) - y = model.generate(x, exogenous_data=(None, hi_res_topo)) + exo_tmp = { + 'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}} + y = model.generate(x, exogenous_data=exo_tmp) assert y.shape[0] == x.shape[0] assert y.shape[1] == x.shape[1] * 2 diff --git a/tests/training/test_train_gan_lr_era.py b/tests/training/test_train_gan_lr_era.py index 59ef88105..776dc8db3 100644 --- a/tests/training/test_train_gan_lr_era.py +++ b/tests/training/test_train_gan_lr_era.py @@ -73,6 +73,7 @@ def test_train_spatial( # test that training works and reduces loss model.train( batch_handler, + input_resolution={'spatial': '30km', 'temporal': '60min'}, n_epoch=n_epoch, weight_gen_advers=0.0, train_gen=True, @@ -173,6 +174,7 @@ def test_train_st(n_epoch=3, log=True): # test that training works and reduces loss model.train( batch_handler, + input_resolution={'spatial': '30km', 'temporal': '60min'}, n_epoch=n_epoch, weight_gen_advers=0.0, train_gen=True, diff --git a/tests/training/test_train_solar.py b/tests/training/test_train_solar.py index 2ac26fc10..9b3a14bca 100644 --- a/tests/training/test_train_solar.py +++ b/tests/training/test_train_solar.py @@ -2,19 +2,19 @@ """Test the basic training of super resolution GAN for solar climate change applications""" import os -import numpy as np import tempfile -from tensorflow.keras.losses import MeanAbsoluteError +import numpy as np from rex import init_logger +from tensorflow.keras.losses import MeanAbsoluteError -from sup3r import TEST_DATA_DIR -from sup3r import CONFIG_DIR -from sup3r.models import Sup3rGan, SolarCC +from sup3r import CONFIG_DIR, TEST_DATA_DIR +from sup3r.models import SolarCC, Sup3rGan +from sup3r.preprocessing.batch_handling import ( + BatchHandlerCC, + SpatialBatchHandlerCC, +) from sup3r.preprocessing.data_handling import DataHandlerH5SolarCC -from sup3r.preprocessing.batch_handling import (BatchHandlerCC, - SpatialBatchHandlerCC) - SHAPE = (20, 20) @@ -54,7 +54,9 @@ def test_solar_cc_model(log=False): loss='MeanAbsoluteError') with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '4km', 'temporal': '40min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, @@ -108,7 +110,9 @@ def test_solar_cc_model_spatial(log=False): model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '25km', 'temporal': '15min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, @@ -149,7 +153,9 @@ def test_solar_custom_loss(log=False): loss='MeanAbsoluteError') with tempfile.TemporaryDirectory() as td: - model.train(batcher, n_epoch=1, + model.train(batcher, + input_resolution={'spatial': '4km', 'temporal': '40min'}, + n_epoch=1, weight_gen_advers=0.0, train_gen=True, train_disc=False, checkpoint_int=None, diff --git a/tests/training/test_train_wind_conditional_moments.py b/tests/training/test_train_wind_conditional_moments.py index 7ac2ae7a2..d33efe46e 100644 --- a/tests/training/test_train_wind_conditional_moments.py +++ b/tests/training/test_train_wind_conditional_moments.py @@ -2,30 +2,29 @@ """Test the basic training of super resolution GAN for solar climate change applications""" import os -import pytest -import numpy as np import tempfile +import numpy as np +import pytest from rex import init_logger -from sup3r import CONFIG_DIR -from sup3r import TEST_DATA_DIR -from sup3r.models import WindCondMom +from sup3r import CONFIG_DIR, TEST_DATA_DIR +from sup3r.models import Sup3rCondMom from sup3r.preprocessing.data_handling import DataHandlerH5 from sup3r.preprocessing.wind_conditional_moment_batch_handling import ( + WindBatchHandlerMom1, + WindBatchHandlerMom1SF, + WindBatchHandlerMom2, + WindBatchHandlerMom2Sep, + WindBatchHandlerMom2SepSF, + WindBatchHandlerMom2SF, WindSpatialBatchHandlerMom1, WindSpatialBatchHandlerMom1SF, WindSpatialBatchHandlerMom2, - WindSpatialBatchHandlerMom2SF, WindSpatialBatchHandlerMom2Sep, WindSpatialBatchHandlerMom2SepSF, - WindBatchHandlerMom1, - WindBatchHandlerMom1SF, - WindBatchHandlerMom2, - WindBatchHandlerMom2SF, - WindBatchHandlerMom2Sep, - WindBatchHandlerMom2SepSF) - + WindSpatialBatchHandlerMom2SF, +) SHAPE = (20, 20) @@ -67,7 +66,7 @@ def make_s_gen_model(custom_layer): {"class": "SpatialExpansion", "spatial_mult": 2}, {"class": "Activation", "activation": "relu"}, - {"class": custom_layer}, + {"class": custom_layer, "name": "topography"}, {"class": "FlexiblePadding", "paddings": [[0, 0], [3, 3], [3, 3], [0, 0]], @@ -108,24 +107,30 @@ def test_wind_non_cc_hi_res_topo_mom1(custom_layer, batch_class, gen_model = make_s_gen_model(custom_layer) - WindCondMom.seed() - model = WindCondMom(gen_model, learning_rate=1e-4) + Sup3rCondMom.seed() + model = Sup3rCondMom(gen_model, learning_rate=1e-4) with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model.train(batcher, n_epoch=n_epoch, + model.train(batcher, + input_resolution={'spatial': '8km', 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=None, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) assert f'test_{n_epoch-1}' in os.listdir(out_dir_root) assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'WindCondMom' + assert model.meta['class'] == 'Sup3rCondMom' assert 'topography' in batcher.output_features assert 'topography' not in model.output_features x = np.random.uniform(0, 1, (4, 30, 30, 3)) - hi_res_topo = np.random.uniform(0, 1, (60, 60)) + hi_res_topo = np.random.uniform(0, 1, (4, 60, 60, 1)) + exo_tmp = { + 'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'layer', 'data': hi_res_topo}]}} - y = model.generate(x, exogenous_data=(None, hi_res_topo)) + y = model.generate(x, exogenous_data=exo_tmp) assert y.shape[0] == x.shape[0] assert y.shape[1] == x.shape[1] * 2 @@ -159,8 +164,8 @@ def test_wind_non_cc_hi_res_st_topo_mom1(batch_class, log=False, 'sup3rcc', 'gen_wind_3x_4x_2f.json') - WindCondMom.seed() - model_mom1 = WindCondMom(fp_gen, learning_rate=1e-4) + Sup3rCondMom.seed() + model_mom1 = Sup3rCondMom(fp_gen, learning_rate=1e-4) batcher = batch_class([handler], batch_size=batch_size, @@ -171,7 +176,10 @@ def test_wind_non_cc_hi_res_st_topo_mom1(batch_class, log=False, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom1.train(batcher, n_epoch=n_epoch, + model_mom1.train(batcher, + input_resolution={'spatial': '12km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=None, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) @@ -204,9 +212,9 @@ def test_wind_non_cc_hi_res_topo_mom2(custom_layer, batch_class, gen_model = make_s_gen_model(custom_layer) - WindCondMom.seed() - model_mom1 = WindCondMom(gen_model, learning_rate=1e-4) - model_mom2 = WindCondMom(gen_model, learning_rate=1e-4) + Sup3rCondMom.seed() + model_mom1 = Sup3rCondMom(gen_model, learning_rate=1e-4) + model_mom2 = Sup3rCondMom(gen_model, learning_rate=1e-4) batcher = batch_class([handler], batch_size=batch_size, @@ -217,7 +225,10 @@ def test_wind_non_cc_hi_res_topo_mom2(custom_layer, batch_class, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batcher, n_epoch=n_epoch, + model_mom2.train(batcher, + input_resolution={'spatial': '8km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=None, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) @@ -251,9 +262,9 @@ def test_wind_non_cc_hi_res_st_topo_mom2(batch_class, log=False, 'sup3rcc', 'gen_wind_3x_4x_2f.json') - WindCondMom.seed() - model_mom1 = WindCondMom(fp_gen, learning_rate=1e-4) - model_mom2 = WindCondMom(fp_gen, learning_rate=1e-4) + Sup3rCondMom.seed() + model_mom1 = Sup3rCondMom(fp_gen, learning_rate=1e-4) + model_mom2 = Sup3rCondMom(fp_gen, learning_rate=1e-4) batcher = batch_class([handler], batch_size=batch_size, @@ -264,7 +275,10 @@ def test_wind_non_cc_hi_res_st_topo_mom2(batch_class, log=False, with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td - model_mom2.train(batcher, n_epoch=n_epoch, + model_mom2.train(batcher, + input_resolution={'spatial': '12km', + 'temporal': '60min'}, + n_epoch=n_epoch, checkpoint_int=None, out_dir=os.path.join(out_dir_root, 'test_{epoch}')) From fa6bf27c86892e80363883095690f93a1fb016bd Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 26 Sep 2023 10:23:17 -0600 Subject: [PATCH 10/15] Removed tests from test_train_gan_exo which don't use exogenous features. --- sup3r/models/abstract.py | 55 ++++++++--- sup3r/pipeline/forward_pass.py | 42 ++++---- tests/data_handling/test_exo_data_handling.py | 27 +++-- tests/data_handling/test_utils_topo.py | 35 ++++--- tests/forward_pass/test_multi_step.py | 5 +- tests/training/test_train_gan_exo.py | 98 ------------------- .../test_train_wind_conditional_moments.py | 2 + 7 files changed, 111 insertions(+), 153 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 5e205c860..11d55094e 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -151,9 +151,9 @@ def _combine_fwp_input(self, low_res, exogenous_data=None): low_res_shape = low_res.shape check = (exogenous_data is not None and low_res.shape[-1] < len(self.training_features)) - exo_data = {k: v for k, v in exogenous_data.items() - if k in self.training_features} if check: + exo_data = {k: v for k, v in exogenous_data.items() + if k in self.training_features} for i, (feature, entry) in enumerate(exo_data.items()): f_idx = low_res_shape[-1] + i training_feature = self.training_features[f_idx] @@ -200,9 +200,9 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None): hi_res_shape = hi_res.shape[-1] check = (exogenous_data is not None and hi_res.shape[-1] < len(self.output_features)) - exo_data = {k: v for k, v in exogenous_data.items() - if k in self.training_features} if check: + exo_data = {k: v for k, v in exogenous_data.items() + if k in self.training_features} for i, (feature, entry) in enumerate(exo_data.items()): f_idx = hi_res_shape[-1] + i training_feature = self.training_features[f_idx] @@ -798,7 +798,7 @@ def update_loss_details(loss_details, new_data, batch_len, prefix=None): Namespace of the breakdown of loss components for a single new batch. batch_len : int - Length of the incomming batch. + Length of the incoming batch. prefix : None | str Option to prefix the names of the loss data when saving to the loss_details dictionary. @@ -1136,6 +1136,39 @@ def _reshape_norm_exo(self, hi_res, hi_res_exo, exo_name, norm_in=True): return hi_res_exo + def _get_layer_exo_input(self, layer_name, exogenous_data): + """Get the high-resolution exo data for the given layer name from the + full exogenous_data dictionary. + + Parameters + ---------- + layer_name : str + Name of Sup3rAdder or Sup3rConcat layer. This should match a + feature key in exogenous_data + exogenous_data : dict | None + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. This doesn't have to include the 'model' key since + this data is for a single step model. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'data': ..., 'resolution': ...}, + {'combine_type': 'layer', 'data': ..., 'resolution': ...}]}} + + """ + msg = (f'layer.name = {layer_name} does not match any ' + 'features in exogenous_data ' + f'({list(exogenous_data)})') + assert layer_name in exogenous_data, msg + steps = exogenous_data[layer_name]['steps'] + combine_types = [step['combine_type'] for step in steps] + msg = ('Received exogenous_data without any combine_type ' + '= "layer" steps, for a model with an Adder/Concat ' + 'layer.') + assert 'layer' in combine_types, msg + idx = combine_types.index('layer') + hi_res_exo = steps[idx]['data'] + return hi_res_exo + def generate(self, low_res, norm_in=True, @@ -1185,14 +1218,11 @@ def generate(self, for i, layer in enumerate(self.generator.layers[1:]): try: if isinstance(layer, (Sup3rAdder, Sup3rConcat)): - exo_name = layer.name - steps = exogenous_data[exo_name]['steps'] - combine_types = [step['combine_type'] for step in steps] - idx = combine_types.index('layer') - hi_res_exo = steps[idx]['data'] + hi_res_exo = self._get_layer_exo_input(layer.name, + exogenous_data) hi_res_exo = self._reshape_norm_exo(hi_res, hi_res_exo, - exo_name, + layer.name, norm_in=norm_in) hi_res = layer(hi_res, hi_res_exo) else: @@ -1242,6 +1272,9 @@ def _tf_generate(self, low_res, hi_res_exo=None): for i, layer in enumerate(self.generator.layers[1:]): try: if isinstance(layer, (Sup3rAdder, Sup3rConcat)): + msg = (f'layer.name = {layer.name} does not match any ' + f'features in exogenous_data ({list(hi_res_exo)})') + assert layer.name in hi_res_exo, msg hr_exo = hi_res_exo[layer.name] hi_res = layer(hi_res, hr_exo) else: diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 83d2583ab..c5c2a9c79 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -12,6 +12,7 @@ from concurrent.futures import as_completed from datetime import datetime as dt from inspect import signature +from typing import ClassVar import numpy as np import psutil @@ -1040,6 +1041,9 @@ class ForwardPass: through the GAN generator to produce high resolution output. """ + OUTPUT_HANDLER_CLASS: ClassVar = {'nc': OutputHandlerNC, + 'h5': OutputHandlerH5} + def __init__(self, strategy, chunk_index=0, node_index=0): """Initialize ForwardPass with ForwardPassStrategy. The stragegy provides the data chunks to run forward passes on @@ -1089,16 +1093,15 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.max_workers = strategy.max_workers self.pass_workers = strategy.pass_workers self.output_workers = strategy.output_workers - self.exo_kwargs = self._prep_exo_extract_kwargs(strategy.exo_kwargs) + self.exo_kwargs = self.update_exo_extract_kwargs(strategy.exo_kwargs) self.exo_features = ([] if not self.exo_kwargs else list(self.exo_kwargs)) self.exogenous_data = self.load_exo_data() self.input_handler_class = strategy.input_handler_class - - if strategy.output_type == 'nc': - self.output_handler_class = OutputHandlerNC - elif strategy.output_type == 'h5': - self.output_handler_class = OutputHandlerH5 + msg = f'Received bad output type {strategy.output_type}' + if strategy.output_type in list(self.OUTPUT_HANDLER_CLASS): + self.output_handler_class = self.OUTPUT_HANDLER_CLASS[ + strategy.output_type] input_handler_kwargs = self.update_input_handler_kwargs(strategy) @@ -1177,7 +1180,7 @@ def get_agg_factors(self, input_res, exo_res): t_agg_factor = self._get_res_ratio(input_t_res, exo_t_res) return s_agg_factor, t_agg_factor - def _prep_exo_extract_single_step(self, step_dict, exo_resolution): + def _get_single_step_agg_and_enhancement(self, step_dict, exo_resolution): """Compute agg and enhancement factors for exogenous data extraction using exo_kwargs single model step. These factors are computed using exo_resolution and the input/output resolution of each model step. @@ -1244,7 +1247,7 @@ def _prep_exo_extract_single_step(self, step_dict, exo_resolution): 'resolution': resolution}) return updated_dict - def _prep_exo_extract_kwargs(self, exo_kwargs): + def update_exo_extract_kwargs(self, exo_kwargs): """Compute agg and enhancement factors for all model steps for all features. @@ -1265,10 +1268,19 @@ def _prep_exo_extract_kwargs(self, exo_kwargs): if exo_kwargs is not None: for feature in exo_kwargs: exo_resolution = exo_kwargs[feature]['exo_resolution'] - for i, step in enumerate(exo_kwargs[feature]['steps']): - out = self._prep_exo_extract_single_step( + steps = exo_kwargs[feature]['steps'] + for i, step in enumerate(steps): + out = self._get_single_step_agg_and_enhancement( step, exo_resolution) exo_kwargs[feature]['steps'][i] = out + exo_kwargs[feature]['s_agg_factors'] = [step['s_agg_factor'] + for step in steps] + exo_kwargs[feature]['t_agg_factors'] = [step['t_agg_factor'] + for step in steps] + exo_kwargs[feature]['s_enhancements'] = [step['s_enhance'] + for step in steps] + exo_kwargs[feature]['t_enhancements'] = [step['t_enhance'] + for step in steps] return exo_kwargs def load_exo_data(self): @@ -1291,16 +1303,8 @@ def load_exo_data(self): exo_kwargs['feature'] = feature exo_kwargs['target'] = self.target exo_kwargs['shape'] = self.shape - steps = exo_kwargs['steps'] exo_kwargs['temporal_slice'] = self.ti_pad_slice - exo_kwargs['s_agg_factors'] = [step['s_agg_factor'] - for step in steps] - exo_kwargs['t_agg_factors'] = [step['t_agg_factor'] - for step in steps] - exo_kwargs['s_enhancements'] = [step['s_enhance'] - for step in steps] - exo_kwargs['t_enhancements'] = [step['t_enhance'] - for step in steps] + steps = exo_kwargs['steps'] sig = signature(ExogenousDataHandler) exo_kwargs = {k: v for k, v in exo_kwargs.items() if k in sig.parameters} diff --git a/tests/data_handling/test_exo_data_handling.py b/tests/data_handling/test_exo_data_handling.py index 0a147fdd6..5bdfd264a 100644 --- a/tests/data_handling/test_exo_data_handling.py +++ b/tests/data_handling/test_exo_data_handling.py @@ -4,6 +4,7 @@ import shutil import numpy as np +import pytest from sup3r import TEST_DATA_DIR from sup3r.preprocessing.data_handling import ExogenousDataHandler @@ -14,19 +15,26 @@ os.path.join(TEST_DATA_DIR, 'va_test.nc'), os.path.join(TEST_DATA_DIR, 'orog_test.nc'), os.path.join(TEST_DATA_DIR, 'zg_test.nc')] -FEATURES = ['topography'] TARGET = (13.67, 125.0) SHAPE = (8, 8) S_ENHANCE = [1, 4] -AGG_FACTORS = [4, 1] +T_ENHANCE = [1, 1] +S_AGG_FACTORS = [4, 1] +T_AGG_FACTORS = [1, 1] -def test_exo_cache(): +@pytest.mark.parametrize('feature', ['topography', 'sza']) +def test_exo_cache(feature): """Test exogenous data caching and re-load""" # no cached data try: - base = ExogenousDataHandler(FILE_PATHS, FEATURES, FP_WTK, S_ENHANCE, - AGG_FACTORS, target=TARGET, shape=SHAPE, + base = ExogenousDataHandler(FILE_PATHS, feature, + source_file=FP_WTK, + s_enhancements=S_ENHANCE, + t_enhancements=T_ENHANCE, + s_agg_factors=S_AGG_FACTORS, + t_agg_factors=T_AGG_FACTORS, + target=TARGET, shape=SHAPE, input_handler='DataHandlerNCforCC') for i, arr in enumerate(base.data): assert arr.shape[0] == SHAPE[0] * S_ENHANCE[i] @@ -41,8 +49,13 @@ def test_exo_cache(): # load cached data try: - cache = ExogenousDataHandler(FILE_PATHS, FEATURES, FP_WTK, S_ENHANCE, - AGG_FACTORS, target=TARGET, shape=SHAPE, + cache = ExogenousDataHandler(FILE_PATHS, feature, + source_file=FP_WTK, + s_enhancements=S_ENHANCE, + t_enhancements=T_ENHANCE, + s_agg_factors=S_AGG_FACTORS, + t_agg_factors=T_AGG_FACTORS, + target=TARGET, shape=SHAPE, input_handler='DataHandlerNCforCC') except Exception as e: if os.path.exists('./exo_cache/'): diff --git a/tests/data_handling/test_utils_topo.py b/tests/data_handling/test_utils_topo.py index 3b3fa8398..294bede19 100644 --- a/tests/data_handling/test_utils_topo.py +++ b/tests/data_handling/test_utils_topo.py @@ -1,14 +1,17 @@ # -*- coding: utf-8 -*- """pytests for topography utilities""" import os + +import matplotlib.pyplot as plt import numpy as np import pytest from scipy.spatial import KDTree -import matplotlib.pyplot as plt from sup3r import TEST_DATA_DIR -from sup3r.utilities.topo import TopoExtractNC, TopoExtractH5 - +from sup3r.preprocessing.data_handling.exo_extraction import ( + TopoExtractH5, + TopoExtractNC, +) FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') TARGET = (39.001, -105.15) @@ -22,32 +25,33 @@ def test_topo_extraction_h5(agg_factor, plot=False): """Test the spatial enhancement of a test grid and then the lookup of the elevation data to a reference WTK file (also the same file for the test)""" - te = TopoExtractH5(FP_WTK, FP_WTK, s_enhance=2, agg_factor=agg_factor, + te = TopoExtractH5(FP_WTK, FP_WTK, s_enhance=2, t_enhance=1, + t_agg_factor=1, s_agg_factor=agg_factor, target=TARGET, shape=SHAPE) - hr_elev = te.hr_elev + hr_elev = te.data tree = KDTree(te.source_lat_lon) # bottom left _, i = tree.query(TARGET, k=agg_factor) - elev = te.source_elevation[i].mean() + elev = te.source_data[i].mean() assert np.allclose(elev, hr_elev[-1, 0]) # top right _, i = tree.query((39.35, -105.2), k=agg_factor) - elev = te.source_elevation[i].mean() + elev = te.source_data[i].mean() assert np.allclose(elev, hr_elev[0, 0]) for idy in range(10, 20): for idx in range(10, 20): lat, lon = te.hr_lat_lon[idy, idx, :] _, i = tree.query((lat, lon), k=agg_factor) - elev = te.source_elevation[i].mean() + elev = te.source_data[i].mean() assert np.allclose(elev, hr_elev[idy, idx]) if plot: a = plt.scatter(te.source_lat_lon[:, 1], te.source_lat_lon[:, 0], - c=te.source_elevation, marker='s', s=150) + c=te.source_data, marker='s', s=150) plt.colorbar(a) plt.savefig('./source_elevation.png') plt.close() @@ -62,32 +66,33 @@ def test_topo_extraction_h5(agg_factor, plot=False): def test_topo_extraction_nc(agg_factor, plot=False): """Test the spatial enhancement of a test grid and then the lookup of the elevation data to a reference WRF file (also the same file for the test)""" - te = TopoExtractNC(FP_WRF, FP_WRF, s_enhance=2, agg_factor=agg_factor, + te = TopoExtractNC(FP_WRF, FP_WRF, s_enhance=2, t_enhance=1, + s_agg_factor=agg_factor, t_agg_factor=1, target=WRF_TARGET, shape=WRF_SHAPE) - hr_elev = te.hr_elev + hr_elev = te.data tree = KDTree(te.source_lat_lon) # bottom left _, i = tree.query(WRF_TARGET, k=agg_factor) - elev = te.source_elevation[i].mean() + elev = te.source_data[i].mean() assert np.allclose(elev, hr_elev[-1, 0]) # top right _, i = tree.query((19.4, -123.6), k=agg_factor) - elev = te.source_elevation[i].mean() + elev = te.source_data[i].mean() assert np.allclose(elev, hr_elev[0, 0]) for idy in range(4, 8): for idx in range(4, 8): lat, lon = te.hr_lat_lon[idy, idx, :] _, i = tree.query((lat, lon), k=agg_factor) - elev = te.source_elevation[i].mean() + elev = te.source_data[i].mean() assert np.allclose(elev, hr_elev[idy, idx]) if plot: a = plt.scatter(te.source_lat_lon[:, 1], te.source_lat_lon[:, 0], - c=te.source_elevation, marker='s', s=150) + c=te.source_data, marker='s', s=150) plt.colorbar(a) plt.savefig('./source_elevation.png') plt.close() diff --git a/tests/forward_pass/test_multi_step.py b/tests/forward_pass/test_multi_step.py index 8db52c8f1..f610829f6 100644 --- a/tests/forward_pass/test_multi_step.py +++ b/tests/forward_pass/test_multi_step.py @@ -13,7 +13,6 @@ SolarMultiStepGan, SpatialThenTemporalGan, Sup3rGan, - TemporalThenSpatialGan, ) FEATURES = ['U_100m', 'V_100m'] @@ -162,11 +161,11 @@ def test_temporal_then_spatial_gan(): model1.save(fp1) model2.save(fp2) - ms_model = TemporalThenSpatialGan.load(fp1, fp2) + ms_model = MultiStepGan.load([fp2, fp1]) x = np.ones((1, 10, 10, 4, len(FEATURES))) out = ms_model.generate(x) - assert out.shape == (1, 60, 60, 16, 2) + assert out.shape == (16, 60, 60, 2) def test_spatial_gan_then_linear_interp(): diff --git a/tests/training/test_train_gan_exo.py b/tests/training/test_train_gan_exo.py index 66cda197c..741019feb 100644 --- a/tests/training/test_train_gan_exo.py +++ b/tests/training/test_train_gan_exo.py @@ -11,7 +11,6 @@ from sup3r.models import Sup3rGan from sup3r.models.data_centric import Sup3rGanDC from sup3r.preprocessing.batch_handling import ( - BatchHandlerCC, BatchHandlerDC, SpatialBatchHandler, SpatialBatchHandlerCC, @@ -36,103 +35,6 @@ TARGET_COORD = (39.01, -105.15) -def test_wind_cc_model(log=False): - """Test the wind climate change wtk super res model. - - NOTE that the full 10x model is too big to train on the 20x20 test data. - """ - - handler = DataHandlerH5WindCC(INPUT_FILE_W, FEATURES_W, - target=TARGET_W, shape=SHAPE, - temporal_slice=slice(None, None, 2), - time_roll=-7, - sample_shape=(20, 20, 96), - worker_kwargs=dict(max_workers=1), - train_only_features=['topography']) - batcher = BatchHandlerCC([handler], batch_size=2, n_batches=2, - s_enhance=4, sub_daily_shape=None) - if log: - init_logger('sup3r', log_level='DEBUG') - - fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_4x_24x_3f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') - - Sup3rGan.seed() - model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - - with tempfile.TemporaryDirectory() as td: - model.train(batcher, - input_resolution={'spatial': '16km', - 'temporal': '3600min'}, - n_epoch=1, - weight_gen_advers=0.0, - train_gen=True, train_disc=False, - checkpoint_int=None, - out_dir=os.path.join(td, 'test_{epoch}')) - - assert 'test_0' in os.listdir(td) - assert model.meta['class'] == 'Sup3rGan' - assert 'topography' not in model.output_features - assert len(model.output_features) == len(FEATURES_W) - 1 - - x = np.random.uniform(0, 1, (1, 4, 4, 4, 4)) - y = model.generate(x) - assert y.shape[0] == x.shape[0] - assert y.shape[1] == x.shape[1] * 4 - assert y.shape[2] == x.shape[2] * 4 - assert y.shape[3] == x.shape[3] * 24 - assert y.shape[4] == x.shape[4] - 1 - - -def test_wind_cc_model_spatial(log=False): - """Test the wind climate change wtk super res model with spatial - enhancement only. - """ - handler = DataHandlerH5WindCC(INPUT_FILE_W, - ('U_100m', 'V_100m', 'topography'), - target=TARGET_W, shape=SHAPE, - temporal_slice=slice(None, None, 2), - time_roll=-7, - val_split=0.1, - sample_shape=(20, 20), - worker_kwargs=dict(max_workers=1), - train_only_features=['topography']) - - batcher = SpatialBatchHandlerCC([handler], batch_size=2, n_batches=2, - s_enhance=2) - - if log: - init_logger('sup3r', log_level='DEBUG') - - fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') - fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') - - Sup3rGan.seed() - model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - - with tempfile.TemporaryDirectory() as td: - model.train(batcher, - input_resolution={'spatial': '16km', - 'temporal': '3600min'}, - n_epoch=1, - weight_gen_advers=0.0, - train_gen=True, train_disc=False, - checkpoint_int=None, - out_dir=os.path.join(td, 'test_{epoch}')) - - assert 'test_0' in os.listdir(td) - assert model.meta['output_features'] == ['U_100m', 'V_100m'] - assert model.meta['class'] == 'Sup3rGan' - assert 'topography' not in model.output_features - - x = np.random.uniform(0, 1, (4, 30, 30, 3)) - y = model.generate(x) - assert y.shape[0] == x.shape[0] - assert y.shape[1] == x.shape[1] * 2 - assert y.shape[2] == x.shape[2] * 2 - assert y.shape[3] == x.shape[3] - 1 - - @pytest.mark.parametrize('custom_layer', ['Sup3rAdder', 'Sup3rConcat']) def test_wind_hi_res_topo(custom_layer, log=False): """Test a special wind cc model with the custom Sup3rAdder or Sup3rConcat diff --git a/tests/training/test_train_wind_conditional_moments.py b/tests/training/test_train_wind_conditional_moments.py index d33efe46e..20cd092d2 100644 --- a/tests/training/test_train_wind_conditional_moments.py +++ b/tests/training/test_train_wind_conditional_moments.py @@ -109,6 +109,7 @@ def test_wind_non_cc_hi_res_topo_mom1(custom_layer, batch_class, Sup3rCondMom.seed() model = Sup3rCondMom(gen_model, learning_rate=1e-4) + input_resolution = {'spatial': '8km', 'temporal': '60min'} with tempfile.TemporaryDirectory() as td: if out_dir_root is None: out_dir_root = td @@ -120,6 +121,7 @@ def test_wind_non_cc_hi_res_topo_mom1(custom_layer, batch_class, assert f'test_{n_epoch-1}' in os.listdir(out_dir_root) assert model.meta['output_features'] == ['U_100m', 'V_100m'] assert model.meta['class'] == 'Sup3rCondMom' + assert model.meta['input_resolution'] == input_resolution assert 'topography' in batcher.output_features assert 'topography' not in model.output_features From 4a2056e1e183b8fa777841b84feb1bbfd58cf5ee Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 27 Sep 2023 12:51:24 -0600 Subject: [PATCH 11/15] Removed tests from test_train_gan_exo which don't use exogenous features. --- requirements.txt | 2 +- sup3r/models/abstract.py | 20 +- sup3r/models/multi_step.py | 225 ++++++------------ sup3r/utilities/pytest.py | 14 +- .../data_handling/test_dual_data_handling.py | 14 +- tests/forward_pass/test_forward_pass.py | 8 +- tests/forward_pass/test_solar_module.py | 19 +- tests/forward_pass/test_surface_model.py | 5 +- tests/output/test_qa.py | 16 +- tests/pipeline/test_cli.py | 65 ++--- tests/pipeline/test_pipeline.py | 27 ++- tests/training/test_train_gan_lr_era.py | 2 +- tests/utilities/test_utilities.py | 19 +- 13 files changed, 193 insertions(+), 243 deletions(-) diff --git a/requirements.txt b/requirements.txt index 06de3fb18..ec27570ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ matplotlib>=3.1 -NREL-rex>=0.2.82 +NREL-rex>=0.2.84 NREL-phygnn>=0.0.23 NREL-rev<0.8.0 NREL-gaps>=0.4.0 diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 11d55094e..55c1b8ffe 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -86,17 +86,31 @@ def input_dims(self): else: return 5 + # pylint: disable=E1101 @property def s_enhance(self): """Factor by which model will enhance spatial resolution. Used in model training during high res coarsening""" - return self.meta.get('s_enhance', None) - + s_enhance = self.meta.get('s_enhance', None) + if s_enhance is None and hasattr(self, '_gen'): + s_enhancements = [getattr(layer, '_spatial_mult', 1) + for layer in self._gen.layers] + s_enhance = np.product(s_enhancements) + self.meta['s_enhance'] = int(s_enhance) + return s_enhance + + # pylint: disable=E1101 @property def t_enhance(self): """Factor by which model will enhance temporal resolution. Used in model training during high res coarsening""" - return self.meta.get('t_enhance', None) + t_enhance = self.meta.get('t_enhance', None) + if t_enhance is None and hasattr(self, '_gen'): + t_enhancements = [getattr(layer, '_temporal_mult', 1) + for layer in self._gen.layers] + t_enhance = np.product(t_enhancements) + self.meta['t_enhance'] = int(t_enhance) + return t_enhance @property def input_resolution(self): diff --git a/sup3r/models/multi_step.py b/sup3r/models/multi_step.py index 3cc7557be..58fe80f3e 100644 --- a/sup3r/models/multi_step.py +++ b/sup3r/models/multi_step.py @@ -374,6 +374,57 @@ def load(cls, spatial_model_dirs, temporal_model_dirs, verbose=True): return cls(s_models, t_models) + def _split_exo_dict(self, split_step, exogenous_data=None): + """Split exogenous_data into two dicts based on split_step. The first + dict has only model steps less than split_step. The second dict has + only model steps greater than or equal to split_step. + + Parameters + ---------- + split_step : int + Step index to use for splitting. If this is for a + SpatialThenTemporal model split_step should be len(spatial_models). + If this is for a TemporalThenSpatial model split_step should be + len(temporal_models). + exogenous_data : dict + Dictionary of exogenous feature data with entries describing + whether features should be combined at input, a mid network layer, + or with output. e.g. + {'topography': {'steps': [ + {'combine_type': 'input', 'model': 0, 'data': ..., + 'resolution': ...}, + {'combine_type': 'layer', 'model': 0, 'data': ..., + 'resolution': ...}]}} + Each array in in 'data' key has 3D or 4D shape: + (spatial_1, spatial_2, 1) + (spatial_1, spatial_2, n_temporal, 1) + + Returns + ------- + split_exo_1 : dict + Same as input dictionary but with only entries with 'model': + model_step where model_step is less than split_step + split_exo_2 : dict + Same as input dictionary but with only entries with 'model': + model_step where model_step is greater than or equal to split_step + """ + split_exo_1 = {} + split_exo_2 = {} + if exogenous_data is not None: + exo_data = copy.deepcopy(exogenous_data) + for feature in exo_data: + steps = [step for step in exo_data[feature]['steps'] + if step['model'] < split_step] + if steps: + split_exo_1[feature] = {'steps': steps} + steps = [step for step in exo_data[feature]['steps'] + if step['model'] >= split_step] + for step in steps: + step.update({'model': step['model'] - split_step}) + if steps: + split_exo_2[feature] = {'steps': steps} + return split_exo_1, split_exo_2 + class SpatialThenTemporalGan(SpatialThenTemporalBase): """A two-step GAN where the first step is a spatial-only enhancement on a @@ -422,65 +473,6 @@ def meta(self): temporal_models = [self.temporal_models.meta] return (*spatial_models, *temporal_models) - @property - def training_features(self): - """Get the list of input feature names that the first spatial - generative model in this SpatialThenTemporalGan model requires as - input.""" - return self.spatial_models.training_features - - @property - def output_features(self): - """Get the list of output feature names that the last spatiotemporal - interpolation model in this SpatialThenTemporalGan model outputs.""" - return self.temporal_models.output_features - - def _split_exo_spatial_temporal(self, exogenous_data=None): - """Split exogenous_data into spatial_exo and temporal_exo eacho of - which are then passed through MultiStepGan models - - Parameters - ---------- - exogenous_data : dict - Dictionary of exogenous feature data with entries describing - whether features should be combined at input, a mid network layer, - or with output. e.g. - {'topography': {'steps': [ - {'combine_type': 'input', 'model': 0, 'data': ..., - 'resolution': ...}, - {'combine_type': 'layer', 'model': 0, 'data': ..., - 'resolution': ...}]}} - Each array in in 'data' key has 3D or 4D shape: - (spatial_1, spatial_2, 1) - (spatial_1, spatial_2, n_temporal, 1) - - Returns - ------- - spatial_exo : dict - Same as input dictionary but with only entries with 'model': - model_step where model_step corresponds to a spatial model step - temporal_exo : dict - Same as input dictionary but with only entries with 'model': - model_step where model_step corresponds to a temporal model step - """ - spatial_exo = {} - temporal_exo = {} - if exogenous_data is not None: - exo_data = copy.deepcopy(exogenous_data) - for feature in exo_data: - steps = [step for step in exo_data[feature]['steps'] - if step['model'] < len(self.spatial_models)] - if steps: - spatial_exo[feature] = {'steps': steps} - steps = [step for step in exo_data[feature]['steps'] - if step['model'] >= len(self.spatial_models)] - t_shift = len(self.spatial_models) - for step in steps: - step.update({'model': step['model'] - t_shift}) - if steps: - temporal_exo[feature] = {'steps': steps} - return spatial_exo, temporal_exo - def generate(self, low_res, norm_in=True, un_norm_out=True, exogenous_data=None): """Use the generator model to generate high res data from low res @@ -518,7 +510,8 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, """ logger.debug('Data input to the 1st step spatial-only ' 'enhancement has shape {}'.format(low_res.shape)) - s_exo, t_exo = self._split_exo_spatial_temporal(exogenous_data) + s_exo, t_exo = self._split_exo_dict( + split_step=len(self.spatial_models), exogenous_data=exogenous_data) try: hi_res = self.spatial_models.generate( low_res, norm_in=norm_in, un_norm_out=True, @@ -592,65 +585,6 @@ def meta(self): return (*temporal_models, *spatial_models) - @property - def training_features(self): - """Get the list of input feature names that the first temporal - generative model in this TemporalThenSpatialGan model requires as - input.""" - return self.temporal_models.training_features - - @property - def output_features(self): - """Get the list of output feature names that the last spatial - interpolation model in this TemporalThenSpatialGan model outputs.""" - return self.spatial_models.output_features - - def _split_exo_temporal_spatial(self, exogenous_data=None): - """Split exogenous_data into spatial_exo and temporal_exo eacho of - which are then passed through MultiStepGan models - - Parameters - ---------- - exogenous_data : dict - Dictionary of exogenous feature data with entries describing - whether features should be combined at input, a mid network layer, - or with output. e.g. - {'topography': {'steps': [ - {'combine_type': 'input', 'model': 0, 'data': ..., - 'resolution': ...}, - {'combine_type': 'layer', 'model': 0, 'data': ..., - 'resolution': ...}]}} - Each array in in 'data' key has 3D or 4D shape: - (spatial_1, spatial_2, 1) - (spatial_1, spatial_2, n_temporal, 1) - - Returns - ------- - temporal_exo : dict - Same as input dictionary but with only entries with 'model': - model_step where model_step corresponds to a temporal model step - spatial_exo : dict - Same as input dictionary but with only entries with 'model': - model_step where model_step corresponds to a spatial model step - """ - spatial_exo = {} - temporal_exo = {} - if exogenous_data is not None: - exo_data = copy.deepcopy(exogenous_data) - for feature in exo_data: - steps = [step for step in exo_data[feature]['steps'] - if step['model'] < len(self.temporal_models)] - if steps: - temporal_exo[feature] = {'steps': steps} - steps = [step for step in exo_data[feature]['steps'] - if step['model'] >= len(self.temporal_models)] - s_shift = len(self.temporal_models) - for step in steps: - step.update({'model': step['model'] - s_shift}) - if steps: - spatial_exo[feature] = {'steps': steps} - return temporal_exo, spatial_exo - def generate(self, low_res, norm_in=True, un_norm_out=True, exogenous_data=None): """Use the generator model to generate high res data from low res @@ -688,7 +622,9 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, """ logger.debug('Data input to the 1st step (spatio)temporal ' 'enhancement has shape {}'.format(low_res.shape)) - t_exo, s_exo = self._split_exo_temporal_spatial(exogenous_data) + t_exo, s_exo = self._split_exo_dict( + split_step=len(self.temporal_models), + exogenous_data=exogenous_data) assert low_res.shape[0] == 1, 'Low res input can only have 1 obs!' @@ -772,7 +708,10 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, meters (must match spatial_1, spatial_2 from low_res), and the second entry includes a 2D (lat, lon) array of high-resolution surface elevation data in meters. e.g. - {'topography': {'steps': [{'data': lr_topo}, {'data': hr_topo'}]}} + {'topography': { + 'steps': [ + {'model': 0, 'combine_type': 'input', 'data': lr_topo}, + {'model': 0, 'combine_type': 'output', 'data': hr_topo'}]}} Returns ------- @@ -787,42 +726,12 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, 'enhancement has shape {}'.format(low_res.shape)) msg = ('MultiStepSurfaceMetGan needs exogenous_data with two ' - 'entries for low and high res topography inputs.') - assert exogenous_data is not None, msg - exo_data = [step['data'] - for step in exogenous_data['topography']['steps']] - assert isinstance(exo_data, (list, tuple)), msg - assert len(exo_data) == 2, msg - - try: - hi_res = self.spatial_models.generate( - low_res, exogenous_data=exogenous_data) - except Exception as e: - msg = ('Could not run the 1st step spatial-only GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - logger.debug('Data output from the 1st step spatial-only ' - 'enhancement has shape {}'.format(hi_res.shape)) - hi_res = np.transpose(hi_res, axes=(1, 2, 0, 3)) - hi_res = np.expand_dims(hi_res, axis=0) - logger.debug('Data from the 1st step spatial-only enhancement has ' - 'been reshaped to {}'.format(hi_res.shape)) - - try: - hi_res = self.temporal_models.generate( - hi_res, norm_in=True, un_norm_out=un_norm_out) - except Exception as e: - msg = ('Could not run the 2nd step (spatio)temporal GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e + 'topography steps, for low and high res topography inputs.') + exo_check = (exogenous_data is not None + and len(exogenous_data['topography']['steps']) == 2) + assert exo_check, msg - logger.debug('Final multistep GAN output has shape: {}' - .format(hi_res.shape)) - - return hi_res + return super().generate(low_res, norm_in, un_norm_out, exogenous_data) @classmethod def load(cls, surface_model_class='SurfaceSpatialMetModel', @@ -1098,7 +1007,9 @@ def generate(self, low_res, norm_in=True, un_norm_out=True, logger.debug('Data input to the SolarMultiStepGan has shape {} which ' 'will be split up for solar- and wind-only features.' .format(low_res.shape)) - s_exo, t_exo = self._split_exo_spatial_temporal(exogenous_data) + s_exo, t_exo = self._split_exo_dict( + split_step=len(self.spatial_models), + exogenous_data=exogenous_data) try: hi_res_wind = self.spatial_wind_models.generate( low_res[..., self.idf_wind], diff --git a/sup3r/utilities/pytest.py b/sup3r/utilities/pytest.py index 7c8581c74..c72a8c9a9 100644 --- a/sup3r/utilities/pytest.py +++ b/sup3r/utilities/pytest.py @@ -34,13 +34,13 @@ def make_fake_nc_files(td, input_file, n_files): ] fake_files = [os.path.join(td, f'input_{date}') for date in fake_dates] for i in range(n_files): - input_dset = xr.open_dataset(input_file) - with xr.Dataset(input_dset) as dset: - dset['Times'][:] = np.array( - [fake_times[i].encode('ASCII')], dtype='|S19' - ) - dset['XTIME'][:] = i - dset.to_netcdf(fake_files[i]) + with xr.open_dataset(input_file) as input_dset: + with xr.Dataset(input_dset) as dset: + dset['Times'][:] = np.array( + [fake_times[i].encode('ASCII')], dtype='|S19' + ) + dset['XTIME'][:] = i + dset.to_netcdf(fake_files[i]) return fake_files diff --git a/tests/data_handling/test_dual_data_handling.py b/tests/data_handling/test_dual_data_handling.py index 13697d50a..fe42324a7 100644 --- a/tests/data_handling/test_dual_data_handling.py +++ b/tests/data_handling/test_dual_data_handling.py @@ -9,12 +9,14 @@ from sup3r import TEST_DATA_DIR from sup3r.preprocessing.data_handling.dual_data_handling import ( - DualDataHandler, ) + DualDataHandler, +) from sup3r.preprocessing.data_handling.h5_data_handling import DataHandlerH5 from sup3r.preprocessing.data_handling.nc_data_handling import DataHandlerNC -from sup3r.preprocessing.dual_batch_handling import (DualBatchHandler, - SpatialDualBatchHandler, - ) +from sup3r.preprocessing.dual_batch_handling import ( + DualBatchHandler, + SpatialDualBatchHandler, +) from sup3r.utilities.utilities import spatial_coarsening FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') @@ -23,7 +25,7 @@ FEATURES = ['U_100m', 'V_100m'] -def test_dual_data_handler(log=True, +def test_dual_data_handler(log=False, full_shape=(20, 20), sample_shape=(10, 10, 1), plot=True): @@ -71,7 +73,7 @@ def test_dual_data_handler(log=True, bbox_inches='tight') -def test_regrid_caching(log=True, +def test_regrid_caching(log=False, full_shape=(20, 20), sample_shape=(10, 10, 1)): """Test caching and loading of regridded data""" diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 31ce6bdf2..1518c6bc5 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -571,8 +571,8 @@ def test_fwp_multi_step_model(): s_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) s_model.meta['training_features'] = ['U_100m', 'V_100m'] s_model.meta['output_features'] = ['U_100m', 'V_100m'] - s_model.meta['s_enhance'] = 2 - s_model.meta['t_enhance'] = 1 + assert s_model.s_enhance == 2 + assert s_model.t_enhance == 1 _ = s_model.generate(np.ones((4, 10, 10, 2))) fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') @@ -580,8 +580,8 @@ def test_fwp_multi_step_model(): st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) st_model.meta['training_features'] = ['U_100m', 'V_100m'] st_model.meta['output_features'] = ['U_100m', 'V_100m'] - st_model.meta['s_enhance'] = 3 - st_model.meta['t_enhance'] = 4 + assert st_model.s_enhance == 3 + assert st_model.t_enhance == 4 _ = st_model.generate(np.ones((4, 10, 10, 6, 2))) with tempfile.TemporaryDirectory() as td: diff --git a/tests/forward_pass/test_solar_module.py b/tests/forward_pass/test_solar_module.py index 29865bf78..319ca74bb 100644 --- a/tests/forward_pass/test_solar_module.py +++ b/tests/forward_pass/test_solar_module.py @@ -1,23 +1,23 @@ # -*- coding: utf-8 -*- """Test the custom sup3r solar module that converts GAN clearsky ratio outputs to irradiance data.""" -import pytest -from click.testing import CliRunner import glob import json import os -import numpy as np import tempfile +from pathlib import Path + import matplotlib.pyplot as plt +import numpy as np +import pytest +from click.testing import CliRunner from rex import Resource -from pathlib import Path from sup3r import TEST_DATA_DIR from sup3r.solar import Solar -from sup3r.utilities.utilities import pd_date_range -from sup3r.utilities.pytest import make_fake_cs_ratio_files from sup3r.solar.solar_cli import from_config as solar_main - +from sup3r.utilities.pytest import make_fake_cs_ratio_files +from sup3r.utilities.utilities import pd_date_range NSRDB_FP = os.path.join(TEST_DATA_DIR, 'test_nsrdb_clearsky_2018.h5') GAN_META = {'s_enhance': 4, 't_enhance': 24} @@ -169,13 +169,14 @@ def test_solar_cli(runner): log_file = os.path.join(td, 'logs/sup3r_solar.log') if os.path.exists(log_file): - with open(log_file, 'r') as f: + with open(log_file) as f: logs = ''.join(list(f.readlines())) msg += '\nlogs:\n{}'.format(logs) raise RuntimeError(msg) - status_files = glob.glob(os.path.join(td, 'jobstatus_*.json')) + status_files = glob.glob(os.path.join(f'{td}/.gaps/', + '*jobstatus*.json')) assert len(status_files) == len(fps) out_files = glob.glob(os.path.join(td, 'chunks/*_irradiance.h5')) diff --git a/tests/forward_pass/test_surface_model.py b/tests/forward_pass/test_surface_model.py index 2c72044d0..d4358e51e 100644 --- a/tests/forward_pass/test_surface_model.py +++ b/tests/forward_pass/test_surface_model.py @@ -161,7 +161,10 @@ def test_multi_step_surface(s_enhance=2, t_enhance=2): exo_tmp = { 'topography': { - 'steps': [{'data': topo_lr}, {'data': topo_hr}]}} + 'steps': [{'model': 0, 'combine_type': 'input', + 'data': topo_lr}, + {'model': 0, 'combine_type': 'output', + 'data': topo_hr}]}} hi_res = ms_model.generate(low_res, exogenous_data=exo_tmp) target_shape = (1, diff --git a/tests/output/test_qa.py b/tests/output/test_qa.py index d7a4747d3..06ebb0a04 100644 --- a/tests/output/test_qa.py +++ b/tests/output/test_qa.py @@ -1,21 +1,21 @@ # -*- coding: utf-8 -*- """pytests for data handling""" import os +import pickle import tempfile -import pandas as pd + import numpy as np -from rex import Resource, init_logger +import pandas as pd import xarray as xr -import pickle +from rex import Resource, init_logger -from sup3r import TEST_DATA_DIR, CONFIG_DIR -from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy +from sup3r import CONFIG_DIR, TEST_DATA_DIR from sup3r.models import Sup3rGan -from sup3r.utilities.pytest import make_fake_nc_files +from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.qa.qa import Sup3rQa from sup3r.qa.stats import Sup3rStatsMulti from sup3r.qa.utilities import continuous_dist - +from sup3r.utilities.pytest import make_fake_nc_files FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') TARGET_COORD = (39.01, -105.15) @@ -201,7 +201,7 @@ def test_qa_h5(): assert np.allclose(test_diff, qa_diff, atol=0.01) -def test_stats(log=True): +def test_stats(log=False): """Test the WindStats module with forward pass output to h5 file.""" if log: diff --git a/tests/pipeline/test_cli.py b/tests/pipeline/test_cli.py index 396e8ccdb..547ae1c74 100644 --- a/tests/pipeline/test_cli.py +++ b/tests/pipeline/test_cli.py @@ -1,25 +1,24 @@ # -*- coding: utf-8 -*- """pytests for sup3r cli""" +import glob import json import os import tempfile -import pytest -import glob -import numpy as np -from rex import ResourceX -from rex import init_logger +import numpy as np +import pytest from click.testing import CliRunner +from rex import ResourceX, init_logger -from sup3r.pipeline.pipeline_cli import from_config as pipe_main +from sup3r import CONFIG_DIR, TEST_DATA_DIR +from sup3r.models.base import Sup3rGan from sup3r.pipeline.forward_pass_cli import from_config as fwp_main -from sup3r.preprocessing.data_extract_cli import from_config as dh_main +from sup3r.pipeline.pipeline_cli import from_config as pipe_main from sup3r.postprocessing.data_collect_cli import from_config as dc_main +from sup3r.preprocessing.data_extract_cli import from_config as dh_main from sup3r.qa.visual_qa_cli import from_config as vqa_main -from sup3r.models.base import Sup3rGan -from sup3r.utilities.pytest import make_fake_nc_files, make_fake_h5_chunks +from sup3r.utilities.pytest import make_fake_h5_chunks, make_fake_nc_files from sup3r.utilities.utilities import correct_path -from sup3r import TEST_DATA_DIR, CONFIG_DIR INPUT_FILE = os.path.join(TEST_DATA_DIR, 'test_wrf_2014-10-01_00_00_00') FEATURES = ['U_100m', 'V_100m', 'BVF2_200m'] @@ -34,7 +33,7 @@ def runner(): return CliRunner() -def test_pipeline_fwp_collect(runner, log=True): +def test_pipeline_fwp_collect(runner, log=False): """Test pipeline with forward pass and data collection""" if log: init_logger('sup3r', log_level='DEBUG') @@ -97,8 +96,8 @@ def test_pipeline_fwp_collect(runner, log=True): with open(pipe_config_path, 'w') as fh: json.dump(pipe_config, fh) - result = runner.invoke(pipe_main, ['-c', pipe_config_path, - '-v', '--monitor']) + result = runner.invoke(pipe_main, ['-c', pipe_config_path, '-v', + '--monitor']) if result.exit_code != 0: import traceback msg = ('Failed with error {}' @@ -197,7 +196,7 @@ def test_data_collection_cli(runner): assert np.allclose(wd_true, fh['winddirection_100m'], atol=0.1) -def test_fwd_pass_cli(runner, log=True): +def test_fwd_pass_cli(runner, log=False): """Test cli call to run forward pass""" if log: init_logger('sup3r', log_level='DEBUG') @@ -210,8 +209,8 @@ def test_fwd_pass_cli(runner, log=True): _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) model.meta['training_features'] = FEATURES model.meta['output_features'] = FEATURES[:2] - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + assert model.s_enhance == 3 + assert model.t_enhance == 4 with tempfile.TemporaryDirectory() as td: input_files = make_fake_nc_files(td, INPUT_FILE, 8) @@ -289,20 +288,27 @@ def test_data_extract_cli(runner): assert len(glob.glob(f'{log_file}')) == 1 -def test_pipeline_fwp_qa(runner): +def test_pipeline_fwp_qa(runner, log=True): """Test the sup3r pipeline with Forward Pass and QA modules via pipeline cli""" + if log: + init_logger('sup3r', log_level='DEBUG') + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + input_resolution = {'spatial': '12km', 'temporal': '60min'} + model.meta['input_resolution'] = input_resolution + assert model.input_resolution == input_resolution + assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) model.meta['training_features'] = FEATURES model.meta['output_features'] = FEATURES[:2] - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + assert model.s_enhance == 3 + assert model.t_enhance == 4 with tempfile.TemporaryDirectory() as td: input_files = make_fake_nc_files(td, INPUT_FILE, 8) @@ -354,8 +360,8 @@ def test_pipeline_fwp_qa(runner): with open(pipe_config_path, 'w') as fh: json.dump(pipe_config, fh) - result = runner.invoke(pipe_main, ['-c', pipe_config_path, - '-v', '--monitor']) + result = runner.invoke(pipe_main, ['-c', pipe_config_path, '-v', + '--monitor']) if result.exit_code != 0: import traceback msg = ('Failed with error {}' @@ -365,31 +371,34 @@ def test_pipeline_fwp_qa(runner): assert len(glob.glob(f'{td}/fwp_log*.log')) == 1 assert len(glob.glob(f'{td}/out*.h5')) == 1 assert len(glob.glob(f'{td}/qa.h5')) == 1 - assert len(glob.glob(f'{td}/*_status.json')) == 1 - - status_fp = glob.glob(f'{td}/*_status.json')[0] - with open(status_fp, 'r') as f: + status_fps = glob.glob(f'{td}/.gaps/*status*.json') + assert len(status_fps) == 1 + status_fp = status_fps[0] + with open(status_fp) as f: status = json.load(f) assert len(status) == 2 assert len(status['forward-pass']) == 2 fwp_status = status['forward-pass'] del fwp_status['pipeline_index'] - fwp_status = list(fwp_status.values())[0] + fwp_status = next(iter(fwp_status.values())) assert fwp_status['job_status'] == 'successful' assert fwp_status['time'] > 0 assert len(status['qa']) == 2 qa_status = status['qa'] del qa_status['pipeline_index'] - qa_status = list(qa_status.values())[0] + qa_status = next(iter(qa_status.values())) assert qa_status['job_status'] == 'successful' assert qa_status['time'] > 0 -def test_visual_qa(runner): +def test_visual_qa(runner, log=False): """Make sure visual qa module creates the right number of plots""" + if log: + init_logger('sup3r', log_level='DEBUG') + time_step = 500 plot_features = ['windspeed_100m', 'winddirection_100m'] with ResourceX(FP_WTK) as res: diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 7274d7a39..92057cf9e 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -1,17 +1,17 @@ """Sup3r pipeline tests""" -import tempfile -import os +import glob import json +import os import shutil -import numpy as np -import glob +import tempfile +import numpy as np from rex import ResourceX -from sup3r.pipeline.pipeline import Sup3rPipeline as Pipeline +from sup3r import CONFIG_DIR, TEST_DATA_DIR from sup3r.models.base import Sup3rGan +from sup3r.pipeline.pipeline import Sup3rPipeline as Pipeline from sup3r.utilities.pytest import make_fake_nc_files -from sup3r import TEST_DATA_DIR, CONFIG_DIR INPUT_FILE = os.path.join(TEST_DATA_DIR, 'test_wrf_2014-10-01_00_00_00') FEATURES = ['U_100m', 'V_100m', 'BVF2_200m'] @@ -26,10 +26,15 @@ def test_fwp_pipeline(): Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) + input_resolution = {'spatial': '12km', 'temporal': '60min'} + model.meta['input_resolution'] = input_resolution + assert model.input_resolution == input_resolution + assert model.output_resolution == {'spatial': '4km', 'temporal': '15min'} + _ = model.generate(np.ones((4, 8, 8, 4, len(FEATURES)))) model.meta['training_features'] = FEATURES model.meta['output_features'] = FEATURES[:2] - model.meta['s_enhance'] = 3 - model.meta['t_enhance'] = 4 + assert model.s_enhance == 3 + assert model.t_enhance == 4 with tempfile.TemporaryDirectory() as td: input_files = make_fake_nc_files(td, INPUT_FILE, 20) @@ -98,8 +103,10 @@ def test_fwp_pipeline(): with ResourceX(fp_out) as f: assert len(f.time_index) == t_enhance * n_tsteps - status_file = glob.glob(os.path.join(td, '*_status.json'))[0] - with open(status_file, 'r') as fh: + status_files = glob.glob(os.path.join(f'{td}/.gaps/', '*status.json')) + assert len(status_files) == 1 + status_file = status_files[0] + with open(status_file) as fh: status = json.load(fh) assert all(s in status for s in ('forward-pass', 'data-collect')) assert all(s not in str(status) diff --git a/tests/training/test_train_gan_lr_era.py b/tests/training/test_train_gan_lr_era.py index 776dc8db3..52c1974ff 100644 --- a/tests/training/test_train_gan_lr_era.py +++ b/tests/training/test_train_gan_lr_era.py @@ -123,7 +123,7 @@ def test_train_spatial( assert loss_og.numpy() < loss_dummy.numpy() -def test_train_st(n_epoch=3, log=True): +def test_train_st(n_epoch=3, log=False): """Test basic spatiotemporal model training with only gen content loss.""" if log: init_logger('sup3r', log_level='DEBUG') diff --git a/tests/utilities/test_utilities.py b/tests/utilities/test_utilities.py index ec52c93dd..93ae8be82 100644 --- a/tests/utilities/test_utilities.py +++ b/tests/utilities/test_utilities.py @@ -15,19 +15,22 @@ from sup3r.postprocessing.file_handling import OutputHandler from sup3r.utilities.interpolate_log_profile import LogLinInterpolator from sup3r.utilities.regridder import RegridOutput -from sup3r.utilities.utilities import (get_chunk_slices, spatial_coarsening, - st_interp, transform_rotate_wind, - uniform_box_sampler, - uniform_time_sampler, - weighted_box_sampler, - weighted_time_sampler, - ) +from sup3r.utilities.utilities import ( + get_chunk_slices, + spatial_coarsening, + st_interp, + transform_rotate_wind, + uniform_box_sampler, + uniform_time_sampler, + weighted_box_sampler, + weighted_time_sampler, +) FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') FP_ERA = os.path.join(TEST_DATA_DIR, 'test_era5_co_2012.nc') -def test_log_interp(log=True): +def test_log_interp(log=False): """Make sure log interp generates reasonable output (e.g. between input levels)""" if log: From fb6787f3d20247124d60b6ff0499661efd808b64 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 28 Sep 2023 14:34:55 -0600 Subject: [PATCH 12/15] error catching for invalid user specified enhancement factors and input_resolution. --- sup3r/models/abstract.py | 118 +++++++++++++++++++----- sup3r/utilities/pytest.py | 30 +++--- tests/forward_pass/test_solar_module.py | 2 +- tests/pipeline/test_pipeline.py | 2 +- tests/training/test_train_gan.py | 37 +++++++- 5 files changed, 145 insertions(+), 44 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 55c1b8ffe..49d6fa177 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -87,29 +87,45 @@ def input_dims(self): return 5 # pylint: disable=E1101 + def get_s_enhance_from_layers(self): + """Compute factor by which model will enhance spatial resolution from + layer attributes. Used in model training during high res coarsening""" + s_enhance = None + if hasattr(self, '_gen'): + s_enhancements = [getattr(layer, '_spatial_mult', 1) + for layer in self._gen.layers] + s_enhance = int(np.product(s_enhancements)) + return s_enhance + + # pylint: disable=E1101 + def get_t_enhance_from_layers(self): + """Compute factor by which model will enhance temporal resolution from + layer attributes. Used in model training during high res coarsening""" + t_enhance = None + if hasattr(self, '_gen'): + t_enhancements = [getattr(layer, '_temporal_mult', 1) + for layer in self._gen.layers] + t_enhance = int(np.product(t_enhancements)) + return t_enhance + @property def s_enhance(self): """Factor by which model will enhance spatial resolution. Used in model training during high res coarsening""" s_enhance = self.meta.get('s_enhance', None) - if s_enhance is None and hasattr(self, '_gen'): - s_enhancements = [getattr(layer, '_spatial_mult', 1) - for layer in self._gen.layers] - s_enhance = np.product(s_enhancements) - self.meta['s_enhance'] = int(s_enhance) + if s_enhance is None: + s_enhance = self.get_s_enhance_from_layers() + self.meta['s_enhance'] = s_enhance return s_enhance - # pylint: disable=E1101 @property def t_enhance(self): """Factor by which model will enhance temporal resolution. Used in model training during high res coarsening""" t_enhance = self.meta.get('t_enhance', None) - if t_enhance is None and hasattr(self, '_gen'): - t_enhancements = [getattr(layer, '_temporal_mult', 1) - for layer in self._gen.layers] - t_enhance = np.product(t_enhancements) - self.meta['t_enhance'] = int(t_enhance) + if t_enhance is None: + t_enhance = self.get_t_enhance_from_layers() + self.meta['t_enhance'] = t_enhance return t_enhance @property @@ -118,21 +134,66 @@ def input_resolution(self): 'temporal':...}""" return self.meta.get('input_resolution', None) + def _get_numerical_resolutions(self): + """Get the input and output resolutions without units""" + ires_num = {k: int(re.search(r'\d+', v).group(0)) + for k, v in self.input_resolution.items()} + enhancements = {'spatial': self.s_enhance, + 'temporal': self.t_enhance} + ores_num = {k: v // enhancements[k] for k, v in ires_num.items()} + return ires_num, ores_num + + def _ensure_valid_input_resolution(self): + """Ensure ehancement factors evenly divide input_resolution""" + + if self.input_resolution is None: + return + + ires_num, ores_num = self._get_numerical_resolutions() + s_enhance = self.meta['s_enhance'] + t_enhance = self.meta['t_enhance'] + check = ( + ires_num['temporal'] / ores_num['temporal'] == t_enhance + and ires_num['spatial'] / ores_num['spatial'] == s_enhance) + msg = (f'Enhancement factors (s_enhance={s_enhance}, ' + f't_enhance={t_enhance}) do not evenly divide ' + f'input resolution ({self.input_resolution})') + if not check: + logger.error(msg) + raise RuntimeError(msg) + + def _ensure_valid_enhancement_factors(self): + """Ensure user provided enhancement factors are the same as those + computed from layer attributes""" + t_enhance = self.meta.get('t_enhance', None) + s_enhance = self.meta.get('s_enhance', None) + if s_enhance is None or t_enhance is None: + return + + layer_se = self.get_s_enhance_from_layers() + layer_te = self.get_t_enhance_from_layers() + layer_se = layer_se if layer_se is not None else self.meta['s_enhance'] + layer_te = layer_te if layer_te is not None else self.meta['t_enhance'] + msg = (f'Enhancement factors computed from layer attributes ' + f'(s_enhance={layer_se}, t_enhance={layer_te}) ' + f'conflict with user provided values (s_enhance={s_enhance}, ' + f't_enhance={t_enhance})') + check = layer_se == s_enhance or layer_te == t_enhance + if not check: + logger.error(msg) + raise RuntimeError(msg) + @property def output_resolution(self): """Resolution of output data. Given as a dictionary {'spatial':..., 'temporal':...}""" - input_res = self.input_resolution - output_res = {} if input_res is None else input_res.copy() - if input_res is not None: - input_temporal = re.search(r'\d+', input_res['temporal']).group(0) - input_spatial = re.search(r'\d+', input_res['spatial']).group(0) - output_temporal = int(self.t_enhance * input_temporal) - output_spatial = int(self.s_enhance * input_spatial) - output_res['temporal'].replace(input_temporal, - str(output_temporal)) - output_res['spatial'].replace(input_spatial, - str(output_spatial)) + output_res = self.meta.get('output_resolution', None) + if self.input_resolution is not None and output_res is None: + output_res = self.input_resolution.copy() + ires_num, ores_num = self._get_numerical_resolutions() + output_res = {k: v.replace(str(ires_num[k]), str(ores_num[k])) + for k, v in self.input_resolution.items()} + self.meta['output_resolution'] = output_res return output_res def _combine_fwp_input(self, low_res, exogenous_data=None): @@ -362,6 +423,9 @@ def _check_exo_features(self, **kwargs): Same as input but with exogenous_features removed from output features """ + if 'output_features' not in kwargs: + return kwargs + output_features = kwargs['output_features'] msg = (f'Last {len(self.exogenous_features)} output features from the ' f'data handler must be {self.exogenous_features} ' @@ -381,8 +445,9 @@ def set_model_params(self, **kwargs): Parameters ---------- kwargs : dict - Keyword arguments including 'training_features', 'output_features', - 'smoothed_features', 's_enhance', 't_enhance', 'smoothing' + Keyword arguments including 'input_resolution', + 'training_features', 'output_features', 'smoothed_features', + 's_enhance', 't_enhance', 'smoothing' """ kwargs = self._check_exo_features(**kwargs) @@ -391,7 +456,7 @@ def set_model_params(self, **kwargs): keys = [k for k in keys if k in kwargs] for var in keys: - val = getattr(self, var, None) + val = self.meta.get(var, None) if val is None: self.meta[var] = kwargs[var] elif val != kwargs[var]: @@ -402,6 +467,9 @@ def set_model_params(self, **kwargs): logger.warning(msg) warn(msg) + self._ensure_valid_enhancement_factors() + self._ensure_valid_input_resolution() + def save_params(self, out_dir): """ Parameters diff --git a/sup3r/utilities/pytest.py b/sup3r/utilities/pytest.py index c72a8c9a9..899bd3b99 100644 --- a/sup3r/utilities/pytest.py +++ b/sup3r/utilities/pytest.py @@ -34,11 +34,12 @@ def make_fake_nc_files(td, input_file, n_files): ] fake_files = [os.path.join(td, f'input_{date}') for date in fake_dates] for i in range(n_files): + if os.path.exists(fake_files[i]): + os.remove(fake_files[i]) with xr.open_dataset(input_file) as input_dset: with xr.Dataset(input_dset) as dset: dset['Times'][:] = np.array( - [fake_times[i].encode('ASCII')], dtype='|S19' - ) + [fake_times[i].encode('ASCII')], dtype='|S19') dset['XTIME'][:] = i dset.to_netcdf(fake_files[i]) return fake_files @@ -68,12 +69,12 @@ def make_fake_multi_time_nc_files(td, input_file, n_steps, n_files): dummy_files = [] for i, files in enumerate(fake_files): dummy_file = os.path.join( - td, f'multi_timestep_file_{str(i).zfill(3)}.nc' - ) + td, f'multi_timestep_file_{str(i).zfill(3)}.nc') + if os.path.exists(dummy_file): + os.remove(dummy_file) dummy_files.append(dummy_file) with xr.open_mfdataset( - files, combine='nested', concat_dim='Time' - ) as dset: + files, combine='nested', concat_dim='Time') as dset: dset.to_netcdf(dummy_file) return dummy_files @@ -104,14 +105,15 @@ def make_fake_era_files(td, input_file, n_files): ] fake_files = [os.path.join(td, f'input_{date}') for date in fake_dates] for i in range(n_files): - input_dset = xr.open_dataset(input_file) - with xr.Dataset(input_dset) as dset: - dset['Times'][:] = np.array( - [fake_times[i].encode('ASCII')], dtype='|S19' - ) - dset['XTIME'][:] = i - dset = dset.rename({'U': 'u', 'V': 'v'}) - dset.to_netcdf(fake_files[i]) + if os.path.exists(fake_files[i]): + os.remove(fake_files[i]) + with xr.open_dataset(input_file) as input_dset: + with xr.Dataset(input_dset) as dset: + dset['Times'][:] = np.array( + [fake_times[i].encode('ASCII')], dtype='|S19') + dset['XTIME'][:] = i + dset = dset.rename({'U': 'u', 'V': 'v'}) + dset.to_netcdf(fake_files[i]) return fake_files diff --git a/tests/forward_pass/test_solar_module.py b/tests/forward_pass/test_solar_module.py index 319ca74bb..428509b16 100644 --- a/tests/forward_pass/test_solar_module.py +++ b/tests/forward_pass/test_solar_module.py @@ -119,7 +119,7 @@ def test_chunk_file_parser(): with tempfile.TemporaryDirectory() as td: for idt in id_temporal: for ids in id_spatial: - fn = ('sup3r_chunk_out_{}_{}.h5'.format(idt, ids)) + fn = 'sup3r_chunk_out_{}_{}.h5'.format(idt, ids) fp = os.path.join(td, fn) Path(fp).touch() all_st_ids.append('{}_{}'.format(idt, ids)) diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 92057cf9e..fd03b989d 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -103,7 +103,7 @@ def test_fwp_pipeline(): with ResourceX(fp_out) as f: assert len(f.time_index) == t_enhance * n_tsteps - status_files = glob.glob(os.path.join(f'{td}/.gaps/', '*status.json')) + status_files = glob.glob(os.path.join(td, '.gaps', '*status.json')) assert len(status_files) == 1 status_file = status_files[0] with open(status_file) as fh: diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 5aa4c96d2..418dc3061 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -128,7 +128,7 @@ def test_train_st_weight_update(n_epoch=2, log=False): adaptive_update_bounds = (0.9, 0.99) with tempfile.TemporaryDirectory() as td: model.train(batch_handler, - input_resolution={'spatial': '8km', 'temporal': '30min'}, + input_resolution={'spatial': '12km', 'temporal': '60min'}, n_epoch=n_epoch, weight_gen_advers=1e-6, train_gen=True, train_disc=True, @@ -236,7 +236,7 @@ def test_train_st_dc(n_epoch=2, log=False): # test that the normalized number of samples from each bin is close # to the weight for that bin model.train(batch_handler, - input_resolution={'spatial': '8km', 'temporal': '30min'}, + input_resolution={'spatial': '12km', 'temporal': '60min'}, n_epoch=n_epoch, weight_gen_advers=0.0, train_gen=True, train_disc=False, @@ -286,7 +286,7 @@ def test_train_st(n_epoch=2, log=False): with tempfile.TemporaryDirectory() as td: # test that training works and reduces loss model.train(batch_handler, - input_resolution={'spatial': '8km', 'temporal': '30min'}, + input_resolution={'spatial': '12km', 'temporal': '60min'}, n_epoch=n_epoch, weight_gen_advers=0.0, train_gen=True, train_disc=False, @@ -382,3 +382,34 @@ def test_optimizer_update(): assert model.optimizer.learning_rate == 0.1 assert model.optimizer_disc.learning_rate == 0.1 + + +def test_input_res_check(): + """Make sure error is raised for invalid input resolution""" + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + + Sup3rGan.seed() + model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4, + learning_rate_disc=4e-4) + + with pytest.raises(RuntimeError): + model.set_model_params( + input_resolution={'spatial': '22km', 'temporal': '9min'}) + + +def test_enhancement_check(): + """Make sure error is raised for invalid enhancement factor inputs""" + + fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') + + Sup3rGan.seed() + model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4, + learning_rate_disc=4e-4) + + with pytest.raises(RuntimeError): + model.set_model_params( + input_resolution={'spatial': '12km', 'temporal': '60min'}, + s_enhance=7, t_enhance=3) From 0e874c63d4c2149d8e6212d802e4e632bc915ae7 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 2 Oct 2023 12:36:26 -0700 Subject: [PATCH 13/15] PR changes. --- .../sup3rcc/run_configs/wind/config_fwp.json | 9 +- sup3r/bias/bias_calc_cli.py | 6 +- sup3r/models/__init__.py | 11 +- sup3r/models/abstract.py | 122 ++++---- sup3r/models/base.py | 2 +- sup3r/models/conditional_moments.py | 2 +- sup3r/models/data_centric.py | 4 +- sup3r/models/multi_step.py | 253 +---------------- sup3r/pipeline/forward_pass.py | 239 ++++------------ .../data_handling/exogenous_data_handling.py | 261 +++++++++++++++--- tests/data_handling/test_exo_data_handling.py | 17 +- tests/forward_pass/test_forward_pass.py | 5 +- tests/forward_pass/test_forward_pass_exo.py | 31 +-- tests/forward_pass/test_multi_step.py | 16 +- 14 files changed, 398 insertions(+), 580 deletions(-) diff --git a/examples/sup3rcc/run_configs/wind/config_fwp.json b/examples/sup3rcc/run_configs/wind/config_fwp.json index acf0e5564..16ce97b48 100755 --- a/examples/sup3rcc/run_configs/wind/config_fwp.json +++ b/examples/sup3rcc/run_configs/wind/config_fwp.json @@ -1,13 +1,12 @@ { "file_paths": "PLACEHOLDER", "model_kwargs": { - "spatial_model_dirs": [ + "model_dirs": [ "./sup3rcc_models_202303/sup3rcc_wind_step1_5x_1x_6f/", - "./sup3rcc_models_202303/sup3rcc_wind_step2_5x_1x_6f/" - ], - "temporal_model_dirs": "./sup3rcc_models_202303/sup3rcc_wind_step3_1x_24x_6f/" + "./sup3rcc_models_202303/sup3rcc_wind_step2_5x_1x_6f/", + "./sup3rcc_models_202303/sup3rcc_wind_step3_1x_24x_6f/"] }, - "model_class": "SpatialThenTemporalGan", + "model_class": "MultiStepGan", "out_pattern": "./chunks/sup3r_chunk_{file_id}.h5", "log_pattern": "./logs/sup3r_fwp_log_{node_index}.log", "bias_correct_method": "monthly_local_linear_bc", diff --git a/sup3r/bias/bias_calc_cli.py b/sup3r/bias/bias_calc_cli.py index 9e34ee571..d20b16f09 100644 --- a/sup3r/bias/bias_calc_cli.py +++ b/sup3r/bias/bias_calc_cli.py @@ -3,15 +3,15 @@ sup3r bias correction calculation CLI entry points. """ import copy -import click import logging import os +import click + import sup3r.bias.bias_calc from sup3r.utilities import ModuleName -from sup3r.version import __version__ from sup3r.utilities.cli import BaseCLI - +from sup3r.version import __version__ logger = logging.getLogger(__name__) diff --git a/sup3r/models/__init__.py b/sup3r/models/__init__.py index 779b65cf0..1179231d9 100644 --- a/sup3r/models/__init__.py +++ b/sup3r/models/__init__.py @@ -4,16 +4,9 @@ from .conditional_moments import Sup3rCondMom from .data_centric import Sup3rGanDC from .linear import LinearInterp -from .multi_step import ( - MultiStepGan, - MultiStepSurfaceMetGan, - SolarMultiStepGan, - SpatialThenTemporalGan, - TemporalThenSpatialGan, -) +from .multi_step import MultiStepGan, MultiStepSurfaceMetGan, SolarMultiStepGan from .solar_cc import SolarCC from .surface import SurfaceSpatialMetModel -SPATIAL_FIRST_MODELS = (SpatialThenTemporalGan, - MultiStepSurfaceMetGan, +SPATIAL_FIRST_MODELS = (MultiStepSurfaceMetGan, SolarMultiStepGan) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 49d6fa177..79dd05ede 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -130,9 +130,15 @@ def t_enhance(self): @property def input_resolution(self): - """Resolution of input data. Given as a dictionary {'spatial':..., - 'temporal':...}""" - return self.meta.get('input_resolution', None) + """Resolution of input data. Given as a dictionary {'spatial': '...km', + 'temporal': '...min'}. The numbers are required to be integers in the + units specified. The units are not strict as long as the resolution + of the exogenous data, when extracting exogenous data, is specified + in the same units.""" + input_resolution = self.meta.get('input_resolution', None) + msg = 'model.input_resolution is None. This needs to be set.' + assert input_resolution is not None, msg + return input_resolution def _get_numerical_resolutions(self): """Get the input and output resolutions without units""" @@ -185,11 +191,11 @@ def _ensure_valid_enhancement_factors(self): @property def output_resolution(self): - """Resolution of output data. Given as a dictionary {'spatial':..., - 'temporal':...}""" + """Resolution of output data. Given as a dictionary + {'spatial': '...km', 'temporal': '...min'}. This is computed from the + input resolution and the enhancement factors.""" output_res = self.meta.get('output_resolution', None) if self.input_resolution is not None and output_res is None: - output_res = self.input_resolution.copy() ires_num, ores_num = self._get_numerical_resolutions() output_res = {k: v.replace(str(ires_num[k]), str(ores_num[k])) for k, v in self.input_resolution.items()} @@ -223,19 +229,20 @@ def _combine_fwp_input(self, low_res, exogenous_data=None): (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - low_res_shape = low_res.shape - check = (exogenous_data is not None - and low_res.shape[-1] < len(self.training_features)) - if check: - exo_data = {k: v for k, v in exogenous_data.items() - if k in self.training_features} - for i, (feature, entry) in enumerate(exo_data.items()): - f_idx = low_res_shape[-1] + i - training_feature = self.training_features[f_idx] - msg = ('The ordering of features in exogenous_data conflicts ' - 'with the ordering of training features. Received ' - f'{feature} instead of {training_feature}.') - assert feature == training_feature, msg + if exogenous_data is None: + return low_res + + training_features = ([] if self.training_features is None + else self.training_features) + fnum_diff = len(training_features) - low_res.shape[-1] + exo_feats = ([] if fnum_diff <= 0 + else self.training_features[-fnum_diff:]) + msg = ('Provided exogenous_data is missing some required features ' + f'({exo_feats})') + assert all(feature in exogenous_data for feature in exo_feats), msg + if exogenous_data is not None and fnum_diff > 0: + for feature in exo_feats: + entry = exogenous_data[feature] combine_types = [step['combine_type'] for step in entry['steps']] if 'input' in combine_types: @@ -272,19 +279,20 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None): (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - hi_res_shape = hi_res.shape[-1] - check = (exogenous_data is not None - and hi_res.shape[-1] < len(self.output_features)) - if check: - exo_data = {k: v for k, v in exogenous_data.items() - if k in self.training_features} - for i, (feature, entry) in enumerate(exo_data.items()): - f_idx = hi_res_shape[-1] + i - training_feature = self.training_features[f_idx] - msg = ('The ordering of features in exogenous_data conflicts ' - 'with the ordering of training features. Received ' - f'{feature} instead of {training_feature}.') - assert feature == training_feature, msg + if exogenous_data is None: + return hi_res + + output_features = ([] if self.output_features is None + else self.output_features) + fnum_diff = len(output_features) - hi_res.shape[-1] + exo_feats = ([] if fnum_diff <= 0 + else self.output_features[-fnum_diff:]) + msg = ('Provided exogenous_data is missing some required features ' + f'({exo_feats})') + assert all(feature in exogenous_data for feature in exo_feats), msg + if exogenous_data is not None and fnum_diff > 0: + for feature in exo_feats: + entry = exogenous_data[feature] combine_types = [step['combine_type'] for step in entry['steps']] if 'output' in combine_types: @@ -318,27 +326,6 @@ def _combine_loss_input(self, high_res_true, high_res_gen): high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1) return high_res_gen - def _get_exo_val_loss_input(self, high_res): - """Get exogenous feature data from high_res - - Parameters - ---------- - high_res : tf.Tensor - Ground truth high resolution spatiotemporal data. - - Returns - ------- - exo_data : dict - Dictionary of exogenous feature data used as input to tf_generate. - e.g. {'topography': np.ndarray(...)} - """ - exo_data = {} - for feature in self.exogenous_features: - f_idx = self.training_features.index(feature) - exo_fdata = high_res[..., f_idx: f_idx + 1] - exo_data[feature] = exo_fdata - return exo_data - @property def exogenous_features(self): """Get list of exogenous filter names the model uses. If the model has @@ -810,6 +797,27 @@ def load_saved_params(out_dir, verbose=True): return params + def get_exo_loss_input(self, high_res): + """Get exogenous feature data from high_res + + Parameters + ---------- + high_res : tf.Tensor + Ground truth high resolution spatiotemporal data. + + Returns + ------- + exo_data : dict + Dictionary of exogenous feature data used as input to tf_generate. + e.g. {'topography': tf.Tensor(...)} + """ + exo_data = {} + for feature in self.exogenous_features: + f_idx = self.training_features.index(feature) + exo_fdata = high_res[..., f_idx: f_idx + 1] + exo_data[feature] = exo_fdata + return exo_data + @staticmethod def get_loss_fun(loss): """Get the initialized loss function class from the sup3r loss library @@ -1290,9 +1298,7 @@ def generate(self, (n_obs, spatial_1, spatial_2, n_features) (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ - low_res = self._combine_fwp_input(low_res, exogenous_data) - if norm_in and self._means is not None: low_res = self.norm_input(low_res) @@ -1408,15 +1414,11 @@ def get_single_grad(self, loss_details : dict Namespace of the breakdown of loss components """ - hi_res_exo = {} - for feature in self.exogenous_features: - f_idx = self.training_features.index(feature) - hi_res_exo[feature] = hi_res_true[..., f_idx: f_idx + 1] - with tf.device(device_name): with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(training_weights) + hi_res_exo = self.get_exo_loss_input(hi_res_true) hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss_out = self.calc_loss(hi_res_true, hi_res_gen, **calc_loss_kwargs) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index 30d6ab22d..b6fa7bc52 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -602,7 +602,7 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - val_exo_data = self._get_exo_val_loss_input(val_batch.high_res) + val_exo_data = self.get_exo_loss_input(val_batch.high_res) high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data) _, v_loss_details = self.calc_loss( val_batch.high_res, high_res_gen, diff --git a/sup3r/models/conditional_moments.py b/sup3r/models/conditional_moments.py index af7835e9d..b1415f616 100644 --- a/sup3r/models/conditional_moments.py +++ b/sup3r/models/conditional_moments.py @@ -283,7 +283,7 @@ def calc_val_loss(self, batch_handler, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - val_exo_data = self._get_exo_val_loss_input(val_batch.high_res) + val_exo_data = self.get_exo_loss_input(val_batch.high_res) output_gen = self._tf_generate(val_batch.low_res, val_exo_data) _, v_loss_details = self.calc_loss( val_batch.output, output_gen, val_batch.mask) diff --git a/sup3r/models/data_centric.py b/sup3r/models/data_centric.py index 2afc75896..2f08fef17 100644 --- a/sup3r/models/data_centric.py +++ b/sup3r/models/data_centric.py @@ -38,7 +38,7 @@ def calc_val_loss_gen(self, batch_handler, weight_gen_advers): """ losses = [] for obs in batch_handler.val_data: - exo_data = self._get_exo_val_loss_input(obs.high_res) + exo_data = self.get_exo_loss_input(obs.high_res) gen = self._tf_generate(obs.low_res, exo_data) loss, _ = self.calc_loss(obs.high_res, gen, weight_gen_advers=weight_gen_advers, @@ -66,7 +66,7 @@ def calc_val_loss_gen_content(self, batch_handler): """ losses = [] for obs in batch_handler.val_data: - exo_data = self._get_exo_val_loss_input(obs.high_res) + exo_data = self.get_exo_loss_input(obs.high_res) gen = self._tf_generate(obs.low_res, exo_data) loss = self.calc_loss_gen_content(obs.high_res, gen) losses.append(float(loss)) diff --git a/sup3r/models/multi_step.py b/sup3r/models/multi_step.py index 58fe80f3e..c76fd47df 100644 --- a/sup3r/models/multi_step.py +++ b/sup3r/models/multi_step.py @@ -426,244 +426,7 @@ def _split_exo_dict(self, split_step, exogenous_data=None): return split_exo_1, split_exo_2 -class SpatialThenTemporalGan(SpatialThenTemporalBase): - """A two-step GAN where the first step is a spatial-only enhancement on a - 4D tensor and the second step is a (spatio)temporal enhancement on a 5D - tensor. - - NOTE: The low res input to the spatial enhancement should be a 4D tensor of - the shape (temporal, spatial_1, spatial_2, features) where temporal - (usually the observation index) is a series of sequential timesteps that - will be transposed to a 5D tensor of shape - (1, spatial_1, spatial_2, temporal, features) tensor and then fed to the - 2nd-step (spatio)temporal GAN. - """ - - @property - def models(self): - """Get an ordered tuple of the Sup3rGan models that are part of this - MultiStepGan - """ - if isinstance(self.spatial_models, MultiStepGan): - spatial_models = self.spatial_models.models - else: - spatial_models = [self.spatial_models] - if isinstance(self.temporal_models, MultiStepGan): - temporal_models = self.temporal_models.models - else: - temporal_models = [self.temporal_models] - - return (*spatial_models, *temporal_models) - - @property - def meta(self): - """Get a tuple of meta data dictionaries for all models - - Returns - ------- - tuple - """ - if isinstance(self.spatial_models, MultiStepGan): - spatial_models = self.spatial_models.meta - else: - spatial_models = [self.spatial_models.meta] - if isinstance(self.temporal_models, MultiStepGan): - temporal_models = self.temporal_models.meta - else: - temporal_models = [self.temporal_models.meta] - return (*spatial_models, *temporal_models) - - def generate(self, low_res, norm_in=True, un_norm_out=True, - exogenous_data=None): - """Use the generator model to generate high res data from low res - input. This is the public generate function. - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data to the 1st step spatial GAN, which is a - 4D array of shape: (temporal, spatial_1, spatial_2, n_features) - norm_in : bool - Flag to normalize low_res input data if the self.means, - self.stdevs attributes are available. The generator should always - received normalized data with mean=0 stdev=1. - un_norm_out : bool - Flag to un-normalize synthetically generated output data to physical - units - exogenous_data : list - List of arrays of exogenous_data with length equal to the - number of model steps. e.g. If we want to include topography as - an exogenous feature in a spatial + temporal multistep model then - we need to provide a list of length=2 with topography at the low - spatial resolution and at the high resolution. If we include more - than one exogenous feature the ordering must be consistent. - Each array in the list has 3D or 4D shape: - (spatial_1, spatial_2, n_features) - (temporal, spatial_1, spatial_2, n_features) - - Returns - ------- - hi_res : ndarray - Synthetically generated high-resolution data output from the 2nd - step (spatio)temporal GAN with a 5D array shape: - (1, spatial_1, spatial_2, n_temporal, n_features) - """ - logger.debug('Data input to the 1st step spatial-only ' - 'enhancement has shape {}'.format(low_res.shape)) - s_exo, t_exo = self._split_exo_dict( - split_step=len(self.spatial_models), exogenous_data=exogenous_data) - try: - hi_res = self.spatial_models.generate( - low_res, norm_in=norm_in, un_norm_out=True, - exogenous_data=s_exo) - except Exception as e: - msg = ('Could not run the 1st step spatial-only GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - logger.debug('Data output from the 1st step spatial-only ' - 'enhancement has shape {}'.format(hi_res.shape)) - hi_res = np.transpose(hi_res, axes=(1, 2, 0, 3)) - hi_res = np.expand_dims(hi_res, axis=0) - logger.debug('Data from the 1st step spatial-only enhancement has ' - 'been reshaped to {}'.format(hi_res.shape)) - - try: - hi_res = self.temporal_models.generate( - hi_res, norm_in=True, un_norm_out=un_norm_out, - exogenous_data=t_exo) - except Exception as e: - msg = ('Could not run the 2nd step (spatio)temporal GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - logger.debug('Final multistep GAN output has shape: {}' - .format(hi_res.shape)) - - return hi_res - - -class TemporalThenSpatialGan(SpatialThenTemporalBase): - """A two-step GAN where the first step is a spatiotemporal enhancement on a - 5D tensor and the second step is a spatial enhancement on a 4D tensor. - """ - - @property - def models(self): - """Get an ordered tuple of the Sup3rGan models that are part of this - MultiStepGan - """ - if isinstance(self.spatial_models, MultiStepGan): - spatial_models = self.spatial_models.models - else: - spatial_models = [self.spatial_models] - if isinstance(self.temporal_models, MultiStepGan): - temporal_models = self.temporal_models.models - else: - temporal_models = [self.temporal_models] - - return (*temporal_models, *spatial_models) - - @property - def meta(self): - """Get a tuple of meta data dictionaries for all models - - Returns - ------- - tuple - """ - if isinstance(self.spatial_models, MultiStepGan): - spatial_models = self.spatial_models.meta - else: - spatial_models = [self.spatial_models.meta] - if isinstance(self.temporal_models, MultiStepGan): - temporal_models = self.temporal_models.meta - else: - temporal_models = [self.temporal_models.meta] - - return (*temporal_models, *spatial_models) - - def generate(self, low_res, norm_in=True, un_norm_out=True, - exogenous_data=None): - """Use the generator model to generate high res data from low res - input. This is the public generate function. - - Parameters - ---------- - low_res : np.ndarray - Low-resolution input data, a 5D array of shape: - (1, spatial_1, spatial_2, n_temporal, n_features) - norm_in : bool - Flag to normalize low_res input data if the self.means, - self.stdevs attributes are available. The generator should always - received normalized data with mean=0 stdev=1. - un_norm_out : bool - Flag to un-normalize synthetically generated output data to physical - units - exogenous_data : list - List of arrays of exogenous_data with length equal to the - number of model steps. e.g. If we want to include topography as - an exogenous feature in a temporal + spatial multistep model then - we need to provide a list of length=2 with topography at the low - spatial resolution and at the high resolution. If we include more - than one exogenous feature the ordering must be consistent. - Each array in the list has 3D or 4D shape: - (spatial_1, spatial_2, n_features) - (temporal, spatial_1, spatial_2, n_features) - - Returns - ------- - hi_res : ndarray - Synthetically generated high-resolution data output from the 2nd - step (spatio)temporal GAN with a 5D array shape: - (1, spatial_1, spatial_2, n_temporal, n_features) - """ - logger.debug('Data input to the 1st step (spatio)temporal ' - 'enhancement has shape {}'.format(low_res.shape)) - t_exo, s_exo = self._split_exo_dict( - split_step=len(self.temporal_models), - exogenous_data=exogenous_data) - - assert low_res.shape[0] == 1, 'Low res input can only have 1 obs!' - - try: - hi_res = self.temporal_models.generate( - low_res, norm_in=norm_in, un_norm_out=True, - exogenous_data=t_exo) - except Exception as e: - msg = ('Could not run the 1st step (spatio)temporal GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - logger.debug('Data output from the 1st step (spatio)temporal ' - 'enhancement has shape {}'.format(hi_res.shape)) - hi_res = np.transpose(hi_res[0], axes=(2, 0, 1, 3)) - logger.debug('Data from the 1st step (spatio)temporal enhancement has ' - 'been reshaped to {}'.format(hi_res.shape)) - - try: - hi_res = self.spatial_models.generate( - hi_res, norm_in=True, un_norm_out=un_norm_out, - exogenous_data=s_exo) - except Exception as e: - msg = ('Could not run the 2nd step spatial GAN on input ' - 'shape {}'.format(low_res.shape)) - logger.exception(msg) - raise RuntimeError(msg) from e - - hi_res = np.transpose(hi_res, axes=(1, 2, 0, 3)) - hi_res = np.expand_dims(hi_res, axis=0) - - logger.debug('Final multistep GAN output has shape: {}' - .format(hi_res.shape)) - - return hi_res - - -class MultiStepSurfaceMetGan(SpatialThenTemporalGan): +class MultiStepSurfaceMetGan(MultiStepGan): """A two-step GAN where the first step is a spatial-only enhancement on a 4D tensor of near-surface temperature and relative humidity data, and the second step is a (spatio)temporal enhancement on a 5D tensor. @@ -777,10 +540,12 @@ def load(cls, surface_model_class='SurfaceSpatialMetModel', t_models = TemporalModelClass.load(verbose=verbose, **temporal_model_kwargs) - return cls(s_models, t_models) + s_models = getattr(s_models, 'models', [s_models]) + t_models = getattr(t_models, 'models', [t_models]) + return cls([*s_models, *t_models]) -class SolarMultiStepGan(SpatialThenTemporalGan): +class SolarMultiStepGan(SpatialThenTemporalBase): """Special multi step model for solar clearsky ratio super resolution. This model takes in two parallel models for wind-only and solar-only @@ -796,11 +561,11 @@ def __init__(self, spatial_solar_models, spatial_wind_models, ---------- spatial_solar_models : MultiStepGan A loaded MultiStepGan object representing the one or more spatial - super resolution steps in this composite SpatialThenTemporalGan + super resolution steps in this composite MultiStepGan model that inputs and outputs clearsky_ratio spatial_wind_models : MultiStepGan A loaded MultiStepGan object representing the one or more spatial - super resolution steps in this composite SpatialThenTemporalGan + super resolution steps in this composite MultiStepGan model that inputs and outputs wind u/v features and must include U_200m + V_200m as output features. temporal_solar_models : MultiStepGan @@ -869,13 +634,13 @@ def preflight(self): @property def spatial_models(self): - """Alias for spatial_solar_models to preserve SpatialThenTemporalGan + """Alias for spatial_solar_models to preserve MultiStepGan interface.""" return self.spatial_solar_models @property def temporal_models(self): - """Alias for temporal_solar_models to preserve SpatialThenTemporalGan + """Alias for temporal_solar_models to preserve MultiStepGan interface.""" return self.temporal_solar_models diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index c5c2a9c79..4a96257be 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -7,7 +7,6 @@ import copy import logging import os -import re import warnings from concurrent.futures import as_completed from datetime import datetime as dt @@ -53,8 +52,7 @@ def __init__(self, s_enhancements, t_enhancements, spatial_pad, - temporal_pad, - ): + temporal_pad): """ Parameters ---------- @@ -610,8 +608,7 @@ def __init__(self, exo_kwargs=None, bias_correct_method=None, bias_correct_kwargs=None, - max_nodes=None, - ): + max_nodes=None): """Use these inputs to initialize data handlers on different nodes and to define the size of the data chunks that will be passed through the generator. @@ -700,7 +697,11 @@ def __init__(self, be used in the model. e.g. {'topography': {'file_paths': 'path to input files', 'source_file': 'path to exo data', 'exo_resolution': {'spatial': '1km', 'temporal': None}, 'steps': [{'model': 0, - 'combine_type': 'input'}, {'model': 0, 'combine_type': 'layer'}]} + 'combine_type': 'input'}, {'model': 0, 'combine_type': 'layer'}]}. + Each step can also include s_agg_factor/t_agg_factor to manually + set aggregation of exogenous_data. If they are not specified they + will be computed based on the model resolution and the exogenous + data resolution. bias_correct_method : str | None Optional bias correction function name that can be imported from the :mod:`sup3r.bias.bias_transforms` module. This will transform @@ -848,8 +849,7 @@ def init_handler(self): out = self.input_handler_class(self.file_paths[0], [], target=self.target, shape=self.grid_shape, - worker_kwargs=dict(ti_workers=1), - ) + worker_kwargs=dict(ti_workers=1)) self._init_handler = out return self._init_handler @@ -1093,7 +1093,7 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.max_workers = strategy.max_workers self.pass_workers = strategy.pass_workers self.output_workers = strategy.output_workers - self.exo_kwargs = self.update_exo_extract_kwargs(strategy.exo_kwargs) + self.exo_kwargs = strategy.exo_kwargs self.exo_features = ([] if not self.exo_kwargs else list(self.exo_kwargs)) self.exogenous_data = self.load_exo_data() @@ -1120,169 +1120,6 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.unpadded_input_data = self.data_handler.data[self.lr_slice[0], self.lr_slice[1]] - def _get_res_ratio(self, input_res, exo_res): - """Compute resolution ratio given input and output resolution - - Parameters - ---------- - input_res : str | None - Input resolution. e.g. '30km' or '60min' - exo_res : str | None - Exo resolution. e.g. '1km' or '5min' - - Returns - ------- - res_ratio : int | None - Ratio of input / exo resolution - """ - ires_num = (None if input_res is None - else int(re.search(r'\d+', input_res).group(0))) - eres_num = (None if exo_res is None - else int(re.search(r'\d+', exo_res).group(0))) - i_units = (None if input_res is None - else input_res.replace(str(ires_num), '')) - e_units = (None if exo_res is None - else exo_res.replace(str(eres_num), '')) - msg = 'Received conflicting units for input and exo resolution' - if e_units is not None: - assert i_units == e_units, msg - if ires_num is not None and eres_num is not None: - res_ratio = int(ires_num / eres_num) - else: - res_ratio = None - return res_ratio - - def get_agg_factors(self, input_res, exo_res): - """Compute aggregation ratio for exo data given input and output - resolution - - Parameters - ---------- - input_res : dict | None - Input resolution. e.g. {'spatial': '30km', 'temporal': '60min'} - exo_res : dict | None - Exogenous data resolution. e.g. - {'spatial': '1km', 'temporal': '5min'} - - Returns - ------- - s_agg_factor : int - Spatial aggregation factor for exogenous data extraction. - t_agg_factor : int - Temporal aggregation factor for exogenous data extraction. - """ - input_s_res = None if input_res is None else input_res['spatial'] - exo_s_res = None if exo_res is None else exo_res['spatial'] - s_res_ratio = self._get_res_ratio(input_s_res, exo_s_res) - s_agg_factor = None if s_res_ratio is None else int(s_res_ratio)**2 - input_t_res = None if input_res is None else input_res['temporal'] - exo_t_res = None if exo_res is None else exo_res['temporal'] - t_agg_factor = self._get_res_ratio(input_t_res, exo_t_res) - return s_agg_factor, t_agg_factor - - def _get_single_step_agg_and_enhancement(self, step_dict, exo_resolution): - """Compute agg and enhancement factors for exogenous data extraction - using exo_kwargs single model step. These factors are computed using - exo_resolution and the input/output resolution of each model step. - - Parameters - ---------- - step_dict : dict - Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} - exo_resolution : dict - Resolution of exogenous data. e.g. {'temporal': 15min, 'spatial': - '1km'} - - Returns - ------- - updated_step_dict : dict - Same as input dictionary with s_agg_factor, t_agg_factor, - s_enhance, t_enhance added - """ - model_step = step_dict['model'] - s_enhance = None - t_enhance = None - s_agg_factor = None - t_agg_factor = None - models = getattr(self.model, 'models', [self.model]) - msg = (f'Model index from exo_kwargs ({model_step} exceeds number ' - f'of model steps ({len(models)})') - assert len(models) > model_step, msg - model = models[model_step] - input_res = model.input_resolution - output_res = model.output_resolution - combine_type = step_dict.get('combine_type', None) - - if combine_type.lower() == 'input': - if model_step == 0: - s_enhance = 1 - t_enhance = 1 - else: - s_enhance = np.product( - self.strategy.s_enhancements[:model_step]) - t_enhance = np.product( - self.strategy.t_enhancements[:model_step]) - s_agg_factor, t_agg_factor = self.get_agg_factors( - input_res, exo_resolution) - resolution = input_res - - elif combine_type.lower() in ('output', 'layer'): - s_enhance = np.product( - self.strategy.s_enhancements[:model_step + 1]) - t_enhance = np.product( - self.strategy.t_enhancements[:model_step + 1]) - s_agg_factor, t_agg_factor = self.get_agg_factors( - output_res, exo_resolution) - resolution = output_res - - else: - msg = ('Received exo_kwargs entry without valid combine_type ' - '(input/layer/output)') - raise OSError(msg) - - updated_dict = step_dict.copy() - updated_dict.update({'s_enhance': s_enhance, 't_enhance': t_enhance, - 's_agg_factor': s_agg_factor, - 't_agg_factor': t_agg_factor, - 'resolution': resolution}) - return updated_dict - - def update_exo_extract_kwargs(self, exo_kwargs): - """Compute agg and enhancement factors for all model steps for all - features. - - Parameters - ---------- - exo_kwargs: dict - Full exo_kwargs dictionary with all feature entries. - e.g. {'topography': {'exo_resolution': {'spatial': '1km', - 'temporal': None}, 'steps': [{'model': 0, 'combine_type': 'input'}, - {'model': 0, 'combine_type': 'layer'}]}} - - Returns - ------- - updated_dict : dict - Same as input dictionary with s_agg_factor, t_agg_factor, - s_enhance, t_enhance added to each step entry for all features - """ - if exo_kwargs is not None: - for feature in exo_kwargs: - exo_resolution = exo_kwargs[feature]['exo_resolution'] - steps = exo_kwargs[feature]['steps'] - for i, step in enumerate(steps): - out = self._get_single_step_agg_and_enhancement( - step, exo_resolution) - exo_kwargs[feature]['steps'][i] = out - exo_kwargs[feature]['s_agg_factors'] = [step['s_agg_factor'] - for step in steps] - exo_kwargs[feature]['t_agg_factors'] = [step['t_agg_factor'] - for step in steps] - exo_kwargs[feature]['s_enhancements'] = [step['s_enhance'] - for step in steps] - exo_kwargs[feature]['t_enhancements'] = [step['t_enhance'] - for step in steps] - return exo_kwargs - def load_exo_data(self): """Extract exogenous data for each exo feature and store data in dictionary with key for each exo feature @@ -1304,12 +1141,13 @@ def load_exo_data(self): exo_kwargs['target'] = self.target exo_kwargs['shape'] = self.shape exo_kwargs['temporal_slice'] = self.ti_pad_slice - steps = exo_kwargs['steps'] + exo_kwargs['models'] = getattr(self.model, 'models', + [self.model]) sig = signature(ExogenousDataHandler) exo_kwargs = {k: v for k, v in exo_kwargs.items() if k in sig.parameters} data = ExogenousDataHandler(**exo_kwargs).data - for i, _ in enumerate(steps): + for i, _ in enumerate(exo_kwargs['steps']): exo_data[feature]['steps'][i]['data'] = data[i] shapes = [None if d is None else d.shape for d in data] logger.info( @@ -1602,11 +1440,41 @@ def pad_width(self): return ((pad_s1_start, pad_s1_end), (pad_s2_start, pad_s2_end), (pad_t_start, pad_t_end)) - @staticmethod - def pad_source_data(input_data, - pad_width, - exo_data, - mode='reflect'): + def _get_step_enhance(self, step): + """Get enhancement factors for a given step and combine type. + + Parameters + ---------- + step : dict + Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} + + Returns + ------- + s_enhance : int + Spatial enhancement factor for given step and combine type + t_enhance : int + Temporal enhancement factor for given step and combine type + """ + combine_type = step['combine_type'] + model_step = step['model'] + if combine_type.lower() == 'input': + if model_step == 0: + s_enhance = 1 + t_enhance = 1 + else: + s_enhance = np.product( + self.strategy.s_enhancements[:model_step]) + t_enhance = np.product( + self.strategy.t_enhancements[:model_step]) + + elif combine_type.lower() in ('output', 'layer'): + s_enhance = np.product( + self.strategy.s_enhancements[:model_step + 1]) + t_enhance = np.product( + self.strategy.t_enhancements[:model_step + 1]) + return s_enhance, t_enhance + + def pad_source_data(self, input_data, pad_width, exo_data, mode='reflect'): """Pad the edges of the source data from the data handler. Parameters @@ -1647,12 +1515,13 @@ def pad_source_data(input_data, if exo_data is not None: for feature in exo_data: for i, step in enumerate(exo_data[feature]['steps']): - exo_pad_width = ((step['s_enhance'] * pad_width[0][0], - step['s_enhance'] * pad_width[0][1]), - (step['s_enhance'] * pad_width[1][0], - step['s_enhance'] * pad_width[1][1]), - (step['t_enhance'] * pad_width[2][0], - step['t_enhance'] * pad_width[2][1]), + s_enhance, t_enhance = self._get_step_enhance(step) + exo_pad_width = ((s_enhance * pad_width[0][0], + s_enhance * pad_width[0][1]), + (s_enhance * pad_width[1][0], + s_enhance * pad_width[1][1]), + (t_enhance * pad_width[2][0], + t_enhance * pad_width[2][1]), (0, 0)) new_exo = np.pad(step['data'], exo_pad_width, mode=mode) exo_data[feature]['steps'][i]['data'] = new_exo diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index 884d4a7c4..d77906544 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -2,9 +2,12 @@ import logging import os import pickle +import re import shutil from typing import ClassVar +import numpy as np + from sup3r.preprocessing.data_handling import exo_extraction from sup3r.preprocessing.data_handling.exo_extraction import ( SzaExtract, @@ -35,10 +38,9 @@ class ExogenousDataHandler: def __init__(self, file_paths, feature, - s_enhancements, - t_enhancements, - s_agg_factors, - t_agg_factors, + steps, + models=None, + exo_resolution=None, source_file=None, target=None, shape=None, @@ -60,34 +62,19 @@ def __init__(self, sup3r resolved. feature : str Exogenous feature to extract from source_h5 - s_enhancements : list - List of factors by which the Sup3rGan model will enhance the - spatial dimensions of low resolution data from file_paths input - where the total spatial enhancement is the product of these - factors. For example, if file_paths has 100km data and there are 2 - spatial enhancement steps of 4x and 5x to a nominal resolution of - 5km, s_enhancements should be [1, 4, 5] and exo_steps should be - [0, 1, 2] so that the input to the 4x model gets exogenous data - at 100km (s_enhance=1, exo_step=0), the input to the 5x model gets - exogenous data at 25km (s_enhance=4, exo_step=1), and there is a - 20x (1*4*5) exogeneous data layer available if the second model can - receive a high-res input feature. The length of this list should be - equal to the number of s_agg_factors - t_enhancements : list - List of factors by which the Sup3rGan model will enhance the - temporal dimension of low resolution data from file_paths input - where the total temporal enhancement is the product of these - factors. - s_agg_factors : list - List of factors by which to aggregate the exo_source - data to the spatial resolution of the file_paths input enhanced by - s_enhance. The length of this list should be equal to the number of - s_enhancements - t_agg_factors : list - List of factors by which to aggregate the exo_source - data to the temporal resolution of the file_paths input enhanced by - t_enhance. The length of this list should be equal to the number of - t_enhancements + models : list + List of models used with the given steps list + steps : list + List of dictionaries containing info on which models to use for a + given step index and what type of exo data the step requires. e.g. + [{'model': 0, 'combine_type': 'input'}, + {'model': 0, 'combine_type': 'layer'}] + Each step entry can also contain s_enhance, t_enhance, + s_agg_factor, t_agg_factor. If they are not included they will be + computed using exo_resolution and model attributes + exo_resolution : dict + Dictionary of spatiotemporal resolution for the given exo data + source. e.g. {'spatial': '4km', 'temporal': '60min'} source_file : str Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or 4km) data from which will be mapped to the enhanced grid of the @@ -128,10 +115,9 @@ def __init__(self, """ self.feature = feature - self.s_enhancements = s_enhancements - self.t_enhancements = t_enhancements - self.s_agg_factors = s_agg_factors - self.t_agg_factors = t_agg_factors + self.steps = steps + self.models = models + self.exo_res = exo_resolution self.source_file = source_file self.file_paths = file_paths self.exo_handler = exo_handler @@ -145,6 +131,13 @@ def __init__(self, self.cache_dir = cache_dir self.data = [] + self.input_check() + agg_enhance = self._get_all_agg_and_enhancement() + self.s_enhancements = agg_enhance['s_enhancements'] + self.t_enhancements = agg_enhance['t_enhancements'] + self.s_agg_factors = agg_enhance['s_agg_factors'] + self.t_agg_factors = agg_enhance['t_agg_factors'] + msg = ('Need to provide the same number of enhancement factors and ' f'agg factors. Received s_enhancements={self.s_enhancements}, ' f'and s_agg_factors={self.s_agg_factors}.') @@ -175,6 +168,202 @@ def __init__(self, f" Received {feature}.") raise NotImplementedError(msg) + def input_check(self): + """Make sure agg factors are provided or exo_resolution and models are + provided. Make sure enhancement factors are provided or models are + provided""" + agg_check = all('s_agg_factor' in v for v in self.steps) + agg_check = agg_check and all('t_agg_factor' in v for v in self.steps) + agg_check = (agg_check + or self.models is not None and self.exo_res is not None) + msg = ("ExogenousDataHandler needs s_agg_factor and t_agg_factor " + "provided in each step in steps list or models and " + "exo_resolution") + assert agg_check, msg + en_check = all('s_enhance' in v for v in self.steps) + en_check = en_check and all('t_enhance' in v for v in self.steps) + en_check = en_check or self.models is not None + msg = ("ExogenousDataHandler needs s_enhance and t_enhance " + "provided in each step in steps list or models") + assert en_check, msg + + def _get_res_ratio(self, input_res, exo_res): + """Compute resolution ratio given input and output resolution + + Parameters + ---------- + input_res : str | None + Input resolution. e.g. '30km' or '60min' + exo_res : str | None + Exo resolution. e.g. '1km' or '5min' + + Returns + ------- + res_ratio : int | None + Ratio of input / exo resolution + """ + ires_num = (None if input_res is None + else int(re.search(r'\d+', input_res).group(0))) + eres_num = (None if exo_res is None + else int(re.search(r'\d+', exo_res).group(0))) + i_units = (None if input_res is None + else input_res.replace(str(ires_num), '')) + e_units = (None if exo_res is None + else exo_res.replace(str(eres_num), '')) + msg = 'Received conflicting units for input and exo resolution' + if e_units is not None: + assert i_units == e_units, msg + if ires_num is not None and eres_num is not None: + res_ratio = int(ires_num / eres_num) + else: + res_ratio = None + return res_ratio + + def get_agg_factors(self, input_res, exo_res): + """Compute aggregation ratio for exo data given input and output + resolution + + Parameters + ---------- + input_res : dict | None + Input resolution. e.g. {'spatial': '30km', 'temporal': '60min'} + exo_res : dict | None + Exogenous data resolution. e.g. + {'spatial': '1km', 'temporal': '5min'} + + Returns + ------- + s_agg_factor : int + Spatial aggregation factor for exogenous data extraction. + t_agg_factor : int + Temporal aggregation factor for exogenous data extraction. + """ + input_s_res = None if input_res is None else input_res['spatial'] + exo_s_res = None if exo_res is None else exo_res['spatial'] + s_res_ratio = self._get_res_ratio(input_s_res, exo_s_res) + s_agg_factor = None if s_res_ratio is None else int(s_res_ratio)**2 + input_t_res = None if input_res is None else input_res['temporal'] + exo_t_res = None if exo_res is None else exo_res['temporal'] + t_agg_factor = self._get_res_ratio(input_t_res, exo_t_res) + return s_agg_factor, t_agg_factor + + def _get_single_step_agg(self, step): + """Compute agg factors for exogenous data extraction + using exo_kwargs single model step. These factors are computed using + exo_resolution and the input/output resolution of each model step. If + agg factors are already provided in step they are not overwritten. + + Parameters + ---------- + step : dict + Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} + + Returns + ------- + updated_step : dict + Same as input dictionary with s_agg_factor, t_agg_factor added + """ + if all(key in step for key in ['s_agg_factor', 't_agg_factor']): + return step + + model_step = step['model'] + combine_type = step.get('combine_type', None) + msg = (f'Model index from exo_kwargs ({model_step} exceeds number ' + f'of model steps ({len(self.models)})') + assert len(self.models) > model_step, msg + model = self.models[model_step] + input_res = model.input_resolution + output_res = model.output_resolution + if combine_type.lower() == 'input': + s_agg_factor, t_agg_factor = self.get_agg_factors( + input_res, self.exo_res) + + elif combine_type.lower() in ('output', 'layer'): + s_agg_factor, t_agg_factor = self.get_agg_factors( + output_res, self.exo_res) + + else: + msg = ('Received exo_kwargs entry without valid combine_type ' + '(input/layer/output)') + raise OSError(msg) + + step.update({'s_agg_factor': s_agg_factor, + 't_agg_factor': t_agg_factor}) + return step + + def _get_single_step_enhance(self, step): + """Get enhancement factors for exogenous data extraction + using exo_kwargs single model step. These factors are computed using + stored enhance attributes of each model and the model step provided. + If enhancement factors are already provided in step they are not + overwritten. + + Parameters + ---------- + step : dict + Model step dictionary. e.g. {'model': 0, 'combine_type': 'input'} + + Returns + ------- + updated_step : dict + Same as input dictionary with s_enhance, t_enhance added + """ + if all(key in step for key in ['s_enhance', 't_enhance']): + return step + + model_step = step['model'] + combine_type = step.get('combine_type', None) + msg = (f'Model index from exo_kwargs ({model_step} exceeds number ' + f'of model steps ({len(self.models)})') + assert len(self.models) > model_step, msg + + s_enhancements = [model.s_enhance for model in self.models] + t_enhancements = [model.t_enhance for model in self.models] + if combine_type.lower() == 'input': + if model_step == 0: + s_enhance = 1 + t_enhance = 1 + else: + s_enhance = np.product(s_enhancements[:model_step]) + t_enhance = np.product(t_enhancements[:model_step]) + + elif combine_type.lower() in ('output', 'layer'): + s_enhance = np.product(s_enhancements[:model_step + 1]) + t_enhance = np.product(t_enhancements[:model_step + 1]) + + else: + msg = ('Received exo_kwargs entry without valid combine_type ' + '(input/layer/output)') + raise OSError(msg) + + step.update({'s_enhance': s_enhance, 't_enhance': t_enhance}) + return step + + def _get_all_agg_and_enhancement(self): + """Compute agg and enhancement factors for all model steps for all + features. + + Returns + ------- + agg_enhance_dict : dict + Dictionary with list of agg and enhancement factors for each model + step + """ + agg_enhance_dict = {} + for i, step in enumerate(self.steps): + out = self._get_single_step_agg(step) + out = self._get_single_step_enhance(out) + self.steps[i] = out + agg_enhance_dict['s_agg_factors'] = [step['s_agg_factor'] + for step in self.steps] + agg_enhance_dict['t_agg_factors'] = [step['t_agg_factor'] + for step in self.steps] + agg_enhance_dict['s_enhancements'] = [step['s_enhance'] + for step in self.steps] + agg_enhance_dict['t_enhancements'] = [step['t_enhance'] + for step in self.steps] + return agg_enhance_dict + def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, t_agg_factor): """Get cache file name diff --git a/tests/data_handling/test_exo_data_handling.py b/tests/data_handling/test_exo_data_handling.py index 5bdfd264a..93091733c 100644 --- a/tests/data_handling/test_exo_data_handling.py +++ b/tests/data_handling/test_exo_data_handling.py @@ -27,13 +27,17 @@ def test_exo_cache(feature): """Test exogenous data caching and re-load""" # no cached data + steps = [] + for s_en, t_en, s_agg, t_agg in zip(S_ENHANCE, T_ENHANCE, S_AGG_FACTORS, + T_AGG_FACTORS): + steps.append({'s_enhance': s_en, + 't_enhance': t_en, + 's_agg_factor': s_agg, + 't_agg_factor': t_agg}) try: base = ExogenousDataHandler(FILE_PATHS, feature, source_file=FP_WTK, - s_enhancements=S_ENHANCE, - t_enhancements=T_ENHANCE, - s_agg_factors=S_AGG_FACTORS, - t_agg_factors=T_AGG_FACTORS, + steps=steps, target=TARGET, shape=SHAPE, input_handler='DataHandlerNCforCC') for i, arr in enumerate(base.data): @@ -51,10 +55,7 @@ def test_exo_cache(feature): try: cache = ExogenousDataHandler(FILE_PATHS, feature, source_file=FP_WTK, - s_enhancements=S_ENHANCE, - t_enhancements=T_ENHANCE, - s_agg_factors=S_AGG_FACTORS, - t_agg_factors=T_AGG_FACTORS, + steps=steps, target=TARGET, shape=SHAPE, input_handler='DataHandlerNCforCC') except Exception as e: diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index 1518c6bc5..04f99081c 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -600,8 +600,7 @@ def test_fwp_multi_step_model(): t_enhance = 4 model_kwargs = { - 'spatial_model_dirs': s_out_dir, - 'temporal_model_dirs': st_out_dir + 'model_dirs': [s_out_dir, st_out_dir] } input_handler_kwargs = dict( @@ -613,7 +612,7 @@ def test_fwp_multi_step_model(): handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', + model_class='MultiStepGan', fwp_chunk_shape=fwp_chunk_shape, spatial_pad=0, temporal_pad=0, diff --git a/tests/forward_pass/test_forward_pass_exo.py b/tests/forward_pass/test_forward_pass_exo.py index f30b83852..1bf24b286 100644 --- a/tests/forward_pass/test_forward_pass_exo.py +++ b/tests/forward_pass/test_forward_pass_exo.py @@ -60,7 +60,7 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): fp_gen = os.path.join(CONFIG_DIR, 'spatiotemporal/gen_3x_4x_2f.json') fp_disc = os.path.join(CONFIG_DIR, 'spatiotemporal/disc.json') st_model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) - st_model.meta['training_features'] = ['U_100m', 'V_100m', 'topography'] + st_model.meta['training_features'] = ['U_100m', 'V_100m'] st_model.meta['output_features'] = ['U_100m', 'V_100m'] st_model.meta['s_enhance'] = 3 st_model.meta['t_enhance'] = 4 @@ -99,8 +99,7 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): } model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] } out_files = os.path.join(td, 'out_{file_id}.h5') @@ -113,7 +112,7 @@ def test_fwp_multi_step_model_topo_exoskip(log=False): handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', + model_class='MultiStepGan', fwp_chunk_shape=fwp_chunk_shape, input_handler_kwargs=input_handler_kwargs, spatial_pad=0, @@ -317,8 +316,7 @@ def test_fwp_multi_step_model_topo_noskip(): } model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] } out_files = os.path.join(td, 'out_{file_id}.h5') @@ -331,7 +329,7 @@ def test_fwp_multi_step_model_topo_noskip(): handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', + model_class='MultiStepGan', fwp_chunk_shape=fwp_chunk_shape, spatial_pad=1, temporal_pad=1, @@ -637,10 +635,6 @@ def test_fwp_multi_step_wind_hi_res_topo(): } } - model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir - } model_kwargs = { 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] } @@ -807,8 +801,7 @@ def test_fwp_wind_hi_res_topo_plus_linear(): } model_kwargs = { - 'spatial_model_dirs': s_out_dir, - 'temporal_model_dirs': t_out_dir + 'model_dirs': [s_out_dir, t_out_dir] } out_files = os.path.join(td, 'out_{file_id}.h5') input_handler_kwargs = dict(target=target, @@ -820,7 +813,7 @@ def test_fwp_wind_hi_res_topo_plus_linear(): handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', + model_class='MultiStepGan', fwp_chunk_shape=(4, 4, 8), spatial_pad=1, temporal_pad=1, @@ -916,8 +909,7 @@ def test_fwp_multi_step_model_multi_exo(): } model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] } out_files = os.path.join(td, 'out_{file_id}.h5') @@ -930,7 +922,7 @@ def test_fwp_multi_step_model_multi_exo(): handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', + model_class='MultiStepGan', fwp_chunk_shape=fwp_chunk_shape, spatial_pad=1, temporal_pad=1, @@ -1173,8 +1165,7 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): } model_kwargs = { - 'spatial_model_dirs': [s1_out_dir, s2_out_dir], - 'temporal_model_dirs': st_out_dir + 'model_dirs': [s1_out_dir, s2_out_dir, st_out_dir] } out_files = os.path.join(td, 'out_{file_id}.h5') input_handler_kwargs = dict(target=target, @@ -1186,7 +1177,7 @@ def test_fwp_multi_step_exo_hi_res_topo_and_sza(): handler = ForwardPassStrategy( input_files, model_kwargs=model_kwargs, - model_class='SpatialThenTemporalGan', + model_class='MultiStepGan', fwp_chunk_shape=(4, 4, 8), spatial_pad=1, temporal_pad=1, diff --git a/tests/forward_pass/test_multi_step.py b/tests/forward_pass/test_multi_step.py index f610829f6..3da773b44 100644 --- a/tests/forward_pass/test_multi_step.py +++ b/tests/forward_pass/test_multi_step.py @@ -11,7 +11,6 @@ LinearInterp, MultiStepGan, SolarMultiStepGan, - SpatialThenTemporalGan, Sup3rGan, ) @@ -70,6 +69,9 @@ def test_multi_step_norm(norm_option): model2.set_norm_stats([0.1, 0.8], [0.04, 0.02]) model3.set_norm_stats([0.1, 0.8], [0.04, 0.02]) + model1.meta['input_resolution'] = {'spatial': '27km', 'temporal': '64min'} + model2.meta['input_resolution'] = {'spatial': '9km', 'temporal': '16min'} + model3.meta['input_resolution'] = {'spatial': '3km', 'temporal': '4min'} model1.set_model_params(training_features=FEATURES, output_features=FEATURES) model2.set_model_params(training_features=FEATURES, @@ -118,6 +120,8 @@ def test_spatial_then_temporal_gan(): model1.set_norm_stats([0.1, 0.2], [0.04, 0.02]) model2.set_norm_stats([0.3, 0.9], [0.02, 0.07]) + model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '40min'} + model2.meta['input_resolution'] = {'spatial': '6km', 'temporal': '40min'} model1.set_model_params(training_features=FEATURES, output_features=FEATURES) model2.set_model_params(training_features=FEATURES, @@ -129,7 +133,7 @@ def test_spatial_then_temporal_gan(): model1.save(fp1) model2.save(fp2) - ms_model = SpatialThenTemporalGan.load(fp1, fp2) + ms_model = MultiStepGan.load([fp1, fp2]) x = np.ones((4, 10, 10, len(FEATURES))) out = ms_model.generate(x) @@ -150,6 +154,8 @@ def test_temporal_then_spatial_gan(): model1.set_norm_stats([0.1, 0.2], [0.04, 0.02]) model2.set_norm_stats([0.3, 0.9], [0.02, 0.07]) + model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '40min'} + model2.meta['input_resolution'] = {'spatial': '6km', 'temporal': '40min'} model1.set_model_params(training_features=FEATURES, output_features=FEATURES) model2.set_model_params(training_features=FEATURES, @@ -178,6 +184,7 @@ def test_spatial_gan_then_linear_interp(): model2 = LinearInterp(features=FEATURES, s_enhance=3, t_enhance=4) model1.set_norm_stats([0.1, 0.2], [0.04, 0.02]) + model1.meta['input_resolution'] = {'spatial': '12km', 'temporal': '60min'} model1.set_model_params(training_features=FEATURES, output_features=FEATURES) @@ -187,7 +194,7 @@ def test_spatial_gan_then_linear_interp(): model1.save(fp1) model2.save(fp2) - ms_model = SpatialThenTemporalGan.load(fp1, fp2) + ms_model = MultiStepGan.load([fp1, fp2]) x = np.ones((4, 10, 10, len(FEATURES))) out = ms_model.generate(x) @@ -204,6 +211,7 @@ def test_solar_multistep(): model1 = Sup3rGan(fp_gen, fp_disc) _ = model1.generate(np.ones((4, 10, 10, len(features1)))) model1.set_norm_stats([0.7], [0.04]) + model1.meta['input_resolution'] = {'spatial': '8km', 'temporal': '40min'} model1.set_model_params(training_features=features1, output_features=features1) @@ -213,6 +221,7 @@ def test_solar_multistep(): model2 = Sup3rGan(fp_gen, fp_disc) _ = model2.generate(np.ones((4, 10, 10, len(features2)))) model2.set_norm_stats([4.2, 5.6], [1.1, 1.3]) + model2.meta['input_resolution'] = {'spatial': '4km', 'temporal': '40min'} model2.set_model_params(training_features=features2, output_features=features2) @@ -223,6 +232,7 @@ def test_solar_multistep(): model3 = Sup3rGan(fp_gen, fp_disc) _ = model3.generate(np.ones((4, 10, 10, 3, len(features_in_3)))) model3.set_norm_stats([0.7, 4.2, 5.6], [0.04, 1.1, 1.3]) + model3.meta['input_resolution'] = {'spatial': '2km', 'temporal': '40min'} model3.set_model_params(training_features=features_in_3, output_features=features_out_3) From fb0c50c199da293118a8fdef9260fd5cb6d37f6d Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 3 Oct 2023 12:35:19 -0700 Subject: [PATCH 14/15] updated sup3rcc configs for new exo format --- .../sup3rcc/run_configs/solar/config_fwp.json | 18 ++++++------ .../sup3rcc/run_configs/trh/config_fwp.json | 28 ++++++++----------- .../sup3rcc/run_configs/wind/config_fwp.json | 18 ++++++------ sup3r/pipeline/forward_pass.py | 23 ++++++--------- 4 files changed, 38 insertions(+), 49 deletions(-) diff --git a/examples/sup3rcc/run_configs/solar/config_fwp.json b/examples/sup3rcc/run_configs/solar/config_fwp.json index 74746f1e8..1b30ba65c 100755 --- a/examples/sup3rcc/run_configs/solar/config_fwp.json +++ b/examples/sup3rcc/run_configs/solar/config_fwp.json @@ -49,15 +49,15 @@ }, "max_nodes": 10, "exo_kwargs": { - "file_paths": ["/scratch/gbuster/sup3r/source_gcm_data/wind_solar_day_MRI-ESM2-0_ssp585_r1i1p1f1_gn_20150101-20151231.nc"], - "features": ["topography"], - "source_file": "/scratch/gbuster/sup3r/source_gcm_data/nsrdb_clearsky.h5", - "target": [23.2, -129], - "shape": [26, 59], - "s_enhancements": [1, 5, 5], - "agg_factors": [625, 25, 1], - "input_handler": "DataHandlerNCforCC", - "exo_steps": [0, 1, 2] + "topography": { + "steps": [{"model": 0, "combine_type": "input"}, + {"model": 0, "combine_type": "layer"}, + {"model": 1, "combine_type": "input"}], + "file_paths": ["/scratch/gbuster/sup3r/source_gcm_data/wind_solar_day_MRI-ESM2-0_ssp585_r1i1p1f1_gn_20150101-20151231.nc"], + "source_file": "/scratch/gbuster/sup3r/source_gcm_data/nsrdb_clearsky.h5", + "target": [23.2, -129], + "shape": [26, 59], + "input_handler": "DataHandlerNCforCC"} }, "execution_control": { "option": "eagle", diff --git a/examples/sup3rcc/run_configs/trh/config_fwp.json b/examples/sup3rcc/run_configs/trh/config_fwp.json index 3fdcd2504..c423be1d5 100755 --- a/examples/sup3rcc/run_configs/trh/config_fwp.json +++ b/examples/sup3rcc/run_configs/trh/config_fwp.json @@ -1,14 +1,9 @@ { "file_paths": "PLACEHOLDER", "model_kwargs": { - "surface_model_kwargs": { - "model_dir": "./sup3rcc_models_202303/sup3rcc_trh_step1_25x_1x_2f/" - }, - "temporal_model_kwargs": { - "model_dirs": [ - "./sup3rcc_models_202303/sup3rcc_trh_step2_1x_24x_2f/" - ] - } + "model_dirs": [ + "./sup3rcc_models_202303/sup3rcc_trh_step1_25x_1x_2f/", + "./sup3rcc_models_202303/sup3rcc_trh_step2_1x_24x_2f/"] }, "model_class": "MultiStepSurfaceMetGan", "out_pattern": "./chunks/sup3r_chunk_{file_id}.h5", @@ -42,15 +37,14 @@ }, "max_nodes": 10, "exo_kwargs": { - "file_paths": ["/datasets/sup3rcc/source/temp_humi_day_MRI-ESM2-0_ssp245_r1i1p1f1_gn_20500101-20501231.nc"], - "features": ["topography"], - "source_file": "/datasets/sup3rcc/source/nsrdb_clearsky.h5", - "target": [23.2, -129], - "shape": [26, 59], - "s_enhancements": [1, 25], - "agg_factors": [625, 1], - "input_handler": "DataHandlerNCforCC", - "exo_steps": [0, 1] + "topography": { + "steps": [{"model": 0, "combine_type": "input"}, + {"model": 0, "combine_type": "output"}], + "file_paths": ["/datasets/sup3rcc/source/temp_humi_day_MRI-ESM2-0_ssp245_r1i1p1f1_gn_20500101-20501231.nc"], + "source_file": "/datasets/sup3rcc/source/nsrdb_clearsky.h5", + "target": [23.2, -129], + "shape": [26, 59], + "input_handler": "DataHandlerNCforCC"} }, "execution_control": { "option": "eagle", diff --git a/examples/sup3rcc/run_configs/wind/config_fwp.json b/examples/sup3rcc/run_configs/wind/config_fwp.json index 16ce97b48..ae908a7a6 100755 --- a/examples/sup3rcc/run_configs/wind/config_fwp.json +++ b/examples/sup3rcc/run_configs/wind/config_fwp.json @@ -38,15 +38,15 @@ }, "max_nodes": 50, "exo_kwargs": { - "file_paths": ["/scratch/gbuster/sup3r/source_gcm_data/wind_solar_day_MRI-ESM2-0_ssp585_r1i1p1f1_gn_20150101-20151231.nc"], - "features": ["topography"], - "source_file": "/scratch/gbuster/sup3r/source_gcm_data/nsrdb_clearsky.h5", - "target": [23.2, -129], - "shape": [26, 59], - "s_enhancements": [1, 5, 5], - "agg_factors": [625, 25, 1], - "input_handler": "DataHandlerNCforCC", - "exo_steps": [0, 1, 2] + "topography": { + "file_paths": ["/scratch/gbuster/sup3r/source_gcm_data/wind_solar_day_MRI-ESM2-0_ssp585_r1i1p1f1_gn_20150101-20151231.nc"], + "source_file": "/scratch/gbuster/sup3r/source_gcm_data/nsrdb_clearsky.h5", + "target": [23.2, -129], + "shape": [26, 59], + "steps": [{"model": 0, "combine_type": "input"}, + {"model": 0, "combine_type": "layer"}, + {"model": 1, "combine_type": "input"}], + "input_handler": "DataHandlerNCforCC"} }, "execution_control": { "option": "eagle", diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 4a96257be..1cd24aa96 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -688,20 +688,15 @@ def __init__(self, used to get the full time index. Doing this is parallel can be helpful when there are a large number of input files. exo_kwargs : dict | None - Dictionary of args to pass to ExogenousDataHandler for extracting - exogenous features such as topography for future multistep foward - pass. This should be a nested dictionary with keys for each - exogeneous feature. The dictionaries corresponding to the feature - names should include the path to exogenous data source, the - resolution of the exogenous data, and how the exogenous data should - be used in the model. e.g. {'topography': {'file_paths': 'path to - input files', 'source_file': 'path to exo data', 'exo_resolution': - {'spatial': '1km', 'temporal': None}, 'steps': [{'model': 0, - 'combine_type': 'input'}, {'model': 0, 'combine_type': 'layer'}]}. - Each step can also include s_agg_factor/t_agg_factor to manually - set aggregation of exogenous_data. If they are not specified they - will be computed based on the model resolution and the exogenous - data resolution. + Dictionary of args to pass to :class:`ExogenousDataHandler` for + extracting exogenous features for multistep foward pass. This + should be a nested dictionary with keys for each exogeneous + feature. The dictionaries corresponding to the feature names + should include the path to exogenous data source, the resolution + of the exogenous data, and how the exogenous data should be used + in the model. e.g. {'topography': {'file_paths': 'path to input + files', 'source_file': 'path to exo data', 'exo_resolution': + {'spatial': '1km', 'temporal': None}, 'steps': [..]}. bias_correct_method : str | None Optional bias correction function name that can be imported from the :mod:`sup3r.bias.bias_transforms` module. This will transform From d78bc08177b6d296eef343d959c1e6750e3179a1 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 3 Oct 2023 13:49:38 -0700 Subject: [PATCH 15/15] get_exo_loss_input -> get_high_res_exo_input --- sup3r/models/abstract.py | 17 ++++++++++++++--- sup3r/models/base.py | 4 ++-- sup3r/models/conditional_moments.py | 4 ++-- sup3r/models/data_centric.py | 4 ++-- .../data_handling/exogenous_data_handling.py | 19 +++++++++++++++---- 5 files changed, 35 insertions(+), 13 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 79dd05ede..a2dc0264e 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -55,6 +55,15 @@ def load(cls, model_dir, verbose=True): model_dir """ + @abstractmethod + def generate(self, + low_res, + norm_in=True, + un_norm_out=True, + exogenous_data=None): + """Use the generator model to generate high res data from low res + input. This is the public generate function.""" + @staticmethod def seed(s=0): """ @@ -141,7 +150,9 @@ def input_resolution(self): return input_resolution def _get_numerical_resolutions(self): - """Get the input and output resolutions without units""" + """Get the input and output resolutions without units. e.g. for + {"spatial": "30km", "temporal": "60min"} this returns + {"spatial": 30, "temporal": 60}""" ires_num = {k: int(re.search(r'\d+', v).group(0)) for k, v in self.input_resolution.items()} enhancements = {'spatial': self.s_enhance, @@ -797,7 +808,7 @@ def load_saved_params(out_dir, verbose=True): return params - def get_exo_loss_input(self, high_res): + def get_high_res_exo_input(self, high_res): """Get exogenous feature data from high_res Parameters @@ -1418,7 +1429,7 @@ def get_single_grad(self, with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(training_weights) - hi_res_exo = self.get_exo_loss_input(hi_res_true) + hi_res_exo = self.get_high_res_exo_input(hi_res_true) hi_res_gen = self._tf_generate(low_res, hi_res_exo) loss_out = self.calc_loss(hi_res_true, hi_res_gen, **calc_loss_kwargs) diff --git a/sup3r/models/base.py b/sup3r/models/base.py index b6fa7bc52..0878ecac2 100644 --- a/sup3r/models/base.py +++ b/sup3r/models/base.py @@ -18,7 +18,7 @@ logger = logging.getLogger(__name__) -class Sup3rGan(AbstractInterface, AbstractSingleModel): +class Sup3rGan(AbstractSingleModel, AbstractInterface): """Basic sup3r GAN model.""" def __init__(self, @@ -602,7 +602,7 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - val_exo_data = self.get_exo_loss_input(val_batch.high_res) + val_exo_data = self.get_high_res_exo_input(val_batch.high_res) high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data) _, v_loss_details = self.calc_loss( val_batch.high_res, high_res_gen, diff --git a/sup3r/models/conditional_moments.py b/sup3r/models/conditional_moments.py index b1415f616..f97872d69 100644 --- a/sup3r/models/conditional_moments.py +++ b/sup3r/models/conditional_moments.py @@ -16,7 +16,7 @@ logger = logging.getLogger(__name__) -class Sup3rCondMom(AbstractInterface, AbstractSingleModel): +class Sup3rCondMom(AbstractSingleModel, AbstractInterface): """Basic Sup3r conditional moments model.""" def __init__(self, gen_layers, @@ -283,7 +283,7 @@ def calc_val_loss(self, batch_handler, loss_details): logger.debug('Starting end-of-epoch validation loss calculation...') loss_details['n_obs'] = 0 for val_batch in batch_handler.val_data: - val_exo_data = self.get_exo_loss_input(val_batch.high_res) + val_exo_data = self.get_high_res_exo_input(val_batch.high_res) output_gen = self._tf_generate(val_batch.low_res, val_exo_data) _, v_loss_details = self.calc_loss( val_batch.output, output_gen, val_batch.mask) diff --git a/sup3r/models/data_centric.py b/sup3r/models/data_centric.py index 2f08fef17..a53f3954d 100644 --- a/sup3r/models/data_centric.py +++ b/sup3r/models/data_centric.py @@ -38,7 +38,7 @@ def calc_val_loss_gen(self, batch_handler, weight_gen_advers): """ losses = [] for obs in batch_handler.val_data: - exo_data = self.get_exo_loss_input(obs.high_res) + exo_data = self.get_high_res_exo_input(obs.high_res) gen = self._tf_generate(obs.low_res, exo_data) loss, _ = self.calc_loss(obs.high_res, gen, weight_gen_advers=weight_gen_advers, @@ -66,7 +66,7 @@ def calc_val_loss_gen_content(self, batch_handler): """ losses = [] for obs in batch_handler.val_data: - exo_data = self.get_exo_loss_input(obs.high_res) + exo_data = self.get_high_res_exo_input(obs.high_res) gen = self._tf_generate(obs.low_res, exo_data) loss = self.calc_loss_gen_content(obs.high_res, gen) losses.append(float(loss)) diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index d77906544..41791006a 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -63,18 +63,29 @@ def __init__(self, feature : str Exogenous feature to extract from source_h5 models : list - List of models used with the given steps list + List of models used with the given steps list. This list of models + is used to determine the input and output resolution and + enhancement factors for each model step which is then used to + determine aggregation factors. If agg factors and enhancement + factors are provided in the steps list the model list is not + needed. steps : list List of dictionaries containing info on which models to use for a given step index and what type of exo data the step requires. e.g. [{'model': 0, 'combine_type': 'input'}, {'model': 0, 'combine_type': 'layer'}] Each step entry can also contain s_enhance, t_enhance, - s_agg_factor, t_agg_factor. If they are not included they will be - computed using exo_resolution and model attributes + s_agg_factor, t_agg_factor. e.g. + [{'model': 0, 'combine_type': 'input', 's_agg_factor': 900, + 's_enhance': 1, 't_agg_factor': 5, 't_enhance': 1}, + {'model': 0, 'combine_type': 'layer', 's_agg_factor', 100, + 's_enhance': 3, 't_agg_factor': 5, 't_enhance': 1}] + If they are not included they will be computed using exo_resolution + and model attributes. exo_resolution : dict Dictionary of spatiotemporal resolution for the given exo data - source. e.g. {'spatial': '4km', 'temporal': '60min'} + source. e.g. {'spatial': '4km', 'temporal': '60min'}. This is used + only if agg factors are not provided in the steps list. source_file : str Filepath to source wtk, nsrdb, or netcdf file to get hi-res (2km or 4km) data from which will be mapped to the enhanced grid of the