diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index 62d72290e..3e461c750 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -8,6 +8,7 @@ import logging import numpy as np import pandas as pd +from glob import glob from scipy.spatial import KDTree from scipy.stats import ks_2samp from scipy.ndimage.filters import gaussian_filter @@ -83,9 +84,9 @@ def __init__(self, base_fps, bias_fps, base_dset, bias_feature, bias_handler_kwargs = bias_handler_kwargs or {} if isinstance(self.base_fps, str): - self.base_fps = [self.base_fps] + self.base_fps = sorted(glob(self.base_fps)) if isinstance(self.bias_fps, str): - self.bias_fps = [self.bias_fps] + self.bias_fps = sorted(glob(self.bias_fps)) self.base_handler = getattr(rex, base_handler) self.bias_handler = getattr(sup3r.preprocessing.data_handling, @@ -297,6 +298,8 @@ def get_bias_data(self, bias_gid): 1D array of temporal data at the requested gid. """ idx = np.where(self.bias_gid_raster == bias_gid) + if self.bias_dh.data is None: + self.bias_dh.load_cached_data() bias_data = self.bias_dh.data[idx][0] if bias_data.shape[-1] == 1: @@ -423,9 +426,9 @@ def _read_base_data(res, base_dset, base_gid): base_data = res[base_dset, :, base_gid] if len(base_data.shape) == 2: - base_data = base_data.mean(axis=1) + base_data = np.nanmean(base_data, axis=1) if base_cs_ghi is not None: - base_cs_ghi = base_cs_ghi.mean(axis=1) + base_cs_ghi = np.nanmean(base_cs_ghi, axis=1) return base_data, base_cs_ghi @@ -516,17 +519,17 @@ def get_linear_correction(bias_data, base_data, bias_feature, base_dset): like: bias_data * scalar + adder """ - bias_std = bias_data.std() + bias_std = np.nanstd(bias_data) if bias_std == 0: - bias_std = base_data.std() + bias_std = np.nanstd(base_data) - scalar = base_data.std() / bias_std - adder = base_data.mean() - bias_data.mean() * scalar + scalar = np.nanstd(base_data) / bias_std + adder = np.nanmean(base_data) - np.nanmean(bias_data) * scalar - out = {f'bias_{bias_feature}_mean': bias_data.mean(), + out = {f'bias_{bias_feature}_mean': np.nanmean(bias_data), f'bias_{bias_feature}_std': bias_std, - f'base_{base_dset}_mean': base_data.mean(), - f'base_{base_dset}_std': base_data.std(), + f'base_{base_dset}_mean': np.nanmean(base_data), + f'base_{base_dset}_std': np.nanstd(base_data), f'{bias_feature}_scalar': scalar, f'{bias_feature}_adder': adder, } diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 0cf8cb5dd..c00d7ac8c 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -85,6 +85,19 @@ def generate(self, low_res): (n_obs, spatial_1, spatial_2, n_temporal, n_features) """ + @property + def input_dims(self): + """Get dimension of model generator input. This is usually 4D for + spatial models and 5D for spatiotemporal models. This gives the input + to the first step if the model is multi-step. Returns 5 for linear + models.""" + if hasattr(self, '_gen'): + return self._gen.layers[0].rank + elif hasattr(self, 'models'): + return self.models[0]._gen.layers[0].rank + else: + return 5 + @property def s_enhance(self): """Factor by which model will enhance spatial resolution. Used in diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 1adf5d30a..44650a81e 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -1061,6 +1061,8 @@ def __init__(self, strategy, chunk_index=0, node_index=0): out = self.pad_source_data(self.input_data, self.pad_width, self.exogenous_data, exo_s_en) self.input_data, self.exogenous_data = out + self.unpadded_input_data = self.data_handler.data[self.lr_slice[0], + self.lr_slice[1]] def update_input_handler_kwargs(self, strategy): """Update the kwargs for the input handler for the current forward pass @@ -1121,6 +1123,12 @@ def lr_times(self): high-resolution time index""" return self.data_handler.time_index[self.ti_crop_slice] + @property + def lr_lat_lon(self): + """Get low resolution lat lon for current chunk""" + return self.strategy.lr_lat_lon[self.lr_slice[0], + self.lr_slice[1]] + @property def hr_lat_lon(self): """Get high resolution lat lon for current chunk""" @@ -1520,6 +1528,9 @@ def _run_generator(cls, data_chunk, hr_crop_slices, logger.exception(msg) raise RuntimeError(msg) from e + if len(hi_res.shape) == 4: + hi_res = np.expand_dims(np.transpose(hi_res, (1, 2, 0, 3)), axis=0) + if (s_enhance is not None and hi_res.shape[1] != s_enhance * data_chunk.shape[i_lr_s]): msg = ('The stated spatial enhancement of {}x did not match ' @@ -1572,14 +1583,14 @@ def _reshape_data_chunk(model, data_chunk, exo_data): if exo_data is not None: for i, arr in enumerate(exo_data): if arr is not None: - tp = isinstance(model, sup3r.models.SPATIAL_FIRST_MODELS) - tp = tp and (i < len(model.spatial_models)) - if tp: + current_model = (model if not hasattr(model, 'models') + else model.models[i]) + if current_model.input_dims == 4: exo_data[i] = np.transpose(arr, axes=(2, 0, 1, 3)) else: exo_data[i] = np.expand_dims(arr, axis=0) - if isinstance(model, sup3r.models.SPATIAL_FIRST_MODELS): + if model.input_dims == 4: i_lr_t = 0 i_lr_s = 1 data_chunk = np.transpose(data_chunk, axes=(2, 0, 1, 3)) diff --git a/sup3r/preprocessing/data_handling.py b/sup3r/preprocessing/data_handling.py index 2e95ecf7a..08c247bf7 100644 --- a/sup3r/preprocessing/data_handling.py +++ b/sup3r/preprocessing/data_handling.py @@ -1036,9 +1036,10 @@ def noncached_features(self): @property def extract_features(self): """Features to extract directly from the source handler""" + lower_features = [f.lower() for f in self.handle_features] return [f for f in self.raw_features if self.lookup(f, 'compute') is None - or Feature.get_basename(f) in self.handle_features] + or Feature.get_basename(f.lower()) in lower_features] @property def derive_features(self): @@ -1842,7 +1843,7 @@ def lin_bc(self, bc_files, threshold=0.1): the data is a 3D array of shape (lat, lon, time) where time is length 1 for annual correction or 12 for monthly correction. threshold : float - Nearest neighbor euclidian distance threshold. If the DataHandler + Nearest neighbor euclidean distance threshold. If the DataHandler coordinates are more than this value away from the bias correction lat/lon, an error is raised. """ @@ -2132,18 +2133,20 @@ def extract_feature(cls, file_paths, raster_index, feature, interp_pressure = f_info.pressure basename = f_info.basename - if feature in handle: - fdata = cls.direct_extract(handle, feature, raster_index, + if feature in handle or feature.lower() in handle: + feat_key = feature if feature in handle else feature.lower() + fdata = cls.direct_extract(handle, feat_key, raster_index, time_slice) - elif basename in handle: + elif basename in handle or basename.lower() in handle: + feat_key = basename if basename in handle else basename.lower() if interp_height is not None: fdata = Interpolator.interp_var_to_height( - handle, basename, raster_index, np.float32(interp_height), + handle, feat_key, raster_index, np.float32(interp_height), time_slice) elif interp_pressure is not None: fdata = Interpolator.interp_var_to_pressure( - handle, basename, raster_index, + handle, feat_key, raster_index, np.float32(interp_pressure), time_slice) else: @@ -2330,11 +2333,14 @@ def _validate_raster_shape(cls, target, grid_shape, lat_lon, raster_index): """ if (raster_index[0].stop > lat_lon.shape[0] or raster_index[1].stop > lat_lon.shape[1] - or raster_index[0].start < 0 or raster_index[1].start < 0): + or raster_index[0].start < 0 + or raster_index[1].start < 0): msg = (f'Invalid target {target}, shape {grid_shape}, and raster ' f'{raster_index} for data domain of size ' f'{lat_lon.shape[:-1]} with lower left corner ' - f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}).') + f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}) ' + f' and upper right corner ({np.max(lat_lon[..., 0])}, ' + f'{np.max(lat_lon[..., 1])}).') raise ValueError(msg) def get_raster_index(self): @@ -2366,40 +2372,15 @@ def get_raster_index(self): .format(raster_index)) if self.raster_file is not None: + basedir = os.path.dirname(self.raster_file) + if not os.path.exists(basedir): + os.makedirs(basedir) logger.debug(f'Saving raster index: {self.raster_file}') np.save(self.raster_file.replace('.txt', '.npy'), raster_index) return raster_index -class DataHandlerNCforERA(DataHandlerNC): - """Data Handler for NETCDF ERA5 data""" - - CHUNKS = {'time': 5, 'lat': 20, 'lon': 20} - """CHUNKS sets the chunk sizes to extract from the data in each dimension. - Chunk sizes that approximately match the data volume being extracted - typically results in the most efficient IO.""" - - @classmethod - def feature_registry(cls): - """Registry of methods for computing features or extracting renamed - features - - Returns - ------- - dict - Method registry - """ - registry = { - 'U_(.*)': 'u_(.*)', - 'V_(.*)': 'v_(.*)', - 'Windspeed_(.*)m': WindspeedNC, - 'Winddirection_(.*)m': WinddirectionNC, - 'topography': 'orog', - 'lat_lon': LatLonNC} - return registry - - class DataHandlerNCforCC(DataHandlerNC): """Data Handler for NETCDF climate change data""" @@ -2754,6 +2735,9 @@ def get_raster_index(self): self.grid_shape, max_delta=self.max_delta) if self.raster_file is not None: + basedir = os.path.dirname(self.raster_file) + if not os.path.exists(basedir): + os.makedirs(basedir) logger.debug(f'Saving raster index: {self.raster_file}') np.savetxt(self.raster_file, raster_index) return raster_index diff --git a/sup3r/preprocessing/feature_handling.py b/sup3r/preprocessing/feature_handling.py index 8622cad90..df56fd666 100644 --- a/sup3r/preprocessing/feature_handling.py +++ b/sup3r/preprocessing/feature_handling.py @@ -1675,10 +1675,13 @@ def get_inputs_recursive(cls, feature, handle_features): """ raw_features = [] method = cls.lookup(feature, 'inputs', handle_features=handle_features) + lower_handle_features = [f.lower() for f in handle_features] check1 = feature not in raw_features - check2 = (cls.valid_handle_features([feature], handle_features) + check2 = (cls.valid_handle_features([feature.lower()], + lower_handle_features) or method is None) + if check1 and check2: raw_features.append(feature) diff --git a/sup3r/utilities/interpolation.py b/sup3r/utilities/interpolation.py index 6b95806bd..21279745f 100644 --- a/sup3r/utilities/interpolation.py +++ b/sup3r/utilities/interpolation.py @@ -68,6 +68,8 @@ def calc_height(cls, data, raster_index, time_slice=slice(None)): msg = ('Need either PHB/PH/HGT or zg/orog in data to perform ' 'height interpolation') raise ValueError(msg) + logger.debug('Spatiotemporally averaged height levels: ' + f'{list(np.nanmean(np.array(hgt), axis=(0, 2, 3)))}') return np.array(hgt) @classmethod @@ -162,6 +164,7 @@ def calc_pressure(cls, data, var, raster_index, time_slice=slice(None)): p_array = np.zeros(data[var][idx].shape, dtype=np.float32) for i in range(p_array.shape[1]): p_array[:, i, ...] = data.plev[i] + logger.info(f'Available pressure levels: {data.plev}') return p_array @@ -206,7 +209,7 @@ def interp_to_level(cls, var_array, lev_array, levels): raise RuntimeError(msg) nans = np.isnan(lev_array) - logger.debug('lev_array.shape: {}'.format(lev_array.shape)) + logger.debug('Level array shape: {}'.format(lev_array.shape)) bad_min = min(levels) < lev_array[:, 0, :, :] bad_max = max(levels) > lev_array[:, -1, :, :] diff --git a/sup3r/utilities/pytest.py b/sup3r/utilities/pytest.py index 1abdb22ec..879539921 100644 --- a/sup3r/utilities/pytest.py +++ b/sup3r/utilities/pytest.py @@ -38,6 +38,38 @@ def make_fake_nc_files(td, input_file, n_files): return fake_files +def make_fake_era_files(td, input_file, n_files): + """Make dummy era files with increasing times. ERA files have a different + naming convention than WRF. + + Parameters + ---------- + input_file : str + File to use as template for all dummy files + n_files : int + Number of dummy files to create + + Returns + ------- + fake_files : list + List of dummy files + """ + fake_dates = [f'2014-10-01_{str(i).zfill(2)}_00_00' + for i in range(n_files)] + fake_times = [f'2014-10-01 {str(i).zfill(2)}:00:00' + for i in range(n_files)] + fake_files = [os.path.join(td, f'input_{date}') for date in fake_dates] + for i in range(n_files): + input_dset = xr.open_dataset(input_file) + with xr.Dataset(input_dset) as dset: + dset['Times'][:] = np.array([fake_times[i].encode('ASCII')], + dtype='|S19') + dset['XTIME'][:] = i + dset = dset.rename({'U': 'u', 'V': 'v'}) + dset.to_netcdf(fake_files[i]) + return fake_files + + def make_fake_h5_chunks(td): """Make fake h5 chunked output files for a 5x spatial 2x temporal multi-node forward pass output. diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index ba84b96d1..30cac6614 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -496,7 +496,9 @@ def transform_rotate_wind(ws, wd, lat_lon): # calculate the angle from the vertical theta = (np.pi / 2) - np.arctan2(dy, dx) - theta[0] = theta[1] # fix the roll row + + if len(theta) > 1: + theta[0] = theta[1] # fix the roll row wd = np.radians(wd) u_rot = np.cos(theta)[:, :, np.newaxis] * ws * np.sin(wd) @@ -550,7 +552,8 @@ def invert_uv(u, v, lat_lon): # calculate the angle from the vertical theta = (np.pi / 2) - np.arctan2(dy, dx) - theta[0] = theta[1] # fix the roll row + if len(theta) > 1: + theta[0] = theta[1] # fix the roll row u_rot = np.cos(theta)[:, :, np.newaxis] * u u_rot -= np.sin(theta)[:, :, np.newaxis] * v diff --git a/tests/data_handling/test_data_handling_nc.py b/tests/data_handling/test_data_handling_nc.py index 20be385f7..19e58ac93 100644 --- a/tests/data_handling/test_data_handling_nc.py +++ b/tests/data_handling/test_data_handling_nc.py @@ -13,7 +13,7 @@ from sup3r.preprocessing.batch_handling import (BatchHandler, SpatialBatchHandler) from sup3r.utilities.interpolation import Interpolator -from sup3r.utilities.pytest import make_fake_nc_files +from sup3r.utilities.pytest import make_fake_nc_files, make_fake_era_files INPUT_FILE = os.path.join(TEST_DATA_DIR, 'test_wrf_2014-10-01_00_00_00') features = ['U_100m', 'V_100m', 'BVF_MO_200m'] @@ -105,6 +105,29 @@ def test_height_interpolation(): assert compare_val - stdev <= val <= compare_val + stdev +def test_single_site_extraction(): + """Make sure single location can be extracted from ERA data without + error.""" + + height = 10 + features = [f'windspeed_{height}m'] + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_era_files(td, INPUT_FILE, 8) + kwargs = dh_kwargs.copy() + kwargs['shape'] = [1, 1] + data_handler = DataHandler(input_files, features, val_split=0.0, + **kwargs) + + data = data_handler.data[0, 0, :, 0] + + data_handler = DataHandler(input_files, features, val_split=0.0, + **dh_kwargs) + + baseline = data_handler.data[-1, 0, :, 0] + + assert np.allclose(baseline, data) + + @pytest.mark.parametrize('sample_shape', [(4, 4, 6), (2, 2, 6), (4, 4, 4), (2, 2, 4)]) def test_spatiotemporal_batch_caching(sample_shape): diff --git a/tests/data_handling/test_feature_handling.py b/tests/data_handling/test_feature_handling.py index 65a174e98..e7eea8590 100644 --- a/tests/data_handling/test_feature_handling.py +++ b/tests/data_handling/test_feature_handling.py @@ -19,6 +19,8 @@ WRF_FEAT = ['U', 'V', 'T', 'UST', 'HFX', 'HGT'] +ERA_FEAT = ['u', 'v'] + NSRDB_FEAT = ['ghi', 'clearsky_ghi', 'wind_speed', 'wind_direction'] CC_FEAT = ['ua', 'uv', 'tas', 'hurs', 'zg', 'orog', 'ta'] @@ -61,6 +63,16 @@ def test_feature_inputs_nc(): out = DataHandlerNC.get_inputs_recursive('BVF2_200m', WRF_FEAT) assert out == ['T_200m', 'T_100m'] + out = DataHandlerNC.get_inputs_recursive('windspeed_200m', WRF_FEAT) + assert out == ['U_200m', 'V_200m', 'lat_lon'] + + +def test_feature_inputs_lowercase(): + """Test basic NC feature name / inputs parsing with lowercase raw + features.""" + out = DataHandlerNC.get_inputs_recursive('windspeed_200m', ERA_FEAT) + assert out == ['U_200m', 'V_200m', 'lat_lon'] + def test_feature_inputs_cc(): """Test basic CC feature name / inputs parsing""" diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index c06ddeb4e..efb544d08 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -95,6 +95,59 @@ def test_fwp_nc_cc(log=False): s_enhance * fwp_chunk_shape[1]) +def test_fwp_spatial_only(): + """Test forward pass handler output for spatial only model.""" + + fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + + Sup3rGan.seed() + model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + _ = model.generate(np.ones((4, 10, 10, len(FEATURES)))) + model.meta['training_features'] = FEATURES + model.meta['output_features'] = ['U_100m', 'V_100m'] + model.meta['s_enhance'] = 2 + model.meta['t_enhance'] = 1 + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + out_dir = os.path.join(td, 's_gan') + model.save(out_dir) + + cache_pattern = os.path.join(td, 'cache') + out_files = os.path.join(td, 'out_{file_id}.nc') + + max_workers = 1 + input_handler_kwargs = dict( + target=target, shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=max_workers), + cache_pattern=cache_pattern, + overwrite_cache=True) + handler = ForwardPassStrategy( + input_files, model_kwargs={'model_dir': out_dir}, + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=1, temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + worker_kwargs=dict(max_workers=max_workers)) + forward_pass = ForwardPass(handler) + assert forward_pass.output_workers == max_workers + assert forward_pass.data_handler.compute_workers == max_workers + assert forward_pass.data_handler.load_workers == max_workers + assert forward_pass.data_handler.norm_workers == max_workers + assert forward_pass.data_handler.extract_workers == max_workers + forward_pass.run(handler, node_index=0) + + with xr.open_dataset(handler.out_files[0]) as fh: + assert fh[FEATURES[0]].shape == ( + len(handler.time_index), + 2 * fwp_chunk_shape[0], + 2 * fwp_chunk_shape[1]) + assert fh[FEATURES[1]].shape == ( + len(handler.time_index), + 2 * fwp_chunk_shape[0], + 2 * fwp_chunk_shape[1]) + + def test_fwp_nc(): """Test forward pass handler output for netcdf write."""