Skip to content

Commit

Permalink
qa test fix
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Feb 15, 2024
1 parent 26a8a62 commit 7c15a28
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
3 changes: 2 additions & 1 deletion sup3r/qa/qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def features(self):
list
"""
# all lower case
ignore = ('meta', 'time_index', 'times', 'xlat', 'xlong')
ignore = ('meta', 'time_index', 'times', 'time', 'xlat', 'xlong',
'south_north', 'west_east')

if self._features is None or self._features == [None]:
if self.output_type == 'nc':
Expand Down
16 changes: 11 additions & 5 deletions sup3r/utilities/loss_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class MaterialDerivativeLoss(tf.keras.losses.Loss):
https://en.wikipedia.org/wiki/Material_derivative
"""

MAE_LOSS = MeanAbsoluteError()
LOSS_METRIC = MeanAbsoluteError()

def _derivative(self, x, axis=1):
"""Custom derivative function for compatibility with tensorflow
Expand Down Expand Up @@ -203,13 +203,19 @@ def __call__(self, x1, x2):
"""
hub_heights = x1.shape[-1] // 2

msg = (f'The {self.__class__} is meant to be used on spatiotemporal '
'data only. Received tensor(s) that are not 5D')
assert len(x1.shape) == 5 and len(x2.shape) == 5, msg

x1_div = tf.stack(
[self._compute_md(x1, fidx=i) for i in range(2 * hub_heights)])
[self._compute_md(x1, fidx=i)
for i in range(0, 2 * hub_heights, 2)])
x2_div = tf.stack(
[self._compute_md(x2, fidx=i) for i in range(2 * hub_heights)])
[self._compute_md(x2, fidx=i)
for i in range(0, 2 * hub_heights, 2)])

mae = self.MAE_LOSS(x1, x2)
div_mae = self.MAE_LOSS(x1_div, x2_div)
mae = self.LOSS_METRIC(x1, x2)
div_mae = self.LOSS_METRIC(x1_div, x2_div)

return (mae + div_mae) / 2

Expand Down

0 comments on commit 7c15a28

Please sign in to comment.