Skip to content

Commit

Permalink
pr changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Feb 15, 2024
1 parent 7c15a28 commit ab5e145
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 22 deletions.
31 changes: 18 additions & 13 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,19 +727,7 @@ def __init__(self,
https://github.com/tensorflow/tensorflow/issues/51870
"""
self._input_handler_kwargs = input_handler_kwargs or {}
target = self._input_handler_kwargs.get('target', None)
grid_shape = self._input_handler_kwargs.get('shape', None)
raster_file = self._input_handler_kwargs.get('raster_file', None)
raster_index = self._input_handler_kwargs.get('raster_index', None)
temporal_slice = self._input_handler_kwargs.get(
'temporal_slice', slice(None, None, 1))
InputMixIn.__init__(self,
target=target,
shape=grid_shape,
raster_file=raster_file,
raster_index=raster_index,
temporal_slice=temporal_slice)

self.init_mixin()
self.file_paths = file_paths
self.model_kwargs = model_kwargs
self.fwp_chunk_shape = fwp_chunk_shape
Expand Down Expand Up @@ -808,6 +796,23 @@ def __init__(self,

self.preflight()

def init_mixin(self):
"""Initialize InputMixIn class"""
target = self._input_handler_kwargs.get('target', None)
grid_shape = self._input_handler_kwargs.get('shape', None)
raster_file = self._input_handler_kwargs.get('raster_file', None)
raster_index = self._input_handler_kwargs.get('raster_index', None)
temporal_slice = self._input_handler_kwargs.get(
'temporal_slice', slice(None, None, 1))
res_kwargs = self._input_handler_kwargs.get('res_kwargs', None)
InputMixIn.__init__(self,
target=target,
shape=grid_shape,
raster_file=raster_file,
raster_index=raster_index,
temporal_slice=temporal_slice,
res_kwargs=res_kwargs)

def preflight(self):
"""Prelight path name formatting and sanity checks"""

Expand Down
5 changes: 4 additions & 1 deletion sup3r/preprocessing/data_handling/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ def __init__(self,
raster_file=None,
raster_index=None,
temporal_slice=slice(None, None, 1),
res_kwargs=None,
):
"""Provide properties of the spatiotemporal data domain
Expand All @@ -484,6 +485,8 @@ def __init__(self,
Slice specifying extent and step of temporal extraction. e.g.
slice(start, stop, time_pruning). If equal to slice(None, None, 1)
the full time dimension is selected.
res_kwargs : dict | None
Dictionary of kwargs to pass to xarray.open_mfdataset.
"""
self.raster_file = raster_file
self.target = target
Expand All @@ -505,7 +508,7 @@ def __init__(self,
self._full_raw_lat_lon = None
self._single_ts_files = None
self._worker_attrs = ['ti_workers']
self.res_kwargs = {}
self.res_kwargs = res_kwargs or {}

@property
def raw_tsteps(self):
Expand Down
7 changes: 4 additions & 3 deletions sup3r/utilities/loss_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ def _derivative(self, x, axis=1):
x[..., -1:] - x[..., -2:-1]], axis=3)

else:
msg = (f'{self.__class__}._derivative received axis={axis}. '
'This is meant to compute only temporal (axis=3) or '
'spatial (axis=1/2) derivatives.')
msg = (f'{self.__class__.__name__}._derivative received '
f'axis={axis}. This is meant to compute only temporal '
'(axis=3) or spatial (axis=1/2) derivatives for tensors '
'of shape (n_obs, spatial_1, spatial_2, temporal)')
raise ValueError(msg)

def _compute_md(self, x, fidx):
Expand Down
8 changes: 4 additions & 4 deletions tests/bias/test_bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
FP_NSRDB = os.path.join(TEST_DATA_DIR, 'test_nsrdb_co_2018.h5')
FP_CC = os.path.join(TEST_DATA_DIR, 'rsds_test.nc')

with xr.open_mfdataset(FP_CC) as fh:
MIN_LAT = np.min(fh.lat.values)
MIN_LON = np.min(fh.lon.values) - 360
TARGET = (MIN_LAT, MIN_LON)
with xr.open_dataset(FP_CC) as fh:
MIN_LAT = np.min(fh.lat.values.astype(np.float32))
MIN_LON = np.min(fh.lon.values.astype(np.float32)) - 360
TARGET = (float(MIN_LAT), float(MIN_LON))
SHAPE = (len(fh.lat.values), len(fh.lon.values))


Expand Down
33 changes: 32 additions & 1 deletion tests/training/test_custom_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
"""Test the basic training of super resolution GAN"""
import numpy as np
import tensorflow as tf
import pytest

from sup3r.utilities.loss_metrics import (MmdMseLoss, CoarseMseLoss,
TemporalExtremesLoss)
TemporalExtremesLoss,
MaterialDerivativeLoss)


def test_mmd_loss():
Expand Down Expand Up @@ -74,3 +76,32 @@ def test_tex_loss():
y[..., 25, 0] = -25
loss = loss_obj(x, y)
assert loss.numpy() > 1.5


def test_md_loss():
"""Test the material derivative calculation in the material derivative
content loss class."""

x = np.random.rand(6, 10, 10, 8, 3)
y = x.copy()

md_loss = MaterialDerivativeLoss()
u_div = md_loss._compute_md(x, fidx=0)
v_div = md_loss._compute_md(x, fidx=1)

u_div_np = np.gradient(y[..., 0], axis=3)
u_div_np += y[..., 0] * np.gradient(y[..., 0], axis=1)
u_div_np += y[..., 1] * np.gradient(y[..., 0], axis=2)

v_div_np = np.gradient(x[..., 1], axis=3)
v_div_np += y[..., 0] * np.gradient(y[..., 1], axis=1)
v_div_np += y[..., 1] * np.gradient(y[..., 1], axis=2)

with pytest.raises(ValueError):
md_loss._derivative(x, axis=0)

with pytest.raises(Exception):
md_loss(x[..., 0], y[..., 0])

assert np.allclose(u_div, u_div_np)
assert np.allclose(v_div, v_div_np)

0 comments on commit ab5e145

Please sign in to comment.