Skip to content

Commit

Permalink
PR changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Oct 3, 2023
1 parent fb6787f commit 0e874c6
Show file tree
Hide file tree
Showing 14 changed files with 398 additions and 580 deletions.
9 changes: 4 additions & 5 deletions examples/sup3rcc/run_configs/wind/config_fwp.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
{
"file_paths": "PLACEHOLDER",
"model_kwargs": {
"spatial_model_dirs": [
"model_dirs": [
"./sup3rcc_models_202303/sup3rcc_wind_step1_5x_1x_6f/",
"./sup3rcc_models_202303/sup3rcc_wind_step2_5x_1x_6f/"
],
"temporal_model_dirs": "./sup3rcc_models_202303/sup3rcc_wind_step3_1x_24x_6f/"
"./sup3rcc_models_202303/sup3rcc_wind_step2_5x_1x_6f/",
"./sup3rcc_models_202303/sup3rcc_wind_step3_1x_24x_6f/"]
},
"model_class": "SpatialThenTemporalGan",
"model_class": "MultiStepGan",
"out_pattern": "./chunks/sup3r_chunk_{file_id}.h5",
"log_pattern": "./logs/sup3r_fwp_log_{node_index}.log",
"bias_correct_method": "monthly_local_linear_bc",
Expand Down
6 changes: 3 additions & 3 deletions sup3r/bias/bias_calc_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
sup3r bias correction calculation CLI entry points.
"""
import copy
import click
import logging
import os

import click

import sup3r.bias.bias_calc
from sup3r.utilities import ModuleName
from sup3r.version import __version__
from sup3r.utilities.cli import BaseCLI

from sup3r.version import __version__

logger = logging.getLogger(__name__)

Expand Down
11 changes: 2 additions & 9 deletions sup3r/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,9 @@
from .conditional_moments import Sup3rCondMom
from .data_centric import Sup3rGanDC
from .linear import LinearInterp
from .multi_step import (
MultiStepGan,
MultiStepSurfaceMetGan,
SolarMultiStepGan,
SpatialThenTemporalGan,
TemporalThenSpatialGan,
)
from .multi_step import MultiStepGan, MultiStepSurfaceMetGan, SolarMultiStepGan
from .solar_cc import SolarCC
from .surface import SurfaceSpatialMetModel

SPATIAL_FIRST_MODELS = (SpatialThenTemporalGan,
MultiStepSurfaceMetGan,
SPATIAL_FIRST_MODELS = (MultiStepSurfaceMetGan,
SolarMultiStepGan)
122 changes: 62 additions & 60 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,15 @@ def t_enhance(self):

@property
def input_resolution(self):
"""Resolution of input data. Given as a dictionary {'spatial':...,
'temporal':...}"""
return self.meta.get('input_resolution', None)
"""Resolution of input data. Given as a dictionary {'spatial': '...km',
'temporal': '...min'}. The numbers are required to be integers in the
units specified. The units are not strict as long as the resolution
of the exogenous data, when extracting exogenous data, is specified
in the same units."""
input_resolution = self.meta.get('input_resolution', None)
msg = 'model.input_resolution is None. This needs to be set.'
assert input_resolution is not None, msg
return input_resolution

def _get_numerical_resolutions(self):
"""Get the input and output resolutions without units"""
Expand Down Expand Up @@ -185,11 +191,11 @@ def _ensure_valid_enhancement_factors(self):

@property
def output_resolution(self):
"""Resolution of output data. Given as a dictionary {'spatial':...,
'temporal':...}"""
"""Resolution of output data. Given as a dictionary
{'spatial': '...km', 'temporal': '...min'}. This is computed from the
input resolution and the enhancement factors."""
output_res = self.meta.get('output_resolution', None)
if self.input_resolution is not None and output_res is None:
output_res = self.input_resolution.copy()
ires_num, ores_num = self._get_numerical_resolutions()
output_res = {k: v.replace(str(ires_num[k]), str(ores_num[k]))
for k, v in self.input_resolution.items()}
Expand Down Expand Up @@ -223,19 +229,20 @@ def _combine_fwp_input(self, low_res, exogenous_data=None):
(n_obs, spatial_1, spatial_2, n_features)
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
"""
low_res_shape = low_res.shape
check = (exogenous_data is not None
and low_res.shape[-1] < len(self.training_features))
if check:
exo_data = {k: v for k, v in exogenous_data.items()
if k in self.training_features}
for i, (feature, entry) in enumerate(exo_data.items()):
f_idx = low_res_shape[-1] + i
training_feature = self.training_features[f_idx]
msg = ('The ordering of features in exogenous_data conflicts '
'with the ordering of training features. Received '
f'{feature} instead of {training_feature}.')
assert feature == training_feature, msg
if exogenous_data is None:
return low_res

training_features = ([] if self.training_features is None
else self.training_features)
fnum_diff = len(training_features) - low_res.shape[-1]
exo_feats = ([] if fnum_diff <= 0
else self.training_features[-fnum_diff:])
msg = ('Provided exogenous_data is missing some required features '
f'({exo_feats})')
assert all(feature in exogenous_data for feature in exo_feats), msg
if exogenous_data is not None and fnum_diff > 0:
for feature in exo_feats:
entry = exogenous_data[feature]
combine_types = [step['combine_type']
for step in entry['steps']]
if 'input' in combine_types:
Expand Down Expand Up @@ -272,19 +279,20 @@ def _combine_fwp_output(self, hi_res, exogenous_data=None):
(n_obs, spatial_1, spatial_2, n_features)
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
"""
hi_res_shape = hi_res.shape[-1]
check = (exogenous_data is not None
and hi_res.shape[-1] < len(self.output_features))
if check:
exo_data = {k: v for k, v in exogenous_data.items()
if k in self.training_features}
for i, (feature, entry) in enumerate(exo_data.items()):
f_idx = hi_res_shape[-1] + i
training_feature = self.training_features[f_idx]
msg = ('The ordering of features in exogenous_data conflicts '
'with the ordering of training features. Received '
f'{feature} instead of {training_feature}.')
assert feature == training_feature, msg
if exogenous_data is None:
return hi_res

output_features = ([] if self.output_features is None
else self.output_features)
fnum_diff = len(output_features) - hi_res.shape[-1]
exo_feats = ([] if fnum_diff <= 0
else self.output_features[-fnum_diff:])
msg = ('Provided exogenous_data is missing some required features '
f'({exo_feats})')
assert all(feature in exogenous_data for feature in exo_feats), msg
if exogenous_data is not None and fnum_diff > 0:
for feature in exo_feats:
entry = exogenous_data[feature]
combine_types = [step['combine_type']
for step in entry['steps']]
if 'output' in combine_types:
Expand Down Expand Up @@ -318,27 +326,6 @@ def _combine_loss_input(self, high_res_true, high_res_gen):
high_res_gen = tf.concat((high_res_gen, exo_data), axis=-1)
return high_res_gen

def _get_exo_val_loss_input(self, high_res):
"""Get exogenous feature data from high_res
Parameters
----------
high_res : tf.Tensor
Ground truth high resolution spatiotemporal data.
Returns
-------
exo_data : dict
Dictionary of exogenous feature data used as input to tf_generate.
e.g. {'topography': np.ndarray(...)}
"""
exo_data = {}
for feature in self.exogenous_features:
f_idx = self.training_features.index(feature)
exo_fdata = high_res[..., f_idx: f_idx + 1]
exo_data[feature] = exo_fdata
return exo_data

@property
def exogenous_features(self):
"""Get list of exogenous filter names the model uses. If the model has
Expand Down Expand Up @@ -810,6 +797,27 @@ def load_saved_params(out_dir, verbose=True):

return params

def get_exo_loss_input(self, high_res):
"""Get exogenous feature data from high_res
Parameters
----------
high_res : tf.Tensor
Ground truth high resolution spatiotemporal data.
Returns
-------
exo_data : dict
Dictionary of exogenous feature data used as input to tf_generate.
e.g. {'topography': tf.Tensor(...)}
"""
exo_data = {}
for feature in self.exogenous_features:
f_idx = self.training_features.index(feature)
exo_fdata = high_res[..., f_idx: f_idx + 1]
exo_data[feature] = exo_fdata
return exo_data

@staticmethod
def get_loss_fun(loss):
"""Get the initialized loss function class from the sup3r loss library
Expand Down Expand Up @@ -1290,9 +1298,7 @@ def generate(self,
(n_obs, spatial_1, spatial_2, n_features)
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
"""

low_res = self._combine_fwp_input(low_res, exogenous_data)

if norm_in and self._means is not None:
low_res = self.norm_input(low_res)

Expand Down Expand Up @@ -1408,15 +1414,11 @@ def get_single_grad(self,
loss_details : dict
Namespace of the breakdown of loss components
"""
hi_res_exo = {}
for feature in self.exogenous_features:
f_idx = self.training_features.index(feature)
hi_res_exo[feature] = hi_res_true[..., f_idx: f_idx + 1]

with tf.device(device_name):
with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(training_weights)

hi_res_exo = self.get_exo_loss_input(hi_res_true)
hi_res_gen = self._tf_generate(low_res, hi_res_exo)
loss_out = self.calc_loss(hi_res_true, hi_res_gen,
**calc_loss_kwargs)
Expand Down
2 changes: 1 addition & 1 deletion sup3r/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ def calc_val_loss(self, batch_handler, weight_gen_advers, loss_details):
logger.debug('Starting end-of-epoch validation loss calculation...')
loss_details['n_obs'] = 0
for val_batch in batch_handler.val_data:
val_exo_data = self._get_exo_val_loss_input(val_batch.high_res)
val_exo_data = self.get_exo_loss_input(val_batch.high_res)
high_res_gen = self._tf_generate(val_batch.low_res, val_exo_data)
_, v_loss_details = self.calc_loss(
val_batch.high_res, high_res_gen,
Expand Down
2 changes: 1 addition & 1 deletion sup3r/models/conditional_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def calc_val_loss(self, batch_handler, loss_details):
logger.debug('Starting end-of-epoch validation loss calculation...')
loss_details['n_obs'] = 0
for val_batch in batch_handler.val_data:
val_exo_data = self._get_exo_val_loss_input(val_batch.high_res)
val_exo_data = self.get_exo_loss_input(val_batch.high_res)
output_gen = self._tf_generate(val_batch.low_res, val_exo_data)
_, v_loss_details = self.calc_loss(
val_batch.output, output_gen, val_batch.mask)
Expand Down
4 changes: 2 additions & 2 deletions sup3r/models/data_centric.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def calc_val_loss_gen(self, batch_handler, weight_gen_advers):
"""
losses = []
for obs in batch_handler.val_data:
exo_data = self._get_exo_val_loss_input(obs.high_res)
exo_data = self.get_exo_loss_input(obs.high_res)
gen = self._tf_generate(obs.low_res, exo_data)
loss, _ = self.calc_loss(obs.high_res, gen,
weight_gen_advers=weight_gen_advers,
Expand Down Expand Up @@ -66,7 +66,7 @@ def calc_val_loss_gen_content(self, batch_handler):
"""
losses = []
for obs in batch_handler.val_data:
exo_data = self._get_exo_val_loss_input(obs.high_res)
exo_data = self.get_exo_loss_input(obs.high_res)
gen = self._tf_generate(obs.low_res, exo_data)
loss = self.calc_loss_gen_content(obs.high_res, gen)
losses.append(float(loss))
Expand Down
Loading

0 comments on commit 0e874c6

Please sign in to comment.