diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index d909e45b4..2ac31677f 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -1,11 +1,11 @@ """Utilities to calculate the bias correction factors for biased data that is going to be fed into the sup3r downscaling models. This is typically used to bias correct GCM data vs. some historical record like the WTK or NSRDB.""" -from abc import abstractmethod import copy import json import logging import os +from abc import abstractmethod from concurrent.futures import ProcessPoolExecutor, as_completed from glob import glob @@ -875,12 +875,7 @@ def _run_single(cls, decimals=decimals) base_arr = np.full(cls.NT, np.nan, dtype=np.float32) - out = {f'bias_{bias_feature}_mean': base_arr.copy(), - f'bias_{bias_feature}_std': base_arr.copy(), - f'base_{base_dset}_mean': base_arr.copy(), - f'base_{base_dset}_std': base_arr.copy(), - f'{bias_feature}_scalar': base_arr.copy(), - f'{bias_feature}_adder': base_arr.copy()} + out = {} for month in range(1, 13): bias_mask = bias_ti.month == month @@ -892,11 +887,68 @@ def _run_single(cls, bias_feature, base_dset) for k, v in mout.items(): + if k not in out: + out[k] = base_arr.copy() out[k][month - 1] = v return out +class MonthlyScalarCorrection(MonthlyLinearCorrection): + """Calculate linear correction *scalar factors to bias correct data. This + typically used when base data is just monthly means and standard deviations + cannot be computed. This is case for vortex data, for example. Thus, just + scalar factors are computed as mean(base_data) / mean(bias_data). Adder + factors are still written but are exactly zero. + + This calculation operates on single bias sites on a montly basis + """ + + @staticmethod + def get_linear_correction(bias_data, base_data, bias_feature, base_dset): + """Get the linear correction factors based on 1D bias and base datasets + + Parameters + ---------- + bias_data : np.ndarray + 1D array of biased data observations. + base_data : np.ndarray + 1D array of base data observations. + bias_feature : str + This is the biased feature from bias_fps to retrieve. This should + be a single feature name corresponding to base_dset + base_dset : str + A single dataset from the base_fps to retrieve. In the case of wind + components, this can be U_100m or V_100m which will retrieve + windspeed and winddirection and derive the U/V component. + + Returns + ------- + out : dict + Dictionary of values defining the mean/std of the bias + base + data and the scalar + adder factors to correct the biased data + like: bias_data * scalar + adder + """ + + bias_std = np.nanstd(bias_data) + if bias_std == 0: + bias_std = np.nanstd(base_data) + + scalar = np.nanmean(base_data) / np.nanmean(bias_data) + adder = np.zeros(scalar.shape) + + out = { + f'bias_{bias_feature}_mean': np.nanmean(bias_data), + f'bias_{bias_feature}_std': bias_std, + f'base_{base_dset}_mean': np.nanmean(base_data), + f'base_{base_dset}_std': np.nanstd(base_data), + f'{bias_feature}_scalar': scalar, + f'{bias_feature}_adder': adder, + } + + return out + + class SkillAssessment(MonthlyLinearCorrection): """Calculate historical skill of one dataset compared to another.""" diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py new file mode 100644 index 000000000..7e63858cf --- /dev/null +++ b/sup3r/bias/bias_correct_means.py @@ -0,0 +1,558 @@ +"""Classes to compute means from vortex and era data and compute bias +correction factors. + +Vortex mean files can be downloaded from IRENA. +https://globalatlas.irena.org/workspace +""" + + +import calendar +import logging +import os +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +import pandas as pd +import rioxarray +import xarray as xr +from rex import Resource +from scipy.interpolate import interp1d + +from sup3r.postprocessing.file_handling import OutputHandler, RexOutputs +from sup3r.utilities import VERSION_RECORD + +logger = logging.getLogger(__name__) + + +class VortexMeanPrepper: + """Class for converting monthly vortex tif files for each height to a + single h5 files containing all monthly means for all requested output + heights. + """ + + def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): + """ + Parameters + ---------- + path_pattern : str + Pattern for input tif files. Needs to include {month} and {height} + format keys. + in_heights : list + List of heights for input files. + out_heights : list + List of output heights used for interpolation + overwrite : bool + Whether to overwrite intermediate netcdf files containing the + interpolated masked monthly means. + """ + msg = 'path_pattern needs to have {month} and {height} format keys' + assert '{month}' in path_pattern and '{height}' in path_pattern, msg + self.path_pattern = path_pattern + self.in_heights = in_heights + self.out_heights = out_heights + self.out_dir = os.path.dirname(path_pattern) + self.overwrite = overwrite + self._mask = None + self._meta = None + + @property + def in_features(self): + """List of features corresponding to input heights.""" + return [f"windspeed_{h}m" for h in self.in_heights] + + @property + def out_features(self): + """List of features corresponding to output heights""" + return [f"windspeed_{h}m" for h in self.out_heights] + + def get_input_file(self, month, height): + """Get vortex tif file for given month and height.""" + return self.path_pattern.format(month=month, height=height) + + def get_height_files(self, month): + """Get set of netcdf files for given month""" + files = [] + for height in self.in_heights: + infile = self.get_input_file(month, height) + outfile = infile.replace(".tif", ".nc") + files.append(outfile) + return files + + @property + def input_files(self): + """Get list of all input files used for h5 meta.""" + files = [] + for height in self.in_heights: + for i in range(1, 13): + month = calendar.month_name[i] + files.append(self.get_input_file(month, height)) + return files + + def get_output_file(self, month): + """Get name of netcdf file for a given month.""" + return os.path.join( + self.out_dir.replace("{month}", month), f"{month}.nc" + ) + + @property + def output_files(self): + """List of output monthly output files each with windspeed for all + input heights + """ + files = [] + for i in range(1, 13): + month = calendar.month_name[i] + files.append(self.get_output_file(month)) + return files + + def convert_month_height_tif(self, month, height): + """Get windspeed mean for the given month and hub height from the + corresponding input file and write this to a netcdf file. + """ + infile = self.get_input_file(month, height) + logger.info(f"Getting mean windspeed_{height}m for {month}.") + outfile = infile.replace(".tif", ".nc") + if os.path.exists(outfile) and self.overwrite: + os.remove(outfile) + + if not os.path.exists(outfile) or self.overwrite: + tmp = rioxarray.open_rasterio(infile) + ds = tmp.to_dataset("band") + ds = ds.rename( + {1: f"windspeed_{height}m", "x": "longitude", "y": "latitude"} + ) + ds.to_netcdf(outfile) + return outfile + + def convert_month_tif(self, month): + """Write netcdf files for all heights for the given month.""" + for height in self.in_heights: + self.convert_month_height_tif(month, height) + + def convert_all_tifs(self): + """Write netcdf files for all heights for all months.""" + for i in range(1, 13): + month = calendar.month_name[i] + logger.info(f"Converting tif files to netcdf files for {month}") + self.convert_month_tif(month) + + @property + def mask(self): + """Mask coordinates without data""" + if self._mask is None: + with xr.open_mfdataset(self.get_height_files("January")) as res: + mask = (res[self.in_features[0]] != -999) & ( + ~np.isnan(res[self.in_features[0]]) + ) + for feat in self.in_features[1:]: + tmp = (res[feat] != -999) & (~np.isnan(res[feat])) + mask = mask & tmp + self._mask = np.array(mask).flatten() + return self._mask + + def get_month(self, month): + """Get interpolated means for all hub heights for the given month. + + Parameters + ---------- + month : str + Name of month to get data for + + Returns + ------- + data : xarray.Dataset + xarray dataset object containing interpolated monthly windspeed + means for all input and output heights + + """ + month_file = self.get_output_file(month) + if os.path.exists(month_file) and self.overwrite: + os.remove(month_file) + + if os.path.exists(month_file) and not self.overwrite: + logger.info(f"Loading month_file {month_file}.") + data = xr.open_dataset(month_file) + else: + logger.info( + "Getting mean windspeed for all heights " + f"({self.in_heights}) for {month}" + ) + data = xr.open_mfdataset(self.get_height_files(month)) + logger.info( + "Interpolating windspeed for all heights " + f"({self.out_heights}) for {month}." + ) + data = self.interp(data) + data.to_netcdf(month_file) + logger.info( + "Saved interpolated means for all heights for " + f"{month} to {month_file}." + ) + return data + + def interp(self, data): + """Interpolate data to requested output heights. + + Parameters + ---------- + data : xarray.Dataset + xarray dataset object containing windspeed for all input heights + + Returns + ------- + data : xarray.Dataset + xarray dataset object containing windspeed for all input and output + heights + """ + var_array = np.zeros( + ( + len(data.latitude) * len(data.longitude), + len(self.in_heights), + ), + dtype=np.float32, + ) + lev_array = var_array.copy() + for i, (h, feat) in enumerate(zip(self.in_heights, self.in_features)): + var_array[..., i] = data[feat].values.flatten() + lev_array[..., i] = h + + logger.info( + f"Interpolating {self.in_features} to {self.out_features} " + f"for {var_array.shape[0]} coordinates." + ) + tmp = [ + interp1d(h, v, fill_value="extrapolate")(self.out_heights) + for h, v in zip(lev_array[self.mask], var_array[self.mask]) + ] + out = np.full( + (len(data.latitude), len(data.longitude), len(self.out_heights)), + np.nan, + dtype=np.float32, + ) + out[self.mask.reshape((len(data.latitude), len(data.longitude)))] = tmp + for i, feat in enumerate(self.out_features): + if feat not in data: + data[feat] = (("latitude", "longitude"), out[..., i]) + return data + + def get_lat_lon(self): + """Get lat lon grid""" + with xr.open_mfdataset(self.get_height_files("January")) as res: + lons, lats = np.meshgrid( + res["longitude"].values, res["latitude"].values + ) + return np.array(lats), np.array(lons) + + @property + def meta(self): + """Get meta with latitude/longitude""" + if self._meta is None: + lats, lons = self.get_lat_lon() + self._meta = pd.DataFrame() + self._meta["latitude"] = lats.flatten()[self.mask] + self._meta["longitude"] = lons.flatten()[self.mask] + return self._meta + + @property + def time_index(self): + """Get time index so output conforms to standard format""" + times = [f'2000-{str(i).zfill(2)}' for i in range(1, 13)] + time_index = pd.DatetimeIndex(times) + return time_index + + def get_all_data(self): + """Get interpolated monthly means for all out heights as a dictionary + to use for h5 writing. + + Returns + ------- + out : dict + Dictionary of arrays containing monthly means for each hub height. + Also includes latitude and longitude. Spatial dimensions are + flattened + """ + data_dict = {} + s_num = len(self.meta) + for i in range(1, 13): + month = calendar.month_name[i] + out = self.get_month(month) + for feat in self.out_features: + if feat not in data_dict: + data_dict[feat] = np.full((s_num, 12), np.nan) + data = out[feat].values.flatten()[self.mask] + data_dict[feat][..., i - 1] = data + return data_dict + + @property + def global_attrs(self): + """Get dictionary on how this data is prepared""" + attrs = { + "input_files": self.input_files, + "class": str(self.__class__), + "version_record": str(VERSION_RECORD), + } + return attrs + + def write_data(self, fp_out, out): + """Write monthly means for all heights to h5 file""" + if fp_out is not None: + if not os.path.exists(os.path.dirname(fp_out)): + os.makedirs(os.path.dirname(fp_out), exist_ok=True) + + if not os.path.exists(fp_out) or self.overwrite: + + OutputHandler._init_h5( + fp_out, self.time_index, self.meta, self.global_attrs + ) + with RexOutputs(fp_out, "a") as f: + + for dset, data in out.items(): + OutputHandler._ensure_dset_in_output(fp_out, dset) + f[dset] = data.T + logger.info(f"Added {dset} to {fp_out}.") + + logger.info( + f"Wrote monthly means for all out heights: {fp_out}" + ) + elif os.path.exists(fp_out): + logger.info(f"{fp_out} already exists and overwrite=False.") + + @classmethod + def run( + cls, path_pattern, in_heights, out_heights, fp_out, overwrite=False + ): + """Read vortex tif files, convert these to monthly netcdf files for all + input heights, interpolate this data to requested output heights, mask + fill values, and write all data to h5 file. + + Parameters + ---------- + path_pattern : str + Pattern for input tif files. Needs to include {month} and {height} + format keys. + in_heights : list + List of heights for input files. + out_heights : list + List of output heights used for interpolation + fp_out : str + Name of final h5 output file to write with means. + overwrite : bool + Whether to overwrite intermediate netcdf files containing the + interpolated masked monthly means. + """ + vprep = cls(path_pattern, in_heights, out_heights, overwrite=overwrite) + vprep.convert_all_tifs() + out = vprep.get_all_data() + vprep.write_data(fp_out, out) + + +class BiasCorrectUpdate: + """Class for bias correcting existing files and writing corrected files.""" + + @classmethod + def get_bc_factors(cls, bc_file, dset, month, global_scalar=1): + """Get bias correction factors for the given dset and month + + Parameters + ---------- + bc_file : str + Name of h5 file containing bias correction factors + dset : str + Name of dataset to apply bias correction factors for + month : int + Index of month to bias correct + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + + Returns + ------- + factors : ndarray + Array of bias correction factors for the given dset and month. + """ + with Resource(bc_file) as res: + logger.info( + f"Getting {dset} bias correction factors for month {month}." + ) + bc_factor = res[f"{dset}_scalar", :, month - 1] + factors = global_scalar * bc_factor + logger.info( + f"Retrieved {dset} bias correction factors for month {month}. " + f"Using global_scalar={global_scalar}." + ) + return factors + + @classmethod + def _correct_month( + cls, fh_in, month, out_file, dset, bc_file, global_scalar + ): + """Bias correct data for a given month. + + Parameters + ---------- + fh_in : Resource() + Resource handler for input file being corrected + month : int + Index of month to be corrected + out_file : str + Name of h5 file containing bias corrected data + dset : str + Name of dataset to bias correct + bc_file : str + Name of file containing bias correction factors for the given dset + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + """ + with RexOutputs(out_file, "a") as fh: + mask = fh.time_index.month == month + mask = np.arange(len(fh.time_index))[mask] + mask = slice(mask[0], mask[-1] + 1) + bc_factors = cls.get_bc_factors( + bc_file=bc_file, + dset=dset, + month=month, + global_scalar=global_scalar, + ) + logger.info(f"Applying bias correction factors for month {month}") + fh[dset, mask, :] = bc_factors * fh_in[dset, mask, :] + + @classmethod + def update_file( + cls, + in_file, + out_file, + dset, + bc_file, + global_scalar=1, + max_workers=None, + ): + """Update the in_file with bias corrected values for the given dset + and write to out_file. + + Parameters + ---------- + in_file : str + Name of h5 file containing data to bias correct + out_file : str + Name of h5 file containing bias corrected data + dset : str + Name of dataset to bias correct + bc_file : str + Name of file containing bias correction factors for the given dset + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + max_workers : int | None + Number of workers to use for parallel processing. + """ + tmp_file = out_file.replace(".h5", ".h5.tmp") + logger.info(f"Bias correcting {dset} in {in_file} with {bc_file}.") + with Resource(in_file) as fh_in: + OutputHandler._init_h5( + tmp_file, fh_in.time_index, fh_in.meta, fh_in.global_attrs + ) + OutputHandler._ensure_dset_in_output(tmp_file, dset) + + if max_workers == 1: + for i in range(1, 13): + try: + cls._correct_month( + fh_in, + month=i, + out_file=tmp_file, + dset=dset, + bc_file=bc_file, + global_scalar=global_scalar, + ) + except Exception as e: + raise RuntimeError( + f"Bias correction failed for month {i}." + ) from e + + logger.info( + f"Added {dset} for month {i} to output file " + f"{tmp_file}." + ) + else: + futures = {} + with ThreadPoolExecutor(max_workers=max_workers) as exe: + for i in range(1, 13): + future = exe.submit( + cls._correct_month, + fh_in=fh_in, + month=i, + out_file=tmp_file, + dset=dset, + bc_file=bc_file, + global_scalar=global_scalar, + ) + futures[future] = i + + logger.info( + f"Submitted bias correction for month {i} " + f"to {tmp_file}." + ) + + for future in as_completed(futures): + _ = future.result() + i = futures[future] + logger.info( + f"Completed bias correction for month {i} " + f"to {tmp_file}." + ) + + os.replace(tmp_file, out_file) + msg = f"Saved bias corrected {dset} to: {out_file}" + logger.info(msg) + + @classmethod + def run( + cls, + in_file, + out_file, + dset, + bc_file, + overwrite=False, + global_scalar=1, + max_workers=None + ): + """Run bias correction update. + + Parameters + ---------- + in_file : str + Name of h5 file containing data to bias correct + out_file : str + Name of h5 file containing bias corrected data + dset : str + Name of dataset to bias correct + bc_file : str + Name of file containing bias correction factors for the given dset + overwrite : bool + Whether to overwrite the output file if it already exists. + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + max_workers : int | None + Number of workers to use for parallel processing. + """ + if os.path.exists(out_file) and not overwrite: + logger.info( + f"{out_file} already exists and overwrite=False. Skipping." + ) + else: + if os.path.exists(out_file) and overwrite: + logger.info( + f"{out_file} exists but overwrite=True. " + f"Removing {out_file}." + ) + os.remove(out_file) + cls.update_file( + in_file, out_file, dset, bc_file, global_scalar=global_scalar, + max_workers=max_workers + ) diff --git a/sup3r/bias/bias_transforms.py b/sup3r/bias/bias_transforms.py index 71d74f568..3995a0f3b 100644 --- a/sup3r/bias/bias_transforms.py +++ b/sup3r/bias/bias_transforms.py @@ -58,7 +58,11 @@ def get_spatial_bc_factors(lat_lon, feature_name, bias_fp, threshold=0.1): raise RuntimeError(msg) msg = (f'Either {dset_scalar} or {dset_adder} not found in {bias_fp}.') - assert dset_scalar in res.dsets and dset_adder in res.dsets, msg + dsets = [dset.lower() for dset in res.dsets] + check = dset_scalar.lower() in dsets and dset_adder.lower() in dsets + assert check, msg + dset_scalar = res.dsets[dsets.index(dset_scalar.lower())] + dset_adder = res.dsets[dsets.index(dset_adder.lower())] scalar = res[dset_scalar, slice_y, slice_x] adder = res[dset_adder, slice_y, slice_x] return scalar, adder diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 9eb6e9ce7..68d0fabff 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -96,6 +96,16 @@ def input_dims(self): else: return 5 + @property + def is_5d(self): + """Check if model expects spatiotemporal input""" + return self.input_dims == 5 + + @property + def is_4d(self): + """Check if model expects spatial only input""" + return self.input_dims == 4 + # pylint: disable=E1101 def get_s_enhance_from_layers(self): """Compute factor by which model will enhance spatial resolution from diff --git a/sup3r/models/multi_step.py b/sup3r/models/multi_step.py index 28c0ffdcd..01cb40cfe 100644 --- a/sup3r/models/multi_step.py +++ b/sup3r/models/multi_step.py @@ -137,10 +137,10 @@ def _transpose_model_input(self, model, hi_res): Synthetically generated high-resolution data transposed according to the number of model input dimensions """ - if model.input_dims == 5 and len(hi_res.shape) == 4: + if model.is_5d and len(hi_res.shape) == 4: hi_res = np.transpose( hi_res, axes=(1, 2, 0, 3))[np.newaxis] - elif model.input_dims == 4 and len(hi_res.shape) == 5: + elif model.is_4d and len(hi_res.shape) == 5: msg = ('Recieved 5D input data with shape ' f'({hi_res.shape}) to a 4D model.') assert hi_res.shape[0] == 1, msg diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 89b90291c..722239ee2 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -383,9 +383,8 @@ def hr_crop_slices(self): if self._hr_crop_slices is None: self._hr_crop_slices = [] for t in self.t_hr_crop_slices: - node_slices = [] - for s in self.s_hr_crop_slices: - node_slices.append((s[0], s[1], t, slice(None))) + node_slices = [(s[0], s[1], t, slice(None)) + for s in self.s_hr_crop_slices] self._hr_crop_slices.append(node_slices) return self._hr_crop_slices @@ -805,7 +804,8 @@ def preflight(self): f'n_spatial_chunks={self.fwp_slicer.n_spatial_chunks}, ' f'n_temporal_chunks={self.fwp_slicer.n_temporal_chunks}, ' f'and n_total_chunks={self.chunks}. ' - f'{self.chunks / self.nodes} chunks per node on average.') + f'{self.chunks / self.nodes:.3f} chunks per node on ' + 'average.') logger.info(f'Using max_workers={self.max_workers}, ' f'pass_workers={self.pass_workers}, ' f'output_workers={self.output_workers}') @@ -847,7 +847,7 @@ def init_handler(self): out = self.input_handler_class(self.file_paths[0], [], target=self.target, shape=self.grid_shape, - worker_kwargs=dict(ti_workers=1)) + worker_kwargs={"ti_workers": 1}) self._init_handler = out return self._init_handler @@ -1165,17 +1165,17 @@ def update_input_handler_kwargs(self, strategy): data handler for the current forward pass chunk """ input_handler_kwargs = copy.deepcopy(strategy._input_handler_kwargs) - fwp_input_handler_kwargs = dict( - file_paths=self.file_paths, - features=self.features, - target=self.target, - shape=self.shape, - temporal_slice=self.temporal_pad_slice, - raster_file=self.raster_file, - cache_pattern=self.cache_pattern, - single_ts_files=self.single_ts_files, - handle_features=strategy.handle_features, - val_split=0.0) + fwp_input_handler_kwargs = { + "file_paths": self.file_paths, + "features": self.features, + "target": self.target, + "shape": self.shape, + "temporal_slice": self.temporal_pad_slice, + "raster_file": self.raster_file, + "cache_pattern": self.cache_pattern, + "single_ts_files": self.single_ts_files, + "handle_features": strategy.handle_features, + "val_split": 0.0} input_handler_kwargs.update(fwp_input_handler_kwargs) return input_handler_kwargs @@ -1705,13 +1705,13 @@ def _reshape_data_chunk(model, data_chunk, exo_data): 'exceeds the number of model steps') assert entry['model'] < len(models), msg current_model = models[entry['model']] - if current_model.input_dims == 4: + if current_model.is_4d: out = np.transpose(entry['data'], axes=(2, 0, 1, 3)) else: out = np.expand_dims(entry['data'], axis=0) exo_data[feature]['steps'][i]['data'] = out - if model.input_dims == 4: + if model.is_4d: i_lr_t = 0 i_lr_s = 1 data_chunk = np.transpose(data_chunk, axes=(2, 0, 1, 3)) @@ -1900,7 +1900,7 @@ def _run_parallel(cls, strategy, node_index): futures = {} start = dt.now() - pool_kws = dict(max_workers=strategy.pass_workers, loggers=['sup3r']) + pool_kws = {"max_workers": strategy.pass_workers, "loggers": ['sup3r']} with SpawnProcessPool(**pool_kws) as exe: now = dt.now() for _i, chunk_index in enumerate(strategy.node_chunks[node_index]): diff --git a/sup3r/postprocessing/collection.py b/sup3r/postprocessing/collection.py index 7a6977dca..f9f584d9b 100644 --- a/sup3r/postprocessing/collection.py +++ b/sup3r/postprocessing/collection.py @@ -1,15 +1,17 @@ # -*- coding: utf-8 -*- -"""H5 file collection.""" +"""H5/NETCDF file collection.""" import glob import logging import os import time +from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor, as_completed from warnings import warn import numpy as np import pandas as pd import psutil +import xarray as xr from gaps import Status from rex.utilities.fun_utils import get_fun_call_str from rex.utilities.loggers import init_logger @@ -22,18 +24,16 @@ logger = logging.getLogger(__name__) -class Collector(OutputMixIn): - """Sup3r H5 file collection framework""" +class BaseCollector(OutputMixIn, ABC): + """Base collector class for H5/NETCDF collection""" def __init__(self, file_paths): - """ - Parameters + """Parameters ---------- file_paths : list | str Explicit list of str file paths that will be sorted and collected - or a single string with unix-style /search/patt*ern.h5. Files - should have non-overlapping time_index dataset and fully - overlapping meta dataset. + or a single string with unix-style /search/patt*ern.. Files + should have non-overlapping time_index and spatial domains. """ if not isinstance(file_paths, list): file_paths = glob.glob(file_paths) @@ -41,6 +41,11 @@ def __init__(self, file_paths): self.data = None self.file_attrs = {} + @classmethod + @abstractmethod + def collect(cls, *args, **kwargs): + """Collect data files from a dir to one output file.""" + @classmethod def get_node_cmd(cls, config): """Get a CLI call to collect data. @@ -51,10 +56,9 @@ def get_node_cmd(cls, config): sup3r collection config with all necessary args and kwargs to run data collection. """ - import_str = ( 'from sup3r.postprocessing.collection ' - 'import Collector;\n' + f'import {cls.__name__};\n' 'from rex import init_logger;\n' 'import time;\n' 'from gaps import Status;\n' @@ -82,6 +86,98 @@ def get_node_cmd(cls, config): return cmd.replace('\\', '/') + +class CollectorNC(BaseCollector): + """Sup3r NETCDF file collection framework""" + + @classmethod + def collect( + cls, + file_paths, + out_file, + features, + log_level=None, + log_file=None, + write_status=False, + job_name=None, + overwrite=True, + res_kwargs=None + ): + """Collect data files from a dir to one output file. + + Filename requirements: + - Should end with ".nc" + + Parameters + ---------- + file_paths : list | str + Explicit list of str file paths that will be sorted and collected + or a single string with unix-style /search/patt*ern.nc. + out_file : str + File path of final output file. + features : list + List of dsets to collect + log_level : str | None + Desired log level, None will not initialize logging. + log_file : str | None + Target log file. None logs to stdout. + write_status : bool + Flag to write status file once complete if running from pipeline. + job_name : str + Job name for status file if running from pipeline. + overwrite : bool + Whether to overwrite existing output file + res_kwargs : dict | None + Dictionary of kwargs to pass to xarray.open_mfdataset. + """ + t0 = time.time() + + logger.info( + f'Initializing collection for file_paths={file_paths}' + ) + + if log_level is not None: + init_logger( + 'sup3r.preprocessing', log_file=log_file, log_level=log_level + ) + + if not os.path.exists(os.path.dirname(out_file)): + os.makedirs(os.path.dirname(out_file), exist_ok=True) + + collector = cls(file_paths) + logger.info( + 'Collecting {} files to {}'.format(len(collector.flist), out_file) + ) + if overwrite and os.path.exists(out_file): + logger.info(f'overwrite=True, removing {out_file}.') + os.remove(out_file) + + if not os.path.exists(out_file): + res_kwargs = (res_kwargs + or {"concat_dim": "Time", "combine": "nested"}) + out = xr.open_mfdataset(collector.flist, **res_kwargs) + features = [feat for feat in out if feat in features + or feat.lower() in features] + out[features].to_netcdf(out_file) + + if write_status and job_name is not None: + status = { + 'out_dir': os.path.dirname(out_file), + 'fout': out_file, + 'flist': collector.flist, + 'job_status': 'successful', + 'runtime': (time.time() - t0) / 60, + } + Status.make_single_job_file( + os.path.dirname(out_file), 'collect', job_name, status + ) + + logger.info('Finished file collection.') + + +class CollectorH5(BaseCollector): + """Sup3r H5 file collection framework""" + @classmethod def get_slices( cls, final_time_index, final_meta, new_time_index, new_meta @@ -107,7 +203,6 @@ def get_slices( col_slice : slice final_meta[col_slice] = new_meta """ - final_index = final_meta.index new_index = new_meta.index row_loc = np.where(final_time_index.isin(new_time_index))[0] @@ -205,7 +300,6 @@ def get_data( col_slice : slice final_meta[col_slice] = new_meta """ - with RexOutputs(file_path, unscale=False, mode='r') as f: f_ti = f.time_index f_meta = f.meta @@ -234,7 +328,7 @@ def get_data( warn(msg) else: - row_slice, col_slice = Collector.get_slices( + row_slice, col_slice = self.get_slices( time_index, meta, f_ti, f_meta ) @@ -246,7 +340,16 @@ def get_data( f_data = np.round(f_data) f_data = f_data.astype(dtype) - self.data[row_slice, col_slice] = f_data + + try: + self.data[row_slice, col_slice] = f_data + except Exception as e: + msg = (f'Failed to add data to self.data[{row_slice}, ' + f'{col_slice}] for feature={feature}, ' + f'file_path={file_path}, time_index={time_index}, ' + f'meta={meta}. {e}') + logger.error(msg) + raise OSError(msg) from e def _get_file_attrs(self, file): """Get meta data and time index for a single file""" @@ -498,13 +601,13 @@ def _write_flist_data( """ with RexOutputs(out_file, mode='r') as f: target_ti = f.time_index - y_write_slice, x_write_slice = Collector.get_slices( + y_write_slice, x_write_slice = self.get_slices( target_ti, target_masked_meta, time_index, subset_masked_meta, ) - Collector._ensure_dset_in_output(out_file, feature) + self._ensure_dset_in_output(out_file, feature) with RexOutputs(out_file, mode='a') as f: try: @@ -684,9 +787,7 @@ def group_time_chunks(self, file_paths, n_writes=None): for file in file_paths: t_chunk = file.split('_')[-2] file_split[t_chunk] = [*file_split.get(t_chunk, []), file] - file_chunks = [] - for files in file_split.values(): - file_chunks.append(files) + file_chunks = list(file_split.values()) logger.debug( f'Split file list into {len(file_chunks)} chunks ' @@ -698,7 +799,7 @@ def group_time_chunks(self, file_paths, n_writes=None): f'n_writes ({n_writes}) must be less than or equal ' f'to the number of temporal chunks ({len(file_chunks)}).' ) - assert n_writes < len(file_chunks), msg + assert n_writes <= len(file_chunks), msg return file_chunks def get_flist_chunks(self, file_paths, n_writes=None, join_times=False): diff --git a/sup3r/postprocessing/data_collect_cli.py b/sup3r/postprocessing/data_collect_cli.py index ac148dd53..fc8621274 100644 --- a/sup3r/postprocessing/data_collect_cli.py +++ b/sup3r/postprocessing/data_collect_cli.py @@ -2,15 +2,16 @@ """ sup3r data collection CLI entry points. """ -import click -import logging import copy +import logging + +import click +from sup3r.postprocessing.collection import CollectorH5, CollectorNC from sup3r.utilities import ModuleName -from sup3r.version import __version__ -from sup3r.postprocessing.collection import Collector from sup3r.utilities.cli import AVAILABLE_HARDWARE_OPTIONS, BaseCLI - +from sup3r.utilities.utilities import get_source_type +from sup3r.version import __version__ logger = logging.getLogger(__name__) @@ -42,13 +43,17 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None): dset_split = config.get('dset_split', False) exec_kwargs = config.get('execution_control', {}) hardware_option = exec_kwargs.pop('option', 'local') + source_type = get_source_type(config['file_paths']) + collector_types = {'h5': CollectorH5, 'nc': CollectorNC} + Collector = collector_types[source_type] configs = [config] if dset_split: configs = [] for feature in config['features']: f_config = copy.deepcopy(config) - f_out_file = config['out_file'].replace('.h5', f'_{feature}.h5') + f_out_file = config['out_file'].replace( + f'.{source_type}', f'_{feature}.{source_type}') f_job_name = config['job_name'] + f'_{feature}' f_log_file = config.get('log_file', None) if f_log_file is not None: diff --git a/sup3r/postprocessing/file_handling.py b/sup3r/postprocessing/file_handling.py index f86339f16..c1244366c 100644 --- a/sup3r/postprocessing/file_handling.py +++ b/sup3r/postprocessing/file_handling.py @@ -560,7 +560,7 @@ def _write_output(cls, data, features, lat_lon, times, out_file, List of coordinate indices used to label each lat lon pair and to help with spatial chunk data collection """ - coords = {'Times': (['Time'], times), + coords = {'Times': (['Time'], [str(t).encode('utf-8') for t in times]), 'XLAT': (['south_north', 'east_west'], lat_lon[..., 0]), 'XLONG': (['south_north', 'east_west'], lat_lon[..., 1])} diff --git a/sup3r/preprocessing/batch_handling.py b/sup3r/preprocessing/batch_handling.py index 6390f71e7..8f5563119 100644 --- a/sup3r/preprocessing/batch_handling.py +++ b/sup3r/preprocessing/batch_handling.py @@ -499,14 +499,14 @@ def __init__(self, ] logger.info(f'Initializing BatchHandler with ' - f'{len(self.data_handlers)} data handlers with handler' + f'{len(self.data_handlers)} data handlers with handler ' f'weights={self.handler_weights}, smoothing={smoothing}. ' f'Using stats_workers={self.stats_workers}, ' f'norm_workers={self.norm_workers}, ' f'load_workers={self.load_workers}.') now = dt.now() - self.parallel_load() + self.load_handler_data() logger.debug(f'Finished loading data of shape {self.shape} ' f'for BatchHandler in {dt.now() - now}.') log_mem(logger, log_level='INFO') @@ -655,8 +655,8 @@ def parallel_normalization(self): logger.debug(f'{i+1} out of {len(futures)} data handlers' ' normalized.') - def parallel_load(self): - """Load data handler data in parallel""" + def load_handler_data(self): + """Load data handler data in parallel or serial""" logger.info(f'Loading {len(self.data_handlers)} data handlers') max_workers = self.load_workers if max_workers == 1: @@ -686,7 +686,7 @@ def parallel_load(self): logger.debug(f'{i+1} out of {len(futures)} handlers ' 'loaded.') - def parallel_stats(self): + def _get_stats(self): """Get standard deviations and means for training features in parallel.""" logger.info(f'Calculating stats for {len(self.training_features)} ' @@ -788,7 +788,7 @@ def get_stats(self): now = dt.now() logger.info('Calculating stdevs/means.') - self.parallel_stats() + self._get_stats() logger.info(f'Finished calculating stats in {dt.now() - now}.') self.cache_stats() diff --git a/sup3r/preprocessing/conditional_moment_batch_handling.py b/sup3r/preprocessing/conditional_moment_batch_handling.py index a1418b101..4d27a92da 100644 --- a/sup3r/preprocessing/conditional_moment_batch_handling.py +++ b/sup3r/preprocessing/conditional_moment_batch_handling.py @@ -902,7 +902,7 @@ def __init__( ) now = dt.now() - self.parallel_load() + self.load_handler_data() logger.debug( f'Finished loading data of shape {self.shape} ' f'for BatchHandler in {dt.now() - now}.' diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index 0fa62d0fd..1ffb48b2d 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -335,6 +335,8 @@ def _run_data_init_if_needed(self): extraction""" if any(self.features): self.data = self.run_all_data_init() + mask = np.isinf(self.data) + self.data[mask] = np.nan nan_perc = 100 * np.isnan(self.data).sum() / self.data.size if nan_perc > 0: msg = 'Data has {:.3f}% NaN values!'.format(nan_perc) @@ -1019,11 +1021,11 @@ def load_cached_data(self, with_split=True): logger.warning(msg) warnings.warn(msg) - logger.debug('Splitting data into training / validation sets ' - f'({1 - self.val_split}, {self.val_split}) ' - f'for {self.input_file_info}') + if with_split and self.val_split > 0: + logger.debug('Splitting data into training / validation sets ' + f'({1 - self.val_split}, {self.val_split}) ' + f'for {self.input_file_info}') - if with_split: self.data, self.val_data = self.split_data( val_split=self.val_split, shuffle_time=self.shuffle_time) @@ -1261,8 +1263,9 @@ def lin_bc(self, bc_files, threshold=0.1): dset_scalar = f'{feature}_scalar' dset_adder = f'{feature}_adder' with Resource(fp) as res: - check = (dset_scalar in res.dsets - and dset_adder in res.dsets) + dsets = [dset.lower() for dset in res.dsets] + check = (dset_scalar.lower() in dsets + and dset_adder.lower() in dsets) if feature not in completed and check: scalar, adder = get_spatial_bc_factors( lat_lon=self.lat_lon, diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index a35b6012c..e22507b57 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -24,8 +24,8 @@ class DualDataHandler(CacheHandlingMixIn, TrainingPrepMixIn): def __init__(self, hr_handler, lr_handler, - regrid_cache_pattern=None, - overwrite_regrid_cache=False, + cache_pattern=None, + overwrite_cache=False, regrid_workers=1, load_cached=True, shuffle_time=False, @@ -41,9 +41,9 @@ def __init__(self, DataHandler for high_res data lr_handler : DataHandler DataHandler for low_res data - regrid_cache_pattern : str + cache_pattern : str Pattern for files to use for saving regridded ERA data. - overwrite_regrid_cache : bool + overwrite_cache : bool Whether to overwrite regrid cache regrid_workers : int | None Number of workers to use for regridding routine. @@ -63,10 +63,10 @@ def __init__(self, self.t_enhance = t_enhance self.lr_dh = lr_handler self.hr_dh = hr_handler - self._cache_pattern = regrid_cache_pattern + self._cache_pattern = cache_pattern self._cached_features = None self._noncached_features = None - self.overwrite_cache = overwrite_regrid_cache + self.overwrite_cache = overwrite_cache self.val_split = val_split self.current_obs_index = None self.load_cached = load_cached @@ -162,24 +162,27 @@ def normalize(self, means, stdevs): dimensions (features) array of means for all features with same ordering as data features """ - logger.info('Normalizing low resolution data.') + logger.info('Normalizing low resolution data features=' + f'{self.features}') self._normalize(data=self.lr_data, val_data=self.lr_val_data, means=means, stds=stdevs, max_workers=self.lr_dh.norm_workers) - logger.info('Normalizing high resolution data.') + logger.info('Normalizing high resolution data features=' + f'{self.output_features}') + indices = [self.features.index(f) for f in self.output_features] self._normalize(data=self.hr_data, val_data=self.hr_val_data, - means=means, - stds=stdevs, + means=means[indices], + stds=stdevs[indices], max_workers=self.hr_dh.norm_workers) @property def output_features(self): """Get list of output features. e.g. those that are returned by a GAN""" - return self.hr_dh.output_features + return self.lr_dh.output_features @property def train_only_features(self): @@ -219,7 +222,7 @@ def _run_pair_checks(self, hr_handler, lr_handler): assert hr_handler.val_split == 0 and lr_handler.val_split == 0, msg msg = ('Handlers have incompatible number of features. ' f'({hr_handler.features} vs {lr_handler.features})') - assert hr_handler.features == lr_handler.features, msg + assert hr_handler.features == self.output_features, msg hr_shape = hr_handler.sample_shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, @@ -233,21 +236,21 @@ def _run_pair_checks(self, hr_handler, lr_handler): hr_shape = self.hr_data.shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, - hr_shape[2] // self.t_enhance, hr_shape[3]) + hr_shape[2] // self.t_enhance) msg = (f'hr_data.shape {self.hr_data.shape} and ' f'lr_data.shape {self.lr_data.shape} are ' f'incompatible. Must be {hr_shape} and {lr_shape}.') - assert self.lr_data.shape == lr_shape, msg + assert self.lr_data.shape[:-1] == lr_shape, msg if self.lr_val_data is not None and self.hr_val_data is not None: hr_shape = self.hr_val_data.shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, - hr_shape[2] // self.t_enhance, hr_shape[3]) + hr_shape[2] // self.t_enhance) msg = (f'hr_val_data.shape {self.hr_val_data.shape} ' f'and lr_val_data.shape {self.lr_val_data.shape}' f' are incompatible. Must be {hr_shape} and {lr_shape}.') - assert self.lr_val_data.shape == lr_shape, msg + assert self.lr_val_data.shape[:-1] == lr_shape, msg @property def grid_mem(self): @@ -378,8 +381,8 @@ def cache_files(self): cache_files = self._get_cache_file_names(self.cache_pattern, grid_shape=self.lr_grid_shape, time_index=self.lr_time_index, - target=self.hr_dh.target, - features=self.hr_dh.features) + target=self.lr_dh.target, + features=self.lr_dh.features) return cache_files @property @@ -499,7 +502,8 @@ def get_next(self): for s in lr_obs_idx[2:-1]: hr_obs_idx.append( slice(s.start * self.t_enhance, s.stop * self.t_enhance)) - hr_obs_idx.append(lr_obs_idx[-1]) + + hr_obs_idx.append(np.arange(len(self.output_features))) hr_obs_idx = tuple(hr_obs_idx) self.current_obs_index = { 'hr_index': hr_obs_idx, diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 6d8c6bdfe..581bd8d0a 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -1,6 +1,9 @@ """Sup3r topography utilities""" import logging +import os +import pickle +import shutil from abc import ABC, abstractmethod import numpy as np @@ -12,7 +15,7 @@ from sup3r.postprocessing.file_handling import OutputHandler from sup3r.preprocessing.data_handling.h5_data_handling import DataHandlerH5 from sup3r.preprocessing.data_handling.nc_data_handling import DataHandlerNC -from sup3r.utilities.utilities import get_source_type +from sup3r.utilities.utilities import generate_random_string, get_source_type logger = logging.getLogger(__name__) @@ -37,9 +40,11 @@ def __init__(self, raster_file=None, max_delta=20, input_handler=None, - ti_workers=1): - """ - Parameters + cache_data=True, + cache_dir='./exo_cache/', + ti_workers=1, + res_kwargs=None): + """Parameters ---------- file_paths : str | list A single source h5 file to extract raster data from or a list @@ -100,6 +105,12 @@ def __init__(self, data handler class to use for input data. Provide a string name to match a class in data_handling.py. If None the correct handler will be guessed based on file type and time series properties. + cache_data : bool + Flag to cache exogeneous data in /exo_cache/ this can + speed up forward passes with large temporal extents when the exo + data is time independent. + cache_dir : str + Directory for storing cache data. Default is './exo_cache' ti_workers : int | None max number of workers to use to get full time index. Useful when there are many input files each with a single time step. If this is @@ -107,8 +118,10 @@ def __init__(self, parallel and then concatenated to get the full time index. If input files do not all have time indices or if there are few input files this should be set to one. + res_kwargs : dict | None + Dictionary of kwargs passed to lowest level resource handler. e.g. + xr.open_dataset(file_paths, **res_kwargs) """ - logger.info(f'Initializing {self.__class__.__name__} utility.') self.ti_workers = ti_workers @@ -122,6 +135,12 @@ def __init__(self, self._source_lat_lon = None self._hr_time_index = None self._src_time_index = None + self.cache_data = cache_data + self.cache_dir = cache_dir + self.temporal_slice = temporal_slice + self.target = target + self.shape = shape + self.res_kwargs = res_kwargs if input_handler is None: in_type = get_source_type(file_paths) @@ -152,6 +171,7 @@ def __init__(self, raster_file=raster_file, max_delta=max_delta, worker_kwargs=dict(ti_workers=ti_workers), + res_kwargs=self.res_kwargs ) @property @@ -159,10 +179,52 @@ def __init__(self, def source_data(self): """Get the 1D array of source data from the exo_source_h5""" + def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, + t_agg_factor): + """Get cache file name + + Parameters + ---------- + feature : str + Name of feature to get cache file for + s_enhance : int + Spatial enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + t_enhance : int + Temporal enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + s_agg_factor : int + Factor by which to aggregate the exo_source data to the spatial + resolution of the file_paths input enhanced by s_enhance. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the temporal + resolution of the file_paths input enhanced by t_enhance. + + Returns + ------- + cache_fp : str + Name of cache file + """ + tsteps = (None if self.temporal_slice is None + or self.temporal_slice.start is None + or self.temporal_slice.stop is None + else self.temporal_slice.stop - self.temporal_slice.start) + fn = f'exo_{feature}_{self.target}_{self.shape},{tsteps}' + fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_' + fn += f'{t_enhance}x.pkl' + fn = fn.replace('(', '').replace(')', '') + fn = fn.replace('[', '').replace(']', '') + fn = fn.replace(',', 'x').replace(' ', '') + cache_fp = os.path.join(self.cache_dir, fn) + if self.cache_data: + os.makedirs(self.cache_dir, exist_ok=True) + return cache_fp + @property def source_temporal_slice(self): """Get the temporal slice for the exo_source data corresponding to the - input file temporal slice""" + input file temporal slice + """ start_index = self.source_time_index.get_indexer( [self.input_handler.hr_time_index[0]], method='nearest')[0] end_index = self.source_time_index.get_indexer( @@ -266,6 +328,34 @@ def nn(self): @property def data(self): + """Get a raster of source values corresponding to the + high-resolution grid (the file_paths input grid * s_enhance * + t_enhance). The shape is (lats, lons, temporal, 1) + """ + cache_fp = self.get_cache_file(feature=self.__class__.__name__, + s_enhance=self._s_enhance, + t_enhance=self._t_enhance, + s_agg_factor=self._s_agg_factor, + t_agg_factor=self._t_agg_factor) + tmp_fp = cache_fp + f'.{generate_random_string(10)}.tmp' + if os.path.exists(cache_fp): + with open(cache_fp, 'rb') as f: + data = pickle.load(f) + + else: + data = self.get_data() + + if self.cache_data: + with open(tmp_fp, 'wb') as f: + pickle.dump(data, f) + shutil.move(tmp_fp, cache_fp) + + if data.shape[-1] == 1 and self.hr_shape[-1] > 1: + data = np.repeat(data, self.hr_shape[-1], axis=-1) + + return data[..., np.newaxis] + + def get_data(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance * t_enhance). The shape is (lats, lons, temporal, 1) @@ -278,7 +368,7 @@ def data(self): hr_data.append(out[..., np.newaxis]) hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) logger.info('Finished mapping raster from {}'.format(self._exo_source)) - return hr_data[..., np.newaxis] + return hr_data @classmethod def get_exo_raster(cls, @@ -293,7 +383,9 @@ def get_exo_raster(cls, temporal_slice=None, raster_file=None, max_delta=20, - input_handler=None): + input_handler=None, + cache_data=True, + cache_dir='./exo_cache/'): """Get the exo feature raster corresponding to the spatially enhanced grid from the file_paths input @@ -356,6 +448,12 @@ class will output a topography raster corresponding to the data handler class to use for input data. Provide a string name to match a class in data_handling.py. If None the correct handler will be guessed based on file type and time series properties. + cache_data : bool + Flag to cache exogeneous data in /exo_cache/ this can + speed up forward passes with large temporal extents when the exo + data is time independent. + cache_dir : str + Directory for storing cache data. Default is './exo_cache' Returns ------- @@ -377,7 +475,9 @@ class will output a topography raster corresponding to the temporal_slice=temporal_slice, raster_file=raster_file, max_delta=max_delta, - input_handler=input_handler) + input_handler=input_handler, + cache_data=cache_data, + cache_dir=cache_dir) return exo.data @@ -389,8 +489,7 @@ def source_data(self): """Get the 1D array of elevation data from the exo_source_h5""" with Resource(self._exo_source) as res: elev = res.get_meta_arr('elevation') - elev = np.repeat(elev[:, np.newaxis], self.hr_shape[-1], axis=-1) - return elev + return elev[:, np.newaxis] @property def source_lat_lon(self): @@ -407,8 +506,7 @@ def source_time_index(self): self._src_time_index = res.time_index return self._src_time_index - @property - def data(self): + def get_data(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance * t_enhance). The shape is (lats, lons, temporal, 1) @@ -417,26 +515,61 @@ def data(self): hr_data = [] for j in range(self._s_agg_factor): out = self.source_data[nn[:, j]] - out = out.reshape(self.hr_shape) + out = out.reshape((*self.hr_shape[:-1], -1)) hr_data.append(out[..., np.newaxis]) hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) logger.info('Finished mapping raster from {}'.format(self._exo_source)) - return hr_data[..., np.newaxis] + return hr_data + + def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, + t_agg_factor): + """Get cache file name. This uses a time independent naming convention. + + Parameters + ---------- + feature : str + Name of feature to get cache file for + s_enhance : int + Spatial enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + t_enhance : int + Temporal enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + s_agg_factor : int + Factor by which to aggregate the exo_source data to the spatial + resolution of the file_paths input enhanced by s_enhance. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the temporal + resolution of the file_paths input enhanced by t_enhance. + + Returns + ------- + cache_fp : str + Name of cache file + """ + fn = f'exo_{feature}_{self.target}_{self.shape}' + fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_' + fn += f'{t_enhance}x.pkl' + fn = fn.replace('(', '').replace(')', '') + fn = fn.replace('[', '').replace(']', '') + fn = fn.replace(',', 'x').replace(' ', '') + cache_fp = os.path.join(self.cache_dir, fn) + if self.cache_data: + os.makedirs(self.cache_dir, exist_ok=True) + return cache_fp class TopoExtractNC(TopoExtractH5): """TopoExtract for netCDF files""" def __init__(self, *args, **kwargs): - """ - Parameters + """Parameters ---------- args : list Same positional arguments as TopoExtract kwargs : dict Same keyword arguments as TopoExtract """ - super().__init__(*args, **kwargs) logger.info('Getting topography for full domain from ' f'{self._exo_source}') @@ -450,8 +583,8 @@ def __init__(self, *args, **kwargs): @property def source_data(self): """Get the 1D array of elevation data from the exo_source_nc""" - elev = self.source_handler.data.reshape((-1, self.lr_shape[-1])) - return elev + elev = self.source_handler.data[..., 0, 0].flatten() + return elev[..., np.newaxis] @property def source_lat_lon(self): @@ -469,12 +602,11 @@ def source_data(self): return SolarPosition(self.hr_time_index, self.hr_lat_lon.reshape((-1, 2))).zenith.T - @property - def data(self): + def get_data(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance * t_enhance). The shape is (lats, lons, temporal, 1) """ hr_data = self.source_data.reshape(self.hr_shape) logger.info('Finished computing SZA data') - return hr_data[..., np.newaxis] + return hr_data diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index 5ff81b831..3149887e4 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -1,9 +1,6 @@ """Sup3r exogenous data handling""" import logging -import os -import pickle import re -import shutil from typing import ClassVar import numpy as np @@ -209,7 +206,8 @@ def __init__(self, input_handler=None, exo_handler=None, cache_data=True, - cache_dir='./exo_cache'): + cache_dir='./exo_cache', + res_kwargs=None): """ Parameters ---------- @@ -282,6 +280,9 @@ def __init__(self, speed up forward passes with large temporal extents cache_dir : str Directory for storing cache data. Default is './exo_cache' + res_kwargs : dict | None + Dictionary of kwargs passed to lowest level resource handler. e.g. + xr.open_dataset(file_paths, **res_kwargs) """ self.feature = feature @@ -300,6 +301,7 @@ def __init__(self, self.cache_data = cache_data self.cache_dir = cache_dir self.data = {feature: {'steps': []}} + self.res_kwargs = res_kwargs self.input_check() agg_enhance = self._get_all_agg_and_enhancement() @@ -541,42 +543,6 @@ def _get_all_agg_and_enhancement(self): for step in self.steps] return agg_enhance_dict - def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, - t_agg_factor): - """Get cache file name - - Parameters - ---------- - feature : str - Name of feature to get cache file for - s_enhance : int - Spatial enhancement for this exogeneous data step (cumulative for - all model steps up to the current step). - t_enhance : int - Temporal enhancement for this exogeneous data step (cumulative for - all model steps up to the current step). - s_agg_factor : int - Factor by which to aggregate the exo_source data to the spatial - resolution of the file_paths input enhanced by s_enhance. - t_agg_factor : int - Factor by which to aggregate the exo_source data to the temporal - resolution of the file_paths input enhanced by t_enhance. - - Returns - ------- - cache_fp : str - Name of cache file - """ - fn = f'exo_{feature}_{self.target}_{self.shape}_sagg{s_agg_factor}_' - fn += f'tagg{t_agg_factor}_{s_enhance}x_{t_enhance}x.pkl' - fn = fn.replace('(', '').replace(')', '') - fn = fn.replace('[', '').replace(']', '') - fn = fn.replace(',', 'x').replace(' ', '') - cache_fp = os.path.join(self.cache_dir, fn) - if self.cache_data: - os.makedirs(self.cache_dir, exist_ok=True) - return cache_fp - def get_exo_data(self, feature, s_enhance, t_enhance, s_agg_factor, t_agg_factor): """Get the exogenous topography data @@ -605,35 +571,23 @@ def get_exo_data(self, feature, s_enhance, t_enhance, s_agg_factor, lon, temporal) """ - cache_fp = self.get_cache_file(feature=feature, - s_enhance=s_enhance, - t_enhance=t_enhance, - s_agg_factor=s_agg_factor, - t_agg_factor=t_agg_factor) - tmp_fp = cache_fp + '.tmp' - if os.path.exists(cache_fp): - with open(cache_fp, 'rb') as f: - data = pickle.load(f) - - else: - exo_handler = self.get_exo_handler(feature, self.source_file, - self.exo_handler) - data = exo_handler(self.file_paths, - self.source_file, - s_enhance=s_enhance, - t_enhance=t_enhance, - s_agg_factor=s_agg_factor, - t_agg_factor=t_agg_factor, - target=self.target, - shape=self.shape, - temporal_slice=self.temporal_slice, - raster_file=self.raster_file, - max_delta=self.max_delta, - input_handler=self.input_handler).data - if self.cache_data: - with open(tmp_fp, 'wb') as f: - pickle.dump(data, f) - shutil.move(tmp_fp, cache_fp) + exo_handler = self.get_exo_handler(feature, self.source_file, + self.exo_handler) + data = exo_handler(self.file_paths, + self.source_file, + s_enhance=s_enhance, + t_enhance=t_enhance, + s_agg_factor=s_agg_factor, + t_agg_factor=t_agg_factor, + target=self.target, + shape=self.shape, + temporal_slice=self.temporal_slice, + raster_file=self.raster_file, + max_delta=self.max_delta, + input_handler=self.input_handler, + cache_data=self.cache_data, + cache_dir=self.cache_dir, + res_kwargs=self.res_kwargs).data return data @classmethod diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index c52c3640d..781cd5f67 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -13,6 +13,7 @@ import numpy as np import pandas as pd +import psutil from scipy.stats import mode from sup3r.utilities.utilities import ( @@ -31,6 +32,7 @@ class CacheHandlingMixIn: """Collection of methods for handling data caching and loading""" def __init__(self): + """Initialize common attributes""" self._noncached_features = None self._cache_pattern = None self._cache_files = None @@ -282,10 +284,13 @@ def _load_single_cached_feature(self, fp, cache_files, features, Error raised if shape conflicts with requested shape """ idx = cache_files.index(fp) - assert features[idx].lower() in fp.lower() + msg = f'{features[idx].lower()} not found in {fp.lower()}.' + assert features[idx].lower() in fp.lower(), msg fp = ignore_case_path_fetch(fp) - logger.info(f'Loading {features[idx]} from ' - f'{fp}.') + mem = psutil.virtual_memory() + logger.info(f'Loading {features[idx]} from {fp}. Current memory ' + f'usage is {mem.used / 1e9:.3f} GB out of ' + f'{mem.total / 1e9:.3f} GB total.') out = None with open(fp, 'rb') as fh: @@ -909,6 +914,10 @@ class TrainingPrepMixIn: """Collection of training related methods. e.g. Training + Validation splitting, normalization""" + def __init__(self): + """Initialize common attributes""" + self.features = None + @classmethod def _split_data_indices(cls, data, @@ -1029,11 +1038,13 @@ def _normalize_data(self, data, val_data, feature_index, mean, std): val_data[..., feature_index] /= std data[..., feature_index] /= std else: - msg = ( - f'Standard Deviation is zero for feature #{feature_index + 1}') + msg = ('Standard Deviation is zero for ' + f'{self.features[feature_index]}') logger.warning(msg) warnings.warn(msg) + logger.info(f'Finished normalizing {self.features[feature_index]}.') + def _normalize(self, data, val_data, means, stds, max_workers=None): """Normalize all data features @@ -1054,6 +1065,10 @@ def _normalize(self, data, val_data, means, stds, max_workers=None): max_workers : int | None Number of workers to use in thread pool for nomalization. """ + msg = f'Received {len(means)} means for {data.shape[-1]} features' + assert len(means) == data.shape[-1], msg + msg = f'Received {len(stds)} stds for {data.shape[-1]} features' + assert len(stds) == data.shape[-1], msg logger.info(f'Normalizing {data.shape[-1]} features.') if max_workers == 1: for i in range(data.shape[-1]): diff --git a/sup3r/preprocessing/dual_batch_handling.py b/sup3r/preprocessing/dual_batch_handling.py index ac1901d9a..5244374e8 100644 --- a/sup3r/preprocessing/dual_batch_handling.py +++ b/sup3r/preprocessing/dual_batch_handling.py @@ -3,9 +3,11 @@ import numpy as np -from sup3r.preprocessing.batch_handling import (Batch, BatchHandler, - ValidationData, - ) +from sup3r.preprocessing.batch_handling import ( + Batch, + BatchHandler, + ValidationData, +) from sup3r.utilities.utilities import uniform_box_sampler, uniform_time_sampler logger = logging.getLogger(__name__) @@ -143,6 +145,16 @@ class DualBatchHandler(BatchHandler): BATCH_CLASS = Batch VAL_CLASS = DualValidationData + @property + def lr_features(self): + """Features in low res batch.""" + return self.data_handlers[0].lr_dh.features + + @property + def hr_features(self): + """Features in high res batch.""" + return self.data_handlers[0].lr_dh.output_features + @property def hr_sample_shape(self): """Get sample shape for high_res data""" @@ -173,19 +185,19 @@ def __next__(self): handler = self.data_handlers[handler_index] high_res = np.zeros((self.batch_size, self.hr_sample_shape[0], self.hr_sample_shape[1], - self.hr_sample_shape[2], self.shape[-1]), + self.hr_sample_shape[2], + len(self.hr_features)), dtype=np.float32) low_res = np.zeros((self.batch_size, self.lr_sample_shape[0], self.lr_sample_shape[1], - self.lr_sample_shape[2], self.shape[-1]), + self.lr_sample_shape[2], + len(self.lr_features)), dtype=np.float32) for i in range(self.batch_size): high_res[i, ...], low_res[i, ...] = handler.get_next() self.current_batch_indices.append(handler.current_obs_index) - high_res = self.BATCH_CLASS.reduce_features( - high_res, self.output_features_ind) batch = self.BATCH_CLASS(low_res=low_res, high_res=high_res) self._i += 1 @@ -220,10 +232,12 @@ def __next__(self): self.current_handler_index = handler_index handler = self.data_handlers[handler_index] high_res = np.zeros((self.batch_size, self.hr_sample_shape[0], - self.hr_sample_shape[1], self.shape[-1]), + self.hr_sample_shape[1], + len(self.hr_features)), dtype=np.float32) low_res = np.zeros((self.batch_size, self.lr_sample_shape[0], - self.lr_sample_shape[1], self.shape[-1]), + self.lr_sample_shape[1], + len(self.lr_features)), dtype=np.float32) for i in range(self.batch_size): @@ -232,8 +246,6 @@ def __next__(self): low_res[i, ...] = lr[..., 0, :] self.current_batch_indices.append(handler.current_obs_index) - high_res = self.BATCH_CLASS.reduce_features( - high_res, self.output_features_ind) batch = self.BATCH_CLASS(low_res=low_res, high_res=high_res) self._i += 1 diff --git a/sup3r/preprocessing/feature_handling.py b/sup3r/preprocessing/feature_handling.py index c5f712e34..9a1ac6eb2 100644 --- a/sup3r/preprocessing/feature_handling.py +++ b/sup3r/preprocessing/feature_handling.py @@ -1,5 +1,4 @@ -""" -Sup3r feature handling module. +"""Sup3r feature handling module. @author: bbenton """ @@ -35,7 +34,8 @@ class DerivedFeature(ABC): """Abstract class for special features which need to be derived from raw - features""" + features + """ @classmethod @abstractmethod @@ -86,7 +86,6 @@ def compute(cls, data, height=None): Clearsky ratio, e.g. the all-sky ghi / the clearsky ghi. NaN where nighttime. """ - # need to use a nightime threshold of 1 W/m2 because cs_ghi is stored # in integer format and weird binning patterns happen in the clearsky # ratio and cloud mask between 0 and 1 W/m2 and sunrise/sunset @@ -104,7 +103,8 @@ def compute(cls, data, height=None): class ClearSkyRatioCC(DerivedFeature): """Clear Sky Ratio feature class for computing from climate change netcdf - data""" + data + """ @classmethod def inputs(cls, feature): @@ -142,7 +142,6 @@ def compute(cls, data, height=None): Clearsky ratio, e.g. the all-sky ghi / the clearsky ghi. This is assumed to be daily average data for climate change source data. """ - cs_ratio = data['rsds'] / data['clearsky_ghi'] cs_ratio = np.minimum(cs_ratio, 1) cs_ratio = np.maximum(cs_ratio, 0) @@ -189,7 +188,6 @@ def compute(cls, data, height=None): nighttime. Data is float32 so it can be normalized without any integer weirdness. """ - # need to use a nightime threshold of 1 W/m2 because cs_ghi is stored # in integer format and weird binning patterns happen in the clearsky # ratio and cloud mask between 0 and 1 W/m2 and sunrise/sunset @@ -208,7 +206,8 @@ def compute(cls, data, height=None): class PotentialTempNC(DerivedFeature): """Potential Temperature feature class for NETCDF data. Needed since T is - perturbation potential temperature.""" + perturbation potential temperature. + """ @classmethod def inputs(cls, feature): @@ -239,7 +238,8 @@ def compute(cls, data, height): class TempNC(DerivedFeature): """Temperature feature class for NETCDF data. Needed since T is potential - temperature not standard temp.""" + temperature not standard temp. + """ @classmethod def inputs(cls, feature): @@ -271,7 +271,8 @@ def compute(cls, data, height): class PressureNC(DerivedFeature): """Pressure feature class for NETCDF data. Needed since P is perturbation - pressure.""" + pressure. + """ @classmethod def inputs(cls, feature): @@ -302,7 +303,8 @@ def compute(cls, data, height): class BVFreqSquaredNC(DerivedFeature): """BVF Squared feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -355,7 +357,6 @@ def inputs(cls, feature): list List of required features for computing RMOL """ - assert feature == 'RMOL' features = ['UST', 'HFX'] return features @@ -431,7 +432,8 @@ def compute(cls, data, height): class BVFreqSquaredH5(DerivedFeature): """BVF Squared feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -495,7 +497,6 @@ def inputs(cls, feature): list List of required features for computing windspeed """ - height = Feature.get_height(feature) features = [f'U_{height}m', f'V_{height}m', 'lat_lon'] return features @@ -516,7 +517,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - ws, _ = invert_uv(data[f'U_{height}m'], data[f'V_{height}m'], data['lat_lon']) return ws @@ -539,7 +539,6 @@ def inputs(cls, feature): list List of required features for computing windspeed """ - height = Feature.get_height(feature) features = [f'U_{height}m', f'V_{height}m', 'lat_lon'] return features @@ -560,7 +559,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - _, wd = invert_uv(data[f'U_{height}m'], data[f'V_{height}m'], data['lat_lon']) return wd @@ -585,7 +583,6 @@ def inputs(cls, feature): list List of required features for computing REWS """ - rotor_center = Feature.get_height(feature) if rotor_center is None: heights = cls.HEIGHTS @@ -612,7 +609,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - if height is None: heights = cls.HEIGHTS else: @@ -643,7 +639,6 @@ def inputs(cls, feature): list List of required features for computing Veer """ - height = Feature.get_height(feature) heights = [int(height), int(height) + 20] features = [] @@ -667,7 +662,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - heights = [int(height), int(height) + 20] shear = np.cos(np.radians(data[f'winddirection_{int(height) + 20}m'])) shear -= np.cos(np.radians(data[f'winddirection_{int(height)}m'])) @@ -694,7 +688,6 @@ def inputs(cls, feature): list List of required features for computing REWS """ - rotor_center = Feature.get_height(feature) if rotor_center is None: heights = cls.HEIGHTS @@ -722,7 +715,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - if height is None: heights = cls.HEIGHTS else: @@ -829,7 +821,8 @@ def compute(cls, data, height): class UWind(DerivedFeature): """U wind component feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -876,7 +869,8 @@ def compute(cls, data, height): class Vorticity(DerivedFeature): """Vorticity feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -919,7 +913,8 @@ def compute(cls, data, height): class VWind(DerivedFeature): """V wind component feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -1047,14 +1042,16 @@ def compute(cls, data, height): class TasMin(Tas): """Daily min air temperature near surface variable from climate change nc - files""" + files + """ CC_FEATURE_NAME = 'tasmin' class TasMax(Tas): """Daily max air temperature near surface variable from climate change nc - files""" + files + """ CC_FEATURE_NAME = 'tasmax' @@ -1079,7 +1076,6 @@ def compute(file_paths, raster_index): lat lon array (spatial_1, spatial_2, 2) """ - fp = file_paths if isinstance(file_paths, str) else file_paths[0] handle = xr.open_dataset(fp) valid_vars = set(handle.variables) @@ -1166,7 +1162,8 @@ def compute(file_paths, raster_index): class Feature: """Class to simplify feature computations. Stores feature height, feature - basename, name of feature in handle""" + basename, name of feature in handle + """ def __init__(self, feature, handle): """Takes a feature (e.g. U_100m) and gets the height (100), basename @@ -1204,7 +1201,6 @@ def get_basename(feature): str feature basename """ - height = Feature.get_height(feature) pressure = Feature.get_pressure(feature) if height is not None or pressure is not None: @@ -1260,7 +1256,8 @@ def get_pressure(feature): class FeatureHandler: """Feature Handler with cache for previously loaded features used in other - calculations """ + calculations + """ FEATURE_REGISTRY: ClassVar[dict] = {} @@ -1280,7 +1277,6 @@ def valid_handle_features(cls, features, handle_features): bool Whether feature basename is in handle """ - if features is None: return False @@ -1304,7 +1300,6 @@ def valid_input_features(cls, features, handle_features): bool Whether feature basename is in handle """ - if features is None: return False @@ -1440,16 +1435,13 @@ def serial_extract(cls, file_paths, raster_index, time_chunks, keys for features. e.g. data[chunk_number][feature] = array. (spatial_1, spatial_2, temporal) """ - data = defaultdict(dict) for t, t_slice in enumerate(time_chunks): for f in input_features: data[t][f] = cls.extract_feature(file_paths, raster_index, f, t_slice, **kwargs) - interval = int(np.ceil(len(time_chunks) / 10)) - if t % interval == 0: - logger.debug(f'{t+1} out of {len(time_chunks)} feature ' - 'chunks extracted.') + logger.debug(f'{t+1} out of {len(time_chunks)} feature ' + 'chunks extracted.') return data @classmethod @@ -1511,7 +1503,6 @@ def parallel_extract(cls, f' time chunks of shape ({shape[0]}, {shape[1]}, ' f'{time_shape}) for {len(input_features)} features') - interval = int(np.ceil(len(futures) / 10)) for i, future in enumerate(as_completed(futures)): v = futures[future] try: @@ -1521,12 +1512,11 @@ def parallel_extract(cls, f' {v["feature"]}') logger.error(msg) raise RuntimeError(msg) from e - if i % interval == 0: - mem = psutil.virtual_memory() - logger.info(f'{i+1} out of {len(futures)} feature ' - 'chunks extracted. Current memory usage is ' - f'{mem.used / 1e9:.3f} GB out of ' - f'{mem.total / 1e9:.3f} GB total.') + mem = psutil.virtual_memory() + logger.info(f'{i+1} out of {len(futures)} feature ' + 'chunks extracted. Current memory usage is ' + f'{mem.used / 1e9:.3f} GB out of ' + f'{mem.total / 1e9:.3f} GB total.') return data @@ -1613,7 +1603,6 @@ def serial_compute(cls, data, file_paths, raster_index, time_chunks, e.g. data[chunk_number][feature] = array. (spatial_1, spatial_2, temporal) """ - if len(derived_features) == 0: return data @@ -1628,10 +1617,8 @@ def serial_compute(cls, data, file_paths, raster_index, time_chunks, file_paths=file_paths, raster_index=raster_index) cls.pop_old_data(data, t, all_features) - interval = int(np.ceil(len(time_chunks) / 10)) - if t % interval == 0: - logger.debug(f'{t+1} out of {len(time_chunks)} feature ' - 'chunks computed.') + logger.debug(f'{t+1} out of {len(time_chunks)} feature ' + 'chunks computed.') return data @@ -1709,18 +1696,16 @@ def parallel_compute(cls, f' time chunks of shape ({shape[0]}, {shape[1]}, ' f'{time_shape}) for {len(derived_features)} features') - interval = int(np.ceil(len(futures) / 10)) for i, future in enumerate(as_completed(futures)): v = futures[future] chunk_idx = v['chunk'] data[chunk_idx] = data.get(chunk_idx, {}) data[chunk_idx][v['feature']] = future.result() - if i % interval == 0: - mem = psutil.virtual_memory() - logger.info(f'{i+1} out of {len(futures)} feature ' - 'chunks computed. Current memory usage is ' - f'{mem.used / 1e9:.3f} GB out of ' - f'{mem.total / 1e9:.3f} GB total.') + mem = psutil.virtual_memory() + logger.info(f'{i+1} out of {len(futures)} feature ' + 'chunks computed. Current memory usage is ' + f'{mem.used / 1e9:.3f} GB out of ' + f'{mem.total / 1e9:.3f} GB total.') return data @@ -1768,7 +1753,6 @@ def _exact_lookup(cls, feature): out : str Matching feature registry entry. """ - out = None for k, v in cls.FEATURE_REGISTRY.items(): if k.lower() == feature.lower(): @@ -1791,7 +1775,6 @@ def _pattern_lookup(cls, feature): out : str Matching feature registry entry. """ - out = None for k, v in cls.FEATURE_REGISTRY.items(): if re.match(k.lower(), feature.lower()): diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index c6e6dc2ca..fee3b8827 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -8,9 +8,11 @@ import logging import os from calendar import monthrange -from concurrent.futures import (ProcessPoolExecutor, ThreadPoolExecutor, - as_completed, - ) +from concurrent.futures import ( + ProcessPoolExecutor, + ThreadPoolExecutor, + as_completed, +) from glob import glob from typing import ClassVar from warnings import warn @@ -35,7 +37,8 @@ class EraDownloader: """Class to handle ERA5 downloading, variable renaming, file combination, - and interpolation.""" + and interpolation. + """ msg = ('To download ERA5 data you need to have a ~/.cdsapirc file ' 'with a valid url and api key. Follow the instructions here: ' @@ -43,36 +46,24 @@ class EraDownloader: req_file = os.path.join(os.path.expanduser('~'), '.cdsapirc') assert os.path.exists(req_file), msg - VALID_VARIABLES: ClassVar[list] = [ - 'u', 'v', 'pressure', 'temperature', 'relative_humidity', - 'specific_humidity', 'total_precipitation', - ] - - KEEP_VARIABLES: ClassVar[list] = ['orog'] - KEEP_VARIABLES += [f'{v}_' for v in VALID_VARIABLES] - - DEFAULT_RENAMED_VARS: ClassVar[list] = [ - 'zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m', - 'temperature', 'pressure', - ] - DEFAULT_DOWNLOAD_VARS: ClassVar[list] = [ - '10m_u_component_of_wind', '10m_v_component_of_wind', - '100m_u_component_of_wind', '100m_v_component_of_wind', - 'u_component_of_wind', 'v_component_of_wind', '2m_temperature', - 'temperature', 'surface_pressure', 'relative_humidity', - 'total_precipitation', - ] - + # variables available on a single level (e.g. surface) SFC_VARS: ClassVar[list] = [ '10m_u_component_of_wind', '10m_v_component_of_wind', '100m_u_component_of_wind', '100m_v_component_of_wind', 'surface_pressure', '2m_temperature', 'geopotential', - 'total_precipitation', + 'total_precipitation', "convective_available_potential_energy", + "2m_dewpoint_temperature", "convective_inhibition", + "surface_latent_heat_flux", "instantaneous_moisture_flux", + "mean_total_precipitation_rate" ] + + # variables available on multiple pressure levels LEVEL_VARS: ClassVar[list] = [ 'u_component_of_wind', 'v_component_of_wind', 'geopotential', - 'temperature', 'relative_humidity', 'specific_humidity', + 'temperature', 'relative_humidity', 'specific_humidity', 'divergence', + 'vertical_velocity', 'pressure', 'potential_vorticity' ] + NAME_MAP: ClassVar[dict] = { 'u10': 'u_10m', 'v10': 'v_10m', @@ -80,12 +71,23 @@ class EraDownloader: 'v100': 'v_100m', 't': 'temperature', 't2m': 'temperature_2m', - 'u': 'u', - 'v': 'v', 'sp': 'pressure_0m', 'r': 'relative_humidity', 'q': 'specific_humidity', - 'tp': 'total_precipitation', + 'd': 'divergence', + } + + SHORT_NAME_MAP: ClassVar[dict] = { + 'convective_inhibition': 'cin', + '2m_dewpoint_temperature': 'd2m', + 'potential_vorticity': 'pv', + 'vertical_velocity': 'w', + 'surface_latent_heat_flux': 'slhf', + 'instantaneous_moisture_flux': 'ie', + 'divergence': 'd', + 'total_precipitation': 'tp', + 'convective_available_potential_energy': 'cape', + 'mean_total_precipitation_rate': 'mtpr' } def __init__(self, @@ -97,8 +99,8 @@ def __init__(self, interp_out_pattern=None, run_interp=True, overwrite=False, - required_shape=None, - variables=None): + variables=None, + check_files=False): """Initialize the class. Parameters @@ -122,13 +124,11 @@ def __init__(self, Whether to run interpolation after downloading and combining files. overwrite : bool Whether to overwrite existing files. - required_shape : tuple | None - Required shape of data to download. Used to check downloaded data. - Should be (n_levels, n_lats, n_lons). If None, no check is - performed. variables : list | None Variables to download. If None this defaults to just gepotential and wind components. + check_files : bool + Check existing files. Remove and redownload if checks fail. """ self.year = year self.month = month @@ -138,15 +138,14 @@ def __init__(self, self.overwrite = overwrite self.combined_out_pattern = combined_out_pattern self.interp_out_pattern = interp_out_pattern + self.check_files = check_files + self.required_shape = None self._interp_file = None self._combined_file = None self._variables = variables self.hours = [str(n).zfill(2) + ":00" for n in range(0, 24)] self.sfc_file_variables = ['geopotential'] self.level_file_variables = ['geopotential'] - - self.shape_check(required_shape, levels) - self.check_good_vars(self.variables) self.prep_var_lists(self.variables) msg = ('Initialized EraDownloader with: ' @@ -158,7 +157,7 @@ def __init__(self, def variables(self): """Get list of requested variables""" if self._variables is None: - self._variables = self.VALID_VARIABLES + raise OSError('Received empty variable list.') return self._variables @property @@ -173,11 +172,11 @@ def days(self): @property def interp_file(self): """Get name of file with interpolated variables""" - if self._interp_file is None: - if self.interp_out_pattern is not None and self.run_interp: - self._interp_file = self.interp_out_pattern.format( - year=self.year, month=str(self.month).zfill(2)) - os.makedirs(os.path.dirname(self._interp_file), exist_ok=True) + if (self._interp_file is None and self.interp_out_pattern is not None + and self.run_interp): + self._interp_file = self.interp_out_pattern.format( + year=self.year, month=str(self.month).zfill(2)) + os.makedirs(os.path.dirname(self._interp_file), exist_ok=True) return self._interp_file @property @@ -233,39 +232,15 @@ def init_dims(cls, old_ds, new_ds, dims): @classmethod def get_tmp_file(cls, file): """Get temp file for given file. Then only needed variables will be - written to the given file.""" + written to the given file. + """ tmp_file = file.replace(".nc", "_tmp.nc") return tmp_file - def shape_check(self, required_shape, levels): - """Check given required shape""" - if required_shape is None or len(required_shape) == 3: - self.required_shape = required_shape - elif len(required_shape) == 2 and len(levels) != required_shape[0]: - self.required_shape = (len(levels), *required_shape) - else: - msg = f'Received weird required_shape: {required_shape}.' - logger.error(msg) - raise OSError(msg) - - def check_good_vars(self, variables): - """Make sure requested variables are valid. - - Parameters - ---------- - variables : list - List of variables to download. Can be any of VALID_VARIABLES - """ - good = all(var in self.VALID_VARIABLES for var in variables) - if not good: - msg = (f'Received variables {variables} not in valid variables ' - f'list {self.VALID_VARIABLES}') - logger.error(msg) - raise OSError(msg) - def _prep_var_lists(self, variables): """Add all downloadable variables for the generic requested variables. - e.g. if variable = 'u' add all downloadable u variables to list.""" + e.g. if variable = 'u' add all downloadable u variables to list. + """ d_vars = [] vars = variables.copy() for i, v in enumerate(vars): @@ -279,7 +254,8 @@ def _prep_var_lists(self, variables): def prep_var_lists(self, variables): """Create surface and level variable lists based on requested - variables.""" + variables. + """ variables = self._prep_var_lists(variables) for var in variables: if var in self.SFC_VARS and var not in self.sfc_file_variables: @@ -287,7 +263,7 @@ def prep_var_lists(self, variables): elif (var in self.LEVEL_VARS and var not in self.level_file_variables): self.level_file_variables.append(var) - else: + elif var not in self.SFC_VARS and var not in self.LEVEL_VARS: msg = f'Requested {var} is not available for download.' logger.warning(msg) warn(msg) @@ -296,64 +272,73 @@ def download_process_combine(self): """Run the download routine.""" sfc_check = len(self.sfc_file_variables) > 0 level_check = (len(self.level_file_variables) > 0 - and self.levels is not None) + and self.levels is not None + and len(self.levels) > 0) if self.level_file_variables: msg = (f'{self.level_file_variables} requested but no levels' ' were provided.') if self.levels is None: logger.warning(msg) warn(msg) + + time_dict = {'year': self.year, 'month': self.month, 'day': self.days, + 'time': self.hours} if sfc_check: - self.download_surface_file() + self.download_file(self.sfc_file_variables, time_dict=time_dict, + area=self.area, out_file=self.surface_file, + level_type='single', overwrite=self.overwrite) if level_check: - self.download_levels_file() + self.download_file(self.level_file_variables, time_dict=time_dict, + area=self.area, out_file=self.level_file, + level_type='pressure', levels=self.levels, + overwrite=self.overwrite) if sfc_check or level_check: self.process_and_combine() - def download_levels_file(self): - """Download file with requested pressure levels""" - if not os.path.exists(self.level_file) or self.overwrite: - msg = (f'Downloading {self.level_file_variables} to ' - f'{self.level_file}.') - logger.info(msg) - CDS_API_CLIENT.retrieve( - 'reanalysis-era5-pressure-levels', { - 'product_type': 'reanalysis', - 'format': 'netcdf', - 'variable': self.level_file_variables, - 'pressure_level': self.levels, - 'year': self.year, - 'month': self.month, - 'day': self.days, - 'time': self.hours, - 'area': self.area, - }, self.level_file) - else: - logger.info(f'File already exists: {self.level_file}.') + @classmethod + def download_file(cls, variables, time_dict, area, out_file, level_type, + levels=None, overwrite=False): + """Download either single-level or pressure-level file - def download_surface_file(self): - """Download surface file""" - if not os.path.exists(self.surface_file) or self.overwrite: - msg = (f'Downloading {self.sfc_file_variables} to ' - f'{self.surface_file}.') + Parameters + ---------- + variables : list + List of variables to download + time_dict : dict + Dictionary with year, month, day, time entries. + area : list + List of bounding box coordinates. + e.g. [max_lat, min_lon, min_lat, max_lon] + out_file : str + Name of output file + level_type : str + Either 'single' or 'pressure' + levels : list + List of pressure levels to download, if level_type == 'pressure' + overwrite : bool + Whether to overwrite existing file + """ + if not os.path.exists(out_file) or overwrite: + msg = (f'Downloading {variables} to ' + f'{out_file} with levels = {levels}.') logger.info(msg) + entry = { + 'product_type': 'reanalysis', + 'format': 'netcdf', + 'variable': variables, + 'area': area} + entry.update(time_dict) + if level_type == 'pressure': + entry['pressure_level'] = levels + logger.info(f'Calling CDS-API with {entry}.') CDS_API_CLIENT.retrieve( - 'reanalysis-era5-single-levels', { - 'product_type': 'reanalysis', - 'format': 'netcdf', - 'variable': self.sfc_file_variables, - 'year': self.year, - 'month': self.month, - 'day': self.days, - 'time': self.hours, - 'area': self.area, - }, self.surface_file) + f'reanalysis-era5-{level_type}-levels', + entry, out_file) else: - logger.info(f'File already exists: {self.surface_file}.') + logger.info(f'File already exists: {out_file}.') def process_surface_file(self): """Rename variables and convert geopotential to geopotential height.""" - dims = ('time', 'latitude', 'longitude') tmp_file = self.get_tmp_file(self.surface_file) with Dataset(self.surface_file, "r") as old_ds: @@ -382,8 +367,9 @@ def map_vars(self, old_ds, ds): ds : Dataset Dataset() object for new file with new variables written. """ - for old_name, new_name in self.NAME_MAP.items(): - if old_name in old_ds.variables: + for old_name in old_ds.variables: + new_name = self.NAME_MAP.get(old_name, old_name) + if new_name not in ds.variables: _ = ds.createVariable(new_name, np.float32, dimensions=old_ds[old_name].dimensions, @@ -414,7 +400,6 @@ def convert_z(self, standard_name, long_name, old_ds, ds): ds : Dataset Dataset() object for new file with new height variable written. """ - _ = ds.createVariable(standard_name, np.float32, dimensions=old_ds['z'].dimensions) @@ -426,7 +411,6 @@ def convert_z(self, standard_name, long_name, old_ds, ds): def process_level_file(self): """Convert geopotential to geopotential height.""" - dims = ('time', 'level', 'latitude', 'longitude') tmp_file = self.get_tmp_file(self.level_file) with Dataset(self.level_file, "r") as old_ds: @@ -437,10 +421,12 @@ def process_level_file(self): ds = self.map_vars(old_ds, ds) - if 'pressure' in self.variables: + if ('pressure' in self.variables + and 'pressure' not in ds.variables): tmp = np.zeros(ds.variables['zg'].shape) for i in range(tmp.shape[1]): tmp[:, i, :, :] = ds.variables['level'][i] * 100 + _ = ds.createVariable('pressure', np.float32, dimensions=dims) @@ -454,7 +440,6 @@ def process_level_file(self): def process_and_combine(self): """Process variables and combine.""" - if not os.path.exists(self.combined_file) or self.overwrite: files = [] if os.path.exists(self.level_file): @@ -466,9 +451,8 @@ def process_and_combine(self): self.process_surface_file() files.append(self.surface_file) - logger.info(f'Combining {files} and {self.surface_file} ' - f'to {self.combined_file}.') - with xr.open_mfdataset(files) as ds: + logger.info(f'Combining {files} to {self.combined_file}.') + with xr.open_mfdataset(files, compat='override') as ds: ds.to_netcdf(self.combined_file) logger.info(f'Finished writing {self.combined_file}') @@ -477,15 +461,17 @@ def process_and_combine(self): if os.path.exists(self.surface_file): os.remove(self.surface_file) - def good_file(self, file, required_shape): + def good_file(self, file, required_shape=None): """Check if file has the required shape and variables. Parameters ---------- file : str Name of file to check for required variables and shape - required_shape : tuple - Required shape for data. Should be (n_levels, n_lats, n_lons). + required_shape : tuple | None + Required shape of data to download. Used to check downloaded data. + Should be (n_levels, n_lats, n_lons). If None, no check is + performed. Returns ------- @@ -497,17 +483,44 @@ def good_file(self, file, required_shape): check_nans=False, check_heights=False, required_shape=required_shape) - good_vars, good_shape, _, _ = out - check = good_vars and good_shape - return check + good_vars, good_shape, good_hgts, _ = out + return bool(good_vars and good_shape and good_hgts) + + def shape_check(self, required_shape, levels): + """Check given required shape""" + if required_shape is None or len(required_shape) == 3: + self.required_shape = required_shape + elif len(required_shape) == 2 and len(levels) != required_shape[0]: + self.required_shape = (len(levels), *required_shape) + else: + msg = f'Received weird required_shape: {required_shape}.' + logger.error(msg) + raise OSError(msg) + + def check_good_vars(self, variables): + """Make sure requested variables are valid. - def check_existing_files(self): + Parameters + ---------- + variables : list + List of variables to download. Can be any of VALID_VARIABLES + """ + valid_vars = list(self.LEVEL_VARS) + list(self.SFC_VARS) + good = all(var in valid_vars for var in variables) + if not good: + msg = (f'Received variables {variables} not in valid variables ' + f'list {valid_vars}') + logger.error(msg) + raise OSError(msg) + + def check_existing_files(self, required_shape=None): """If files exist already check them for good shape and required variables. Remove them if there was a problem so we can continue with - routine from scratch.""" + routine from scratch. + """ if os.path.exists(self.combined_file): try: - check = self.good_file(self.combined_file, self.required_shape) + check = self.good_file(self.combined_file, required_shape) if not check: msg = f'Bad file: {self.combined_file}' logger.error(msg) @@ -530,23 +543,35 @@ def check_existing_files(self): def run_interpolation(self, max_workers=None, **kwargs): """Run interpolation to get final final. Runs log interpolation up to - max_log_height (usually 100m) and linear interpolation above this.""" + max_log_height (usually 100m) and linear interpolation above this. + """ + variables = [var for var in self.variables if var in self.LEVEL_VARS] + for var in self.variables: + if var in self.NAME_MAP: + variables.append(self.NAME_MAP[var]) + elif (var in self.SHORT_NAME_MAP + and var not in self.NAME_MAP.values()): + variables.append(self.SHORT_NAME_MAP[var]) + else: + variables.append(var) LogLinInterpolator.run(infile=self.combined_file, outfile=self.interp_file, max_workers=max_workers, - variables=self.variables, + variables=variables, overwrite=self.overwrite, **kwargs) - def get_monthly_file(self, interp_workers=None, **interp_kwargs): + def get_monthly_file(self, interp_workers=None, keep_variables=None, + **interp_kwargs): """Download level and surface files, process variables, and combine processed files. Includes checks for shape and variables and option to - interpolate.""" - + interpolate. + """ if os.path.exists(self.combined_file) and self.overwrite: os.remove(self.combined_file) - self.check_existing_files() + if self.check_files: + self.check_existing_files() if not os.path.exists(self.combined_file): self.download_process_combine() @@ -555,10 +580,10 @@ def get_monthly_file(self, interp_workers=None, **interp_kwargs): self.run_interpolation(max_workers=interp_workers, **interp_kwargs) if self.interp_file is not None and os.path.exists(self.interp_file): - if self.already_pruned(self.interp_file): + if self.already_pruned(self.interp_file, keep_variables): logger.info(f'{self.interp_file} pruned already.') else: - self.prune_output(self.interp_file) + self.prune_output(self.interp_file, keep_variables) @classmethod def all_months_exist(cls, year, file_pattern): @@ -583,20 +608,32 @@ def all_months_exist(cls, year, file_pattern): for month in range(1, 13)) @classmethod - def already_pruned(cls, infile): + def already_pruned(cls, infile, keep_variables): """Check if file has been pruned already.""" + if keep_variables is None: + logger.info('Received keep_variables=None. Skipping pruning.') + return + else: + logger.info(f'Received keep_variables={keep_variables}.') pruned = True with Dataset(infile, 'r') as ds: - for var in ds.variables: - if not any(name in var for name in cls.KEEP_VARIABLES): + variables = [var for var in ds.variables + if var not in ('time', 'latitude', 'longitude')] + for var in variables: + if not any(name in var for name in keep_variables): logger.info(f'Pruning {var} in {infile}.') pruned = False return pruned @classmethod - def prune_output(cls, infile): + def prune_output(cls, infile, keep_variables=None): """Prune output file to keep just single level variables""" + if keep_variables is None: + logger.info('Received keep_variables=None. Skipping pruning.') + return + else: + logger.info(f'Received keep_variables={keep_variables}.') logger.info(f'Pruning {infile}.') tmp_file = cls.get_tmp_file(infile) @@ -605,7 +642,7 @@ def prune_output(cls, infile): new_ds = cls.init_dims(old_ds, new_ds, ('time', 'latitude', 'longitude')) for var in old_ds.variables: - if any(name in var for name in cls.KEEP_VARIABLES): + if any(name in var for name in keep_variables): old_var = old_ds[var] vals = old_var[:] _ = new_ds.createVariable( @@ -632,9 +669,10 @@ def run_month(cls, interp_out_pattern=None, run_interp=True, overwrite=False, - required_shape=None, interp_workers=None, variables=None, + keep_variables=None, + check_files=False, **interp_kwargs): """Run routine for all months in the requested year. @@ -659,15 +697,16 @@ def run_month(cls, Whether to run interpolation after downloading and combining files. overwrite : bool Whether to overwrite existing files. - required_shape : tuple | None - Required shape of data to download. Used to check downloaded data. - Should be (n_levels, n_lats, n_lons). If None, no check is - performed. interp_workers : int | None Max number of workers to use for interpolation. variables : list | None Variables to download. If None this defaults to just gepotential and wind components. + keep_variables : list | None + Variables to keep in final files. All other variables will be + pruned. + check_files : bool + Check existing files. Remove and redownload if checks fail. **interp_kwargs : dict Keyword args for LogLinInterpolator.run() """ @@ -679,9 +718,10 @@ def run_month(cls, interp_out_pattern=interp_out_pattern, run_interp=run_interp, overwrite=overwrite, - required_shape=required_shape, - variables=variables) + variables=variables, + check_files=check_files) downloader.get_monthly_file(interp_workers=interp_workers, + keep_variables=keep_variables, **interp_kwargs) @classmethod @@ -695,10 +735,11 @@ def run_year(cls, interp_yearly_file=None, run_interp=True, overwrite=False, - required_shape=None, max_workers=None, interp_workers=None, variables=None, + keep_variables=None, + check_files=False, **interp_kwargs): """Run routine for all months in the requested year. @@ -725,10 +766,6 @@ def run_year(cls, Whether to run interpolation after downloading and combining files. overwrite : bool Whether to overwrite existing files. - required_shape : tuple | None - Required shape of data to download. Used to check downloaded data. - Should be (n_levels, n_lats, n_lons). If None, no check is - performed. max_workers : int Max number of workers to use for downloading and processing monthly files. @@ -737,6 +774,11 @@ def run_year(cls, variables : list | None Variables to download. If None this defaults to just gepotential and wind components. + keep_variables : list | None + Variables to keep in final files. All other variables will be + pruned. + check_files : bool + Check existing files. Remove and redownload if checks fail. **interp_kwargs : dict Keyword args for LogLinInterpolator.run() """ @@ -750,9 +792,10 @@ def run_year(cls, interp_out_pattern=interp_out_pattern, run_interp=run_interp, overwrite=overwrite, - required_shape=required_shape, interp_workers=interp_workers, variables=variables, + keep_variables=keep_variables, + check_files=check_files, **interp_kwargs) else: futures = {} @@ -768,9 +811,10 @@ def run_year(cls, interp_out_pattern=interp_out_pattern, run_interp=run_interp, overwrite=overwrite, - required_shape=required_shape, interp_workers=interp_workers, + keep_variables=keep_variables, variables=variables, + check_files=check_files, **interp_kwargs) futures[future] = {'year': year, 'month': month} logger.info(f'Submitted future for year {year} and month ' @@ -1040,7 +1084,6 @@ def check_single_file(cls, good_shape = None good_vars = None good_hgts = None - var_list = (var_list if var_list is not None else cls.VALID_VARIABLES) try: res = xr.open_dataset(file) except Exception as e: @@ -1064,7 +1107,6 @@ def check_single_file(cls, def run_files_checks(cls, file_pattern, var_list=None, - required_shape=None, check_nans=True, check_heights=True, max_interp_height=200, @@ -1080,9 +1122,6 @@ def run_files_checks(cls, var_list : list | None List of variables to check. If None: ['zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m'] - required_shape : None | tuple - Required shape for data. Should include (n_levels, n_lats, n_lons). - If None the shape check will be skipped. check_nans : bool Whether to check data for NaNs. check_heights : bool @@ -1119,8 +1158,7 @@ def run_files_checks(cls, check_nans=check_nans, check_heights=check_heights, max_interp_height=max_interp_height, - max_workers=height_check_workers, - required_shape=required_shape) + max_workers=height_check_workers) df.loc[i, df.columns[1:]] = out logger.info(f'Finished checking {file}.') else: @@ -1133,8 +1171,7 @@ def run_files_checks(cls, check_nans=check_nans, check_heights=check_heights, max_interp_height=max_interp_height, - max_workers=height_check_workers, - required_shape=required_shape) + max_workers=height_check_workers) msg = (f'Submitted file check future for {file}. Future ' f'{i + 1} of {len(files)}.') logger.info(msg) diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index 7a7f51910..52bfb5f12 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -29,14 +29,18 @@ class LogLinInterpolator: """Open ERA5 file, log interpolate wind components between 0 - max_log_height, linearly interpolate components above max_log_height - meters, and save to file""" + meters, and save to file + """ DEFAULT_OUTPUT_HEIGHTS: ClassVar[dict] = { - 'u': [40, 80, 120, 160, 200], - 'v': [40, 80, 120, 160, 200], + 'u': [10, 40, 80, 100, 120, 160, 200], + 'v': [10, 40, 80, 100, 120, 160, 200], + 'w': [10, 40, 80, 100, 120, 160, 200], + 'pv': [10, 40, 80, 100, 120, 160, 200], 'temperature': [10, 40, 80, 100, 120, 160, 200], 'pressure': [0, 100, 200], 'relative_humidity': [80, 100, 120], + 'divergence': [80, 100, 120] } def __init__( @@ -144,7 +148,8 @@ def _load_single_var(self, variable): def load(self): """Load ERA5 data and create data arrays""" self.data_dict = {} - for var in self.variables: + vars = [var for var in self.variables if var in self.new_heights] + for var in vars: self.data_dict[var] = {} out = self._load_single_var(var) self.data_dict[var]['heights'] = out[0] @@ -153,7 +158,8 @@ def load(self): def interpolate_vars(self, max_workers=None): """Interpolate u/v wind components below 100m using log profile. - Interpolate non wind data linearly.""" + Interpolate non wind data linearly. + """ for var, arrs in self.data_dict.items(): max_log_height = self.max_log_height if var not in ('u', 'v'): @@ -232,7 +238,8 @@ def init_dims(cls, old_ds, new_ds, dims): @classmethod def get_tmp_file(cls, file): """Get temp file for given file. Then only needed variables will be - written to the given file.""" + written to the given file. + """ tmp_file = file.replace('.nc', '_tmp.nc') return tmp_file @@ -401,8 +408,8 @@ def ws_log_profile(z, a, b): good = True levels = np.array(levels) - lev_mask = (0 < levels) & (levels <= max_log_height) - var_mask = (0 < lev_array_samp) & (lev_array_samp <= max_log_height) + lev_mask = (levels > 0) & (levels <= max_log_height) + var_mask = (lev_array_samp > 0) & (lev_array_samp <= max_log_height) try: popt, _ = curve_fit(ws_log_profile, lev_array_samp[var_mask], @@ -476,8 +483,8 @@ def _interp_var_to_height(cls, max_log_height=max_log_height) if any(levels > max_log_height): - lev_mask = levels >= max_log_height - var_mask = lev_array >= max_log_height + lev_mask = levels > max_log_height + var_mask = lev_array > max_log_height if len(lev_array[var_mask]) > 1: lin_ws = interp1d(lev_array[var_mask], var_array[var_mask], @@ -581,7 +588,6 @@ def interp_single_ts(cls, out_array : ndarray Array of interpolated values. """ - # Interp each vertical column of height and var to requested levels zip_iter = zip(hgt_t, var_t, mask) out_array = [] diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index 64a0680ed..71d36ede1 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -15,6 +15,8 @@ def gaussian_kernel(x1, x2, sigma=1.0): x2 : tf.tensor high resolution data (n_obs, spatial_1, spatial_2, temporal, features) + sigma : float + Standard deviation for gaussian kernel Returns ------- diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index b1b898893..3e015b3dc 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -344,8 +344,12 @@ def __call__(self, data): Flattened regridded spatiotemporal data (spatial, temporal) """ + if len(data.shape) == 3: + data = data.reshape((data.shape[0] * data.shape[1], -1)) + msg = 'Input data must be 2D (spatial, temporal)' + assert len(data.shape) == 2, msg vals = [ - data[:, :, i].flatten()[self.indices][np.newaxis] + data[np.array(self.indices), i][np.newaxis] for i in range(data.shape[-1]) ] vals = np.concatenate(vals, axis=0) diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index d6875aa4a..17a85e4fa 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -7,7 +7,9 @@ import glob import logging import os +import random import re +import string from fnmatch import fnmatch from warnings import warn @@ -26,6 +28,14 @@ logger = logging.getLogger(__name__) +def generate_random_string(length): + """Generate random string with given length. Used for naming temporary + files to avoid collisions.""" + letters = string.ascii_letters + random_string = ''.join(random.choice(letters) for i in range(length)) + return random_string + + def windspeed_log_law(z, a, b, c): """Windspeed log profile. diff --git a/tests/data_handling/test_data_handling_nc.py b/tests/data_handling/test_data_handling_nc.py index 59ada2965..110c3016e 100644 --- a/tests/data_handling/test_data_handling_nc.py +++ b/tests/data_handling/test_data_handling_nc.py @@ -2,18 +2,21 @@ """pytests for data handling""" import os +import tempfile + +import matplotlib.pyplot as plt import numpy as np import pytest -import matplotlib.pyplot as plt -import tempfile import xarray as xr from sup3r import TEST_DATA_DIR +from sup3r.preprocessing.batch_handling import ( + BatchHandler, + SpatialBatchHandler, +) from sup3r.preprocessing.data_handling import DataHandlerNC as DataHandler -from sup3r.preprocessing.batch_handling import (BatchHandler, - SpatialBatchHandler) from sup3r.utilities.interpolation import Interpolator -from sup3r.utilities.pytest import make_fake_nc_files, make_fake_era_files +from sup3r.utilities.pytest import make_fake_era_files, make_fake_nc_files INPUT_FILE = os.path.join(TEST_DATA_DIR, 'test_wrf_2014-10-01_00_00_00') features = ['U_100m', 'V_100m', 'BVF_MO_200m'] @@ -166,7 +169,8 @@ def test_data_caching(): if os.path.exists(cache_pattern): os.system(f'rm {cache_pattern}') handler = DataHandler(INPUT_FILE, features, - cache_pattern=cache_pattern, **dh_kwargs) + cache_pattern=cache_pattern, **dh_kwargs, + val_split=0.1) assert handler.data is None handler.load_cached_data() assert handler.data.shape == (shape[0], shape[1], diff --git a/tests/data_handling/test_dual_data_handling.py b/tests/data_handling/test_dual_data_handling.py index 248954673..1a105c3c1 100644 --- a/tests/data_handling/test_dual_data_handling.py +++ b/tests/data_handling/test_dual_data_handling.py @@ -102,7 +102,7 @@ def test_regrid_caching(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', + cache_pattern=f'{td}/cache.pkl', ) # Load handlers again @@ -126,7 +126,7 @@ def test_regrid_caching(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', + cache_pattern=f'{td}/cache.pkl', ) assert np.array_equal(old_dh.lr_data, new_dh.lr_data) assert np.array_equal(old_dh.hr_data, new_dh.hr_data) @@ -161,7 +161,7 @@ def test_regrid_caching_in_steps(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', + cache_pattern=f'{td}/cache.pkl', ) # Load handlers again with one cached feature and one noncached feature @@ -185,7 +185,7 @@ def test_regrid_caching_in_steps(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl') + cache_pattern=f'{td}/cache.pkl') assert np.array_equal(dh_step2.lr_data[..., 0:1], dh_step1.lr_data) assert np.array_equal(dh_step2.noncached_features, FEATURES[1:]) diff --git a/tests/output/test_output_handling.py b/tests/output/test_output_handling.py index 6f85c9e99..a4cc9629f 100644 --- a/tests/output/test_output_handling.py +++ b/tests/output/test_output_handling.py @@ -9,7 +9,7 @@ from rex import ResourceX, init_logger from sup3r import __version__ -from sup3r.postprocessing.collection import Collector +from sup3r.postprocessing.collection import CollectorH5 from sup3r.postprocessing.file_handling import OutputHandlerH5, OutputHandlerNC from sup3r.utilities.pytest import make_fake_h5_chunks from sup3r.utilities.utilities import invert_uv, transform_rotate_wind @@ -140,7 +140,7 @@ def test_h5_out_and_collect(): low_res_times, ) = out - Collector.collect(out_files, fp_out, features=features) + CollectorH5.collect(out_files, fp_out, features=features) with ResourceX(fp_out) as fh: full_ti = fh.time_index combined_ti = [] @@ -207,7 +207,7 @@ def test_h5_collect_mask(log=False): out = make_fake_h5_chunks(td) (out_files, data, _, _, features, _, _, _, _, _, _) = out - Collector.collect(out_files, fp_out, features=features) + CollectorH5.collect(out_files, fp_out, features=features) indices = np.arange(np.product(data.shape[:2])) indices = indices[slice(-len(indices) // 2, None)] removed = [] @@ -220,7 +220,7 @@ def test_h5_collect_mask(log=False): mask_meta['gid'][:] = np.arange(len(mask_meta)) mask_meta.to_csv(mask_file, index=False) - Collector.collect( + CollectorH5.collect( out_files, fp_out_mask, features=features, diff --git a/tests/utilities/test_utilities.py b/tests/utilities/test_utilities.py index 93ae8be82..63f4a7c37 100644 --- a/tests/utilities/test_utilities.py +++ b/tests/utilities/test_utilities.py @@ -11,7 +11,7 @@ from scipy.interpolate import interp1d from sup3r import TEST_DATA_DIR -from sup3r.postprocessing.collection import Collector +from sup3r.postprocessing.collection import CollectorH5 from sup3r.postprocessing.file_handling import OutputHandler from sup3r.utilities.interpolate_log_profile import LogLinInterpolator from sup3r.utilities.regridder import RegridOutput @@ -100,13 +100,13 @@ def test_regridding(log=False): for node_index in range(regrid_output.nodes): regrid_output.run(node_index=node_index) - Collector.collect(regrid_output.out_files, - collect_file, - regrid_output.output_features, - target_final_meta_file=meta_path, - join_times=False, - n_writes=2, - max_workers=1) + CollectorH5.collect(regrid_output.out_files, + collect_file, + regrid_output.output_features, + target_final_meta_file=meta_path, + join_times=False, + n_writes=2, + max_workers=1) with Resource(collect_file) as out_res: for height in heights: ws_name = f'windspeed_{height}m'