Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bnb/dev #148

Merged
merged 6 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions sup3r/bias/bias_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import logging
import numpy as np
import pandas as pd
from glob import glob
from scipy.spatial import KDTree
from scipy.stats import ks_2samp
from scipy.ndimage.filters import gaussian_filter
Expand Down Expand Up @@ -83,9 +84,9 @@ def __init__(self, base_fps, bias_fps, base_dset, bias_feature,
bias_handler_kwargs = bias_handler_kwargs or {}

if isinstance(self.base_fps, str):
self.base_fps = [self.base_fps]
self.base_fps = sorted(glob(self.base_fps))
bnb32 marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(self.bias_fps, str):
self.bias_fps = [self.bias_fps]
self.bias_fps = sorted(glob(self.bias_fps))

self.base_handler = getattr(rex, base_handler)
self.bias_handler = getattr(sup3r.preprocessing.data_handling,
Expand Down Expand Up @@ -297,6 +298,8 @@ def get_bias_data(self, bias_gid):
1D array of temporal data at the requested gid.
"""
idx = np.where(self.bias_gid_raster == bias_gid)
if self.bias_dh.data is None:
self.bias_dh.load_cached_data()
bias_data = self.bias_dh.data[idx][0]

if bias_data.shape[-1] == 1:
Expand Down Expand Up @@ -423,9 +426,9 @@ def _read_base_data(res, base_dset, base_gid):
base_data = res[base_dset, :, base_gid]

if len(base_data.shape) == 2:
base_data = base_data.mean(axis=1)
base_data = np.nanmean(base_data, axis=1)
if base_cs_ghi is not None:
base_cs_ghi = base_cs_ghi.mean(axis=1)
base_cs_ghi = np.nanmean(base_cs_ghi, axis=1)

return base_data, base_cs_ghi

Expand Down Expand Up @@ -516,17 +519,17 @@ def get_linear_correction(bias_data, base_data, bias_feature, base_dset):
like: bias_data * scalar + adder
"""

bias_std = bias_data.std()
bias_std = np.nanstd(bias_data)
if bias_std == 0:
bias_std = base_data.std()
bias_std = np.nanstd(base_data)

scalar = base_data.std() / bias_std
adder = base_data.mean() - bias_data.mean() * scalar
scalar = np.nanstd(base_data) / bias_std
adder = np.nanmean(base_data) - np.nanmean(bias_data) * scalar

out = {f'bias_{bias_feature}_mean': bias_data.mean(),
out = {f'bias_{bias_feature}_mean': np.nanmean(bias_data),
f'bias_{bias_feature}_std': bias_std,
f'base_{base_dset}_mean': base_data.mean(),
f'base_{base_dset}_std': base_data.std(),
f'base_{base_dset}_mean': np.nanmean(base_data),
f'base_{base_dset}_std': np.nanstd(base_data),
f'{bias_feature}_scalar': scalar,
f'{bias_feature}_adder': adder,
}
Expand Down
13 changes: 13 additions & 0 deletions sup3r/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ def generate(self, low_res):
(n_obs, spatial_1, spatial_2, n_temporal, n_features)
"""

@property
def input_dims(self):
"""Get dimension of model generator input. This is usually 4D for
spatial models and 5D for spatiotemporal models. This gives the input
to the first step if the model is multi-step. Returns 5 for linear
models."""
if hasattr(self, '_gen'):
return self._gen.layers[0].rank
elif hasattr(self, 'models'):
return self.models[0]._gen.layers[0].rank
else:
return 5

@property
def s_enhance(self):
"""Factor by which model will enhance spatial resolution. Used in
Expand Down
19 changes: 15 additions & 4 deletions sup3r/pipeline/forward_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,8 @@ def __init__(self, strategy, chunk_index=0, node_index=0):
out = self.pad_source_data(self.input_data, self.pad_width,
self.exogenous_data, exo_s_en)
self.input_data, self.exogenous_data = out
self.unpadded_input_data = self.data_handler.data[self.lr_slice[0],
self.lr_slice[1]]

def update_input_handler_kwargs(self, strategy):
"""Update the kwargs for the input handler for the current forward pass
Expand Down Expand Up @@ -1121,6 +1123,12 @@ def lr_times(self):
high-resolution time index"""
return self.data_handler.time_index[self.ti_crop_slice]

@property
def lr_lat_lon(self):
"""Get low resolution lat lon for current chunk"""
return self.strategy.lr_lat_lon[self.lr_slice[0],
self.lr_slice[1]]

@property
def hr_lat_lon(self):
"""Get high resolution lat lon for current chunk"""
Expand Down Expand Up @@ -1520,6 +1528,9 @@ def _run_generator(cls, data_chunk, hr_crop_slices,
logger.exception(msg)
raise RuntimeError(msg) from e

if len(hi_res.shape) == 4:
hi_res = np.expand_dims(np.transpose(hi_res, (1, 2, 0, 3)), axis=0)

if (s_enhance is not None
and hi_res.shape[1] != s_enhance * data_chunk.shape[i_lr_s]):
msg = ('The stated spatial enhancement of {}x did not match '
Expand Down Expand Up @@ -1572,14 +1583,14 @@ def _reshape_data_chunk(model, data_chunk, exo_data):
if exo_data is not None:
for i, arr in enumerate(exo_data):
if arr is not None:
tp = isinstance(model, sup3r.models.SPATIAL_FIRST_MODELS)
tp = tp and (i < len(model.spatial_models))
if tp:
current_model = (model if not hasattr(model, 'models')
else model.models[i])
if current_model.input_dims == 4:
exo_data[i] = np.transpose(arr, axes=(2, 0, 1, 3))
else:
exo_data[i] = np.expand_dims(arr, axis=0)

if isinstance(model, sup3r.models.SPATIAL_FIRST_MODELS):
if model.input_dims == 4:
i_lr_t = 0
i_lr_s = 1
data_chunk = np.transpose(data_chunk, axes=(2, 0, 1, 3))
Expand Down
58 changes: 21 additions & 37 deletions sup3r/preprocessing/data_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,9 +1036,10 @@ def noncached_features(self):
@property
def extract_features(self):
"""Features to extract directly from the source handler"""
lower_features = [f.lower() for f in self.handle_features]
return [f for f in self.raw_features
if self.lookup(f, 'compute') is None
or Feature.get_basename(f) in self.handle_features]
or Feature.get_basename(f.lower()) in lower_features]

@property
def derive_features(self):
Expand Down Expand Up @@ -1842,7 +1843,7 @@ def lin_bc(self, bc_files, threshold=0.1):
the data is a 3D array of shape (lat, lon, time) where time is
length 1 for annual correction or 12 for monthly correction.
threshold : float
Nearest neighbor euclidian distance threshold. If the DataHandler
Nearest neighbor euclidean distance threshold. If the DataHandler
coordinates are more than this value away from the bias correction
lat/lon, an error is raised.
"""
Expand Down Expand Up @@ -2132,18 +2133,20 @@ def extract_feature(cls, file_paths, raster_index, feature,
interp_pressure = f_info.pressure
basename = f_info.basename

if feature in handle:
fdata = cls.direct_extract(handle, feature, raster_index,
if feature in handle or feature.lower() in handle:
bnb32 marked this conversation as resolved.
Show resolved Hide resolved
feat_key = feature if feature in handle else feature.lower()
fdata = cls.direct_extract(handle, feat_key, raster_index,
time_slice)

elif basename in handle:
elif basename in handle or basename.lower() in handle:
feat_key = basename if basename in handle else basename.lower()
if interp_height is not None:
fdata = Interpolator.interp_var_to_height(
handle, basename, raster_index, np.float32(interp_height),
handle, feat_key, raster_index, np.float32(interp_height),
time_slice)
elif interp_pressure is not None:
fdata = Interpolator.interp_var_to_pressure(
handle, basename, raster_index,
handle, feat_key, raster_index,
np.float32(interp_pressure), time_slice)

else:
Expand Down Expand Up @@ -2330,11 +2333,14 @@ def _validate_raster_shape(cls, target, grid_shape, lat_lon, raster_index):
"""
if (raster_index[0].stop > lat_lon.shape[0]
or raster_index[1].stop > lat_lon.shape[1]
or raster_index[0].start < 0 or raster_index[1].start < 0):
or raster_index[0].start < 0
or raster_index[1].start < 0):
msg = (f'Invalid target {target}, shape {grid_shape}, and raster '
f'{raster_index} for data domain of size '
f'{lat_lon.shape[:-1]} with lower left corner '
f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}).')
f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}) '
f' and upper right corner ({np.max(lat_lon[..., 0])}, '
f'{np.max(lat_lon[..., 1])}).')
raise ValueError(msg)

def get_raster_index(self):
Expand Down Expand Up @@ -2366,40 +2372,15 @@ def get_raster_index(self):
.format(raster_index))

if self.raster_file is not None:
basedir = os.path.dirname(self.raster_file)
if not os.path.exists(basedir):
os.makedirs(basedir)
logger.debug(f'Saving raster index: {self.raster_file}')
np.save(self.raster_file.replace('.txt', '.npy'), raster_index)

return raster_index


class DataHandlerNCforERA(DataHandlerNC):
"""Data Handler for NETCDF ERA5 data"""

CHUNKS = {'time': 5, 'lat': 20, 'lon': 20}
"""CHUNKS sets the chunk sizes to extract from the data in each dimension.
Chunk sizes that approximately match the data volume being extracted
typically results in the most efficient IO."""

@classmethod
def feature_registry(cls):
"""Registry of methods for computing features or extracting renamed
features

Returns
-------
dict
Method registry
"""
registry = {
'U_(.*)': 'u_(.*)',
'V_(.*)': 'v_(.*)',
'Windspeed_(.*)m': WindspeedNC,
'Winddirection_(.*)m': WinddirectionNC,
'topography': 'orog',
'lat_lon': LatLonNC}
return registry


class DataHandlerNCforCC(DataHandlerNC):
"""Data Handler for NETCDF climate change data"""

Expand Down Expand Up @@ -2754,6 +2735,9 @@ def get_raster_index(self):
self.grid_shape,
max_delta=self.max_delta)
if self.raster_file is not None:
basedir = os.path.dirname(self.raster_file)
if not os.path.exists(basedir):
os.makedirs(basedir)
logger.debug(f'Saving raster index: {self.raster_file}')
np.savetxt(self.raster_file, raster_index)
return raster_index
Expand Down
5 changes: 4 additions & 1 deletion sup3r/preprocessing/feature_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,10 +1675,13 @@ def get_inputs_recursive(cls, feature, handle_features):
"""
raw_features = []
method = cls.lookup(feature, 'inputs', handle_features=handle_features)
lower_handle_features = [f.lower() for f in handle_features]

check1 = feature not in raw_features
check2 = (cls.valid_handle_features([feature], handle_features)
check2 = (cls.valid_handle_features([feature.lower()],
lower_handle_features)
or method is None)

if check1 and check2:
raw_features.append(feature)

Expand Down
5 changes: 4 additions & 1 deletion sup3r/utilities/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def calc_height(cls, data, raster_index, time_slice=slice(None)):
msg = ('Need either PHB/PH/HGT or zg/orog in data to perform '
'height interpolation')
raise ValueError(msg)
logger.debug('Spatiotemporally averaged height levels: '
f'{list(np.nanmean(np.array(hgt), axis=(0, 2, 3)))}')
return np.array(hgt)

@classmethod
Expand Down Expand Up @@ -162,6 +164,7 @@ def calc_pressure(cls, data, var, raster_index, time_slice=slice(None)):
p_array = np.zeros(data[var][idx].shape, dtype=np.float32)
for i in range(p_array.shape[1]):
p_array[:, i, ...] = data.plev[i]
logger.info(f'Available pressure levels: {data.plev}')

return p_array

Expand Down Expand Up @@ -206,7 +209,7 @@ def interp_to_level(cls, var_array, lev_array, levels):
raise RuntimeError(msg)

nans = np.isnan(lev_array)
logger.debug('lev_array.shape: {}'.format(lev_array.shape))
logger.debug('Level array shape: {}'.format(lev_array.shape))
bad_min = min(levels) < lev_array[:, 0, :, :]
bad_max = max(levels) > lev_array[:, -1, :, :]

Expand Down
32 changes: 32 additions & 0 deletions sup3r/utilities/pytest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,38 @@ def make_fake_nc_files(td, input_file, n_files):
return fake_files


def make_fake_era_files(td, input_file, n_files):
"""Make dummy era files with increasing times. ERA files have a different
naming convention than WRF.

Parameters
----------
input_file : str
File to use as template for all dummy files
n_files : int
Number of dummy files to create

Returns
-------
fake_files : list
List of dummy files
"""
fake_dates = [f'2014-10-01_{str(i).zfill(2)}_00_00'
for i in range(n_files)]
fake_times = [f'2014-10-01 {str(i).zfill(2)}:00:00'
for i in range(n_files)]
fake_files = [os.path.join(td, f'input_{date}') for date in fake_dates]
for i in range(n_files):
input_dset = xr.open_dataset(input_file)
with xr.Dataset(input_dset) as dset:
dset['Times'][:] = np.array([fake_times[i].encode('ASCII')],
dtype='|S19')
dset['XTIME'][:] = i
dset = dset.rename({'U': 'u', 'V': 'v'})
dset.to_netcdf(fake_files[i])
return fake_files


def make_fake_h5_chunks(td):
"""Make fake h5 chunked output files for a 5x spatial 2x temporal
multi-node forward pass output.
Expand Down
7 changes: 5 additions & 2 deletions sup3r/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,9 @@ def transform_rotate_wind(ws, wd, lat_lon):

# calculate the angle from the vertical
theta = (np.pi / 2) - np.arctan2(dy, dx)
theta[0] = theta[1] # fix the roll row

if len(theta) > 1:
theta[0] = theta[1] # fix the roll row
wd = np.radians(wd)

u_rot = np.cos(theta)[:, :, np.newaxis] * ws * np.sin(wd)
Expand Down Expand Up @@ -550,7 +552,8 @@ def invert_uv(u, v, lat_lon):

# calculate the angle from the vertical
theta = (np.pi / 2) - np.arctan2(dy, dx)
theta[0] = theta[1] # fix the roll row
if len(theta) > 1:
theta[0] = theta[1] # fix the roll row

u_rot = np.cos(theta)[:, :, np.newaxis] * u
u_rot -= np.sin(theta)[:, :, np.newaxis] * v
Expand Down
25 changes: 24 additions & 1 deletion tests/data_handling/test_data_handling_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sup3r.preprocessing.batch_handling import (BatchHandler,
SpatialBatchHandler)
from sup3r.utilities.interpolation import Interpolator
from sup3r.utilities.pytest import make_fake_nc_files
from sup3r.utilities.pytest import make_fake_nc_files, make_fake_era_files

INPUT_FILE = os.path.join(TEST_DATA_DIR, 'test_wrf_2014-10-01_00_00_00')
features = ['U_100m', 'V_100m', 'BVF_MO_200m']
Expand Down Expand Up @@ -105,6 +105,29 @@ def test_height_interpolation():
assert compare_val - stdev <= val <= compare_val + stdev


def test_single_site_extraction():
"""Make sure single location can be extracted from ERA data without
error."""

height = 10
features = [f'windspeed_{height}m']
with tempfile.TemporaryDirectory() as td:
input_files = make_fake_era_files(td, INPUT_FILE, 8)
kwargs = dh_kwargs.copy()
kwargs['shape'] = [1, 1]
data_handler = DataHandler(input_files, features, val_split=0.0,
**kwargs)

data = data_handler.data[0, 0, :, 0]

data_handler = DataHandler(input_files, features, val_split=0.0,
**dh_kwargs)

baseline = data_handler.data[-1, 0, :, 0]

assert np.allclose(baseline, data)


@pytest.mark.parametrize('sample_shape',
[(4, 4, 6), (2, 2, 6), (4, 4, 4), (2, 2, 4)])
def test_spatiotemporal_batch_caching(sample_shape):
Expand Down
Loading