From b15d994446887b835edcf84bd1ff88576fd02c22 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 24 Oct 2023 07:55:59 -0600 Subject: [PATCH 01/44] doc string fix --- sup3r/utilities/era_downloader.py | 10 ++++++---- sup3r/utilities/loss_metrics.py | 2 ++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index c6e6dc2ca..a7d7b5216 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -8,9 +8,11 @@ import logging import os from calendar import monthrange -from concurrent.futures import (ProcessPoolExecutor, ThreadPoolExecutor, - as_completed, - ) +from concurrent.futures import ( + ProcessPoolExecutor, + ThreadPoolExecutor, + as_completed, +) from glob import glob from typing import ClassVar from warnings import warn @@ -52,7 +54,7 @@ class EraDownloader: KEEP_VARIABLES += [f'{v}_' for v in VALID_VARIABLES] DEFAULT_RENAMED_VARS: ClassVar[list] = [ - 'zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m', + 'z', 'zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m', 'temperature', 'pressure', ] DEFAULT_DOWNLOAD_VARS: ClassVar[list] = [ diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index 64a0680ed..71d36ede1 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -15,6 +15,8 @@ def gaussian_kernel(x1, x2, sigma=1.0): x2 : tf.tensor high resolution data (n_obs, spatial_1, spatial_2, temporal, features) + sigma : float + Standard deviation for gaussian kernel Returns ------- From 8cec3dc07ee0cc73f211eebbc09c52ae9c1388d7 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 30 Oct 2023 11:56:53 -0600 Subject: [PATCH 02/44] nc topoextract fix --- sup3r/preprocessing/data_handling/exo_extraction.py | 3 ++- sup3r/utilities/regridder.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 6d8c6bdfe..1290703f2 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -450,7 +450,8 @@ def __init__(self, *args, **kwargs): @property def source_data(self): """Get the 1D array of elevation data from the exo_source_nc""" - elev = self.source_handler.data.reshape((-1, self.lr_shape[-1])) + elev = self.source_handler.data[..., 0, 0].flatten() + elev = np.repeat(elev[..., np.newaxis], self.hr_shape[-1], axis=-1) return elev @property diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index f73cfc74d..f1c215319 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -345,7 +345,7 @@ def __call__(self, data): (spatial, temporal) """ vals = [ - data[:, :, i].flatten()[self.indices][np.newaxis] + data[:, :, i].flatten()[np.array(self.indices)][np.newaxis] for i in range(data.shape[-1]) ] vals = np.concatenate(vals, axis=0) From 779f67941bf7ce6481cd5a95807f91d2ae0340bf Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 30 Oct 2023 13:36:40 -0600 Subject: [PATCH 03/44] ignore case for bc adder/scalar lookup --- sup3r/bias/bias_transforms.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sup3r/bias/bias_transforms.py b/sup3r/bias/bias_transforms.py index 71d74f568..3995a0f3b 100644 --- a/sup3r/bias/bias_transforms.py +++ b/sup3r/bias/bias_transforms.py @@ -58,7 +58,11 @@ def get_spatial_bc_factors(lat_lon, feature_name, bias_fp, threshold=0.1): raise RuntimeError(msg) msg = (f'Either {dset_scalar} or {dset_adder} not found in {bias_fp}.') - assert dset_scalar in res.dsets and dset_adder in res.dsets, msg + dsets = [dset.lower() for dset in res.dsets] + check = dset_scalar.lower() in dsets and dset_adder.lower() in dsets + assert check, msg + dset_scalar = res.dsets[dsets.index(dset_scalar.lower())] + dset_adder = res.dsets[dsets.index(dset_adder.lower())] scalar = res[dset_scalar, slice_y, slice_x] adder = res[dset_adder, slice_y, slice_x] return scalar, adder From b746cb7b4889be65a36fccf5dbb272f8dbda234b Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 30 Oct 2023 20:55:47 -0600 Subject: [PATCH 04/44] some tweaks for training with cape - only found in era not wtk so need to make sure different sets of features are indexed carefully. cape added to era_downloader. added vortex mean based bias correction code. --- sup3r/bias/bias_correct_means.py | 707 ++++++++++++++++++ .../data_handling/dual_data_handling.py | 31 +- sup3r/preprocessing/data_handling/mixin.py | 6 +- sup3r/preprocessing/dual_batch_handling.py | 34 +- sup3r/utilities/era_downloader.py | 16 +- 5 files changed, 758 insertions(+), 36 deletions(-) create mode 100644 sup3r/bias/bias_correct_means.py diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py new file mode 100644 index 000000000..cbd98e1b5 --- /dev/null +++ b/sup3r/bias/bias_correct_means.py @@ -0,0 +1,707 @@ +"""Classes to compute means from vortex and era data and compute bias +correction factors.""" + + +import calendar +import json +import logging +import os + +import h5py +import numpy as np +import pandas as pd +import rioxarray +import xarray as xr +from rex import Resource +from sklearn.neighbors import BallTree +from sup3r.postprocessing.file_handling import OutputHandler, RexOutputs +from sup3r.preprocessing.feature_handling import Feature +from sup3r.utilities import VERSION_RECORD +from sup3r.utilities.interpolation import Interpolator +from scipy.interpolate import interp1d + +logger = logging.getLogger(__name__) + + +class VortexMeanPrepper: + """Class for converting monthly vortex tif files for each height to a + single h5 files containing all monthly means for all requested output + heights""" + + def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): + """ + Parameters + ---------- + path_pattern : str + Pattern for input tif files. Needs to include {month} and {height} + format keys. + in_heights : list + List of heights for input files. + out_heights : list + List of output heights used for interpolation + overwrite : bool + Whether to overwrite intermediate netcdf files containing the + interpolated masked monthly means. + """ + self.path_pattern = path_pattern + self.in_heights = in_heights + self.out_heights = out_heights + self.out_dir = os.path.dirname(path_pattern) + self.overwrite = overwrite + self._mask = None + + @property + def in_features(self): + """List of features corresponding to input heights""" + return [f"windspeed_{h}m" for h in self.in_heights] + + @property + def out_features(self): + """List of features corresponding to output heights""" + return [f"windspeed_{h}m" for h in self.out_heights] + + def get_input_file(self, month, height): + """Get vortex tif file for given month and height.""" + return self.path_pattern.format(month=month, height=height) + + def get_height_files(self, month): + """Get set of netcdf files for given month""" + files = [] + for height in self.in_heights: + infile = self.get_input_file(month, height) + outfile = infile.replace(".tif", ".nc") + files.append(outfile) + return files + + @property + def input_files(self): + """Get list of all input files used for h5 meta.""" + files = [] + for height in self.in_heights: + for i in range(1, 13): + month = calendar.month_name[i] + files.append(self.get_input_file(month, height)) + return files + + def get_output_file(self, month): + """Get name of netcdf file for a given month.""" + return os.path.join( + self.out_dir.replace("{month}", month), f"{month}.nc" + ) + + @property + def output_files(self): + """List of output monthly output files each with windspeed for all + input heights""" + files = [] + for i in range(1, 13): + month = calendar.month_name[i] + files.append(self.get_output_file(month)) + return files + + def convert_month_height_tif(self, month, height): + """Get windspeed mean for the given month and hub height from the + corresponding input file and write this to a netcdf file.""" + infile = self.get_input_file(month, height) + logger.info(f"Getting mean windspeed_{height}m for {month}.") + outfile = infile.replace(".tif", ".nc") + if os.path.exists(outfile) and self.overwrite: + os.remove(outfile) + + if not os.path.exists(outfile) or self.overwrite: + tmp = rioxarray.open_rasterio(infile) + ds = tmp.to_dataset("band") + ds = ds.rename( + {1: f"windspeed_{height}m", "x": "longitude", "y": "latitude"} + ) + ds.to_netcdf(outfile) + return outfile + + def convert_month_tif(self, month): + """Write netcdf files for all heights for the given month""" + for height in self.in_heights: + self.convert_month_height_tif(month, height) + + def convert_all_tifs(self): + """Write netcdf files for all heights for all months""" + for i in range(1, 13): + month = calendar.month_name[i] + logger.info(f"Converting tif files to netcdf files for {month}") + self.convert_month_tif(month) + + @property + def mask(self): + """Mask coordinates without data""" + if self._mask is None: + with xr.open_mfdataset(self.get_height_files('January')) as res: + mask = ((res[self.in_features[0]] != -999) + & (~np.isnan(res[self.in_features[0]]))) + for feat in self.in_features[1:]: + tmp = ((res[feat] != -999) & (~np.isnan(res[feat]))) + mask = mask & tmp + self._mask = np.array(mask).flatten() + return self._mask + + def get_month(self, month): + """Get interpolated means for all hub heights for the given month. + + Parameters + ---------- + month : str + Name of month to get data for + + Returns + ------- + data : xarray.Dataset + xarray dataset object containing interpolated monthly windspeed + means for all input and output heights + + """ + month_file = self.get_output_file(month) + if os.path.exists(month_file) and self.overwrite: + os.remove(month_file) + + if os.path.exists(month_file) and not self.overwrite: + logger.info(f"Loading month_file {month_file}.") + data = xr.open_dataset(month_file) + else: + logger.info( + "Getting mean windspeed for all heights " + f"({self.in_heights}) for {month}" + ) + data = xr.open_mfdataset(self.get_height_files(month)) + logger.info( + "Interpolating windspeed for all heights " + f"({self.out_heights}) for {month}." + ) + data = self.interp(data) + data.to_netcdf(month_file) + logger.info( + "Saved interpolated means for all heights for " + f"{month} to {month_file}." + ) + return data + + def interp(self, data): + """Interpolate data to requested output heights. + + Parameters + ---------- + data : xarray.Dataset + xarray dataset object containing windspeed for all input heights + + Returns + ------- + data : xarray.Dataset + xarray dataset object containing windspeed for all input and output + heights + """ + var_array = np.zeros( + ( + len(data.latitude) * len(data.longitude), + len(self.in_heights), + ), dtype=np.float32 + ) + lev_array = var_array.copy() + for i, (h, feat) in enumerate(zip(self.in_heights, self.in_features)): + var_array[..., i] = data[feat].values.flatten() + lev_array[..., i] = h + + logger.info(f'Interpolating {self.in_features} to {self.out_features} ' + f'for {var_array.shape[0]} coordinates.') + tmp = [interp1d(h, v, fill_value='extrapolate')(self.out_heights) + for h, v in zip(lev_array[self.mask], var_array[self.mask])] + out = np.full((len(data.latitude), len(data.longitude), + len(self.out_heights)), np.nan, dtype=np.float32) + out[self.mask.reshape((len(data.latitude), len(data.longitude)))] = tmp + for i, feat in enumerate(self.out_features): + if feat not in data: + data[feat] = (("latitude", "longitude"), out[..., i]) + return data + + def get_lat_lon(self): + """Get lat lon grid""" + with xr.open_mfdataset(self.get_height_files("January")) as res: + lons, lats = np.meshgrid( + res["longitude"].values, res["latitude"].values + ) + return np.array(lats), np.array(lons) + + def get_all_data(self): + """Get interpolated monthly means for all out heights as a dictionary + to use for h5 writing. + + Returns + ------- + out : dict + Dictionary of arrays containing monthly means for each hub height. + Also includes latitude and longitude. Spatial dimensions are + flattened + """ + data_dict = {} + lats, lons = self.get_lat_lon() + data_dict["latitude"] = lats.flatten()[self.mask] + data_dict["longitude"] = lons.flatten()[self.mask] + s_num = len(data_dict["longitude"]) + for i in range(1, 13): + month = calendar.month_name[i] + out = self.get_month(month) + for feat in self.out_features: + if feat not in data_dict: + data_dict[feat] = np.full((s_num, 12), np.nan) + data = out[feat].values.flatten()[self.mask] + data_dict[feat][..., i - 1] = data + return data_dict + + @property + def meta(self): + """Get a meta data dictionary on how this data is prepared""" + meta = { + "input_files": self.input_files, + "class": str(self.__class__), + "version_record": VERSION_RECORD, + } + return meta + + def write_data(self, fp_out, out): + """Write monthly means for all heights to h5 file""" + if fp_out is not None: + if not os.path.exists(os.path.dirname(fp_out)): + os.makedirs(os.path.dirname(fp_out), exist_ok=True) + + if not os.path.exists(fp_out) or self.overwrite: + with h5py.File(fp_out, "w") as f: + for dset, data in out.items(): + f.create_dataset(dset, data=data) + logger.info(f"Added {dset} to {fp_out}.") + + for k, v in self.meta.items(): + f.attrs[k] = json.dumps(v) + + logger.info( + f"Wrote monthly means for all out heights: {fp_out}" + ) + elif os.path.exists(fp_out): + logger.info(f"{fp_out} already exists and overwrite=False.") + + @classmethod + def run( + cls, path_pattern, in_heights, out_heights, fp_out, overwrite=False + ): + """Read vortex tif files, convert these to monthly netcdf files for all + input heights, interpolate this data to requested output heights, mask + fill values, and write all data to h5 file.""" + vprep = cls(path_pattern, in_heights, out_heights, overwrite=overwrite) + vprep.convert_all_tifs() + out = vprep.get_all_data() + vprep.write_data(fp_out, out) + + +class EraMeanPrepper: + """Class to compute monthly windspeed means from ERA data.""" + + def __init__(self, era_pattern, years, features): + """ + Parameters + ---------- + era_pattern : str + Pattern pointing to era files with u/v wind components at the given + heights. Must have a {year} format key. + years : list + List of ERA years to use for calculating means. + features : list + List of features to compute means for. e.g. ['windspeed_10m'] + """ + self.era_pattern = era_pattern + self.years = years + self.features = features + self.lats, self.lons = self.get_lat_lon() + + @property + def shape(self): + """Get shape of spatial dimensions (lats, lons)""" + return self.lats.shape + + @property + def heights(self): + """List of feature heights""" + heights = [Feature.get_height(feature) for feature in self.features] + return heights + + @property + def input_files(self): + """List of ERA input files to use for calculating means.""" + return [self.era_pattern.format(year=year) for year in self.years] + + def get_lat_lon(self): + """Get arrays of latitude and longitude for ERA domain""" + with xr.open_dataset(self.input_files[0]) as res: + lons, lats = np.meshgrid( + res["longitude"].values, res["latitude"].values + ) + return lats, lons + + def get_windspeed(self, data, height): + """Compute windspeed from u/v wind components from given data. + + Parameters + ---------- + data : xarray.Dataset + xarray dataset object for a year of ERA data. Must include u/v + components for the given height. e.g. u_{height}m, v_{height}m. + height : int + Height to compute windspeed for. + """ + return np.hypot( + data[f"u_{height}m"].values, data[f"v_{height}m"].values + ) + + def get_month_mean(self, data, height, month): + """Get windspeed_{height}m mean for the given month. + + Parameters + ---------- + data : xarray.Dataset + xarray dataset object for a year of ERA data. Must include u/v + components for the given height. e.g. u_{height}m, v_{height}m. + height : int + Height to compute windspeed for. + month : int + Index of month to get mean for. e.g. 1 = Jan, 2 = Feb, etc. + + Returns + ------- + out : np.ndarray + Array of time averaged windspeed data for the given month. + + """ + mask = pd.to_datetime(data["time"]).month == month + ws = self.get_windspeed(data, height)[mask] + return ws.mean(axis=0) + + def get_all_means(self, height): + """Get monthly means for all months across all given years for the + given height. + + Parameters + ---------- + height : int + Height to compute windspeed for. + + Returns + ------- + means : dict + Dictionary of windspeed_{height}m means for each month + """ + feature = self.features[self.heights.index(height)] + means = {i: [] for i in range(1, 13)} + for i, year in enumerate(self.years): + logger.info(f"Getting means for year={year}, feature={feature}.") + data = xr.open_dataset(self.input_files[i]) + for m in range(1, 13): + means[m].append(self.get_month_mean(data, height, month=m)) + means = {m: np.dstack(arr).mean(axis=-1) for m, arr in means.items()} + return means + + def write_csv(self, out, out_file): + """Write monthly means to a csv file. + + Parameters + ---------- + out : dict + Dictionary of windspeed_{height}m means for each month + out_file : str + Name of csv output file. + """ + logger.info(f"Writing means to {out_file}.") + out = { + f"{str(calendar.month_name[m])[:3]}_mean": v.flatten() + for m, v in out.items() + } + df = pd.DataFrame.from_dict(out) + df["latitude"] = self.lats.flatten() + df["longitude"] = self.lons.flatten() + df["gid"] = np.arange(len(df["latitude"])) + df.to_csv(out_file) + logger.info(f"Finished writing means for {out_file}.") + + @classmethod + def run(cls, era_pattern, years, features, out_pattern): + """Compute monthly windspeed means for the given heights, using the + given years of ERA data, and write the means to csv files for each + height.""" + em = cls(era_pattern=era_pattern, years=years, features=features) + for height, feature in zip(em.heights, em.features): + means = em.get_all_means(height) + out_file = out_pattern.format(feature=feature) + em.write_csv(means, out_file=out_file) + logger.info( + f"Finished writing means for years={years} and " + f"heights={em.heights}." + ) + + +class BiasCorrectionFromMeans: + """Class for getting bias correction factors from bias and base data files + with precomputed monthly means.""" + + MIN_DISTANCE = 1e-12 + + def __init__(self, bias_fp, base_fp, dset, leaf_size=4): + self.dset = dset + self.bias_fp = bias_fp + self.base_fp = base_fp + self.leaf_size = leaf_size + self.bias_means = pd.read_csv(bias_fp) + self.base_means = Resource(base_fp) + self.bias_meta = self.bias_means[["latitude", "longitude"]] + self.base_meta = pd.DataFrame(columns=["latitude", "longitude"]) + self.base_meta["latitude"] = self.base_means["latitude"] + self.base_meta["longitude"] = self.base_means["longitude"] + self._base_tree = None + logger.info( + "Finished initializing BiasCorrectionFromMeans for " + f"bias_fp={bias_fp}, base_fp={base_fp}, dset={dset}." + ) + + @property + def base_tree(self): + """Build ball tree from source_meta""" + if self._base_tree is None: + logger.info("Building ball tree for regridding.") + self._base_tree = BallTree( + np.deg2rad(self.base_meta), + leaf_size=self.leaf_size, + metric="haversine", + ) + return self._base_tree + + @property + def meta(self): + """Get a meta data dictionary on how these bias factors were + calculated""" + meta = { + "base_fp": self.base_fp, + "bias_fp": self.bias_fp, + "dset": self.dset, + "class": str(self.__class__), + "version_record": VERSION_RECORD, + "NOTES": ("scalar factors computed from base_data / bias_data."), + } + return meta + + @property + def height(self): + """Get feature height""" + return Feature.get_height(self.dset) + + @property + def u_name(self): + """Get corresponding u component for given height""" + return f"u_{self.height}m" + + @property + def v_name(self): + """Get corresponding v component for given height""" + return f"v_{self.height}m" + + def get_base_data(self, knn=1): + """Get means for baseline data.""" + logger.info(f"Getting base data for {self.dset}.") + dists, gids = self.base_tree.query(np.deg2rad(self.bias_meta), k=knn) + mask = dists < self.MIN_DISTANCE + if mask.sum() > 0: + logger.info( + f"{np.sum(mask)} of {np.product(mask.shape)} " + "distances are zero." + ) + dists[mask] = self.MIN_DISTANCE + weights = 1 / dists + norm = np.sum(weights, axis=-1) + out = self.base_means[self.dset, gids] + out = np.einsum("ijk,ij->ik", out, weights) / norm[:, np.newaxis] + return out + + def get_bias_data(self): + """Get means for biased data.""" + logger.info(f"Getting bias data for {self.dset}.") + cols = [col for col in self.bias_means.columns if "mean" in col] + bias_data = self.bias_means[cols].to_numpy() + return bias_data + + def get_corrections(self, global_scalar=1, knn=1): + """Get bias correction factors.""" + logger.info(f"Getting correction factors for {self.dset}.") + base_data = self.get_base_data(knn=knn) + bias_data = self.get_bias_data() + scaler = global_scalar * base_data / bias_data + adder = 0 + + out = { + "latitude": self.bias_meta["latitude"], + "longitude": self.bias_meta["longitude"], + f"base_{self.dset}_mean": base_data, + f"bias_{self.dset}_mean": bias_data, + f"{self.dset}_adder": adder, + f"{self.dset}_scalar": scaler, + f"{self.dset}_global_scalar": global_scalar, + } + return out + + def get_uv_corrections(self, global_scalar=1, knn=1): + """Write windspeed bias correction factors for u/v components""" + u_out = self.get_corrections(global_scalar=global_scalar, knn=knn) + v_out = u_out.copy() + u_out[f"{self.u_name}_scalar"] = u_out[f"{self.dset}_scalar"] + v_out[f"{self.v_name}_scalar"] = v_out[f"{self.dset}_scalar"] + u_out[f"{self.u_name}_adder"] = u_out[f"{self.dset}_adder"] + v_out[f"{self.v_name}_adder"] = v_out[f"{self.dset}_adder"] + return u_out, v_out + + def write_output(self, fp_out, out): + """Write bias correction factors to h5 file.""" + logger.info(f"Writing correction factors to file: {fp_out}.") + with h5py.File(fp_out, "w") as f: + for dset, data in out.items(): + f.create_dataset(dset, data=data) + logger.info(f"Added {dset} to {fp_out}.") + for k, v in self.meta.items(): + f.attrs[k] = json.dumps(v) + logger.info(f"Finished writing output to {fp_out}.") + + @classmethod + def run( + cls, + bias_fp, + base_fp, + dset, + fp_out, + global_scalar=1.0, + knn=1, + out_shape=None + ): + """Run bias correction factor computation and write.""" + bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) + out = bc.get_corrections(global_scalar=global_scalar, knn=knn) + if out_shape is not None: + for k, v in out.items(): + if k in ('latitude', 'longitude'): + out[k] = np.array(v).reshape(out_shape) + elif not isinstance(v, float): + out[k] = np.array(v).reshape((*out_shape, 12)) + bc.write_output(fp_out, out) + + @classmethod + def run_uv( + cls, + bias_fp, + base_fp, + dset, + fp_pattern, + global_scalar=1.0, + knn=1, + out_shape=None + ): + """Run bias correction factor computation and write.""" + bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) + out_u, out_v = bc.get_uv_corrections( + global_scalar=global_scalar, knn=knn + ) + if out_shape is not None: + for k, v in out_u.items(): + if k in ('latitude', 'longitude'): + out_u[k] = np.array(v).reshape(out_shape) + elif not isinstance(v, float): + out_u[k] = np.array(v).reshape((*out_shape, 12)) + for k, v in out_v.items(): + if k in ('latitude', 'longitude'): + out_v[k] = np.array(v).reshape(out_shape) + elif not isinstance(v, float): + out_v[k] = np.array(v).reshape((*out_shape, 12)) + bc.write_output(fp_pattern.format(feature=bc.u_name), out_u) + bc.write_output(fp_pattern.format(feature=bc.v_name), out_v) + + +class BiasCorrectUpdate: + """Class for bias correcting existing files and writing corrected files.""" + + @classmethod + def get_bc_factors(cls, bc_file, dset, month, global_scalar=1): + """Get bias correction factors for the given dset and month""" + with Resource(bc_file) as res: + logger.info( + f"Getting {dset} bias correction factors for month {month}." + ) + bc_factor = res[f"{dset}_scalar", :, month - 1] + factors = global_scalar * bc_factor + logger.info( + f"Retrieved {dset} bias correction factors for month {month}. " + f"Using global_scalar={global_scalar}." + ) + return factors + + @classmethod + def update_file(cls, in_file, out_file, dset, bc_file, global_scalar=1): + """Update the in_file with bias corrected values for the given dset + and write to out_file.""" + tmp_file = out_file.replace(".h5", ".h5.tmp") + logger.info(f"Bias correcting {dset} in {in_file} with {bc_file}.") + with Resource(in_file) as fh_in: + OutputHandler._init_h5( + tmp_file, fh_in.time_index, fh_in.meta, fh_in.global_attrs + ) + OutputHandler._ensure_dset_in_output(tmp_file, dset) + for i in range(1, 13): + try: + with RexOutputs(tmp_file, "a") as fh: + mask = fh.time_index.month == i + mask = np.arange(len(fh.time_index))[mask] + mask = slice(mask[0], mask[-1] + 1) + bc_factors = cls.get_bc_factors( + bc_file=bc_file, + dset=dset, + month=i, + global_scalar=global_scalar, + ) + logger.info( + f"Applying bias correction factors for month {i}" + ) + fh[dset, mask, :] = bc_factors * fh_in[dset, mask, :] + except Exception as e: + raise RuntimeError( + f"Bias correction failed for month {i}." + ) from e + + logger.info( + f"Added {dset} for month {i} to output file {tmp_file}." + ) + + os.replace(tmp_file, out_file) + msg = f"Saved bias corrected {dset} to: {out_file}" + logger.info(msg) + + @classmethod + def run( + cls, + in_file, + out_file, + dset, + bc_file, + overwrite=False, + global_scalar=1, + ): + """Run bias correction update.""" + if os.path.exists(out_file) and not overwrite: + logger.info( + f"{out_file} already exists and overwrite=False. " "Skipping." + ) + else: + if os.path.exists(out_file) and overwrite: + logger.info( + f"{out_file} exists but overwrite=True. " + f"Removing {out_file}." + ) + os.remove(out_file) + cls.update_file( + in_file, out_file, dset, bc_file, global_scalar=global_scalar + ) diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index a35b6012c..49c6a3076 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -24,8 +24,8 @@ class DualDataHandler(CacheHandlingMixIn, TrainingPrepMixIn): def __init__(self, hr_handler, lr_handler, - regrid_cache_pattern=None, - overwrite_regrid_cache=False, + cache_pattern=None, + overwrite_cache=False, regrid_workers=1, load_cached=True, shuffle_time=False, @@ -41,9 +41,9 @@ def __init__(self, DataHandler for high_res data lr_handler : DataHandler DataHandler for low_res data - regrid_cache_pattern : str + cache_pattern : str Pattern for files to use for saving regridded ERA data. - overwrite_regrid_cache : bool + overwrite_cache : bool Whether to overwrite regrid cache regrid_workers : int | None Number of workers to use for regridding routine. @@ -63,10 +63,10 @@ def __init__(self, self.t_enhance = t_enhance self.lr_dh = lr_handler self.hr_dh = hr_handler - self._cache_pattern = regrid_cache_pattern + self._cache_pattern = cache_pattern self._cached_features = None self._noncached_features = None - self.overwrite_cache = overwrite_regrid_cache + self.overwrite_cache = overwrite_cache self.val_split = val_split self.current_obs_index = None self.load_cached = load_cached @@ -179,7 +179,7 @@ def normalize(self, means, stdevs): def output_features(self): """Get list of output features. e.g. those that are returned by a GAN""" - return self.hr_dh.output_features + return self.lr_dh.output_features @property def train_only_features(self): @@ -219,7 +219,7 @@ def _run_pair_checks(self, hr_handler, lr_handler): assert hr_handler.val_split == 0 and lr_handler.val_split == 0, msg msg = ('Handlers have incompatible number of features. ' f'({hr_handler.features} vs {lr_handler.features})') - assert hr_handler.features == lr_handler.features, msg + assert hr_handler.features == lr_handler.output_features, msg hr_shape = hr_handler.sample_shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, @@ -233,21 +233,21 @@ def _run_pair_checks(self, hr_handler, lr_handler): hr_shape = self.hr_data.shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, - hr_shape[2] // self.t_enhance, hr_shape[3]) + hr_shape[2] // self.t_enhance) msg = (f'hr_data.shape {self.hr_data.shape} and ' f'lr_data.shape {self.lr_data.shape} are ' f'incompatible. Must be {hr_shape} and {lr_shape}.') - assert self.lr_data.shape == lr_shape, msg + assert self.lr_data.shape[:-1] == lr_shape, msg if self.lr_val_data is not None and self.hr_val_data is not None: hr_shape = self.hr_val_data.shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, - hr_shape[2] // self.t_enhance, hr_shape[3]) + hr_shape[2] // self.t_enhance) msg = (f'hr_val_data.shape {self.hr_val_data.shape} ' f'and lr_val_data.shape {self.lr_val_data.shape}' f' are incompatible. Must be {hr_shape} and {lr_shape}.') - assert self.lr_val_data.shape == lr_shape, msg + assert self.lr_val_data.shape[:-1] == lr_shape, msg @property def grid_mem(self): @@ -378,8 +378,8 @@ def cache_files(self): cache_files = self._get_cache_file_names(self.cache_pattern, grid_shape=self.lr_grid_shape, time_index=self.lr_time_index, - target=self.hr_dh.target, - features=self.hr_dh.features) + target=self.lr_dh.target, + features=self.lr_dh.features) return cache_files @property @@ -499,7 +499,8 @@ def get_next(self): for s in lr_obs_idx[2:-1]: hr_obs_idx.append( slice(s.start * self.t_enhance, s.stop * self.t_enhance)) - hr_obs_idx.append(lr_obs_idx[-1]) + + hr_obs_idx.append(np.arange(len(self.output_features))) hr_obs_idx = tuple(hr_obs_idx) self.current_obs_index = { 'hr_index': hr_obs_idx, diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index c52c3640d..79da9bbe8 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -282,10 +282,10 @@ def _load_single_cached_feature(self, fp, cache_files, features, Error raised if shape conflicts with requested shape """ idx = cache_files.index(fp) - assert features[idx].lower() in fp.lower() + msg = f'{features[idx].lower()} not found in {fp.lower()}.' + assert features[idx].lower() in fp.lower(), msg fp = ignore_case_path_fetch(fp) - logger.info(f'Loading {features[idx]} from ' - f'{fp}.') + logger.info(f'Loading {features[idx]} from {fp}.') out = None with open(fp, 'rb') as fh: diff --git a/sup3r/preprocessing/dual_batch_handling.py b/sup3r/preprocessing/dual_batch_handling.py index ac1901d9a..5244374e8 100644 --- a/sup3r/preprocessing/dual_batch_handling.py +++ b/sup3r/preprocessing/dual_batch_handling.py @@ -3,9 +3,11 @@ import numpy as np -from sup3r.preprocessing.batch_handling import (Batch, BatchHandler, - ValidationData, - ) +from sup3r.preprocessing.batch_handling import ( + Batch, + BatchHandler, + ValidationData, +) from sup3r.utilities.utilities import uniform_box_sampler, uniform_time_sampler logger = logging.getLogger(__name__) @@ -143,6 +145,16 @@ class DualBatchHandler(BatchHandler): BATCH_CLASS = Batch VAL_CLASS = DualValidationData + @property + def lr_features(self): + """Features in low res batch.""" + return self.data_handlers[0].lr_dh.features + + @property + def hr_features(self): + """Features in high res batch.""" + return self.data_handlers[0].lr_dh.output_features + @property def hr_sample_shape(self): """Get sample shape for high_res data""" @@ -173,19 +185,19 @@ def __next__(self): handler = self.data_handlers[handler_index] high_res = np.zeros((self.batch_size, self.hr_sample_shape[0], self.hr_sample_shape[1], - self.hr_sample_shape[2], self.shape[-1]), + self.hr_sample_shape[2], + len(self.hr_features)), dtype=np.float32) low_res = np.zeros((self.batch_size, self.lr_sample_shape[0], self.lr_sample_shape[1], - self.lr_sample_shape[2], self.shape[-1]), + self.lr_sample_shape[2], + len(self.lr_features)), dtype=np.float32) for i in range(self.batch_size): high_res[i, ...], low_res[i, ...] = handler.get_next() self.current_batch_indices.append(handler.current_obs_index) - high_res = self.BATCH_CLASS.reduce_features( - high_res, self.output_features_ind) batch = self.BATCH_CLASS(low_res=low_res, high_res=high_res) self._i += 1 @@ -220,10 +232,12 @@ def __next__(self): self.current_handler_index = handler_index handler = self.data_handlers[handler_index] high_res = np.zeros((self.batch_size, self.hr_sample_shape[0], - self.hr_sample_shape[1], self.shape[-1]), + self.hr_sample_shape[1], + len(self.hr_features)), dtype=np.float32) low_res = np.zeros((self.batch_size, self.lr_sample_shape[0], - self.lr_sample_shape[1], self.shape[-1]), + self.lr_sample_shape[1], + len(self.lr_features)), dtype=np.float32) for i in range(self.batch_size): @@ -232,8 +246,6 @@ def __next__(self): low_res[i, ...] = lr[..., 0, :] self.current_batch_indices.append(handler.current_obs_index) - high_res = self.BATCH_CLASS.reduce_features( - high_res, self.output_features_ind) batch = self.BATCH_CLASS(low_res=low_res, high_res=high_res) self._i += 1 diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index a7d7b5216..a9771906c 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -48,14 +48,15 @@ class EraDownloader: VALID_VARIABLES: ClassVar[list] = [ 'u', 'v', 'pressure', 'temperature', 'relative_humidity', 'specific_humidity', 'total_precipitation', + 'convective_available_potential_energy' ] - KEEP_VARIABLES: ClassVar[list] = ['orog'] + KEEP_VARIABLES: ClassVar[list] = ['orog', 'cape'] KEEP_VARIABLES += [f'{v}_' for v in VALID_VARIABLES] DEFAULT_RENAMED_VARS: ClassVar[list] = [ 'z', 'zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m', - 'temperature', 'pressure', + 'temperature', 'pressure', 'cape' ] DEFAULT_DOWNLOAD_VARS: ClassVar[list] = [ '10m_u_component_of_wind', '10m_v_component_of_wind', @@ -69,7 +70,7 @@ class EraDownloader: '10m_u_component_of_wind', '10m_v_component_of_wind', '100m_u_component_of_wind', '100m_v_component_of_wind', 'surface_pressure', '2m_temperature', 'geopotential', - 'total_precipitation', + 'total_precipitation', "convective_available_potential_energy" ] LEVEL_VARS: ClassVar[list] = [ 'u_component_of_wind', 'v_component_of_wind', 'geopotential', @@ -88,6 +89,7 @@ class EraDownloader: 'r': 'relative_humidity', 'q': 'specific_humidity', 'tp': 'total_precipitation', + 'cape': 'cape' } def __init__(self, @@ -298,7 +300,8 @@ def download_process_combine(self): """Run the download routine.""" sfc_check = len(self.sfc_file_variables) > 0 level_check = (len(self.level_file_variables) > 0 - and self.levels is not None) + and self.levels is not None + and len(self.levels) > 0) if self.level_file_variables: msg = (f'{self.level_file_variables} requested but no levels' ' were provided.') @@ -316,7 +319,7 @@ def download_levels_file(self): """Download file with requested pressure levels""" if not os.path.exists(self.level_file) or self.overwrite: msg = (f'Downloading {self.level_file_variables} to ' - f'{self.level_file}.') + f'{self.level_file} with levels = {self.levels}.') logger.info(msg) CDS_API_CLIENT.retrieve( 'reanalysis-era5-pressure-levels', { @@ -468,8 +471,7 @@ def process_and_combine(self): self.process_surface_file() files.append(self.surface_file) - logger.info(f'Combining {files} and {self.surface_file} ' - f'to {self.combined_file}.') + logger.info(f'Combining {files} to {self.combined_file}.') with xr.open_mfdataset(files) as ds: ds.to_netcdf(self.combined_file) logger.info(f'Finished writing {self.combined_file}') From f406d7b1cec18be19327cb4ac379887b9a8b4221 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 30 Oct 2023 21:10:46 -0600 Subject: [PATCH 05/44] unused import --- sup3r/bias/bias_correct_means.py | 45 +++++++++++++++++++------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py index cbd98e1b5..3a4f19214 100644 --- a/sup3r/bias/bias_correct_means.py +++ b/sup3r/bias/bias_correct_means.py @@ -13,12 +13,12 @@ import rioxarray import xarray as xr from rex import Resource +from scipy.interpolate import interp1d from sklearn.neighbors import BallTree + from sup3r.postprocessing.file_handling import OutputHandler, RexOutputs from sup3r.preprocessing.feature_handling import Feature from sup3r.utilities import VERSION_RECORD -from sup3r.utilities.interpolation import Interpolator -from scipy.interpolate import interp1d logger = logging.getLogger(__name__) @@ -133,11 +133,12 @@ def convert_all_tifs(self): def mask(self): """Mask coordinates without data""" if self._mask is None: - with xr.open_mfdataset(self.get_height_files('January')) as res: - mask = ((res[self.in_features[0]] != -999) - & (~np.isnan(res[self.in_features[0]]))) + with xr.open_mfdataset(self.get_height_files("January")) as res: + mask = (res[self.in_features[0]] != -999) & ( + ~np.isnan(res[self.in_features[0]]) + ) for feat in self.in_features[1:]: - tmp = ((res[feat] != -999) & (~np.isnan(res[feat]))) + tmp = (res[feat] != -999) & (~np.isnan(res[feat])) mask = mask & tmp self._mask = np.array(mask).flatten() return self._mask @@ -200,19 +201,27 @@ def interp(self, data): ( len(data.latitude) * len(data.longitude), len(self.in_heights), - ), dtype=np.float32 + ), + dtype=np.float32, ) lev_array = var_array.copy() for i, (h, feat) in enumerate(zip(self.in_heights, self.in_features)): var_array[..., i] = data[feat].values.flatten() lev_array[..., i] = h - logger.info(f'Interpolating {self.in_features} to {self.out_features} ' - f'for {var_array.shape[0]} coordinates.') - tmp = [interp1d(h, v, fill_value='extrapolate')(self.out_heights) - for h, v in zip(lev_array[self.mask], var_array[self.mask])] - out = np.full((len(data.latitude), len(data.longitude), - len(self.out_heights)), np.nan, dtype=np.float32) + logger.info( + f"Interpolating {self.in_features} to {self.out_features} " + f"for {var_array.shape[0]} coordinates." + ) + tmp = [ + interp1d(h, v, fill_value="extrapolate")(self.out_heights) + for h, v in zip(lev_array[self.mask], var_array[self.mask]) + ] + out = np.full( + (len(data.latitude), len(data.longitude), len(self.out_heights)), + np.nan, + dtype=np.float32, + ) out[self.mask.reshape((len(data.latitude), len(data.longitude)))] = tmp for i, feat in enumerate(self.out_features): if feat not in data: @@ -578,14 +587,14 @@ def run( fp_out, global_scalar=1.0, knn=1, - out_shape=None + out_shape=None, ): """Run bias correction factor computation and write.""" bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) out = bc.get_corrections(global_scalar=global_scalar, knn=knn) if out_shape is not None: for k, v in out.items(): - if k in ('latitude', 'longitude'): + if k in ("latitude", "longitude"): out[k] = np.array(v).reshape(out_shape) elif not isinstance(v, float): out[k] = np.array(v).reshape((*out_shape, 12)) @@ -600,7 +609,7 @@ def run_uv( fp_pattern, global_scalar=1.0, knn=1, - out_shape=None + out_shape=None, ): """Run bias correction factor computation and write.""" bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) @@ -609,12 +618,12 @@ def run_uv( ) if out_shape is not None: for k, v in out_u.items(): - if k in ('latitude', 'longitude'): + if k in ("latitude", "longitude"): out_u[k] = np.array(v).reshape(out_shape) elif not isinstance(v, float): out_u[k] = np.array(v).reshape((*out_shape, 12)) for k, v in out_v.items(): - if k in ('latitude', 'longitude'): + if k in ("latitude", "longitude"): out_v[k] = np.array(v).reshape(out_shape) elif not isinstance(v, float): out_v[k] = np.array(v).reshape((*out_shape, 12)) From 765dbd845d9573db68bb8e4a7b3ef7fb0e3b4c55 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 31 Oct 2023 12:45:15 -0600 Subject: [PATCH 06/44] exo cache naming issue now that data has temporal dimension --- .../data_handling/exogenous_data_handling.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index 5ff81b831..f8bda65a2 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -567,8 +567,12 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, cache_fp : str Name of cache file """ - fn = f'exo_{feature}_{self.target}_{self.shape}_sagg{s_agg_factor}_' - fn += f'tagg{t_agg_factor}_{s_enhance}x_{t_enhance}x.pkl' + tsteps = (None if self.temporal_slice.start is None + or self.temporal_slice.end is None + else self.temporal_slice.end - self.temporal_slice.start) + fn = f'exo_{feature}_{self.target}_{self.shape},{tsteps}' + fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_' + fn += f'{t_enhance}x.pkl' fn = fn.replace('(', '').replace(')', '') fn = fn.replace('[', '').replace(']', '') fn = fn.replace(',', 'x').replace(' ', '') From 81a0db5ccbe8ef64fd16d031ce4cd1b6a6bed4c7 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 1 Nov 2023 08:50:26 -0600 Subject: [PATCH 07/44] exo caching moved to exo_extract and time independent naming convention used for topo. --- .../data_handling/exo_extraction.py | 139 +++++++++++++++++- .../data_handling/exogenous_data_handling.py | 88 ++--------- sup3r/utilities/era_downloader.py | 82 ++++++----- sup3r/utilities/interpolate_log_profile.py | 1 + 4 files changed, 196 insertions(+), 114 deletions(-) diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 1290703f2..5066d082c 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -1,6 +1,9 @@ """Sup3r topography utilities""" import logging +import os +import pickle +import shutil from abc import ABC, abstractmethod import numpy as np @@ -37,6 +40,8 @@ def __init__(self, raster_file=None, max_delta=20, input_handler=None, + cache_data=True, + cache_dir='./exo_cache/', ti_workers=1): """ Parameters @@ -100,6 +105,12 @@ def __init__(self, data handler class to use for input data. Provide a string name to match a class in data_handling.py. If None the correct handler will be guessed based on file type and time series properties. + cache_data : bool + Flag to cache exogeneous data in /exo_cache/ this can + speed up forward passes with large temporal extents when the exo + data is time independent. + cache_dir : str + Directory for storing cache data. Default is './exo_cache' ti_workers : int | None max number of workers to use to get full time index. Useful when there are many input files each with a single time step. If this is @@ -122,6 +133,11 @@ def __init__(self, self._source_lat_lon = None self._hr_time_index = None self._src_time_index = None + self.cache_data = cache_data + self.cache_dir = cache_dir + self.temporal_slice = temporal_slice + self.target = target + self.shape = shape if input_handler is None: in_type = get_source_type(file_paths) @@ -159,6 +175,46 @@ def __init__(self, def source_data(self): """Get the 1D array of source data from the exo_source_h5""" + def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, + t_agg_factor): + """Get cache file name + + Parameters + ---------- + feature : str + Name of feature to get cache file for + s_enhance : int + Spatial enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + t_enhance : int + Temporal enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + s_agg_factor : int + Factor by which to aggregate the exo_source data to the spatial + resolution of the file_paths input enhanced by s_enhance. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the temporal + resolution of the file_paths input enhanced by t_enhance. + + Returns + ------- + cache_fp : str + Name of cache file + """ + tsteps = (None if self.temporal_slice.start is None + or self.temporal_slice.stop is None + else self.temporal_slice.stop - self.temporal_slice.start) + fn = f'exo_{feature}_{self.target}_{self.shape},{tsteps}' + fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_' + fn += f'{t_enhance}x.pkl' + fn = fn.replace('(', '').replace(')', '') + fn = fn.replace('[', '').replace(']', '') + fn = fn.replace(',', 'x').replace(' ', '') + cache_fp = os.path.join(self.cache_dir, fn) + if self.cache_data: + os.makedirs(self.cache_dir, exist_ok=True) + return cache_fp + @property def source_temporal_slice(self): """Get the temporal slice for the exo_source data corresponding to the @@ -266,6 +322,30 @@ def nn(self): @property def data(self): + """Get a raster of source values corresponding to the + high-resolution grid (the file_paths input grid * s_enhance * + t_enhance). The shape is (lats, lons, temporal, 1) + """ + cache_fp = self.get_cache_file(feature=self.__class__.__name__, + s_enhance=self._s_enhance, + t_enhance=self._t_enhance, + s_agg_factor=self._s_agg_factor, + t_agg_factor=self._t_agg_factor) + tmp_fp = cache_fp + '.tmp' + if os.path.exists(cache_fp): + with open(cache_fp, 'rb') as f: + data = pickle.load(f) + + else: + data = self.get_data() + + if self.cache_data: + with open(tmp_fp, 'wb') as f: + pickle.dump(data, f) + shutil.move(tmp_fp, cache_fp) + return data + + def get_data(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance * t_enhance). The shape is (lats, lons, temporal, 1) @@ -293,7 +373,9 @@ def get_exo_raster(cls, temporal_slice=None, raster_file=None, max_delta=20, - input_handler=None): + input_handler=None, + cache_data=True, + cache_dir='./exo_cache/'): """Get the exo feature raster corresponding to the spatially enhanced grid from the file_paths input @@ -356,6 +438,12 @@ class will output a topography raster corresponding to the data handler class to use for input data. Provide a string name to match a class in data_handling.py. If None the correct handler will be guessed based on file type and time series properties. + cache_data : bool + Flag to cache exogeneous data in /exo_cache/ this can + speed up forward passes with large temporal extents when the exo + data is time independent. + cache_dir : str + Directory for storing cache data. Default is './exo_cache' Returns ------- @@ -377,7 +465,9 @@ class will output a topography raster corresponding to the temporal_slice=temporal_slice, raster_file=raster_file, max_delta=max_delta, - input_handler=input_handler) + input_handler=input_handler, + cache_data=cache_data, + cache_dir=cache_dir) return exo.data @@ -407,8 +497,7 @@ def source_time_index(self): self._src_time_index = res.time_index return self._src_time_index - @property - def data(self): + def get_data(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance * t_enhance). The shape is (lats, lons, temporal, 1) @@ -421,7 +510,44 @@ def data(self): hr_data.append(out[..., np.newaxis]) hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) logger.info('Finished mapping raster from {}'.format(self._exo_source)) - return hr_data[..., np.newaxis] + out = hr_data[..., np.newaxis] + + def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, + t_agg_factor): + """Get cache file name. This uses a time independent naming convention. + + Parameters + ---------- + feature : str + Name of feature to get cache file for + s_enhance : int + Spatial enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + t_enhance : int + Temporal enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + s_agg_factor : int + Factor by which to aggregate the exo_source data to the spatial + resolution of the file_paths input enhanced by s_enhance. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the temporal + resolution of the file_paths input enhanced by t_enhance. + + Returns + ------- + cache_fp : str + Name of cache file + """ + fn = f'exo_{feature}_{self.target}_{self.shape}' + fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_' + fn += f'{t_enhance}x.pkl' + fn = fn.replace('(', '').replace(')', '') + fn = fn.replace('[', '').replace(']', '') + fn = fn.replace(',', 'x').replace(' ', '') + cache_fp = os.path.join(self.cache_dir, fn) + if self.cache_data: + os.makedirs(self.cache_dir, exist_ok=True) + return cache_fp class TopoExtractNC(TopoExtractH5): @@ -470,8 +596,7 @@ def source_data(self): return SolarPosition(self.hr_time_index, self.hr_lat_lon.reshape((-1, 2))).zenith.T - @property - def data(self): + def get_data(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance * t_enhance). The shape is (lats, lons, temporal, 1) diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index f8bda65a2..f3d35fe73 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -1,9 +1,6 @@ """Sup3r exogenous data handling""" import logging -import os -import pickle import re -import shutil from typing import ClassVar import numpy as np @@ -541,46 +538,6 @@ def _get_all_agg_and_enhancement(self): for step in self.steps] return agg_enhance_dict - def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, - t_agg_factor): - """Get cache file name - - Parameters - ---------- - feature : str - Name of feature to get cache file for - s_enhance : int - Spatial enhancement for this exogeneous data step (cumulative for - all model steps up to the current step). - t_enhance : int - Temporal enhancement for this exogeneous data step (cumulative for - all model steps up to the current step). - s_agg_factor : int - Factor by which to aggregate the exo_source data to the spatial - resolution of the file_paths input enhanced by s_enhance. - t_agg_factor : int - Factor by which to aggregate the exo_source data to the temporal - resolution of the file_paths input enhanced by t_enhance. - - Returns - ------- - cache_fp : str - Name of cache file - """ - tsteps = (None if self.temporal_slice.start is None - or self.temporal_slice.end is None - else self.temporal_slice.end - self.temporal_slice.start) - fn = f'exo_{feature}_{self.target}_{self.shape},{tsteps}' - fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_' - fn += f'{t_enhance}x.pkl' - fn = fn.replace('(', '').replace(')', '') - fn = fn.replace('[', '').replace(']', '') - fn = fn.replace(',', 'x').replace(' ', '') - cache_fp = os.path.join(self.cache_dir, fn) - if self.cache_data: - os.makedirs(self.cache_dir, exist_ok=True) - return cache_fp - def get_exo_data(self, feature, s_enhance, t_enhance, s_agg_factor, t_agg_factor): """Get the exogenous topography data @@ -609,35 +566,22 @@ def get_exo_data(self, feature, s_enhance, t_enhance, s_agg_factor, lon, temporal) """ - cache_fp = self.get_cache_file(feature=feature, - s_enhance=s_enhance, - t_enhance=t_enhance, - s_agg_factor=s_agg_factor, - t_agg_factor=t_agg_factor) - tmp_fp = cache_fp + '.tmp' - if os.path.exists(cache_fp): - with open(cache_fp, 'rb') as f: - data = pickle.load(f) - - else: - exo_handler = self.get_exo_handler(feature, self.source_file, - self.exo_handler) - data = exo_handler(self.file_paths, - self.source_file, - s_enhance=s_enhance, - t_enhance=t_enhance, - s_agg_factor=s_agg_factor, - t_agg_factor=t_agg_factor, - target=self.target, - shape=self.shape, - temporal_slice=self.temporal_slice, - raster_file=self.raster_file, - max_delta=self.max_delta, - input_handler=self.input_handler).data - if self.cache_data: - with open(tmp_fp, 'wb') as f: - pickle.dump(data, f) - shutil.move(tmp_fp, cache_fp) + exo_handler = self.get_exo_handler(feature, self.source_file, + self.exo_handler) + data = exo_handler(self.file_paths, + self.source_file, + s_enhance=s_enhance, + t_enhance=t_enhance, + s_agg_factor=s_agg_factor, + t_agg_factor=t_agg_factor, + target=self.target, + shape=self.shape, + temporal_slice=self.temporal_slice, + raster_file=self.raster_file, + max_delta=self.max_delta, + input_handler=self.input_handler, + cache_data=self.cache_data, + cache_dir=self.cache_dir).data return data @classmethod diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index a9771906c..5fb5c94e9 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -48,22 +48,7 @@ class EraDownloader: VALID_VARIABLES: ClassVar[list] = [ 'u', 'v', 'pressure', 'temperature', 'relative_humidity', 'specific_humidity', 'total_precipitation', - 'convective_available_potential_energy' - ] - - KEEP_VARIABLES: ClassVar[list] = ['orog', 'cape'] - KEEP_VARIABLES += [f'{v}_' for v in VALID_VARIABLES] - - DEFAULT_RENAMED_VARS: ClassVar[list] = [ - 'z', 'zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m', - 'temperature', 'pressure', 'cape' - ] - DEFAULT_DOWNLOAD_VARS: ClassVar[list] = [ - '10m_u_component_of_wind', '10m_v_component_of_wind', - '100m_u_component_of_wind', '100m_v_component_of_wind', - 'u_component_of_wind', 'v_component_of_wind', '2m_temperature', - 'temperature', 'surface_pressure', 'relative_humidity', - 'total_precipitation', + 'convective_available_potential_energy', 'divergence' ] SFC_VARS: ClassVar[list] = [ @@ -74,7 +59,7 @@ class EraDownloader: ] LEVEL_VARS: ClassVar[list] = [ 'u_component_of_wind', 'v_component_of_wind', 'geopotential', - 'temperature', 'relative_humidity', 'specific_humidity', + 'temperature', 'relative_humidity', 'specific_humidity', 'divergence' ] NAME_MAP: ClassVar[dict] = { 'u10': 'u_10m', @@ -89,7 +74,8 @@ class EraDownloader: 'r': 'relative_humidity', 'q': 'specific_humidity', 'tp': 'total_precipitation', - 'cape': 'cape' + 'cape': 'cape', + 'd': 'divergence' } def __init__(self, @@ -387,16 +373,16 @@ def map_vars(self, old_ds, ds): ds : Dataset Dataset() object for new file with new variables written. """ - for old_name, new_name in self.NAME_MAP.items(): - if old_name in old_ds.variables: - _ = ds.createVariable(new_name, - np.float32, - dimensions=old_ds[old_name].dimensions, - ) - vals = old_ds.variables[old_name][:] - if 'temperature' in new_name: - vals -= 273.15 - ds.variables[new_name][:] = vals + for old_name in old_ds.variables: + new_name = self.NAME_MAP.get(old_name, old_name) + _ = ds.createVariable(new_name, + np.float32, + dimensions=old_ds[old_name].dimensions, + ) + vals = old_ds.variables[old_name][:] + if 'temperature' in new_name: + vals -= 273.15 + ds.variables[new_name][:] = vals return ds def convert_z(self, standard_name, long_name, old_ds, ds): @@ -542,7 +528,8 @@ def run_interpolation(self, max_workers=None, **kwargs): overwrite=self.overwrite, **kwargs) - def get_monthly_file(self, interp_workers=None, **interp_kwargs): + def get_monthly_file(self, interp_workers=None, keep_variables=None, + **interp_kwargs): """Download level and surface files, process variables, and combine processed files. Includes checks for shape and variables and option to interpolate.""" @@ -559,10 +546,10 @@ def get_monthly_file(self, interp_workers=None, **interp_kwargs): self.run_interpolation(max_workers=interp_workers, **interp_kwargs) if self.interp_file is not None and os.path.exists(self.interp_file): - if self.already_pruned(self.interp_file): + if self.already_pruned(self.interp_file, keep_variables): logger.info(f'{self.interp_file} pruned already.') else: - self.prune_output(self.interp_file) + self.prune_output(self.interp_file, keep_variables) @classmethod def all_months_exist(cls, year, file_pattern): @@ -587,21 +574,35 @@ def all_months_exist(cls, year, file_pattern): for month in range(1, 13)) @classmethod - def already_pruned(cls, infile): + def already_pruned(cls, infile, keep_variables): """Check if file has been pruned already.""" + if keep_variables is None: + logger.info('Received keep_variables=None. Skipping pruning.') + return + else: + logger.info( + f'Received keep_variables={keep_variables}. Skipping pruning.') + pruned = True with Dataset(infile, 'r') as ds: for var in ds.variables: - if not any(name in var for name in cls.KEEP_VARIABLES): + if not any(name in var for name in keep_variables): logger.info(f'Pruning {var} in {infile}.') pruned = False return pruned @classmethod - def prune_output(cls, infile): + def prune_output(cls, infile, keep_variables=None): """Prune output file to keep just single level variables""" + if keep_variables is None: + logger.info('Received keep_variables=None. Skipping pruning.') + return + else: + logger.info( + f'Received keep_variables={keep_variables}. Skipping pruning.') + logger.info(f'Pruning {infile}.') tmp_file = cls.get_tmp_file(infile) with Dataset(infile, 'r') as old_ds: @@ -609,7 +610,7 @@ def prune_output(cls, infile): new_ds = cls.init_dims(old_ds, new_ds, ('time', 'latitude', 'longitude')) for var in old_ds.variables: - if any(name in var for name in cls.KEEP_VARIABLES): + if any(name in var for name in keep_variables): old_var = old_ds[var] vals = old_var[:] _ = new_ds.createVariable( @@ -639,6 +640,7 @@ def run_month(cls, required_shape=None, interp_workers=None, variables=None, + keep_variables=None, **interp_kwargs): """Run routine for all months in the requested year. @@ -672,6 +674,9 @@ def run_month(cls, variables : list | None Variables to download. If None this defaults to just gepotential and wind components. + keep_variables : list | None + Variables to keep in final files. All other variables will be + pruned. **interp_kwargs : dict Keyword args for LogLinInterpolator.run() """ @@ -686,6 +691,7 @@ def run_month(cls, required_shape=required_shape, variables=variables) downloader.get_monthly_file(interp_workers=interp_workers, + keep_variables=keep_variables, **interp_kwargs) @classmethod @@ -703,6 +709,7 @@ def run_year(cls, max_workers=None, interp_workers=None, variables=None, + keep_variables=None, **interp_kwargs): """Run routine for all months in the requested year. @@ -741,6 +748,9 @@ def run_year(cls, variables : list | None Variables to download. If None this defaults to just gepotential and wind components. + keep_variables : list | None + Variables to keep in final files. All other variables will be + pruned. **interp_kwargs : dict Keyword args for LogLinInterpolator.run() """ @@ -757,6 +767,7 @@ def run_year(cls, required_shape=required_shape, interp_workers=interp_workers, variables=variables, + keep_variables=keep_variables, **interp_kwargs) else: futures = {} @@ -774,6 +785,7 @@ def run_year(cls, overwrite=overwrite, required_shape=required_shape, interp_workers=interp_workers, + keep_variables=keep_variables, variables=variables, **interp_kwargs) futures[future] = {'year': year, 'month': month} diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index 7a7f51910..993b1dcc1 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -37,6 +37,7 @@ class LogLinInterpolator: 'temperature': [10, 40, 80, 100, 120, 160, 200], 'pressure': [0, 100, 200], 'relative_humidity': [80, 100, 120], + 'divergence': [80, 100, 120] } def __init__( From fe8546db57ae1e73a590eecc9af5cec9665e6905 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 1 Nov 2023 09:11:14 -0600 Subject: [PATCH 08/44] modified methods to cache single time step for time independent exo data and apply repeat after loading cache. --- .../data_handling/exo_extraction.py | 17 ++++++++++------- sup3r/utilities/era_downloader.py | 6 ++---- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 5066d082c..428580ac7 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -201,7 +201,8 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, cache_fp : str Name of cache file """ - tsteps = (None if self.temporal_slice.start is None + tsteps = (None if self.temporal_slice is None + or self.temporal_slice.start is None or self.temporal_slice.stop is None else self.temporal_slice.stop - self.temporal_slice.start) fn = f'exo_{feature}_{self.target}_{self.shape},{tsteps}' @@ -343,6 +344,10 @@ def data(self): with open(tmp_fp, 'wb') as f: pickle.dump(data, f) shutil.move(tmp_fp, cache_fp) + + if data.shape[-2] == 1 and self.hr_shape[-1] > 1: + data = np.repeat(data[..., np.newaxis, :], self.hr_shape[-1], + axis=-2) return data def get_data(self): @@ -479,8 +484,7 @@ def source_data(self): """Get the 1D array of elevation data from the exo_source_h5""" with Resource(self._exo_source) as res: elev = res.get_meta_arr('elevation') - elev = np.repeat(elev[:, np.newaxis], self.hr_shape[-1], axis=-1) - return elev + return elev[:, np.newaxis] @property def source_lat_lon(self): @@ -506,11 +510,11 @@ def get_data(self): hr_data = [] for j in range(self._s_agg_factor): out = self.source_data[nn[:, j]] - out = out.reshape(self.hr_shape) + out = out.reshape((*self.hr_shape[:-1], -1)) hr_data.append(out[..., np.newaxis]) hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) logger.info('Finished mapping raster from {}'.format(self._exo_source)) - out = hr_data[..., np.newaxis] + return hr_data[..., np.newaxis] def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, t_agg_factor): @@ -577,8 +581,7 @@ def __init__(self, *args, **kwargs): def source_data(self): """Get the 1D array of elevation data from the exo_source_nc""" elev = self.source_handler.data[..., 0, 0].flatten() - elev = np.repeat(elev[..., np.newaxis], self.hr_shape[-1], axis=-1) - return elev + return elev[..., np.newaxis] @property def source_lat_lon(self): diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 5fb5c94e9..2d9fa0711 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -581,8 +581,7 @@ def already_pruned(cls, infile, keep_variables): logger.info('Received keep_variables=None. Skipping pruning.') return else: - logger.info( - f'Received keep_variables={keep_variables}. Skipping pruning.') + logger.info(f'Received keep_variables={keep_variables}.') pruned = True with Dataset(infile, 'r') as ds: @@ -600,8 +599,7 @@ def prune_output(cls, infile, keep_variables=None): logger.info('Received keep_variables=None. Skipping pruning.') return else: - logger.info( - f'Received keep_variables={keep_variables}. Skipping pruning.') + logger.info(f'Received keep_variables={keep_variables}.') logger.info(f'Pruning {infile}.') tmp_file = cls.get_tmp_file(infile) From 5123490624871cb4227c9346204479b59208a5cc Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 1 Nov 2023 11:30:22 -0600 Subject: [PATCH 09/44] normalization bug for dual handlers - assumed lr and hr handlers had same features originally but this isnt required. means/stds have to be indexed carefully. --- .../data_handling/dual_data_handling.py | 13 ++++++++----- sup3r/preprocessing/data_handling/mixin.py | 4 ++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index 49c6a3076..e22507b57 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -162,17 +162,20 @@ def normalize(self, means, stdevs): dimensions (features) array of means for all features with same ordering as data features """ - logger.info('Normalizing low resolution data.') + logger.info('Normalizing low resolution data features=' + f'{self.features}') self._normalize(data=self.lr_data, val_data=self.lr_val_data, means=means, stds=stdevs, max_workers=self.lr_dh.norm_workers) - logger.info('Normalizing high resolution data.') + logger.info('Normalizing high resolution data features=' + f'{self.output_features}') + indices = [self.features.index(f) for f in self.output_features] self._normalize(data=self.hr_data, val_data=self.hr_val_data, - means=means, - stds=stdevs, + means=means[indices], + stds=stdevs[indices], max_workers=self.hr_dh.norm_workers) @property @@ -219,7 +222,7 @@ def _run_pair_checks(self, hr_handler, lr_handler): assert hr_handler.val_split == 0 and lr_handler.val_split == 0, msg msg = ('Handlers have incompatible number of features. ' f'({hr_handler.features} vs {lr_handler.features})') - assert hr_handler.features == lr_handler.output_features, msg + assert hr_handler.features == self.output_features, msg hr_shape = hr_handler.sample_shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index 79da9bbe8..f2f2c0494 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -1054,6 +1054,10 @@ def _normalize(self, data, val_data, means, stds, max_workers=None): max_workers : int | None Number of workers to use in thread pool for nomalization. """ + msg = f'Received {len(means)} means for {data.shape[-1]} features' + assert len(means) == data.shape[-1], msg + msg = f'Received {len(stds)} stds for {data.shape[-1]} features' + assert len(stds) == data.shape[-1], msg logger.info(f'Normalizing {data.shape[-1]} features.') if max_workers == 1: for i in range(data.shape[-1]): From 3b80de90b7c5d9cae490a2cd2c9073b6adfda5e1 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 1 Nov 2023 11:30:22 -0600 Subject: [PATCH 10/44] normalization bug for dual handlers - assumed lr and hr handlers had same features originally but this isnt required. means/stds have to be indexed carefully. --- .../data_handling/dual_data_handling.py | 13 ++++++++----- sup3r/preprocessing/data_handling/exo_extraction.py | 11 +++++------ sup3r/preprocessing/data_handling/mixin.py | 4 ++++ 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index 49c6a3076..e22507b57 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -162,17 +162,20 @@ def normalize(self, means, stdevs): dimensions (features) array of means for all features with same ordering as data features """ - logger.info('Normalizing low resolution data.') + logger.info('Normalizing low resolution data features=' + f'{self.features}') self._normalize(data=self.lr_data, val_data=self.lr_val_data, means=means, stds=stdevs, max_workers=self.lr_dh.norm_workers) - logger.info('Normalizing high resolution data.') + logger.info('Normalizing high resolution data features=' + f'{self.output_features}') + indices = [self.features.index(f) for f in self.output_features] self._normalize(data=self.hr_data, val_data=self.hr_val_data, - means=means, - stds=stdevs, + means=means[indices], + stds=stdevs[indices], max_workers=self.hr_dh.norm_workers) @property @@ -219,7 +222,7 @@ def _run_pair_checks(self, hr_handler, lr_handler): assert hr_handler.val_split == 0 and lr_handler.val_split == 0, msg msg = ('Handlers have incompatible number of features. ' f'({hr_handler.features} vs {lr_handler.features})') - assert hr_handler.features == lr_handler.output_features, msg + assert hr_handler.features == self.output_features, msg hr_shape = hr_handler.sample_shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 428580ac7..9901b4b83 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -345,10 +345,9 @@ def data(self): pickle.dump(data, f) shutil.move(tmp_fp, cache_fp) - if data.shape[-2] == 1 and self.hr_shape[-1] > 1: - data = np.repeat(data[..., np.newaxis, :], self.hr_shape[-1], - axis=-2) - return data + if data.shape[-1] == 1 and self.hr_shape[-1] > 1: + data = np.repeat(data, self.hr_shape[-1], axis=-1) + return data[..., np.newaxis] def get_data(self): """Get a raster of source values corresponding to the @@ -363,7 +362,7 @@ def get_data(self): hr_data.append(out[..., np.newaxis]) hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) logger.info('Finished mapping raster from {}'.format(self._exo_source)) - return hr_data[..., np.newaxis] + return hr_data @classmethod def get_exo_raster(cls, @@ -514,7 +513,7 @@ def get_data(self): hr_data.append(out[..., np.newaxis]) hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) logger.info('Finished mapping raster from {}'.format(self._exo_source)) - return hr_data[..., np.newaxis] + return hr_data def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, t_agg_factor): diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index 79da9bbe8..f2f2c0494 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -1054,6 +1054,10 @@ def _normalize(self, data, val_data, means, stds, max_workers=None): max_workers : int | None Number of workers to use in thread pool for nomalization. """ + msg = f'Received {len(means)} means for {data.shape[-1]} features' + assert len(means) == data.shape[-1], msg + msg = f'Received {len(stds)} stds for {data.shape[-1]} features' + assert len(stds) == data.shape[-1], msg logger.info(f'Normalizing {data.shape[-1]} features.') if max_workers == 1: for i in range(data.shape[-1]): From 8e1b5500d8582092d6e01e059816a12b8764d7fd Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 3 Nov 2023 07:08:55 -0600 Subject: [PATCH 11/44] some edits to netcdf output for multistep forward-pass pipeline (run spatial fwp with netcdf output and then temporal fwp on that output) --- sup3r/postprocessing/file_handling.py | 2 +- .../data_handling/exo_extraction.py | 8 +++++++- .../data_handling/exogenous_data_handling.py | 10 ++++++++-- sup3r/utilities/era_downloader.py | 17 +++++++++-------- 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/sup3r/postprocessing/file_handling.py b/sup3r/postprocessing/file_handling.py index f86339f16..c1244366c 100644 --- a/sup3r/postprocessing/file_handling.py +++ b/sup3r/postprocessing/file_handling.py @@ -560,7 +560,7 @@ def _write_output(cls, data, features, lat_lon, times, out_file, List of coordinate indices used to label each lat lon pair and to help with spatial chunk data collection """ - coords = {'Times': (['Time'], times), + coords = {'Times': (['Time'], [str(t).encode('utf-8') for t in times]), 'XLAT': (['south_north', 'east_west'], lat_lon[..., 0]), 'XLONG': (['south_north', 'east_west'], lat_lon[..., 1])} diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 9901b4b83..8e3c60d07 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -42,7 +42,8 @@ def __init__(self, input_handler=None, cache_data=True, cache_dir='./exo_cache/', - ti_workers=1): + ti_workers=1, + res_kwargs=None): """ Parameters ---------- @@ -118,6 +119,9 @@ def __init__(self, parallel and then concatenated to get the full time index. If input files do not all have time indices or if there are few input files this should be set to one. + res_kwargs : dict | None + Dictionary of kwargs passed to lowest level resource handler. e.g. + xr.open_dataset(file_paths, **res_kwargs) """ logger.info(f'Initializing {self.__class__.__name__} utility.') @@ -138,6 +142,7 @@ def __init__(self, self.temporal_slice = temporal_slice self.target = target self.shape = shape + self.res_kwargs = res_kwargs if input_handler is None: in_type = get_source_type(file_paths) @@ -168,6 +173,7 @@ def __init__(self, raster_file=raster_file, max_delta=max_delta, worker_kwargs=dict(ti_workers=ti_workers), + res_kwargs=self.res_kwargs ) @property diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index f3d35fe73..3149887e4 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -206,7 +206,8 @@ def __init__(self, input_handler=None, exo_handler=None, cache_data=True, - cache_dir='./exo_cache'): + cache_dir='./exo_cache', + res_kwargs=None): """ Parameters ---------- @@ -279,6 +280,9 @@ def __init__(self, speed up forward passes with large temporal extents cache_dir : str Directory for storing cache data. Default is './exo_cache' + res_kwargs : dict | None + Dictionary of kwargs passed to lowest level resource handler. e.g. + xr.open_dataset(file_paths, **res_kwargs) """ self.feature = feature @@ -297,6 +301,7 @@ def __init__(self, self.cache_data = cache_data self.cache_dir = cache_dir self.data = {feature: {'steps': []}} + self.res_kwargs = res_kwargs self.input_check() agg_enhance = self._get_all_agg_and_enhancement() @@ -581,7 +586,8 @@ def get_exo_data(self, feature, s_enhance, t_enhance, s_agg_factor, max_delta=self.max_delta, input_handler=self.input_handler, cache_data=self.cache_data, - cache_dir=self.cache_dir).data + cache_dir=self.cache_dir, + res_kwargs=self.res_kwargs).data return data @classmethod diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 2d9fa0711..0a572c9ef 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -375,14 +375,15 @@ def map_vars(self, old_ds, ds): """ for old_name in old_ds.variables: new_name = self.NAME_MAP.get(old_name, old_name) - _ = ds.createVariable(new_name, - np.float32, - dimensions=old_ds[old_name].dimensions, - ) - vals = old_ds.variables[old_name][:] - if 'temperature' in new_name: - vals -= 273.15 - ds.variables[new_name][:] = vals + if new_name not in ds.variables: + _ = ds.createVariable(new_name, + np.float32, + dimensions=old_ds[old_name].dimensions, + ) + vals = old_ds.variables[old_name][:] + if 'temperature' in new_name: + vals -= 273.15 + ds.variables[new_name][:] = vals return ds def convert_z(self, standard_name, long_name, old_ds, ds): From 3a7edb6f56defb95594d668bf56b7f815132a5da Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 3 Nov 2023 08:20:29 -0600 Subject: [PATCH 12/44] bc code doc strings --- sup3r/bias/bias_correct_means.py | 318 +++++++++++++++--- .../data_handling/exo_extraction.py | 14 +- .../data_handling/test_dual_data_handling.py | 8 +- 3 files changed, 286 insertions(+), 54 deletions(-) diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py index 3a4f19214..eb4f1f996 100644 --- a/sup3r/bias/bias_correct_means.py +++ b/sup3r/bias/bias_correct_means.py @@ -1,11 +1,13 @@ """Classes to compute means from vortex and era data and compute bias -correction factors.""" +correction factors. +""" import calendar import json import logging import os +from concurrent.futures import ThreadPoolExecutor, as_completed import h5py import numpy as np @@ -15,7 +17,6 @@ from rex import Resource from scipy.interpolate import interp1d from sklearn.neighbors import BallTree - from sup3r.postprocessing.file_handling import OutputHandler, RexOutputs from sup3r.preprocessing.feature_handling import Feature from sup3r.utilities import VERSION_RECORD @@ -26,11 +27,11 @@ class VortexMeanPrepper: """Class for converting monthly vortex tif files for each height to a single h5 files containing all monthly means for all requested output - heights""" + heights. + """ def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): - """ - Parameters + """Parameters ---------- path_pattern : str Pattern for input tif files. Needs to include {month} and {height} @@ -52,7 +53,7 @@ def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): @property def in_features(self): - """List of features corresponding to input heights""" + """List of features corresponding to input heights.""" return [f"windspeed_{h}m" for h in self.in_heights] @property @@ -92,7 +93,8 @@ def get_output_file(self, month): @property def output_files(self): """List of output monthly output files each with windspeed for all - input heights""" + input heights + """ files = [] for i in range(1, 13): month = calendar.month_name[i] @@ -101,7 +103,8 @@ def output_files(self): def convert_month_height_tif(self, month, height): """Get windspeed mean for the given month and hub height from the - corresponding input file and write this to a netcdf file.""" + corresponding input file and write this to a netcdf file. + """ infile = self.get_input_file(month, height) logger.info(f"Getting mean windspeed_{height}m for {month}.") outfile = infile.replace(".tif", ".nc") @@ -118,12 +121,12 @@ def convert_month_height_tif(self, month, height): return outfile def convert_month_tif(self, month): - """Write netcdf files for all heights for the given month""" + """Write netcdf files for all heights for the given month.""" for height in self.in_heights: self.convert_month_height_tif(month, height) def convert_all_tifs(self): - """Write netcdf files for all heights for all months""" + """Write netcdf files for all heights for all months.""" for i in range(1, 13): month = calendar.month_name[i] logger.info(f"Converting tif files to netcdf files for {month}") @@ -299,7 +302,23 @@ def run( ): """Read vortex tif files, convert these to monthly netcdf files for all input heights, interpolate this data to requested output heights, mask - fill values, and write all data to h5 file.""" + fill values, and write all data to h5 file. + + Parameters + ---------- + path_pattern : str + Pattern for input tif files. Needs to include {month} and {height} + format keys. + in_heights : list + List of heights for input files. + out_heights : list + List of output heights used for interpolation + fp_out : str + Name of final h5 output file to write with means. + overwrite : bool + Whether to overwrite intermediate netcdf files containing the + interpolated masked monthly means. + """ vprep = cls(path_pattern, in_heights, out_heights, overwrite=overwrite) vprep.convert_all_tifs() out = vprep.get_all_data() @@ -310,8 +329,7 @@ class EraMeanPrepper: """Class to compute monthly windspeed means from ERA data.""" def __init__(self, era_pattern, years, features): - """ - Parameters + """Parameters ---------- era_pattern : str Pattern pointing to era files with u/v wind components at the given @@ -438,7 +456,21 @@ def write_csv(self, out, out_file): def run(cls, era_pattern, years, features, out_pattern): """Compute monthly windspeed means for the given heights, using the given years of ERA data, and write the means to csv files for each - height.""" + height. + + Parameters + ---------- + era_pattern : str + Pattern pointing to era files with u/v wind components at the given + heights. Must have a {year} format key. + years : list + List of ERA years to use for calculating means. + features : list + List of features to compute means for. e.g. ['windspeed_10m'] + out_pattern : str + Pattern pointing to csv files to write means to. Must have a + {feature} format key. + """ em = cls(era_pattern=era_pattern, years=years, features=features) for height, feature in zip(em.heights, em.features): means = em.get_all_means(height) @@ -452,11 +484,23 @@ def run(cls, era_pattern, years, features, out_pattern): class BiasCorrectionFromMeans: """Class for getting bias correction factors from bias and base data files - with precomputed monthly means.""" + with precomputed monthly means. + """ MIN_DISTANCE = 1e-12 def __init__(self, bias_fp, base_fp, dset, leaf_size=4): + """Parameters + ---------- + bias_fp : str + Path to csv file containing means for biased data + base_fp : str + Path to csv file containing means for unbiased data + dset : str + Name of dataset to compute bias correction factor for + leaf_size : int + Leaf size for ball tree used to match bias and base grids + """ self.dset = dset self.bias_fp = bias_fp self.base_fp = base_fp @@ -488,7 +532,8 @@ def base_tree(self): @property def meta(self): """Get a meta data dictionary on how these bias factors were - calculated""" + calculated + """ meta = { "base_fp": self.base_fp, "bias_fp": self.bias_fp, @@ -585,12 +630,44 @@ def run( base_fp, dset, fp_out, + leaf_size=4, global_scalar=1.0, knn=1, out_shape=None, ): - """Run bias correction factor computation and write.""" - bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) + """Run bias correction factor computation and write. + + Parameters + ---------- + bias_fp : str + Path to csv file containing means for biased data + base_fp : str + Path to csv file containing means for unbiased data + dset : str + Name of dataset to compute bias correction factor for + fp_out : str + Name of output file containing bias correction factors + leaf_size : int + Leaf size for ball tree used to match bias and base grids + global_scalar : float + Optional global scalar to use for multiplying all bias correction + factors. This can be used to improve systemic bias against + observation data. This is just written to output files, not + included in the stored bias correction factor values. + knn : int + Number of nearest neighbors to use when matching bias and base + grids. This should be based on difference in resolution. e.g. if + bias grid is 30km and base grid is 3km then knn should be 100 to + aggregate 3km to 30km. + out_shape : tuple | None + Optional 2D shape for output. If this is provided then the bias + correction arrays will be reshaped to this shape. If not provided + the arrays will stay flattened. When using this to write bc files + that will be used in a forward-pass routine this shape should be + the same as the spatial shape of the forward-pass input data. + """ + bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset, + leaf_size=leaf_size) out = bc.get_corrections(global_scalar=global_scalar, knn=knn) if out_shape is not None: for k, v in out.items(): @@ -611,7 +688,37 @@ def run_uv( knn=1, out_shape=None, ): - """Run bias correction factor computation and write.""" + """Run bias correction factor computation and write. + + Parameters + ---------- + bias_fp : str + Path to csv file containing means for biased data + base_fp : str + Path to csv file containing means for unbiased data + dset : str + Name of dataset to compute bias correction factor for + fp_pattern : str + Pattern for output file. Should contain {feature} format key. + leaf_size : int + Leaf size for ball tree used to match bias and base grids + global_scalar : float + Optional global scalar to use for multiplying all bias correction + factors. This can be used to improve systemic bias against + observation data. This is just written to output files, not + included in the stored bias correction factor values. + knn : int + Number of nearest neighbors to use when matching bias and base + grids. This should be based on difference in resolution. e.g. if + bias grid is 30km and base grid is 3km then knn should be 100 to + aggregate 3km to 30km. + out_shape : tuple | None + Optional 2D shape for output. If this is provided then the bias + correction arrays will be reshaped to this shape. If not provided + the arrays will stay flattened. When using this to write bc files + that will be used in a forward-pass routine this shape should be + the same as the spatial shape of the forward-pass input data. + """ bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) out_u, out_v = bc.get_uv_corrections( global_scalar=global_scalar, knn=knn @@ -636,7 +743,26 @@ class BiasCorrectUpdate: @classmethod def get_bc_factors(cls, bc_file, dset, month, global_scalar=1): - """Get bias correction factors for the given dset and month""" + """Get bias correction factors for the given dset and month + + Parameters + ---------- + bc_file : str + Name of h5 file containing bias correction factors + dset : str + Name of dataset to apply bias correction factors for + month : int + Index of month to bias correct + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + + Returns + ------- + factors : ndarray + Array of bias correction factors for the given dset and month. + """ with Resource(bc_file) as res: logger.info( f"Getting {dset} bias correction factors for month {month}." @@ -650,9 +776,71 @@ def get_bc_factors(cls, bc_file, dset, month, global_scalar=1): return factors @classmethod - def update_file(cls, in_file, out_file, dset, bc_file, global_scalar=1): + def _correct_month( + cls, fh_in, month, out_file, dset, bc_file, global_scalar + ): + """Bias correct data for a given month. + + Parameters + ---------- + fh_in : Resource() + Resource handler for input file being corrected + month : int + Index of month to be corrected + out_file : str + Name of h5 file containing bias corrected data + dset : str + Name of dataset to bias correct + bc_file : str + Name of file containing bias correction factors for the given dset + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + """ + with RexOutputs(out_file, "a") as fh: + mask = fh.time_index.month == month + mask = np.arange(len(fh.time_index))[mask] + mask = slice(mask[0], mask[-1] + 1) + bc_factors = cls.get_bc_factors( + bc_file=bc_file, + dset=dset, + month=month, + global_scalar=global_scalar, + ) + logger.info(f"Applying bias correction factors for month {month}") + fh[dset, mask, :] = bc_factors * fh_in[dset, mask, :] + + @classmethod + def update_file( + cls, + in_file, + out_file, + dset, + bc_file, + global_scalar=1, + max_workers=None, + ): """Update the in_file with bias corrected values for the given dset - and write to out_file.""" + and write to out_file. + + Parameters + ---------- + in_file : str + Name of h5 file containing data to bias correct + out_file : str + Name of h5 file containing bias corrected data + dset : str + Name of dataset to bias correct + bc_file : str + Name of file containing bias correction factors for the given dset + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + max_workers : int | None + Number of workers to use for parallel processing. + """ tmp_file = out_file.replace(".h5", ".h5.tmp") logger.info(f"Bias correcting {dset} in {in_file} with {bc_file}.") with Resource(in_file) as fh_in: @@ -660,30 +848,54 @@ def update_file(cls, in_file, out_file, dset, bc_file, global_scalar=1): tmp_file, fh_in.time_index, fh_in.meta, fh_in.global_attrs ) OutputHandler._ensure_dset_in_output(tmp_file, dset) - for i in range(1, 13): - try: - with RexOutputs(tmp_file, "a") as fh: - mask = fh.time_index.month == i - mask = np.arange(len(fh.time_index))[mask] - mask = slice(mask[0], mask[-1] + 1) - bc_factors = cls.get_bc_factors( - bc_file=bc_file, + + if max_workers == 1: + for i in range(1, 13): + try: + cls._correct_month( + fh_in, + month=i, + out_file=tmp_file, dset=dset, + bc_file=bc_file, + global_scalar=global_scalar, + ) + except Exception as e: + raise RuntimeError( + f"Bias correction failed for month {i}." + ) from e + + logger.info( + f"Added {dset} for month {i} to output file " + f"{tmp_file}." + ) + else: + futures = {} + with ThreadPoolExecutor(max_workers=max_workers) as exe: + for i in range(1, 13): + future = exe.submit( + cls._correct_month, + fh_in=fh_in, month=i, + out_file=tmp_file, + dset=dset, + bc_file=bc_file, global_scalar=global_scalar, ) + futures[future] = i + logger.info( - f"Applying bias correction factors for month {i}" + f"Submitted bias correction for month {i} " + f"to {tmp_file}." ) - fh[dset, mask, :] = bc_factors * fh_in[dset, mask, :] - except Exception as e: - raise RuntimeError( - f"Bias correction failed for month {i}." - ) from e - logger.info( - f"Added {dset} for month {i} to output file {tmp_file}." - ) + for future in as_completed(futures): + _ = future.result() + i = futures[future] + logger.info( + f"Completed bias correction for month {i} " + f"to {tmp_file}." + ) os.replace(tmp_file, out_file) msg = f"Saved bias corrected {dset} to: {out_file}" @@ -698,11 +910,32 @@ def run( bc_file, overwrite=False, global_scalar=1, + max_workers=None ): - """Run bias correction update.""" + """Run bias correction update. + + Parameters + ---------- + in_file : str + Name of h5 file containing data to bias correct + out_file : str + Name of h5 file containing bias corrected data + dset : str + Name of dataset to bias correct + bc_file : str + Name of file containing bias correction factors for the given dset + overwrite : bool + Whether to overwrite the output file if it already exists. + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + max_workers : int | None + Number of workers to use for parallel processing. + """ if os.path.exists(out_file) and not overwrite: logger.info( - f"{out_file} already exists and overwrite=False. " "Skipping." + f"{out_file} already exists and overwrite=False. Skipping." ) else: if os.path.exists(out_file) and overwrite: @@ -712,5 +945,6 @@ def run( ) os.remove(out_file) cls.update_file( - in_file, out_file, dset, bc_file, global_scalar=global_scalar + in_file, out_file, dset, bc_file, global_scalar=global_scalar, + max_workers=max_workers ) diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 8e3c60d07..81727400d 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -44,8 +44,7 @@ def __init__(self, cache_dir='./exo_cache/', ti_workers=1, res_kwargs=None): - """ - Parameters + """Parameters ---------- file_paths : str | list A single source h5 file to extract raster data from or a list @@ -123,7 +122,6 @@ def __init__(self, Dictionary of kwargs passed to lowest level resource handler. e.g. xr.open_dataset(file_paths, **res_kwargs) """ - logger.info(f'Initializing {self.__class__.__name__} utility.') self.ti_workers = ti_workers @@ -225,7 +223,8 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, @property def source_temporal_slice(self): """Get the temporal slice for the exo_source data corresponding to the - input file temporal slice""" + input file temporal slice + """ start_index = self.source_time_index.get_indexer( [self.input_handler.hr_time_index[0]], method='nearest')[0] end_index = self.source_time_index.get_indexer( @@ -353,6 +352,7 @@ def data(self): if data.shape[-1] == 1 and self.hr_shape[-1] > 1: data = np.repeat(data, self.hr_shape[-1], axis=-1) + return data[..., np.newaxis] def get_data(self): @@ -563,15 +563,13 @@ class TopoExtractNC(TopoExtractH5): """TopoExtract for netCDF files""" def __init__(self, *args, **kwargs): - """ - Parameters + """Parameters ---------- args : list Same positional arguments as TopoExtract kwargs : dict Same keyword arguments as TopoExtract """ - super().__init__(*args, **kwargs) logger.info('Getting topography for full domain from ' f'{self._exo_source}') @@ -611,4 +609,4 @@ def get_data(self): """ hr_data = self.source_data.reshape(self.hr_shape) logger.info('Finished computing SZA data') - return hr_data[..., np.newaxis] + return hr_data diff --git a/tests/data_handling/test_dual_data_handling.py b/tests/data_handling/test_dual_data_handling.py index 248954673..1a105c3c1 100644 --- a/tests/data_handling/test_dual_data_handling.py +++ b/tests/data_handling/test_dual_data_handling.py @@ -102,7 +102,7 @@ def test_regrid_caching(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', + cache_pattern=f'{td}/cache.pkl', ) # Load handlers again @@ -126,7 +126,7 @@ def test_regrid_caching(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', + cache_pattern=f'{td}/cache.pkl', ) assert np.array_equal(old_dh.lr_data, new_dh.lr_data) assert np.array_equal(old_dh.hr_data, new_dh.hr_data) @@ -161,7 +161,7 @@ def test_regrid_caching_in_steps(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', + cache_pattern=f'{td}/cache.pkl', ) # Load handlers again with one cached feature and one noncached feature @@ -185,7 +185,7 @@ def test_regrid_caching_in_steps(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl') + cache_pattern=f'{td}/cache.pkl') assert np.array_equal(dh_step2.lr_data[..., 0:1], dh_step1.lr_data) assert np.array_equal(dh_step2.noncached_features, FEATURES[1:]) From a33e2d9769f55c2f5b767c5ccc38dacf040e735e Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 3 Nov 2023 08:20:29 -0600 Subject: [PATCH 13/44] bc code doc strings --- sup3r/bias/bias_correct_means.py | 318 +++++++++++++++--- sup3r/postprocessing/collection.py | 8 +- .../data_handling/exo_extraction.py | 14 +- sup3r/preprocessing/feature_handling.py | 85 ++--- sup3r/utilities/interpolate_log_profile.py | 10 +- .../data_handling/test_dual_data_handling.py | 8 +- 6 files changed, 330 insertions(+), 113 deletions(-) diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py index 3a4f19214..eb4f1f996 100644 --- a/sup3r/bias/bias_correct_means.py +++ b/sup3r/bias/bias_correct_means.py @@ -1,11 +1,13 @@ """Classes to compute means from vortex and era data and compute bias -correction factors.""" +correction factors. +""" import calendar import json import logging import os +from concurrent.futures import ThreadPoolExecutor, as_completed import h5py import numpy as np @@ -15,7 +17,6 @@ from rex import Resource from scipy.interpolate import interp1d from sklearn.neighbors import BallTree - from sup3r.postprocessing.file_handling import OutputHandler, RexOutputs from sup3r.preprocessing.feature_handling import Feature from sup3r.utilities import VERSION_RECORD @@ -26,11 +27,11 @@ class VortexMeanPrepper: """Class for converting monthly vortex tif files for each height to a single h5 files containing all monthly means for all requested output - heights""" + heights. + """ def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): - """ - Parameters + """Parameters ---------- path_pattern : str Pattern for input tif files. Needs to include {month} and {height} @@ -52,7 +53,7 @@ def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): @property def in_features(self): - """List of features corresponding to input heights""" + """List of features corresponding to input heights.""" return [f"windspeed_{h}m" for h in self.in_heights] @property @@ -92,7 +93,8 @@ def get_output_file(self, month): @property def output_files(self): """List of output monthly output files each with windspeed for all - input heights""" + input heights + """ files = [] for i in range(1, 13): month = calendar.month_name[i] @@ -101,7 +103,8 @@ def output_files(self): def convert_month_height_tif(self, month, height): """Get windspeed mean for the given month and hub height from the - corresponding input file and write this to a netcdf file.""" + corresponding input file and write this to a netcdf file. + """ infile = self.get_input_file(month, height) logger.info(f"Getting mean windspeed_{height}m for {month}.") outfile = infile.replace(".tif", ".nc") @@ -118,12 +121,12 @@ def convert_month_height_tif(self, month, height): return outfile def convert_month_tif(self, month): - """Write netcdf files for all heights for the given month""" + """Write netcdf files for all heights for the given month.""" for height in self.in_heights: self.convert_month_height_tif(month, height) def convert_all_tifs(self): - """Write netcdf files for all heights for all months""" + """Write netcdf files for all heights for all months.""" for i in range(1, 13): month = calendar.month_name[i] logger.info(f"Converting tif files to netcdf files for {month}") @@ -299,7 +302,23 @@ def run( ): """Read vortex tif files, convert these to monthly netcdf files for all input heights, interpolate this data to requested output heights, mask - fill values, and write all data to h5 file.""" + fill values, and write all data to h5 file. + + Parameters + ---------- + path_pattern : str + Pattern for input tif files. Needs to include {month} and {height} + format keys. + in_heights : list + List of heights for input files. + out_heights : list + List of output heights used for interpolation + fp_out : str + Name of final h5 output file to write with means. + overwrite : bool + Whether to overwrite intermediate netcdf files containing the + interpolated masked monthly means. + """ vprep = cls(path_pattern, in_heights, out_heights, overwrite=overwrite) vprep.convert_all_tifs() out = vprep.get_all_data() @@ -310,8 +329,7 @@ class EraMeanPrepper: """Class to compute monthly windspeed means from ERA data.""" def __init__(self, era_pattern, years, features): - """ - Parameters + """Parameters ---------- era_pattern : str Pattern pointing to era files with u/v wind components at the given @@ -438,7 +456,21 @@ def write_csv(self, out, out_file): def run(cls, era_pattern, years, features, out_pattern): """Compute monthly windspeed means for the given heights, using the given years of ERA data, and write the means to csv files for each - height.""" + height. + + Parameters + ---------- + era_pattern : str + Pattern pointing to era files with u/v wind components at the given + heights. Must have a {year} format key. + years : list + List of ERA years to use for calculating means. + features : list + List of features to compute means for. e.g. ['windspeed_10m'] + out_pattern : str + Pattern pointing to csv files to write means to. Must have a + {feature} format key. + """ em = cls(era_pattern=era_pattern, years=years, features=features) for height, feature in zip(em.heights, em.features): means = em.get_all_means(height) @@ -452,11 +484,23 @@ def run(cls, era_pattern, years, features, out_pattern): class BiasCorrectionFromMeans: """Class for getting bias correction factors from bias and base data files - with precomputed monthly means.""" + with precomputed monthly means. + """ MIN_DISTANCE = 1e-12 def __init__(self, bias_fp, base_fp, dset, leaf_size=4): + """Parameters + ---------- + bias_fp : str + Path to csv file containing means for biased data + base_fp : str + Path to csv file containing means for unbiased data + dset : str + Name of dataset to compute bias correction factor for + leaf_size : int + Leaf size for ball tree used to match bias and base grids + """ self.dset = dset self.bias_fp = bias_fp self.base_fp = base_fp @@ -488,7 +532,8 @@ def base_tree(self): @property def meta(self): """Get a meta data dictionary on how these bias factors were - calculated""" + calculated + """ meta = { "base_fp": self.base_fp, "bias_fp": self.bias_fp, @@ -585,12 +630,44 @@ def run( base_fp, dset, fp_out, + leaf_size=4, global_scalar=1.0, knn=1, out_shape=None, ): - """Run bias correction factor computation and write.""" - bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) + """Run bias correction factor computation and write. + + Parameters + ---------- + bias_fp : str + Path to csv file containing means for biased data + base_fp : str + Path to csv file containing means for unbiased data + dset : str + Name of dataset to compute bias correction factor for + fp_out : str + Name of output file containing bias correction factors + leaf_size : int + Leaf size for ball tree used to match bias and base grids + global_scalar : float + Optional global scalar to use for multiplying all bias correction + factors. This can be used to improve systemic bias against + observation data. This is just written to output files, not + included in the stored bias correction factor values. + knn : int + Number of nearest neighbors to use when matching bias and base + grids. This should be based on difference in resolution. e.g. if + bias grid is 30km and base grid is 3km then knn should be 100 to + aggregate 3km to 30km. + out_shape : tuple | None + Optional 2D shape for output. If this is provided then the bias + correction arrays will be reshaped to this shape. If not provided + the arrays will stay flattened. When using this to write bc files + that will be used in a forward-pass routine this shape should be + the same as the spatial shape of the forward-pass input data. + """ + bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset, + leaf_size=leaf_size) out = bc.get_corrections(global_scalar=global_scalar, knn=knn) if out_shape is not None: for k, v in out.items(): @@ -611,7 +688,37 @@ def run_uv( knn=1, out_shape=None, ): - """Run bias correction factor computation and write.""" + """Run bias correction factor computation and write. + + Parameters + ---------- + bias_fp : str + Path to csv file containing means for biased data + base_fp : str + Path to csv file containing means for unbiased data + dset : str + Name of dataset to compute bias correction factor for + fp_pattern : str + Pattern for output file. Should contain {feature} format key. + leaf_size : int + Leaf size for ball tree used to match bias and base grids + global_scalar : float + Optional global scalar to use for multiplying all bias correction + factors. This can be used to improve systemic bias against + observation data. This is just written to output files, not + included in the stored bias correction factor values. + knn : int + Number of nearest neighbors to use when matching bias and base + grids. This should be based on difference in resolution. e.g. if + bias grid is 30km and base grid is 3km then knn should be 100 to + aggregate 3km to 30km. + out_shape : tuple | None + Optional 2D shape for output. If this is provided then the bias + correction arrays will be reshaped to this shape. If not provided + the arrays will stay flattened. When using this to write bc files + that will be used in a forward-pass routine this shape should be + the same as the spatial shape of the forward-pass input data. + """ bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) out_u, out_v = bc.get_uv_corrections( global_scalar=global_scalar, knn=knn @@ -636,7 +743,26 @@ class BiasCorrectUpdate: @classmethod def get_bc_factors(cls, bc_file, dset, month, global_scalar=1): - """Get bias correction factors for the given dset and month""" + """Get bias correction factors for the given dset and month + + Parameters + ---------- + bc_file : str + Name of h5 file containing bias correction factors + dset : str + Name of dataset to apply bias correction factors for + month : int + Index of month to bias correct + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + + Returns + ------- + factors : ndarray + Array of bias correction factors for the given dset and month. + """ with Resource(bc_file) as res: logger.info( f"Getting {dset} bias correction factors for month {month}." @@ -650,9 +776,71 @@ def get_bc_factors(cls, bc_file, dset, month, global_scalar=1): return factors @classmethod - def update_file(cls, in_file, out_file, dset, bc_file, global_scalar=1): + def _correct_month( + cls, fh_in, month, out_file, dset, bc_file, global_scalar + ): + """Bias correct data for a given month. + + Parameters + ---------- + fh_in : Resource() + Resource handler for input file being corrected + month : int + Index of month to be corrected + out_file : str + Name of h5 file containing bias corrected data + dset : str + Name of dataset to bias correct + bc_file : str + Name of file containing bias correction factors for the given dset + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + """ + with RexOutputs(out_file, "a") as fh: + mask = fh.time_index.month == month + mask = np.arange(len(fh.time_index))[mask] + mask = slice(mask[0], mask[-1] + 1) + bc_factors = cls.get_bc_factors( + bc_file=bc_file, + dset=dset, + month=month, + global_scalar=global_scalar, + ) + logger.info(f"Applying bias correction factors for month {month}") + fh[dset, mask, :] = bc_factors * fh_in[dset, mask, :] + + @classmethod + def update_file( + cls, + in_file, + out_file, + dset, + bc_file, + global_scalar=1, + max_workers=None, + ): """Update the in_file with bias corrected values for the given dset - and write to out_file.""" + and write to out_file. + + Parameters + ---------- + in_file : str + Name of h5 file containing data to bias correct + out_file : str + Name of h5 file containing bias corrected data + dset : str + Name of dataset to bias correct + bc_file : str + Name of file containing bias correction factors for the given dset + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + max_workers : int | None + Number of workers to use for parallel processing. + """ tmp_file = out_file.replace(".h5", ".h5.tmp") logger.info(f"Bias correcting {dset} in {in_file} with {bc_file}.") with Resource(in_file) as fh_in: @@ -660,30 +848,54 @@ def update_file(cls, in_file, out_file, dset, bc_file, global_scalar=1): tmp_file, fh_in.time_index, fh_in.meta, fh_in.global_attrs ) OutputHandler._ensure_dset_in_output(tmp_file, dset) - for i in range(1, 13): - try: - with RexOutputs(tmp_file, "a") as fh: - mask = fh.time_index.month == i - mask = np.arange(len(fh.time_index))[mask] - mask = slice(mask[0], mask[-1] + 1) - bc_factors = cls.get_bc_factors( - bc_file=bc_file, + + if max_workers == 1: + for i in range(1, 13): + try: + cls._correct_month( + fh_in, + month=i, + out_file=tmp_file, dset=dset, + bc_file=bc_file, + global_scalar=global_scalar, + ) + except Exception as e: + raise RuntimeError( + f"Bias correction failed for month {i}." + ) from e + + logger.info( + f"Added {dset} for month {i} to output file " + f"{tmp_file}." + ) + else: + futures = {} + with ThreadPoolExecutor(max_workers=max_workers) as exe: + for i in range(1, 13): + future = exe.submit( + cls._correct_month, + fh_in=fh_in, month=i, + out_file=tmp_file, + dset=dset, + bc_file=bc_file, global_scalar=global_scalar, ) + futures[future] = i + logger.info( - f"Applying bias correction factors for month {i}" + f"Submitted bias correction for month {i} " + f"to {tmp_file}." ) - fh[dset, mask, :] = bc_factors * fh_in[dset, mask, :] - except Exception as e: - raise RuntimeError( - f"Bias correction failed for month {i}." - ) from e - logger.info( - f"Added {dset} for month {i} to output file {tmp_file}." - ) + for future in as_completed(futures): + _ = future.result() + i = futures[future] + logger.info( + f"Completed bias correction for month {i} " + f"to {tmp_file}." + ) os.replace(tmp_file, out_file) msg = f"Saved bias corrected {dset} to: {out_file}" @@ -698,11 +910,32 @@ def run( bc_file, overwrite=False, global_scalar=1, + max_workers=None ): - """Run bias correction update.""" + """Run bias correction update. + + Parameters + ---------- + in_file : str + Name of h5 file containing data to bias correct + out_file : str + Name of h5 file containing bias corrected data + dset : str + Name of dataset to bias correct + bc_file : str + Name of file containing bias correction factors for the given dset + overwrite : bool + Whether to overwrite the output file if it already exists. + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + max_workers : int | None + Number of workers to use for parallel processing. + """ if os.path.exists(out_file) and not overwrite: logger.info( - f"{out_file} already exists and overwrite=False. " "Skipping." + f"{out_file} already exists and overwrite=False. Skipping." ) else: if os.path.exists(out_file) and overwrite: @@ -712,5 +945,6 @@ def run( ) os.remove(out_file) cls.update_file( - in_file, out_file, dset, bc_file, global_scalar=global_scalar + in_file, out_file, dset, bc_file, global_scalar=global_scalar, + max_workers=max_workers ) diff --git a/sup3r/postprocessing/collection.py b/sup3r/postprocessing/collection.py index bcbaef968..2384c914d 100644 --- a/sup3r/postprocessing/collection.py +++ b/sup3r/postprocessing/collection.py @@ -26,8 +26,7 @@ class Collector(OutputMixIn): """Sup3r H5 file collection framework""" def __init__(self, file_paths): - """ - Parameters + """Parameters ---------- file_paths : list | str Explicit list of str file paths that will be sorted and collected @@ -51,7 +50,6 @@ def get_node_cmd(cls, config): sup3r collection config with all necessary args and kwargs to run data collection. """ - import_str = ( 'from sup3r.postprocessing.collection ' 'import Collector;\n' @@ -107,7 +105,6 @@ def get_slices( col_slice : slice final_meta[col_slice] = new_meta """ - final_index = final_meta.index new_index = new_meta.index row_loc = np.where(final_time_index.isin(new_time_index))[0] @@ -205,7 +202,6 @@ def get_data( col_slice : slice final_meta[col_slice] = new_meta """ - with RexOutputs(file_path, unscale=False, mode='r') as f: f_ti = f.time_index f_meta = f.meta @@ -698,7 +694,7 @@ def group_time_chunks(self, file_paths, n_writes=None): f'n_writes ({n_writes}) must be less than or equal ' f'to the number of temporal chunks ({len(file_chunks)}).' ) - assert n_writes < len(file_chunks), msg + assert n_writes <= len(file_chunks), msg return file_chunks def get_flist_chunks(self, file_paths, n_writes=None, join_times=False): diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 8e3c60d07..81727400d 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -44,8 +44,7 @@ def __init__(self, cache_dir='./exo_cache/', ti_workers=1, res_kwargs=None): - """ - Parameters + """Parameters ---------- file_paths : str | list A single source h5 file to extract raster data from or a list @@ -123,7 +122,6 @@ def __init__(self, Dictionary of kwargs passed to lowest level resource handler. e.g. xr.open_dataset(file_paths, **res_kwargs) """ - logger.info(f'Initializing {self.__class__.__name__} utility.') self.ti_workers = ti_workers @@ -225,7 +223,8 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, @property def source_temporal_slice(self): """Get the temporal slice for the exo_source data corresponding to the - input file temporal slice""" + input file temporal slice + """ start_index = self.source_time_index.get_indexer( [self.input_handler.hr_time_index[0]], method='nearest')[0] end_index = self.source_time_index.get_indexer( @@ -353,6 +352,7 @@ def data(self): if data.shape[-1] == 1 and self.hr_shape[-1] > 1: data = np.repeat(data, self.hr_shape[-1], axis=-1) + return data[..., np.newaxis] def get_data(self): @@ -563,15 +563,13 @@ class TopoExtractNC(TopoExtractH5): """TopoExtract for netCDF files""" def __init__(self, *args, **kwargs): - """ - Parameters + """Parameters ---------- args : list Same positional arguments as TopoExtract kwargs : dict Same keyword arguments as TopoExtract """ - super().__init__(*args, **kwargs) logger.info('Getting topography for full domain from ' f'{self._exo_source}') @@ -611,4 +609,4 @@ def get_data(self): """ hr_data = self.source_data.reshape(self.hr_shape) logger.info('Finished computing SZA data') - return hr_data[..., np.newaxis] + return hr_data diff --git a/sup3r/preprocessing/feature_handling.py b/sup3r/preprocessing/feature_handling.py index c5f712e34..3738ce03c 100644 --- a/sup3r/preprocessing/feature_handling.py +++ b/sup3r/preprocessing/feature_handling.py @@ -1,5 +1,4 @@ -""" -Sup3r feature handling module. +"""Sup3r feature handling module. @author: bbenton """ @@ -35,7 +34,8 @@ class DerivedFeature(ABC): """Abstract class for special features which need to be derived from raw - features""" + features + """ @classmethod @abstractmethod @@ -86,7 +86,6 @@ def compute(cls, data, height=None): Clearsky ratio, e.g. the all-sky ghi / the clearsky ghi. NaN where nighttime. """ - # need to use a nightime threshold of 1 W/m2 because cs_ghi is stored # in integer format and weird binning patterns happen in the clearsky # ratio and cloud mask between 0 and 1 W/m2 and sunrise/sunset @@ -104,7 +103,8 @@ def compute(cls, data, height=None): class ClearSkyRatioCC(DerivedFeature): """Clear Sky Ratio feature class for computing from climate change netcdf - data""" + data + """ @classmethod def inputs(cls, feature): @@ -142,7 +142,6 @@ def compute(cls, data, height=None): Clearsky ratio, e.g. the all-sky ghi / the clearsky ghi. This is assumed to be daily average data for climate change source data. """ - cs_ratio = data['rsds'] / data['clearsky_ghi'] cs_ratio = np.minimum(cs_ratio, 1) cs_ratio = np.maximum(cs_ratio, 0) @@ -189,7 +188,6 @@ def compute(cls, data, height=None): nighttime. Data is float32 so it can be normalized without any integer weirdness. """ - # need to use a nightime threshold of 1 W/m2 because cs_ghi is stored # in integer format and weird binning patterns happen in the clearsky # ratio and cloud mask between 0 and 1 W/m2 and sunrise/sunset @@ -208,7 +206,8 @@ def compute(cls, data, height=None): class PotentialTempNC(DerivedFeature): """Potential Temperature feature class for NETCDF data. Needed since T is - perturbation potential temperature.""" + perturbation potential temperature. + """ @classmethod def inputs(cls, feature): @@ -239,7 +238,8 @@ def compute(cls, data, height): class TempNC(DerivedFeature): """Temperature feature class for NETCDF data. Needed since T is potential - temperature not standard temp.""" + temperature not standard temp. + """ @classmethod def inputs(cls, feature): @@ -271,7 +271,8 @@ def compute(cls, data, height): class PressureNC(DerivedFeature): """Pressure feature class for NETCDF data. Needed since P is perturbation - pressure.""" + pressure. + """ @classmethod def inputs(cls, feature): @@ -302,7 +303,8 @@ def compute(cls, data, height): class BVFreqSquaredNC(DerivedFeature): """BVF Squared feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -355,7 +357,6 @@ def inputs(cls, feature): list List of required features for computing RMOL """ - assert feature == 'RMOL' features = ['UST', 'HFX'] return features @@ -431,7 +432,8 @@ def compute(cls, data, height): class BVFreqSquaredH5(DerivedFeature): """BVF Squared feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -495,7 +497,6 @@ def inputs(cls, feature): list List of required features for computing windspeed """ - height = Feature.get_height(feature) features = [f'U_{height}m', f'V_{height}m', 'lat_lon'] return features @@ -516,7 +517,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - ws, _ = invert_uv(data[f'U_{height}m'], data[f'V_{height}m'], data['lat_lon']) return ws @@ -539,7 +539,6 @@ def inputs(cls, feature): list List of required features for computing windspeed """ - height = Feature.get_height(feature) features = [f'U_{height}m', f'V_{height}m', 'lat_lon'] return features @@ -560,7 +559,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - _, wd = invert_uv(data[f'U_{height}m'], data[f'V_{height}m'], data['lat_lon']) return wd @@ -585,7 +583,6 @@ def inputs(cls, feature): list List of required features for computing REWS """ - rotor_center = Feature.get_height(feature) if rotor_center is None: heights = cls.HEIGHTS @@ -612,7 +609,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - if height is None: heights = cls.HEIGHTS else: @@ -643,7 +639,6 @@ def inputs(cls, feature): list List of required features for computing Veer """ - height = Feature.get_height(feature) heights = [int(height), int(height) + 20] features = [] @@ -667,7 +662,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - heights = [int(height), int(height) + 20] shear = np.cos(np.radians(data[f'winddirection_{int(height) + 20}m'])) shear -= np.cos(np.radians(data[f'winddirection_{int(height)}m'])) @@ -694,7 +688,6 @@ def inputs(cls, feature): list List of required features for computing REWS """ - rotor_center = Feature.get_height(feature) if rotor_center is None: heights = cls.HEIGHTS @@ -722,7 +715,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - if height is None: heights = cls.HEIGHTS else: @@ -829,7 +821,8 @@ def compute(cls, data, height): class UWind(DerivedFeature): """U wind component feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -876,7 +869,8 @@ def compute(cls, data, height): class Vorticity(DerivedFeature): """Vorticity feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -919,7 +913,8 @@ def compute(cls, data, height): class VWind(DerivedFeature): """V wind component feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -1047,14 +1042,16 @@ def compute(cls, data, height): class TasMin(Tas): """Daily min air temperature near surface variable from climate change nc - files""" + files + """ CC_FEATURE_NAME = 'tasmin' class TasMax(Tas): """Daily max air temperature near surface variable from climate change nc - files""" + files + """ CC_FEATURE_NAME = 'tasmax' @@ -1079,7 +1076,6 @@ def compute(file_paths, raster_index): lat lon array (spatial_1, spatial_2, 2) """ - fp = file_paths if isinstance(file_paths, str) else file_paths[0] handle = xr.open_dataset(fp) valid_vars = set(handle.variables) @@ -1166,7 +1162,8 @@ def compute(file_paths, raster_index): class Feature: """Class to simplify feature computations. Stores feature height, feature - basename, name of feature in handle""" + basename, name of feature in handle + """ def __init__(self, feature, handle): """Takes a feature (e.g. U_100m) and gets the height (100), basename @@ -1204,7 +1201,6 @@ def get_basename(feature): str feature basename """ - height = Feature.get_height(feature) pressure = Feature.get_pressure(feature) if height is not None or pressure is not None: @@ -1260,7 +1256,8 @@ def get_pressure(feature): class FeatureHandler: """Feature Handler with cache for previously loaded features used in other - calculations """ + calculations + """ FEATURE_REGISTRY: ClassVar[dict] = {} @@ -1280,7 +1277,6 @@ def valid_handle_features(cls, features, handle_features): bool Whether feature basename is in handle """ - if features is None: return False @@ -1304,7 +1300,6 @@ def valid_input_features(cls, features, handle_features): bool Whether feature basename is in handle """ - if features is None: return False @@ -1440,16 +1435,13 @@ def serial_extract(cls, file_paths, raster_index, time_chunks, keys for features. e.g. data[chunk_number][feature] = array. (spatial_1, spatial_2, temporal) """ - data = defaultdict(dict) for t, t_slice in enumerate(time_chunks): for f in input_features: data[t][f] = cls.extract_feature(file_paths, raster_index, f, t_slice, **kwargs) - interval = int(np.ceil(len(time_chunks) / 10)) - if t % interval == 0: - logger.debug(f'{t+1} out of {len(time_chunks)} feature ' - 'chunks extracted.') + logger.debug(f'{t+1} out of {len(time_chunks)} feature ' + 'chunks extracted.') return data @classmethod @@ -1511,7 +1503,6 @@ def parallel_extract(cls, f' time chunks of shape ({shape[0]}, {shape[1]}, ' f'{time_shape}) for {len(input_features)} features') - interval = int(np.ceil(len(futures) / 10)) for i, future in enumerate(as_completed(futures)): v = futures[future] try: @@ -1521,12 +1512,11 @@ def parallel_extract(cls, f' {v["feature"]}') logger.error(msg) raise RuntimeError(msg) from e - if i % interval == 0: - mem = psutil.virtual_memory() - logger.info(f'{i+1} out of {len(futures)} feature ' - 'chunks extracted. Current memory usage is ' - f'{mem.used / 1e9:.3f} GB out of ' - f'{mem.total / 1e9:.3f} GB total.') + mem = psutil.virtual_memory() + logger.info(f'{i+1} out of {len(futures)} feature ' + 'chunks extracted. Current memory usage is ' + f'{mem.used / 1e9:.3f} GB out of ' + f'{mem.total / 1e9:.3f} GB total.') return data @@ -1613,7 +1603,6 @@ def serial_compute(cls, data, file_paths, raster_index, time_chunks, e.g. data[chunk_number][feature] = array. (spatial_1, spatial_2, temporal) """ - if len(derived_features) == 0: return data @@ -1768,7 +1757,6 @@ def _exact_lookup(cls, feature): out : str Matching feature registry entry. """ - out = None for k, v in cls.FEATURE_REGISTRY.items(): if k.lower() == feature.lower(): @@ -1791,7 +1779,6 @@ def _pattern_lookup(cls, feature): out : str Matching feature registry entry. """ - out = None for k, v in cls.FEATURE_REGISTRY.items(): if re.match(k.lower(), feature.lower()): diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index 993b1dcc1..be6921f5e 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -29,7 +29,8 @@ class LogLinInterpolator: """Open ERA5 file, log interpolate wind components between 0 - max_log_height, linearly interpolate components above max_log_height - meters, and save to file""" + meters, and save to file + """ DEFAULT_OUTPUT_HEIGHTS: ClassVar[dict] = { 'u': [40, 80, 120, 160, 200], @@ -154,7 +155,8 @@ def load(self): def interpolate_vars(self, max_workers=None): """Interpolate u/v wind components below 100m using log profile. - Interpolate non wind data linearly.""" + Interpolate non wind data linearly. + """ for var, arrs in self.data_dict.items(): max_log_height = self.max_log_height if var not in ('u', 'v'): @@ -233,7 +235,8 @@ def init_dims(cls, old_ds, new_ds, dims): @classmethod def get_tmp_file(cls, file): """Get temp file for given file. Then only needed variables will be - written to the given file.""" + written to the given file. + """ tmp_file = file.replace('.nc', '_tmp.nc') return tmp_file @@ -582,7 +585,6 @@ def interp_single_ts(cls, out_array : ndarray Array of interpolated values. """ - # Interp each vertical column of height and var to requested levels zip_iter = zip(hgt_t, var_t, mask) out_array = [] diff --git a/tests/data_handling/test_dual_data_handling.py b/tests/data_handling/test_dual_data_handling.py index 248954673..1a105c3c1 100644 --- a/tests/data_handling/test_dual_data_handling.py +++ b/tests/data_handling/test_dual_data_handling.py @@ -102,7 +102,7 @@ def test_regrid_caching(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', + cache_pattern=f'{td}/cache.pkl', ) # Load handlers again @@ -126,7 +126,7 @@ def test_regrid_caching(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', + cache_pattern=f'{td}/cache.pkl', ) assert np.array_equal(old_dh.lr_data, new_dh.lr_data) assert np.array_equal(old_dh.hr_data, new_dh.hr_data) @@ -161,7 +161,7 @@ def test_regrid_caching_in_steps(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', + cache_pattern=f'{td}/cache.pkl', ) # Load handlers again with one cached feature and one noncached feature @@ -185,7 +185,7 @@ def test_regrid_caching_in_steps(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl') + cache_pattern=f'{td}/cache.pkl') assert np.array_equal(dh_step2.lr_data[..., 0:1], dh_step1.lr_data) assert np.array_equal(dh_step2.noncached_features, FEATURES[1:]) From f7e107acd3b60fa36de00b0273e9c1157f8cfe13 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 4 Nov 2023 14:50:42 -0600 Subject: [PATCH 14/44] misc logging tweaks --- sup3r/bias/bias_correct_means.py | 31 ++++++++------- sup3r/preprocessing/batch_handling.py | 12 +++--- .../conditional_moment_batch_handling.py | 2 +- sup3r/preprocessing/data_handling/base.py | 8 ++-- sup3r/preprocessing/data_handling/mixin.py | 7 +++- sup3r/preprocessing/feature_handling.py | 18 ++++----- sup3r/utilities/era_downloader.py | 38 ++++++++++--------- tests/data_handling/test_data_handling_nc.py | 16 +++++--- 8 files changed, 70 insertions(+), 62 deletions(-) diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py index eb4f1f996..4a878d725 100644 --- a/sup3r/bias/bias_correct_means.py +++ b/sup3r/bias/bias_correct_means.py @@ -17,6 +17,7 @@ from rex import Resource from scipy.interpolate import interp1d from sklearn.neighbors import BallTree + from sup3r.postprocessing.file_handling import OutputHandler, RexOutputs from sup3r.preprocessing.feature_handling import Feature from sup3r.utilities import VERSION_RECORD @@ -589,7 +590,7 @@ def get_corrections(self, global_scalar=1, knn=1): base_data = self.get_base_data(knn=knn) bias_data = self.get_bias_data() scaler = global_scalar * base_data / bias_data - adder = 0 + adder = np.zeros(scaler.shape) out = { "latitude": self.bias_meta["latitude"], @@ -670,13 +671,19 @@ def run( leaf_size=leaf_size) out = bc.get_corrections(global_scalar=global_scalar, knn=knn) if out_shape is not None: - for k, v in out.items(): - if k in ("latitude", "longitude"): - out[k] = np.array(v).reshape(out_shape) - elif not isinstance(v, float): - out[k] = np.array(v).reshape((*out_shape, 12)) + out = cls._reshape_output(out, out_shape) bc.write_output(fp_out, out) + @classmethod + def _reshape_output(cls, out, out_shape): + """Reshape output according to given output shape""" + for k, v in out.items(): + if k in ("latitude", "longitude"): + out[k] = np.array(v).reshape(out_shape) + elif not isinstance(v, (int, float)): + out[k] = np.array(v).reshape((*out_shape, 12)) + return out + @classmethod def run_uv( cls, @@ -724,16 +731,8 @@ def run_uv( global_scalar=global_scalar, knn=knn ) if out_shape is not None: - for k, v in out_u.items(): - if k in ("latitude", "longitude"): - out_u[k] = np.array(v).reshape(out_shape) - elif not isinstance(v, float): - out_u[k] = np.array(v).reshape((*out_shape, 12)) - for k, v in out_v.items(): - if k in ("latitude", "longitude"): - out_v[k] = np.array(v).reshape(out_shape) - elif not isinstance(v, float): - out_v[k] = np.array(v).reshape((*out_shape, 12)) + out_u = cls._reshape_output(out_u, out_shape) + out_v = cls._reshape_output(out_v, out_shape) bc.write_output(fp_pattern.format(feature=bc.u_name), out_u) bc.write_output(fp_pattern.format(feature=bc.v_name), out_v) diff --git a/sup3r/preprocessing/batch_handling.py b/sup3r/preprocessing/batch_handling.py index 6390f71e7..8f5563119 100644 --- a/sup3r/preprocessing/batch_handling.py +++ b/sup3r/preprocessing/batch_handling.py @@ -499,14 +499,14 @@ def __init__(self, ] logger.info(f'Initializing BatchHandler with ' - f'{len(self.data_handlers)} data handlers with handler' + f'{len(self.data_handlers)} data handlers with handler ' f'weights={self.handler_weights}, smoothing={smoothing}. ' f'Using stats_workers={self.stats_workers}, ' f'norm_workers={self.norm_workers}, ' f'load_workers={self.load_workers}.') now = dt.now() - self.parallel_load() + self.load_handler_data() logger.debug(f'Finished loading data of shape {self.shape} ' f'for BatchHandler in {dt.now() - now}.') log_mem(logger, log_level='INFO') @@ -655,8 +655,8 @@ def parallel_normalization(self): logger.debug(f'{i+1} out of {len(futures)} data handlers' ' normalized.') - def parallel_load(self): - """Load data handler data in parallel""" + def load_handler_data(self): + """Load data handler data in parallel or serial""" logger.info(f'Loading {len(self.data_handlers)} data handlers') max_workers = self.load_workers if max_workers == 1: @@ -686,7 +686,7 @@ def parallel_load(self): logger.debug(f'{i+1} out of {len(futures)} handlers ' 'loaded.') - def parallel_stats(self): + def _get_stats(self): """Get standard deviations and means for training features in parallel.""" logger.info(f'Calculating stats for {len(self.training_features)} ' @@ -788,7 +788,7 @@ def get_stats(self): now = dt.now() logger.info('Calculating stdevs/means.') - self.parallel_stats() + self._get_stats() logger.info(f'Finished calculating stats in {dt.now() - now}.') self.cache_stats() diff --git a/sup3r/preprocessing/conditional_moment_batch_handling.py b/sup3r/preprocessing/conditional_moment_batch_handling.py index a1418b101..4d27a92da 100644 --- a/sup3r/preprocessing/conditional_moment_batch_handling.py +++ b/sup3r/preprocessing/conditional_moment_batch_handling.py @@ -902,7 +902,7 @@ def __init__( ) now = dt.now() - self.parallel_load() + self.load_handler_data() logger.debug( f'Finished loading data of shape {self.shape} ' f'for BatchHandler in {dt.now() - now}.' diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index 94eea1673..28425663b 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -1019,11 +1019,11 @@ def load_cached_data(self, with_split=True): logger.warning(msg) warnings.warn(msg) - logger.debug('Splitting data into training / validation sets ' - f'({1 - self.val_split}, {self.val_split}) ' - f'for {self.input_file_info}') + if with_split and self.val_split > 0: + logger.debug('Splitting data into training / validation sets ' + f'({1 - self.val_split}, {self.val_split}) ' + f'for {self.input_file_info}') - if with_split: self.data, self.val_data = self.split_data( val_split=self.val_split, shuffle_time=self.shuffle_time) diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index f2f2c0494..67e11cb5b 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -13,6 +13,7 @@ import numpy as np import pandas as pd +import psutil from scipy.stats import mode from sup3r.utilities.utilities import ( @@ -31,6 +32,7 @@ class CacheHandlingMixIn: """Collection of methods for handling data caching and loading""" def __init__(self): + """Initialize common attributes""" self._noncached_features = None self._cache_pattern = None self._cache_files = None @@ -285,7 +287,10 @@ def _load_single_cached_feature(self, fp, cache_files, features, msg = f'{features[idx].lower()} not found in {fp.lower()}.' assert features[idx].lower() in fp.lower(), msg fp = ignore_case_path_fetch(fp) - logger.info(f'Loading {features[idx]} from {fp}.') + mem = psutil.virtual_memory() + logger.info(f'Loading {features[idx]} from {fp}. Current memory ' + f'usage is {mem.used / 1e9:.3f} GB out of ' + f'{mem.total / 1e9:.3f} GB total.') out = None with open(fp, 'rb') as fh: diff --git a/sup3r/preprocessing/feature_handling.py b/sup3r/preprocessing/feature_handling.py index 3738ce03c..9a1ac6eb2 100644 --- a/sup3r/preprocessing/feature_handling.py +++ b/sup3r/preprocessing/feature_handling.py @@ -1617,10 +1617,8 @@ def serial_compute(cls, data, file_paths, raster_index, time_chunks, file_paths=file_paths, raster_index=raster_index) cls.pop_old_data(data, t, all_features) - interval = int(np.ceil(len(time_chunks) / 10)) - if t % interval == 0: - logger.debug(f'{t+1} out of {len(time_chunks)} feature ' - 'chunks computed.') + logger.debug(f'{t+1} out of {len(time_chunks)} feature ' + 'chunks computed.') return data @@ -1698,18 +1696,16 @@ def parallel_compute(cls, f' time chunks of shape ({shape[0]}, {shape[1]}, ' f'{time_shape}) for {len(derived_features)} features') - interval = int(np.ceil(len(futures) / 10)) for i, future in enumerate(as_completed(futures)): v = futures[future] chunk_idx = v['chunk'] data[chunk_idx] = data.get(chunk_idx, {}) data[chunk_idx][v['feature']] = future.result() - if i % interval == 0: - mem = psutil.virtual_memory() - logger.info(f'{i+1} out of {len(futures)} feature ' - 'chunks computed. Current memory usage is ' - f'{mem.used / 1e9:.3f} GB out of ' - f'{mem.total / 1e9:.3f} GB total.') + mem = psutil.virtual_memory() + logger.info(f'{i+1} out of {len(futures)} feature ' + 'chunks computed. Current memory usage is ' + f'{mem.used / 1e9:.3f} GB out of ' + f'{mem.total / 1e9:.3f} GB total.') return data diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 0a572c9ef..d627ca6c5 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -37,7 +37,8 @@ class EraDownloader: """Class to handle ERA5 downloading, variable renaming, file combination, - and interpolation.""" + and interpolation. + """ msg = ('To download ERA5 data you need to have a ~/.cdsapirc file ' 'with a valid url and api key. Follow the instructions here: ' @@ -223,7 +224,8 @@ def init_dims(cls, old_ds, new_ds, dims): @classmethod def get_tmp_file(cls, file): """Get temp file for given file. Then only needed variables will be - written to the given file.""" + written to the given file. + """ tmp_file = file.replace(".nc", "_tmp.nc") return tmp_file @@ -255,7 +257,8 @@ def check_good_vars(self, variables): def _prep_var_lists(self, variables): """Add all downloadable variables for the generic requested variables. - e.g. if variable = 'u' add all downloadable u variables to list.""" + e.g. if variable = 'u' add all downloadable u variables to list. + """ d_vars = [] vars = variables.copy() for i, v in enumerate(vars): @@ -269,7 +272,8 @@ def _prep_var_lists(self, variables): def prep_var_lists(self, variables): """Create surface and level variable lists based on requested - variables.""" + variables. + """ variables = self._prep_var_lists(variables) for var in variables: if var in self.SFC_VARS and var not in self.sfc_file_variables: @@ -344,7 +348,6 @@ def download_surface_file(self): def process_surface_file(self): """Rename variables and convert geopotential to geopotential height.""" - dims = ('time', 'latitude', 'longitude') tmp_file = self.get_tmp_file(self.surface_file) with Dataset(self.surface_file, "r") as old_ds: @@ -406,7 +409,6 @@ def convert_z(self, standard_name, long_name, old_ds, ds): ds : Dataset Dataset() object for new file with new height variable written. """ - _ = ds.createVariable(standard_name, np.float32, dimensions=old_ds['z'].dimensions) @@ -418,7 +420,6 @@ def convert_z(self, standard_name, long_name, old_ds, ds): def process_level_file(self): """Convert geopotential to geopotential height.""" - dims = ('time', 'level', 'latitude', 'longitude') tmp_file = self.get_tmp_file(self.level_file) with Dataset(self.level_file, "r") as old_ds: @@ -429,10 +430,12 @@ def process_level_file(self): ds = self.map_vars(old_ds, ds) - if 'pressure' in self.variables: + if ('pressure' in self.variables + and 'pressure' not in ds.variables): tmp = np.zeros(ds.variables['zg'].shape) for i in range(tmp.shape[1]): tmp[:, i, :, :] = ds.variables['level'][i] * 100 + _ = ds.createVariable('pressure', np.float32, dimensions=dims) @@ -446,7 +449,6 @@ def process_level_file(self): def process_and_combine(self): """Process variables and combine.""" - if not os.path.exists(self.combined_file) or self.overwrite: files = [] if os.path.exists(self.level_file): @@ -459,7 +461,7 @@ def process_and_combine(self): files.append(self.surface_file) logger.info(f'Combining {files} to {self.combined_file}.') - with xr.open_mfdataset(files) as ds: + with xr.open_mfdataset(files, compat='override') as ds: ds.to_netcdf(self.combined_file) logger.info(f'Finished writing {self.combined_file}') @@ -495,7 +497,8 @@ def good_file(self, file, required_shape): def check_existing_files(self): """If files exist already check them for good shape and required variables. Remove them if there was a problem so we can continue with - routine from scratch.""" + routine from scratch. + """ if os.path.exists(self.combined_file): try: check = self.good_file(self.combined_file, self.required_shape) @@ -521,7 +524,8 @@ def check_existing_files(self): def run_interpolation(self, max_workers=None, **kwargs): """Run interpolation to get final final. Runs log interpolation up to - max_log_height (usually 100m) and linear interpolation above this.""" + max_log_height (usually 100m) and linear interpolation above this. + """ LogLinInterpolator.run(infile=self.combined_file, outfile=self.interp_file, max_workers=max_workers, @@ -533,8 +537,8 @@ def get_monthly_file(self, interp_workers=None, keep_variables=None, **interp_kwargs): """Download level and surface files, process variables, and combine processed files. Includes checks for shape and variables and option to - interpolate.""" - + interpolate. + """ if os.path.exists(self.combined_file) and self.overwrite: os.remove(self.combined_file) @@ -577,7 +581,6 @@ def all_months_exist(cls, year, file_pattern): @classmethod def already_pruned(cls, infile, keep_variables): """Check if file has been pruned already.""" - if keep_variables is None: logger.info('Received keep_variables=None. Skipping pruning.') return @@ -586,7 +589,9 @@ def already_pruned(cls, infile, keep_variables): pruned = True with Dataset(infile, 'r') as ds: - for var in ds.variables: + variables = [var for var in ds.variables + if var not in ('time', 'latitude', 'longitude')] + for var in variables: if not any(name in var for name in keep_variables): logger.info(f'Pruning {var} in {infile}.') pruned = False @@ -595,7 +600,6 @@ def already_pruned(cls, infile, keep_variables): @classmethod def prune_output(cls, infile, keep_variables=None): """Prune output file to keep just single level variables""" - if keep_variables is None: logger.info('Received keep_variables=None. Skipping pruning.') return diff --git a/tests/data_handling/test_data_handling_nc.py b/tests/data_handling/test_data_handling_nc.py index 59ada2965..110c3016e 100644 --- a/tests/data_handling/test_data_handling_nc.py +++ b/tests/data_handling/test_data_handling_nc.py @@ -2,18 +2,21 @@ """pytests for data handling""" import os +import tempfile + +import matplotlib.pyplot as plt import numpy as np import pytest -import matplotlib.pyplot as plt -import tempfile import xarray as xr from sup3r import TEST_DATA_DIR +from sup3r.preprocessing.batch_handling import ( + BatchHandler, + SpatialBatchHandler, +) from sup3r.preprocessing.data_handling import DataHandlerNC as DataHandler -from sup3r.preprocessing.batch_handling import (BatchHandler, - SpatialBatchHandler) from sup3r.utilities.interpolation import Interpolator -from sup3r.utilities.pytest import make_fake_nc_files, make_fake_era_files +from sup3r.utilities.pytest import make_fake_era_files, make_fake_nc_files INPUT_FILE = os.path.join(TEST_DATA_DIR, 'test_wrf_2014-10-01_00_00_00') features = ['U_100m', 'V_100m', 'BVF_MO_200m'] @@ -166,7 +169,8 @@ def test_data_caching(): if os.path.exists(cache_pattern): os.system(f'rm {cache_pattern}') handler = DataHandler(INPUT_FILE, features, - cache_pattern=cache_pattern, **dh_kwargs) + cache_pattern=cache_pattern, **dh_kwargs, + val_split=0.1) assert handler.data is None handler.load_cached_data() assert handler.data.shape == (shape[0], shape[1], From a5b4c3ec4f441fe3f8277e02ecd15d5be52002e2 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 5 Nov 2023 09:42:57 -0700 Subject: [PATCH 15/44] nc collector for intermediate write during multi fwp pipeline --- sup3r/postprocessing/collection.py | 141 +++++++++++++++++++-- sup3r/postprocessing/data_collect_cli.py | 14 +- sup3r/preprocessing/data_handling/mixin.py | 10 +- sup3r/utilities/regridder.py | 2 + 4 files changed, 151 insertions(+), 16 deletions(-) diff --git a/sup3r/postprocessing/collection.py b/sup3r/postprocessing/collection.py index 2384c914d..0cba2744e 100644 --- a/sup3r/postprocessing/collection.py +++ b/sup3r/postprocessing/collection.py @@ -4,12 +4,14 @@ import logging import os import time +from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor, as_completed from warnings import warn import numpy as np import pandas as pd import psutil +import xarray as xr from gaps import Status from rex.utilities.fun_utils import get_fun_call_str from rex.utilities.loggers import init_logger @@ -22,17 +24,16 @@ logger = logging.getLogger(__name__) -class Collector(OutputMixIn): - """Sup3r H5 file collection framework""" +class BaseCollector(ABC, OutputMixIn): + """Base collector class for H5/NETCDF collection""" def __init__(self, file_paths): """Parameters ---------- file_paths : list | str Explicit list of str file paths that will be sorted and collected - or a single string with unix-style /search/patt*ern.h5. Files - should have non-overlapping time_index dataset and fully - overlapping meta dataset. + or a single string with unix-style /search/patt*ern.. Files + should have non-overlapping time_index and spatial domains. """ if not isinstance(file_paths, list): file_paths = glob.glob(file_paths) @@ -40,6 +41,11 @@ def __init__(self, file_paths): self.data = None self.file_attrs = {} + @classmethod + @abstractmethod + def collect(cls, *args, **kwargs): + """Collect data files from a dir to one output file.""" + @classmethod def get_node_cmd(cls, config): """Get a CLI call to collect data. @@ -52,7 +58,7 @@ def get_node_cmd(cls, config): """ import_str = ( 'from sup3r.postprocessing.collection ' - 'import Collector;\n' + f'import {cls.__class__.__name__};\n' 'from rex import init_logger;\n' 'import time;\n' 'from gaps import Status;\n' @@ -80,6 +86,123 @@ def get_node_cmd(cls, config): return cmd.replace('\\', '/') + +class CollectorNC(BaseCollector): + """Sup3r NETCDF file collection framework""" + + @classmethod + def collect( + cls, + file_paths, + out_file, + features, + max_workers=None, + log_level=None, + log_file=None, + write_status=False, + job_name=None, + join_times=False, + target_final_meta_file=None, + n_writes=None, + overwrite=True, + threshold=1e-4, + ): + """Collect data files from a dir to one output file. + + Filename requirements: + - Should end with ".nc" + + Parameters + ---------- + file_paths : list | str + Explicit list of str file paths that will be sorted and collected + or a single string with unix-style /search/patt*ern.nc. + out_file : str + File path of final output file. + features : list + List of dsets to collect + max_workers : int | None + Number of workers to use in parallel. 1 runs serial, + None will use all available workers. + log_level : str | None + Desired log level, None will not initialize logging. + log_file : str | None + Target log file. None logs to stdout. + write_status : bool + Flag to write status file once complete if running from pipeline. + job_name : str + Job name for status file if running from pipeline. + join_times : bool + Option to split full file list into chunks with each chunk having + the same temporal_chunk_index. The number of writes will then be + min(number of temporal chunks, n_writes). This ensures that each + write has all the spatial chunks for a given time index. Assumes + file_paths have a suffix format + _{temporal_chunk_index}_{spatial_chunk_index}.h5. This is required + if there are multiple writes and chunks have different time + indices. + target_final_meta_file : str + Path to target final meta containing coordinates to keep from the + full file list collected meta. This can be but is not necessarily a + subset of the full list of coordinates for all files in the file + list. This is used to remove coordinates from the full file list + which are not present in the target_final_meta. Either this full + meta or a subset, depending on which coordinates are present in + the data to be collected, will be the final meta for the collected + output files. + n_writes : int | None + Number of writes to split full file list into. Must be less than + or equal to the number of temporal chunks if chunks have different + time indices. + overwrite : bool + Whether to overwrite existing output file + threshold : float + Threshold distance for finding target coordinates within full meta + """ + t0 = time.time() + + logger.info( + f'Initializing collection for file_paths={file_paths}, ' + f'with max_workers={max_workers}.' + ) + + if log_level is not None: + init_logger( + 'sup3r.preprocessing', log_file=log_file, log_level=log_level + ) + + if not os.path.exists(os.path.dirname(out_file)): + os.makedirs(os.path.dirname(out_file), exist_ok=True) + + collector = cls(file_paths) + logger.info( + 'Collecting {} files to {}'.format(len(collector.flist), out_file) + ) + if overwrite and os.path.exists(out_file): + logger.info(f'overwrite=True, removing {out_file}.') + os.remove(out_file) + + out = xr.open_mfdataset(collector.flist) + out.to_netcdf(out_file) + + if write_status and job_name is not None: + status = { + 'out_dir': os.path.dirname(out_file), + 'fout': out_file, + 'flist': collector.flist, + 'job_status': 'successful', + 'runtime': (time.time() - t0) / 60, + } + Status.make_single_job_file( + os.path.dirname(out_file), 'collect', job_name, status + ) + + logger.info('Finished file collection.') + + +class CollectorH5(BaseCollector): + """Sup3r H5 file collection framework""" + @classmethod def get_slices( cls, final_time_index, final_meta, new_time_index, new_meta @@ -230,7 +353,7 @@ def get_data( warn(msg) else: - row_slice, col_slice = Collector.get_slices( + row_slice, col_slice = self.get_slices( time_index, meta, f_ti, f_meta ) @@ -494,13 +617,13 @@ def _write_flist_data( """ with RexOutputs(out_file, mode='r') as f: target_ti = f.time_index - y_write_slice, x_write_slice = Collector.get_slices( + y_write_slice, x_write_slice = self.get_slices( target_ti, target_masked_meta, time_index, subset_masked_meta, ) - Collector._ensure_dset_in_output(out_file, feature) + self._ensure_dset_in_output(out_file, feature) with RexOutputs(out_file, mode='a') as f: try: diff --git a/sup3r/postprocessing/data_collect_cli.py b/sup3r/postprocessing/data_collect_cli.py index 5133f5586..e2ad55871 100644 --- a/sup3r/postprocessing/data_collect_cli.py +++ b/sup3r/postprocessing/data_collect_cli.py @@ -2,15 +2,16 @@ """ sup3r data collection CLI entry points. """ -import click -import logging import copy +import logging + +import click +from sup3r.postprocessing.collection import CollectorH5, CollectorNC from sup3r.utilities import ModuleName -from sup3r.version import __version__ -from sup3r.postprocessing.collection import Collector from sup3r.utilities.cli import AVAILABLE_HARDWARE_OPTIONS, BaseCLI - +from sup3r.utilities.utilities import get_source_type +from sup3r.version import __version__ logger = logging.getLogger(__name__) @@ -42,6 +43,9 @@ def from_config(ctx, config_file, verbose=False, **__): dset_split = config.get('dset_split', False) exec_kwargs = config.get('execution_control', {}) hardware_option = exec_kwargs.pop('option', 'local') + source_type = get_source_type(config['file_paths']) + collector_types = {'h5': CollectorH5, 'nc': CollectorNC} + Collector = collector_types[source_type] configs = [config] if dset_split: diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index 67e11cb5b..781cd5f67 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -914,6 +914,10 @@ class TrainingPrepMixIn: """Collection of training related methods. e.g. Training + Validation splitting, normalization""" + def __init__(self): + """Initialize common attributes""" + self.features = None + @classmethod def _split_data_indices(cls, data, @@ -1034,11 +1038,13 @@ def _normalize_data(self, data, val_data, feature_index, mean, std): val_data[..., feature_index] /= std data[..., feature_index] /= std else: - msg = ( - f'Standard Deviation is zero for feature #{feature_index + 1}') + msg = ('Standard Deviation is zero for ' + f'{self.features[feature_index]}') logger.warning(msg) warnings.warn(msg) + logger.info(f'Finished normalizing {self.features[feature_index]}.') + def _normalize(self, data, val_data, means, stds, max_workers=None): """Normalize all data features diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index f1c215319..64ed901af 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -344,6 +344,8 @@ def __call__(self, data): Flattened regridded spatiotemporal data (spatial, temporal) """ + msg = 'Input data must be 3D (spatial_1, spatial_2, temporal)' + assert len(data.shape) == 3, msg vals = [ data[:, :, i].flatten()[np.array(self.indices)][np.newaxis] for i in range(data.shape[-1]) From 3a80ec798e3029ac6021bfb26f8a5bbe19148147 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 5 Nov 2023 18:22:49 -0700 Subject: [PATCH 16/44] cleaning up args --- sup3r/postprocessing/collection.py | 53 ++++++---------------- sup3r/postprocessing/data_collect_cli.py | 3 +- sup3r/utilities/interpolate_log_profile.py | 4 +- 3 files changed, 18 insertions(+), 42 deletions(-) diff --git a/sup3r/postprocessing/collection.py b/sup3r/postprocessing/collection.py index 0cba2744e..5b87a9a48 100644 --- a/sup3r/postprocessing/collection.py +++ b/sup3r/postprocessing/collection.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -"""H5 file collection.""" +"""H5/NETCDF file collection.""" import glob import logging import os @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) -class BaseCollector(ABC, OutputMixIn): +class BaseCollector(OutputMixIn, ABC): """Base collector class for H5/NETCDF collection""" def __init__(self, file_paths): @@ -58,7 +58,7 @@ def get_node_cmd(cls, config): """ import_str = ( 'from sup3r.postprocessing.collection ' - f'import {cls.__class__.__name__};\n' + f'import {cls.__name__};\n' 'from rex import init_logger;\n' 'import time;\n' 'from gaps import Status;\n' @@ -96,16 +96,12 @@ def collect( file_paths, out_file, features, - max_workers=None, log_level=None, log_file=None, write_status=False, job_name=None, - join_times=False, - target_final_meta_file=None, - n_writes=None, overwrite=True, - threshold=1e-4, + res_kwargs=None ): """Collect data files from a dir to one output file. @@ -121,9 +117,6 @@ def collect( File path of final output file. features : list List of dsets to collect - max_workers : int | None - Number of workers to use in parallel. 1 runs serial, - None will use all available workers. log_level : str | None Desired log level, None will not initialize logging. log_file : str | None @@ -132,38 +125,15 @@ def collect( Flag to write status file once complete if running from pipeline. job_name : str Job name for status file if running from pipeline. - join_times : bool - Option to split full file list into chunks with each chunk having - the same temporal_chunk_index. The number of writes will then be - min(number of temporal chunks, n_writes). This ensures that each - write has all the spatial chunks for a given time index. Assumes - file_paths have a suffix format - _{temporal_chunk_index}_{spatial_chunk_index}.h5. This is required - if there are multiple writes and chunks have different time - indices. - target_final_meta_file : str - Path to target final meta containing coordinates to keep from the - full file list collected meta. This can be but is not necessarily a - subset of the full list of coordinates for all files in the file - list. This is used to remove coordinates from the full file list - which are not present in the target_final_meta. Either this full - meta or a subset, depending on which coordinates are present in - the data to be collected, will be the final meta for the collected - output files. - n_writes : int | None - Number of writes to split full file list into. Must be less than - or equal to the number of temporal chunks if chunks have different - time indices. overwrite : bool Whether to overwrite existing output file - threshold : float - Threshold distance for finding target coordinates within full meta + res_kwargs : dict | None + Dictionary of kwargs to pass to xarray.open_mfdataset. """ t0 = time.time() logger.info( - f'Initializing collection for file_paths={file_paths}, ' - f'with max_workers={max_workers}.' + f'Initializing collection for file_paths={file_paths}' ) if log_level is not None: @@ -182,8 +152,13 @@ def collect( logger.info(f'overwrite=True, removing {out_file}.') os.remove(out_file) - out = xr.open_mfdataset(collector.flist) - out.to_netcdf(out_file) + if not os.path.exists(out_file): + res_kwargs = (res_kwargs + or {"concat_dim": "Time", "combine": "nested"}) + out = xr.open_mfdataset(collector.flist, **res_kwargs) + features = [feat for feat in out if feat in features + or feat.lower() in features] + out[features].to_netcdf(out_file) if write_status and job_name is not None: status = { diff --git a/sup3r/postprocessing/data_collect_cli.py b/sup3r/postprocessing/data_collect_cli.py index e2ad55871..dbf3228dc 100644 --- a/sup3r/postprocessing/data_collect_cli.py +++ b/sup3r/postprocessing/data_collect_cli.py @@ -52,7 +52,8 @@ def from_config(ctx, config_file, verbose=False, **__): configs = [] for feature in config['features']: f_config = copy.deepcopy(config) - f_out_file = config['out_file'].replace('.h5', f'_{feature}.h5') + f_out_file = config['out_file'].replace( + f'.{source_type}', f'_{feature}.{source_type}') f_job_name = config['job_name'] + f'_{feature}' f_log_file = config.get('log_file', None) if f_log_file is not None: diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index be6921f5e..1a42b8ef7 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -33,8 +33,8 @@ class LogLinInterpolator: """ DEFAULT_OUTPUT_HEIGHTS: ClassVar[dict] = { - 'u': [40, 80, 120, 160, 200], - 'v': [40, 80, 120, 160, 200], + 'u': [10, 40, 80, 100, 120, 160, 200], + 'v': [10, 40, 80, 100, 120, 160, 200], 'temperature': [10, 40, 80, 100, 120, 160, 200], 'pressure': [0, 100, 200], 'relative_humidity': [80, 100, 120], From c6784ff1e4fc2fb3ae5278a3c31cb6e85d3e0663 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 6 Nov 2023 16:12:24 -0700 Subject: [PATCH 17/44] break at max_log_height --- sup3r/utilities/interpolate_log_profile.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index 1a42b8ef7..2ed8c39db 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -405,8 +405,8 @@ def ws_log_profile(z, a, b): good = True levels = np.array(levels) - lev_mask = (0 < levels) & (levels <= max_log_height) - var_mask = (0 < lev_array_samp) & (lev_array_samp <= max_log_height) + lev_mask = (levels > 0) & (levels <= max_log_height) + var_mask = (lev_array_samp > 0) & (lev_array_samp <= max_log_height) try: popt, _ = curve_fit(ws_log_profile, lev_array_samp[var_mask], @@ -480,8 +480,8 @@ def _interp_var_to_height(cls, max_log_height=max_log_height) if any(levels > max_log_height): - lev_mask = levels >= max_log_height - var_mask = lev_array >= max_log_height + lev_mask = levels > max_log_height + var_mask = lev_array > max_log_height if len(lev_array[var_mask]) > 1: lin_ws = interp1d(lev_array[var_mask], var_array[var_mask], From e6586f35cca7f4f02d02b55362f5685c366560a5 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 24 Oct 2023 07:55:59 -0600 Subject: [PATCH 18/44] doc string fix --- sup3r/utilities/era_downloader.py | 10 ++++++---- sup3r/utilities/loss_metrics.py | 2 ++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index c6e6dc2ca..a7d7b5216 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -8,9 +8,11 @@ import logging import os from calendar import monthrange -from concurrent.futures import (ProcessPoolExecutor, ThreadPoolExecutor, - as_completed, - ) +from concurrent.futures import ( + ProcessPoolExecutor, + ThreadPoolExecutor, + as_completed, +) from glob import glob from typing import ClassVar from warnings import warn @@ -52,7 +54,7 @@ class EraDownloader: KEEP_VARIABLES += [f'{v}_' for v in VALID_VARIABLES] DEFAULT_RENAMED_VARS: ClassVar[list] = [ - 'zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m', + 'z', 'zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m', 'temperature', 'pressure', ] DEFAULT_DOWNLOAD_VARS: ClassVar[list] = [ diff --git a/sup3r/utilities/loss_metrics.py b/sup3r/utilities/loss_metrics.py index 64a0680ed..71d36ede1 100644 --- a/sup3r/utilities/loss_metrics.py +++ b/sup3r/utilities/loss_metrics.py @@ -15,6 +15,8 @@ def gaussian_kernel(x1, x2, sigma=1.0): x2 : tf.tensor high resolution data (n_obs, spatial_1, spatial_2, temporal, features) + sigma : float + Standard deviation for gaussian kernel Returns ------- From 9d6ec49e1dd52962296bce3e3e830d47c3e082cc Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 30 Oct 2023 11:56:53 -0600 Subject: [PATCH 19/44] nc topoextract fix --- sup3r/preprocessing/data_handling/exo_extraction.py | 3 ++- sup3r/utilities/regridder.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 6d8c6bdfe..1290703f2 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -450,7 +450,8 @@ def __init__(self, *args, **kwargs): @property def source_data(self): """Get the 1D array of elevation data from the exo_source_nc""" - elev = self.source_handler.data.reshape((-1, self.lr_shape[-1])) + elev = self.source_handler.data[..., 0, 0].flatten() + elev = np.repeat(elev[..., np.newaxis], self.hr_shape[-1], axis=-1) return elev @property diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index b1b898893..d0af2abf6 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -345,7 +345,7 @@ def __call__(self, data): (spatial, temporal) """ vals = [ - data[:, :, i].flatten()[self.indices][np.newaxis] + data[:, :, i].flatten()[np.array(self.indices)][np.newaxis] for i in range(data.shape[-1]) ] vals = np.concatenate(vals, axis=0) From 2b298b42c74a5364720ce7283d8b3b12c626da08 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 30 Oct 2023 13:36:40 -0600 Subject: [PATCH 20/44] ignore case for bc adder/scalar lookup --- sup3r/bias/bias_transforms.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sup3r/bias/bias_transforms.py b/sup3r/bias/bias_transforms.py index 71d74f568..3995a0f3b 100644 --- a/sup3r/bias/bias_transforms.py +++ b/sup3r/bias/bias_transforms.py @@ -58,7 +58,11 @@ def get_spatial_bc_factors(lat_lon, feature_name, bias_fp, threshold=0.1): raise RuntimeError(msg) msg = (f'Either {dset_scalar} or {dset_adder} not found in {bias_fp}.') - assert dset_scalar in res.dsets and dset_adder in res.dsets, msg + dsets = [dset.lower() for dset in res.dsets] + check = dset_scalar.lower() in dsets and dset_adder.lower() in dsets + assert check, msg + dset_scalar = res.dsets[dsets.index(dset_scalar.lower())] + dset_adder = res.dsets[dsets.index(dset_adder.lower())] scalar = res[dset_scalar, slice_y, slice_x] adder = res[dset_adder, slice_y, slice_x] return scalar, adder From 6aa3aa9655c0316da5e9fb938f0525a33daab716 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 30 Oct 2023 20:55:47 -0600 Subject: [PATCH 21/44] some tweaks for training with cape - only found in era not wtk so need to make sure different sets of features are indexed carefully. cape added to era_downloader. added vortex mean based bias correction code. --- sup3r/bias/bias_correct_means.py | 707 ++++++++++++++++++ .../data_handling/dual_data_handling.py | 31 +- sup3r/preprocessing/data_handling/mixin.py | 6 +- sup3r/preprocessing/dual_batch_handling.py | 34 +- sup3r/utilities/era_downloader.py | 16 +- 5 files changed, 758 insertions(+), 36 deletions(-) create mode 100644 sup3r/bias/bias_correct_means.py diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py new file mode 100644 index 000000000..cbd98e1b5 --- /dev/null +++ b/sup3r/bias/bias_correct_means.py @@ -0,0 +1,707 @@ +"""Classes to compute means from vortex and era data and compute bias +correction factors.""" + + +import calendar +import json +import logging +import os + +import h5py +import numpy as np +import pandas as pd +import rioxarray +import xarray as xr +from rex import Resource +from sklearn.neighbors import BallTree +from sup3r.postprocessing.file_handling import OutputHandler, RexOutputs +from sup3r.preprocessing.feature_handling import Feature +from sup3r.utilities import VERSION_RECORD +from sup3r.utilities.interpolation import Interpolator +from scipy.interpolate import interp1d + +logger = logging.getLogger(__name__) + + +class VortexMeanPrepper: + """Class for converting monthly vortex tif files for each height to a + single h5 files containing all monthly means for all requested output + heights""" + + def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): + """ + Parameters + ---------- + path_pattern : str + Pattern for input tif files. Needs to include {month} and {height} + format keys. + in_heights : list + List of heights for input files. + out_heights : list + List of output heights used for interpolation + overwrite : bool + Whether to overwrite intermediate netcdf files containing the + interpolated masked monthly means. + """ + self.path_pattern = path_pattern + self.in_heights = in_heights + self.out_heights = out_heights + self.out_dir = os.path.dirname(path_pattern) + self.overwrite = overwrite + self._mask = None + + @property + def in_features(self): + """List of features corresponding to input heights""" + return [f"windspeed_{h}m" for h in self.in_heights] + + @property + def out_features(self): + """List of features corresponding to output heights""" + return [f"windspeed_{h}m" for h in self.out_heights] + + def get_input_file(self, month, height): + """Get vortex tif file for given month and height.""" + return self.path_pattern.format(month=month, height=height) + + def get_height_files(self, month): + """Get set of netcdf files for given month""" + files = [] + for height in self.in_heights: + infile = self.get_input_file(month, height) + outfile = infile.replace(".tif", ".nc") + files.append(outfile) + return files + + @property + def input_files(self): + """Get list of all input files used for h5 meta.""" + files = [] + for height in self.in_heights: + for i in range(1, 13): + month = calendar.month_name[i] + files.append(self.get_input_file(month, height)) + return files + + def get_output_file(self, month): + """Get name of netcdf file for a given month.""" + return os.path.join( + self.out_dir.replace("{month}", month), f"{month}.nc" + ) + + @property + def output_files(self): + """List of output monthly output files each with windspeed for all + input heights""" + files = [] + for i in range(1, 13): + month = calendar.month_name[i] + files.append(self.get_output_file(month)) + return files + + def convert_month_height_tif(self, month, height): + """Get windspeed mean for the given month and hub height from the + corresponding input file and write this to a netcdf file.""" + infile = self.get_input_file(month, height) + logger.info(f"Getting mean windspeed_{height}m for {month}.") + outfile = infile.replace(".tif", ".nc") + if os.path.exists(outfile) and self.overwrite: + os.remove(outfile) + + if not os.path.exists(outfile) or self.overwrite: + tmp = rioxarray.open_rasterio(infile) + ds = tmp.to_dataset("band") + ds = ds.rename( + {1: f"windspeed_{height}m", "x": "longitude", "y": "latitude"} + ) + ds.to_netcdf(outfile) + return outfile + + def convert_month_tif(self, month): + """Write netcdf files for all heights for the given month""" + for height in self.in_heights: + self.convert_month_height_tif(month, height) + + def convert_all_tifs(self): + """Write netcdf files for all heights for all months""" + for i in range(1, 13): + month = calendar.month_name[i] + logger.info(f"Converting tif files to netcdf files for {month}") + self.convert_month_tif(month) + + @property + def mask(self): + """Mask coordinates without data""" + if self._mask is None: + with xr.open_mfdataset(self.get_height_files('January')) as res: + mask = ((res[self.in_features[0]] != -999) + & (~np.isnan(res[self.in_features[0]]))) + for feat in self.in_features[1:]: + tmp = ((res[feat] != -999) & (~np.isnan(res[feat]))) + mask = mask & tmp + self._mask = np.array(mask).flatten() + return self._mask + + def get_month(self, month): + """Get interpolated means for all hub heights for the given month. + + Parameters + ---------- + month : str + Name of month to get data for + + Returns + ------- + data : xarray.Dataset + xarray dataset object containing interpolated monthly windspeed + means for all input and output heights + + """ + month_file = self.get_output_file(month) + if os.path.exists(month_file) and self.overwrite: + os.remove(month_file) + + if os.path.exists(month_file) and not self.overwrite: + logger.info(f"Loading month_file {month_file}.") + data = xr.open_dataset(month_file) + else: + logger.info( + "Getting mean windspeed for all heights " + f"({self.in_heights}) for {month}" + ) + data = xr.open_mfdataset(self.get_height_files(month)) + logger.info( + "Interpolating windspeed for all heights " + f"({self.out_heights}) for {month}." + ) + data = self.interp(data) + data.to_netcdf(month_file) + logger.info( + "Saved interpolated means for all heights for " + f"{month} to {month_file}." + ) + return data + + def interp(self, data): + """Interpolate data to requested output heights. + + Parameters + ---------- + data : xarray.Dataset + xarray dataset object containing windspeed for all input heights + + Returns + ------- + data : xarray.Dataset + xarray dataset object containing windspeed for all input and output + heights + """ + var_array = np.zeros( + ( + len(data.latitude) * len(data.longitude), + len(self.in_heights), + ), dtype=np.float32 + ) + lev_array = var_array.copy() + for i, (h, feat) in enumerate(zip(self.in_heights, self.in_features)): + var_array[..., i] = data[feat].values.flatten() + lev_array[..., i] = h + + logger.info(f'Interpolating {self.in_features} to {self.out_features} ' + f'for {var_array.shape[0]} coordinates.') + tmp = [interp1d(h, v, fill_value='extrapolate')(self.out_heights) + for h, v in zip(lev_array[self.mask], var_array[self.mask])] + out = np.full((len(data.latitude), len(data.longitude), + len(self.out_heights)), np.nan, dtype=np.float32) + out[self.mask.reshape((len(data.latitude), len(data.longitude)))] = tmp + for i, feat in enumerate(self.out_features): + if feat not in data: + data[feat] = (("latitude", "longitude"), out[..., i]) + return data + + def get_lat_lon(self): + """Get lat lon grid""" + with xr.open_mfdataset(self.get_height_files("January")) as res: + lons, lats = np.meshgrid( + res["longitude"].values, res["latitude"].values + ) + return np.array(lats), np.array(lons) + + def get_all_data(self): + """Get interpolated monthly means for all out heights as a dictionary + to use for h5 writing. + + Returns + ------- + out : dict + Dictionary of arrays containing monthly means for each hub height. + Also includes latitude and longitude. Spatial dimensions are + flattened + """ + data_dict = {} + lats, lons = self.get_lat_lon() + data_dict["latitude"] = lats.flatten()[self.mask] + data_dict["longitude"] = lons.flatten()[self.mask] + s_num = len(data_dict["longitude"]) + for i in range(1, 13): + month = calendar.month_name[i] + out = self.get_month(month) + for feat in self.out_features: + if feat not in data_dict: + data_dict[feat] = np.full((s_num, 12), np.nan) + data = out[feat].values.flatten()[self.mask] + data_dict[feat][..., i - 1] = data + return data_dict + + @property + def meta(self): + """Get a meta data dictionary on how this data is prepared""" + meta = { + "input_files": self.input_files, + "class": str(self.__class__), + "version_record": VERSION_RECORD, + } + return meta + + def write_data(self, fp_out, out): + """Write monthly means for all heights to h5 file""" + if fp_out is not None: + if not os.path.exists(os.path.dirname(fp_out)): + os.makedirs(os.path.dirname(fp_out), exist_ok=True) + + if not os.path.exists(fp_out) or self.overwrite: + with h5py.File(fp_out, "w") as f: + for dset, data in out.items(): + f.create_dataset(dset, data=data) + logger.info(f"Added {dset} to {fp_out}.") + + for k, v in self.meta.items(): + f.attrs[k] = json.dumps(v) + + logger.info( + f"Wrote monthly means for all out heights: {fp_out}" + ) + elif os.path.exists(fp_out): + logger.info(f"{fp_out} already exists and overwrite=False.") + + @classmethod + def run( + cls, path_pattern, in_heights, out_heights, fp_out, overwrite=False + ): + """Read vortex tif files, convert these to monthly netcdf files for all + input heights, interpolate this data to requested output heights, mask + fill values, and write all data to h5 file.""" + vprep = cls(path_pattern, in_heights, out_heights, overwrite=overwrite) + vprep.convert_all_tifs() + out = vprep.get_all_data() + vprep.write_data(fp_out, out) + + +class EraMeanPrepper: + """Class to compute monthly windspeed means from ERA data.""" + + def __init__(self, era_pattern, years, features): + """ + Parameters + ---------- + era_pattern : str + Pattern pointing to era files with u/v wind components at the given + heights. Must have a {year} format key. + years : list + List of ERA years to use for calculating means. + features : list + List of features to compute means for. e.g. ['windspeed_10m'] + """ + self.era_pattern = era_pattern + self.years = years + self.features = features + self.lats, self.lons = self.get_lat_lon() + + @property + def shape(self): + """Get shape of spatial dimensions (lats, lons)""" + return self.lats.shape + + @property + def heights(self): + """List of feature heights""" + heights = [Feature.get_height(feature) for feature in self.features] + return heights + + @property + def input_files(self): + """List of ERA input files to use for calculating means.""" + return [self.era_pattern.format(year=year) for year in self.years] + + def get_lat_lon(self): + """Get arrays of latitude and longitude for ERA domain""" + with xr.open_dataset(self.input_files[0]) as res: + lons, lats = np.meshgrid( + res["longitude"].values, res["latitude"].values + ) + return lats, lons + + def get_windspeed(self, data, height): + """Compute windspeed from u/v wind components from given data. + + Parameters + ---------- + data : xarray.Dataset + xarray dataset object for a year of ERA data. Must include u/v + components for the given height. e.g. u_{height}m, v_{height}m. + height : int + Height to compute windspeed for. + """ + return np.hypot( + data[f"u_{height}m"].values, data[f"v_{height}m"].values + ) + + def get_month_mean(self, data, height, month): + """Get windspeed_{height}m mean for the given month. + + Parameters + ---------- + data : xarray.Dataset + xarray dataset object for a year of ERA data. Must include u/v + components for the given height. e.g. u_{height}m, v_{height}m. + height : int + Height to compute windspeed for. + month : int + Index of month to get mean for. e.g. 1 = Jan, 2 = Feb, etc. + + Returns + ------- + out : np.ndarray + Array of time averaged windspeed data for the given month. + + """ + mask = pd.to_datetime(data["time"]).month == month + ws = self.get_windspeed(data, height)[mask] + return ws.mean(axis=0) + + def get_all_means(self, height): + """Get monthly means for all months across all given years for the + given height. + + Parameters + ---------- + height : int + Height to compute windspeed for. + + Returns + ------- + means : dict + Dictionary of windspeed_{height}m means for each month + """ + feature = self.features[self.heights.index(height)] + means = {i: [] for i in range(1, 13)} + for i, year in enumerate(self.years): + logger.info(f"Getting means for year={year}, feature={feature}.") + data = xr.open_dataset(self.input_files[i]) + for m in range(1, 13): + means[m].append(self.get_month_mean(data, height, month=m)) + means = {m: np.dstack(arr).mean(axis=-1) for m, arr in means.items()} + return means + + def write_csv(self, out, out_file): + """Write monthly means to a csv file. + + Parameters + ---------- + out : dict + Dictionary of windspeed_{height}m means for each month + out_file : str + Name of csv output file. + """ + logger.info(f"Writing means to {out_file}.") + out = { + f"{str(calendar.month_name[m])[:3]}_mean": v.flatten() + for m, v in out.items() + } + df = pd.DataFrame.from_dict(out) + df["latitude"] = self.lats.flatten() + df["longitude"] = self.lons.flatten() + df["gid"] = np.arange(len(df["latitude"])) + df.to_csv(out_file) + logger.info(f"Finished writing means for {out_file}.") + + @classmethod + def run(cls, era_pattern, years, features, out_pattern): + """Compute monthly windspeed means for the given heights, using the + given years of ERA data, and write the means to csv files for each + height.""" + em = cls(era_pattern=era_pattern, years=years, features=features) + for height, feature in zip(em.heights, em.features): + means = em.get_all_means(height) + out_file = out_pattern.format(feature=feature) + em.write_csv(means, out_file=out_file) + logger.info( + f"Finished writing means for years={years} and " + f"heights={em.heights}." + ) + + +class BiasCorrectionFromMeans: + """Class for getting bias correction factors from bias and base data files + with precomputed monthly means.""" + + MIN_DISTANCE = 1e-12 + + def __init__(self, bias_fp, base_fp, dset, leaf_size=4): + self.dset = dset + self.bias_fp = bias_fp + self.base_fp = base_fp + self.leaf_size = leaf_size + self.bias_means = pd.read_csv(bias_fp) + self.base_means = Resource(base_fp) + self.bias_meta = self.bias_means[["latitude", "longitude"]] + self.base_meta = pd.DataFrame(columns=["latitude", "longitude"]) + self.base_meta["latitude"] = self.base_means["latitude"] + self.base_meta["longitude"] = self.base_means["longitude"] + self._base_tree = None + logger.info( + "Finished initializing BiasCorrectionFromMeans for " + f"bias_fp={bias_fp}, base_fp={base_fp}, dset={dset}." + ) + + @property + def base_tree(self): + """Build ball tree from source_meta""" + if self._base_tree is None: + logger.info("Building ball tree for regridding.") + self._base_tree = BallTree( + np.deg2rad(self.base_meta), + leaf_size=self.leaf_size, + metric="haversine", + ) + return self._base_tree + + @property + def meta(self): + """Get a meta data dictionary on how these bias factors were + calculated""" + meta = { + "base_fp": self.base_fp, + "bias_fp": self.bias_fp, + "dset": self.dset, + "class": str(self.__class__), + "version_record": VERSION_RECORD, + "NOTES": ("scalar factors computed from base_data / bias_data."), + } + return meta + + @property + def height(self): + """Get feature height""" + return Feature.get_height(self.dset) + + @property + def u_name(self): + """Get corresponding u component for given height""" + return f"u_{self.height}m" + + @property + def v_name(self): + """Get corresponding v component for given height""" + return f"v_{self.height}m" + + def get_base_data(self, knn=1): + """Get means for baseline data.""" + logger.info(f"Getting base data for {self.dset}.") + dists, gids = self.base_tree.query(np.deg2rad(self.bias_meta), k=knn) + mask = dists < self.MIN_DISTANCE + if mask.sum() > 0: + logger.info( + f"{np.sum(mask)} of {np.product(mask.shape)} " + "distances are zero." + ) + dists[mask] = self.MIN_DISTANCE + weights = 1 / dists + norm = np.sum(weights, axis=-1) + out = self.base_means[self.dset, gids] + out = np.einsum("ijk,ij->ik", out, weights) / norm[:, np.newaxis] + return out + + def get_bias_data(self): + """Get means for biased data.""" + logger.info(f"Getting bias data for {self.dset}.") + cols = [col for col in self.bias_means.columns if "mean" in col] + bias_data = self.bias_means[cols].to_numpy() + return bias_data + + def get_corrections(self, global_scalar=1, knn=1): + """Get bias correction factors.""" + logger.info(f"Getting correction factors for {self.dset}.") + base_data = self.get_base_data(knn=knn) + bias_data = self.get_bias_data() + scaler = global_scalar * base_data / bias_data + adder = 0 + + out = { + "latitude": self.bias_meta["latitude"], + "longitude": self.bias_meta["longitude"], + f"base_{self.dset}_mean": base_data, + f"bias_{self.dset}_mean": bias_data, + f"{self.dset}_adder": adder, + f"{self.dset}_scalar": scaler, + f"{self.dset}_global_scalar": global_scalar, + } + return out + + def get_uv_corrections(self, global_scalar=1, knn=1): + """Write windspeed bias correction factors for u/v components""" + u_out = self.get_corrections(global_scalar=global_scalar, knn=knn) + v_out = u_out.copy() + u_out[f"{self.u_name}_scalar"] = u_out[f"{self.dset}_scalar"] + v_out[f"{self.v_name}_scalar"] = v_out[f"{self.dset}_scalar"] + u_out[f"{self.u_name}_adder"] = u_out[f"{self.dset}_adder"] + v_out[f"{self.v_name}_adder"] = v_out[f"{self.dset}_adder"] + return u_out, v_out + + def write_output(self, fp_out, out): + """Write bias correction factors to h5 file.""" + logger.info(f"Writing correction factors to file: {fp_out}.") + with h5py.File(fp_out, "w") as f: + for dset, data in out.items(): + f.create_dataset(dset, data=data) + logger.info(f"Added {dset} to {fp_out}.") + for k, v in self.meta.items(): + f.attrs[k] = json.dumps(v) + logger.info(f"Finished writing output to {fp_out}.") + + @classmethod + def run( + cls, + bias_fp, + base_fp, + dset, + fp_out, + global_scalar=1.0, + knn=1, + out_shape=None + ): + """Run bias correction factor computation and write.""" + bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) + out = bc.get_corrections(global_scalar=global_scalar, knn=knn) + if out_shape is not None: + for k, v in out.items(): + if k in ('latitude', 'longitude'): + out[k] = np.array(v).reshape(out_shape) + elif not isinstance(v, float): + out[k] = np.array(v).reshape((*out_shape, 12)) + bc.write_output(fp_out, out) + + @classmethod + def run_uv( + cls, + bias_fp, + base_fp, + dset, + fp_pattern, + global_scalar=1.0, + knn=1, + out_shape=None + ): + """Run bias correction factor computation and write.""" + bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) + out_u, out_v = bc.get_uv_corrections( + global_scalar=global_scalar, knn=knn + ) + if out_shape is not None: + for k, v in out_u.items(): + if k in ('latitude', 'longitude'): + out_u[k] = np.array(v).reshape(out_shape) + elif not isinstance(v, float): + out_u[k] = np.array(v).reshape((*out_shape, 12)) + for k, v in out_v.items(): + if k in ('latitude', 'longitude'): + out_v[k] = np.array(v).reshape(out_shape) + elif not isinstance(v, float): + out_v[k] = np.array(v).reshape((*out_shape, 12)) + bc.write_output(fp_pattern.format(feature=bc.u_name), out_u) + bc.write_output(fp_pattern.format(feature=bc.v_name), out_v) + + +class BiasCorrectUpdate: + """Class for bias correcting existing files and writing corrected files.""" + + @classmethod + def get_bc_factors(cls, bc_file, dset, month, global_scalar=1): + """Get bias correction factors for the given dset and month""" + with Resource(bc_file) as res: + logger.info( + f"Getting {dset} bias correction factors for month {month}." + ) + bc_factor = res[f"{dset}_scalar", :, month - 1] + factors = global_scalar * bc_factor + logger.info( + f"Retrieved {dset} bias correction factors for month {month}. " + f"Using global_scalar={global_scalar}." + ) + return factors + + @classmethod + def update_file(cls, in_file, out_file, dset, bc_file, global_scalar=1): + """Update the in_file with bias corrected values for the given dset + and write to out_file.""" + tmp_file = out_file.replace(".h5", ".h5.tmp") + logger.info(f"Bias correcting {dset} in {in_file} with {bc_file}.") + with Resource(in_file) as fh_in: + OutputHandler._init_h5( + tmp_file, fh_in.time_index, fh_in.meta, fh_in.global_attrs + ) + OutputHandler._ensure_dset_in_output(tmp_file, dset) + for i in range(1, 13): + try: + with RexOutputs(tmp_file, "a") as fh: + mask = fh.time_index.month == i + mask = np.arange(len(fh.time_index))[mask] + mask = slice(mask[0], mask[-1] + 1) + bc_factors = cls.get_bc_factors( + bc_file=bc_file, + dset=dset, + month=i, + global_scalar=global_scalar, + ) + logger.info( + f"Applying bias correction factors for month {i}" + ) + fh[dset, mask, :] = bc_factors * fh_in[dset, mask, :] + except Exception as e: + raise RuntimeError( + f"Bias correction failed for month {i}." + ) from e + + logger.info( + f"Added {dset} for month {i} to output file {tmp_file}." + ) + + os.replace(tmp_file, out_file) + msg = f"Saved bias corrected {dset} to: {out_file}" + logger.info(msg) + + @classmethod + def run( + cls, + in_file, + out_file, + dset, + bc_file, + overwrite=False, + global_scalar=1, + ): + """Run bias correction update.""" + if os.path.exists(out_file) and not overwrite: + logger.info( + f"{out_file} already exists and overwrite=False. " "Skipping." + ) + else: + if os.path.exists(out_file) and overwrite: + logger.info( + f"{out_file} exists but overwrite=True. " + f"Removing {out_file}." + ) + os.remove(out_file) + cls.update_file( + in_file, out_file, dset, bc_file, global_scalar=global_scalar + ) diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index a35b6012c..49c6a3076 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -24,8 +24,8 @@ class DualDataHandler(CacheHandlingMixIn, TrainingPrepMixIn): def __init__(self, hr_handler, lr_handler, - regrid_cache_pattern=None, - overwrite_regrid_cache=False, + cache_pattern=None, + overwrite_cache=False, regrid_workers=1, load_cached=True, shuffle_time=False, @@ -41,9 +41,9 @@ def __init__(self, DataHandler for high_res data lr_handler : DataHandler DataHandler for low_res data - regrid_cache_pattern : str + cache_pattern : str Pattern for files to use for saving regridded ERA data. - overwrite_regrid_cache : bool + overwrite_cache : bool Whether to overwrite regrid cache regrid_workers : int | None Number of workers to use for regridding routine. @@ -63,10 +63,10 @@ def __init__(self, self.t_enhance = t_enhance self.lr_dh = lr_handler self.hr_dh = hr_handler - self._cache_pattern = regrid_cache_pattern + self._cache_pattern = cache_pattern self._cached_features = None self._noncached_features = None - self.overwrite_cache = overwrite_regrid_cache + self.overwrite_cache = overwrite_cache self.val_split = val_split self.current_obs_index = None self.load_cached = load_cached @@ -179,7 +179,7 @@ def normalize(self, means, stdevs): def output_features(self): """Get list of output features. e.g. those that are returned by a GAN""" - return self.hr_dh.output_features + return self.lr_dh.output_features @property def train_only_features(self): @@ -219,7 +219,7 @@ def _run_pair_checks(self, hr_handler, lr_handler): assert hr_handler.val_split == 0 and lr_handler.val_split == 0, msg msg = ('Handlers have incompatible number of features. ' f'({hr_handler.features} vs {lr_handler.features})') - assert hr_handler.features == lr_handler.features, msg + assert hr_handler.features == lr_handler.output_features, msg hr_shape = hr_handler.sample_shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, @@ -233,21 +233,21 @@ def _run_pair_checks(self, hr_handler, lr_handler): hr_shape = self.hr_data.shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, - hr_shape[2] // self.t_enhance, hr_shape[3]) + hr_shape[2] // self.t_enhance) msg = (f'hr_data.shape {self.hr_data.shape} and ' f'lr_data.shape {self.lr_data.shape} are ' f'incompatible. Must be {hr_shape} and {lr_shape}.') - assert self.lr_data.shape == lr_shape, msg + assert self.lr_data.shape[:-1] == lr_shape, msg if self.lr_val_data is not None and self.hr_val_data is not None: hr_shape = self.hr_val_data.shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, - hr_shape[2] // self.t_enhance, hr_shape[3]) + hr_shape[2] // self.t_enhance) msg = (f'hr_val_data.shape {self.hr_val_data.shape} ' f'and lr_val_data.shape {self.lr_val_data.shape}' f' are incompatible. Must be {hr_shape} and {lr_shape}.') - assert self.lr_val_data.shape == lr_shape, msg + assert self.lr_val_data.shape[:-1] == lr_shape, msg @property def grid_mem(self): @@ -378,8 +378,8 @@ def cache_files(self): cache_files = self._get_cache_file_names(self.cache_pattern, grid_shape=self.lr_grid_shape, time_index=self.lr_time_index, - target=self.hr_dh.target, - features=self.hr_dh.features) + target=self.lr_dh.target, + features=self.lr_dh.features) return cache_files @property @@ -499,7 +499,8 @@ def get_next(self): for s in lr_obs_idx[2:-1]: hr_obs_idx.append( slice(s.start * self.t_enhance, s.stop * self.t_enhance)) - hr_obs_idx.append(lr_obs_idx[-1]) + + hr_obs_idx.append(np.arange(len(self.output_features))) hr_obs_idx = tuple(hr_obs_idx) self.current_obs_index = { 'hr_index': hr_obs_idx, diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index c52c3640d..79da9bbe8 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -282,10 +282,10 @@ def _load_single_cached_feature(self, fp, cache_files, features, Error raised if shape conflicts with requested shape """ idx = cache_files.index(fp) - assert features[idx].lower() in fp.lower() + msg = f'{features[idx].lower()} not found in {fp.lower()}.' + assert features[idx].lower() in fp.lower(), msg fp = ignore_case_path_fetch(fp) - logger.info(f'Loading {features[idx]} from ' - f'{fp}.') + logger.info(f'Loading {features[idx]} from {fp}.') out = None with open(fp, 'rb') as fh: diff --git a/sup3r/preprocessing/dual_batch_handling.py b/sup3r/preprocessing/dual_batch_handling.py index ac1901d9a..5244374e8 100644 --- a/sup3r/preprocessing/dual_batch_handling.py +++ b/sup3r/preprocessing/dual_batch_handling.py @@ -3,9 +3,11 @@ import numpy as np -from sup3r.preprocessing.batch_handling import (Batch, BatchHandler, - ValidationData, - ) +from sup3r.preprocessing.batch_handling import ( + Batch, + BatchHandler, + ValidationData, +) from sup3r.utilities.utilities import uniform_box_sampler, uniform_time_sampler logger = logging.getLogger(__name__) @@ -143,6 +145,16 @@ class DualBatchHandler(BatchHandler): BATCH_CLASS = Batch VAL_CLASS = DualValidationData + @property + def lr_features(self): + """Features in low res batch.""" + return self.data_handlers[0].lr_dh.features + + @property + def hr_features(self): + """Features in high res batch.""" + return self.data_handlers[0].lr_dh.output_features + @property def hr_sample_shape(self): """Get sample shape for high_res data""" @@ -173,19 +185,19 @@ def __next__(self): handler = self.data_handlers[handler_index] high_res = np.zeros((self.batch_size, self.hr_sample_shape[0], self.hr_sample_shape[1], - self.hr_sample_shape[2], self.shape[-1]), + self.hr_sample_shape[2], + len(self.hr_features)), dtype=np.float32) low_res = np.zeros((self.batch_size, self.lr_sample_shape[0], self.lr_sample_shape[1], - self.lr_sample_shape[2], self.shape[-1]), + self.lr_sample_shape[2], + len(self.lr_features)), dtype=np.float32) for i in range(self.batch_size): high_res[i, ...], low_res[i, ...] = handler.get_next() self.current_batch_indices.append(handler.current_obs_index) - high_res = self.BATCH_CLASS.reduce_features( - high_res, self.output_features_ind) batch = self.BATCH_CLASS(low_res=low_res, high_res=high_res) self._i += 1 @@ -220,10 +232,12 @@ def __next__(self): self.current_handler_index = handler_index handler = self.data_handlers[handler_index] high_res = np.zeros((self.batch_size, self.hr_sample_shape[0], - self.hr_sample_shape[1], self.shape[-1]), + self.hr_sample_shape[1], + len(self.hr_features)), dtype=np.float32) low_res = np.zeros((self.batch_size, self.lr_sample_shape[0], - self.lr_sample_shape[1], self.shape[-1]), + self.lr_sample_shape[1], + len(self.lr_features)), dtype=np.float32) for i in range(self.batch_size): @@ -232,8 +246,6 @@ def __next__(self): low_res[i, ...] = lr[..., 0, :] self.current_batch_indices.append(handler.current_obs_index) - high_res = self.BATCH_CLASS.reduce_features( - high_res, self.output_features_ind) batch = self.BATCH_CLASS(low_res=low_res, high_res=high_res) self._i += 1 diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index a7d7b5216..a9771906c 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -48,14 +48,15 @@ class EraDownloader: VALID_VARIABLES: ClassVar[list] = [ 'u', 'v', 'pressure', 'temperature', 'relative_humidity', 'specific_humidity', 'total_precipitation', + 'convective_available_potential_energy' ] - KEEP_VARIABLES: ClassVar[list] = ['orog'] + KEEP_VARIABLES: ClassVar[list] = ['orog', 'cape'] KEEP_VARIABLES += [f'{v}_' for v in VALID_VARIABLES] DEFAULT_RENAMED_VARS: ClassVar[list] = [ 'z', 'zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m', - 'temperature', 'pressure', + 'temperature', 'pressure', 'cape' ] DEFAULT_DOWNLOAD_VARS: ClassVar[list] = [ '10m_u_component_of_wind', '10m_v_component_of_wind', @@ -69,7 +70,7 @@ class EraDownloader: '10m_u_component_of_wind', '10m_v_component_of_wind', '100m_u_component_of_wind', '100m_v_component_of_wind', 'surface_pressure', '2m_temperature', 'geopotential', - 'total_precipitation', + 'total_precipitation', "convective_available_potential_energy" ] LEVEL_VARS: ClassVar[list] = [ 'u_component_of_wind', 'v_component_of_wind', 'geopotential', @@ -88,6 +89,7 @@ class EraDownloader: 'r': 'relative_humidity', 'q': 'specific_humidity', 'tp': 'total_precipitation', + 'cape': 'cape' } def __init__(self, @@ -298,7 +300,8 @@ def download_process_combine(self): """Run the download routine.""" sfc_check = len(self.sfc_file_variables) > 0 level_check = (len(self.level_file_variables) > 0 - and self.levels is not None) + and self.levels is not None + and len(self.levels) > 0) if self.level_file_variables: msg = (f'{self.level_file_variables} requested but no levels' ' were provided.') @@ -316,7 +319,7 @@ def download_levels_file(self): """Download file with requested pressure levels""" if not os.path.exists(self.level_file) or self.overwrite: msg = (f'Downloading {self.level_file_variables} to ' - f'{self.level_file}.') + f'{self.level_file} with levels = {self.levels}.') logger.info(msg) CDS_API_CLIENT.retrieve( 'reanalysis-era5-pressure-levels', { @@ -468,8 +471,7 @@ def process_and_combine(self): self.process_surface_file() files.append(self.surface_file) - logger.info(f'Combining {files} and {self.surface_file} ' - f'to {self.combined_file}.') + logger.info(f'Combining {files} to {self.combined_file}.') with xr.open_mfdataset(files) as ds: ds.to_netcdf(self.combined_file) logger.info(f'Finished writing {self.combined_file}') From 6878a9319c720dc281379a7cd5b3823d7ed7393b Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 30 Oct 2023 21:10:46 -0600 Subject: [PATCH 22/44] unused import --- sup3r/bias/bias_correct_means.py | 45 +++++++++++++++++++------------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py index cbd98e1b5..3a4f19214 100644 --- a/sup3r/bias/bias_correct_means.py +++ b/sup3r/bias/bias_correct_means.py @@ -13,12 +13,12 @@ import rioxarray import xarray as xr from rex import Resource +from scipy.interpolate import interp1d from sklearn.neighbors import BallTree + from sup3r.postprocessing.file_handling import OutputHandler, RexOutputs from sup3r.preprocessing.feature_handling import Feature from sup3r.utilities import VERSION_RECORD -from sup3r.utilities.interpolation import Interpolator -from scipy.interpolate import interp1d logger = logging.getLogger(__name__) @@ -133,11 +133,12 @@ def convert_all_tifs(self): def mask(self): """Mask coordinates without data""" if self._mask is None: - with xr.open_mfdataset(self.get_height_files('January')) as res: - mask = ((res[self.in_features[0]] != -999) - & (~np.isnan(res[self.in_features[0]]))) + with xr.open_mfdataset(self.get_height_files("January")) as res: + mask = (res[self.in_features[0]] != -999) & ( + ~np.isnan(res[self.in_features[0]]) + ) for feat in self.in_features[1:]: - tmp = ((res[feat] != -999) & (~np.isnan(res[feat]))) + tmp = (res[feat] != -999) & (~np.isnan(res[feat])) mask = mask & tmp self._mask = np.array(mask).flatten() return self._mask @@ -200,19 +201,27 @@ def interp(self, data): ( len(data.latitude) * len(data.longitude), len(self.in_heights), - ), dtype=np.float32 + ), + dtype=np.float32, ) lev_array = var_array.copy() for i, (h, feat) in enumerate(zip(self.in_heights, self.in_features)): var_array[..., i] = data[feat].values.flatten() lev_array[..., i] = h - logger.info(f'Interpolating {self.in_features} to {self.out_features} ' - f'for {var_array.shape[0]} coordinates.') - tmp = [interp1d(h, v, fill_value='extrapolate')(self.out_heights) - for h, v in zip(lev_array[self.mask], var_array[self.mask])] - out = np.full((len(data.latitude), len(data.longitude), - len(self.out_heights)), np.nan, dtype=np.float32) + logger.info( + f"Interpolating {self.in_features} to {self.out_features} " + f"for {var_array.shape[0]} coordinates." + ) + tmp = [ + interp1d(h, v, fill_value="extrapolate")(self.out_heights) + for h, v in zip(lev_array[self.mask], var_array[self.mask]) + ] + out = np.full( + (len(data.latitude), len(data.longitude), len(self.out_heights)), + np.nan, + dtype=np.float32, + ) out[self.mask.reshape((len(data.latitude), len(data.longitude)))] = tmp for i, feat in enumerate(self.out_features): if feat not in data: @@ -578,14 +587,14 @@ def run( fp_out, global_scalar=1.0, knn=1, - out_shape=None + out_shape=None, ): """Run bias correction factor computation and write.""" bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) out = bc.get_corrections(global_scalar=global_scalar, knn=knn) if out_shape is not None: for k, v in out.items(): - if k in ('latitude', 'longitude'): + if k in ("latitude", "longitude"): out[k] = np.array(v).reshape(out_shape) elif not isinstance(v, float): out[k] = np.array(v).reshape((*out_shape, 12)) @@ -600,7 +609,7 @@ def run_uv( fp_pattern, global_scalar=1.0, knn=1, - out_shape=None + out_shape=None, ): """Run bias correction factor computation and write.""" bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) @@ -609,12 +618,12 @@ def run_uv( ) if out_shape is not None: for k, v in out_u.items(): - if k in ('latitude', 'longitude'): + if k in ("latitude", "longitude"): out_u[k] = np.array(v).reshape(out_shape) elif not isinstance(v, float): out_u[k] = np.array(v).reshape((*out_shape, 12)) for k, v in out_v.items(): - if k in ('latitude', 'longitude'): + if k in ("latitude", "longitude"): out_v[k] = np.array(v).reshape(out_shape) elif not isinstance(v, float): out_v[k] = np.array(v).reshape((*out_shape, 12)) From 874e88e67ff92e8854b2580f61125084ae942140 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 31 Oct 2023 12:45:15 -0600 Subject: [PATCH 23/44] exo cache naming issue now that data has temporal dimension --- .../data_handling/exogenous_data_handling.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index 5ff81b831..f8bda65a2 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -567,8 +567,12 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, cache_fp : str Name of cache file """ - fn = f'exo_{feature}_{self.target}_{self.shape}_sagg{s_agg_factor}_' - fn += f'tagg{t_agg_factor}_{s_enhance}x_{t_enhance}x.pkl' + tsteps = (None if self.temporal_slice.start is None + or self.temporal_slice.end is None + else self.temporal_slice.end - self.temporal_slice.start) + fn = f'exo_{feature}_{self.target}_{self.shape},{tsteps}' + fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_' + fn += f'{t_enhance}x.pkl' fn = fn.replace('(', '').replace(')', '') fn = fn.replace('[', '').replace(']', '') fn = fn.replace(',', 'x').replace(' ', '') From 26d01634c17bde1199cd400292a719693d8683d3 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 1 Nov 2023 08:50:26 -0600 Subject: [PATCH 24/44] exo caching moved to exo_extract and time independent naming convention used for topo. --- .../data_handling/exo_extraction.py | 139 +++++++++++++++++- .../data_handling/exogenous_data_handling.py | 88 ++--------- sup3r/utilities/era_downloader.py | 82 ++++++----- sup3r/utilities/interpolate_log_profile.py | 1 + 4 files changed, 196 insertions(+), 114 deletions(-) diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 1290703f2..5066d082c 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -1,6 +1,9 @@ """Sup3r topography utilities""" import logging +import os +import pickle +import shutil from abc import ABC, abstractmethod import numpy as np @@ -37,6 +40,8 @@ def __init__(self, raster_file=None, max_delta=20, input_handler=None, + cache_data=True, + cache_dir='./exo_cache/', ti_workers=1): """ Parameters @@ -100,6 +105,12 @@ def __init__(self, data handler class to use for input data. Provide a string name to match a class in data_handling.py. If None the correct handler will be guessed based on file type and time series properties. + cache_data : bool + Flag to cache exogeneous data in /exo_cache/ this can + speed up forward passes with large temporal extents when the exo + data is time independent. + cache_dir : str + Directory for storing cache data. Default is './exo_cache' ti_workers : int | None max number of workers to use to get full time index. Useful when there are many input files each with a single time step. If this is @@ -122,6 +133,11 @@ def __init__(self, self._source_lat_lon = None self._hr_time_index = None self._src_time_index = None + self.cache_data = cache_data + self.cache_dir = cache_dir + self.temporal_slice = temporal_slice + self.target = target + self.shape = shape if input_handler is None: in_type = get_source_type(file_paths) @@ -159,6 +175,46 @@ def __init__(self, def source_data(self): """Get the 1D array of source data from the exo_source_h5""" + def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, + t_agg_factor): + """Get cache file name + + Parameters + ---------- + feature : str + Name of feature to get cache file for + s_enhance : int + Spatial enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + t_enhance : int + Temporal enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + s_agg_factor : int + Factor by which to aggregate the exo_source data to the spatial + resolution of the file_paths input enhanced by s_enhance. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the temporal + resolution of the file_paths input enhanced by t_enhance. + + Returns + ------- + cache_fp : str + Name of cache file + """ + tsteps = (None if self.temporal_slice.start is None + or self.temporal_slice.stop is None + else self.temporal_slice.stop - self.temporal_slice.start) + fn = f'exo_{feature}_{self.target}_{self.shape},{tsteps}' + fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_' + fn += f'{t_enhance}x.pkl' + fn = fn.replace('(', '').replace(')', '') + fn = fn.replace('[', '').replace(']', '') + fn = fn.replace(',', 'x').replace(' ', '') + cache_fp = os.path.join(self.cache_dir, fn) + if self.cache_data: + os.makedirs(self.cache_dir, exist_ok=True) + return cache_fp + @property def source_temporal_slice(self): """Get the temporal slice for the exo_source data corresponding to the @@ -266,6 +322,30 @@ def nn(self): @property def data(self): + """Get a raster of source values corresponding to the + high-resolution grid (the file_paths input grid * s_enhance * + t_enhance). The shape is (lats, lons, temporal, 1) + """ + cache_fp = self.get_cache_file(feature=self.__class__.__name__, + s_enhance=self._s_enhance, + t_enhance=self._t_enhance, + s_agg_factor=self._s_agg_factor, + t_agg_factor=self._t_agg_factor) + tmp_fp = cache_fp + '.tmp' + if os.path.exists(cache_fp): + with open(cache_fp, 'rb') as f: + data = pickle.load(f) + + else: + data = self.get_data() + + if self.cache_data: + with open(tmp_fp, 'wb') as f: + pickle.dump(data, f) + shutil.move(tmp_fp, cache_fp) + return data + + def get_data(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance * t_enhance). The shape is (lats, lons, temporal, 1) @@ -293,7 +373,9 @@ def get_exo_raster(cls, temporal_slice=None, raster_file=None, max_delta=20, - input_handler=None): + input_handler=None, + cache_data=True, + cache_dir='./exo_cache/'): """Get the exo feature raster corresponding to the spatially enhanced grid from the file_paths input @@ -356,6 +438,12 @@ class will output a topography raster corresponding to the data handler class to use for input data. Provide a string name to match a class in data_handling.py. If None the correct handler will be guessed based on file type and time series properties. + cache_data : bool + Flag to cache exogeneous data in /exo_cache/ this can + speed up forward passes with large temporal extents when the exo + data is time independent. + cache_dir : str + Directory for storing cache data. Default is './exo_cache' Returns ------- @@ -377,7 +465,9 @@ class will output a topography raster corresponding to the temporal_slice=temporal_slice, raster_file=raster_file, max_delta=max_delta, - input_handler=input_handler) + input_handler=input_handler, + cache_data=cache_data, + cache_dir=cache_dir) return exo.data @@ -407,8 +497,7 @@ def source_time_index(self): self._src_time_index = res.time_index return self._src_time_index - @property - def data(self): + def get_data(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance * t_enhance). The shape is (lats, lons, temporal, 1) @@ -421,7 +510,44 @@ def data(self): hr_data.append(out[..., np.newaxis]) hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) logger.info('Finished mapping raster from {}'.format(self._exo_source)) - return hr_data[..., np.newaxis] + out = hr_data[..., np.newaxis] + + def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, + t_agg_factor): + """Get cache file name. This uses a time independent naming convention. + + Parameters + ---------- + feature : str + Name of feature to get cache file for + s_enhance : int + Spatial enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + t_enhance : int + Temporal enhancement for this exogeneous data step (cumulative for + all model steps up to the current step). + s_agg_factor : int + Factor by which to aggregate the exo_source data to the spatial + resolution of the file_paths input enhanced by s_enhance. + t_agg_factor : int + Factor by which to aggregate the exo_source data to the temporal + resolution of the file_paths input enhanced by t_enhance. + + Returns + ------- + cache_fp : str + Name of cache file + """ + fn = f'exo_{feature}_{self.target}_{self.shape}' + fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_' + fn += f'{t_enhance}x.pkl' + fn = fn.replace('(', '').replace(')', '') + fn = fn.replace('[', '').replace(']', '') + fn = fn.replace(',', 'x').replace(' ', '') + cache_fp = os.path.join(self.cache_dir, fn) + if self.cache_data: + os.makedirs(self.cache_dir, exist_ok=True) + return cache_fp class TopoExtractNC(TopoExtractH5): @@ -470,8 +596,7 @@ def source_data(self): return SolarPosition(self.hr_time_index, self.hr_lat_lon.reshape((-1, 2))).zenith.T - @property - def data(self): + def get_data(self): """Get a raster of source values corresponding to the high-resolution grid (the file_paths input grid * s_enhance * t_enhance). The shape is (lats, lons, temporal, 1) diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index f8bda65a2..f3d35fe73 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -1,9 +1,6 @@ """Sup3r exogenous data handling""" import logging -import os -import pickle import re -import shutil from typing import ClassVar import numpy as np @@ -541,46 +538,6 @@ def _get_all_agg_and_enhancement(self): for step in self.steps] return agg_enhance_dict - def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, - t_agg_factor): - """Get cache file name - - Parameters - ---------- - feature : str - Name of feature to get cache file for - s_enhance : int - Spatial enhancement for this exogeneous data step (cumulative for - all model steps up to the current step). - t_enhance : int - Temporal enhancement for this exogeneous data step (cumulative for - all model steps up to the current step). - s_agg_factor : int - Factor by which to aggregate the exo_source data to the spatial - resolution of the file_paths input enhanced by s_enhance. - t_agg_factor : int - Factor by which to aggregate the exo_source data to the temporal - resolution of the file_paths input enhanced by t_enhance. - - Returns - ------- - cache_fp : str - Name of cache file - """ - tsteps = (None if self.temporal_slice.start is None - or self.temporal_slice.end is None - else self.temporal_slice.end - self.temporal_slice.start) - fn = f'exo_{feature}_{self.target}_{self.shape},{tsteps}' - fn += f'_sagg{s_agg_factor}_tagg{t_agg_factor}_{s_enhance}x_' - fn += f'{t_enhance}x.pkl' - fn = fn.replace('(', '').replace(')', '') - fn = fn.replace('[', '').replace(']', '') - fn = fn.replace(',', 'x').replace(' ', '') - cache_fp = os.path.join(self.cache_dir, fn) - if self.cache_data: - os.makedirs(self.cache_dir, exist_ok=True) - return cache_fp - def get_exo_data(self, feature, s_enhance, t_enhance, s_agg_factor, t_agg_factor): """Get the exogenous topography data @@ -609,35 +566,22 @@ def get_exo_data(self, feature, s_enhance, t_enhance, s_agg_factor, lon, temporal) """ - cache_fp = self.get_cache_file(feature=feature, - s_enhance=s_enhance, - t_enhance=t_enhance, - s_agg_factor=s_agg_factor, - t_agg_factor=t_agg_factor) - tmp_fp = cache_fp + '.tmp' - if os.path.exists(cache_fp): - with open(cache_fp, 'rb') as f: - data = pickle.load(f) - - else: - exo_handler = self.get_exo_handler(feature, self.source_file, - self.exo_handler) - data = exo_handler(self.file_paths, - self.source_file, - s_enhance=s_enhance, - t_enhance=t_enhance, - s_agg_factor=s_agg_factor, - t_agg_factor=t_agg_factor, - target=self.target, - shape=self.shape, - temporal_slice=self.temporal_slice, - raster_file=self.raster_file, - max_delta=self.max_delta, - input_handler=self.input_handler).data - if self.cache_data: - with open(tmp_fp, 'wb') as f: - pickle.dump(data, f) - shutil.move(tmp_fp, cache_fp) + exo_handler = self.get_exo_handler(feature, self.source_file, + self.exo_handler) + data = exo_handler(self.file_paths, + self.source_file, + s_enhance=s_enhance, + t_enhance=t_enhance, + s_agg_factor=s_agg_factor, + t_agg_factor=t_agg_factor, + target=self.target, + shape=self.shape, + temporal_slice=self.temporal_slice, + raster_file=self.raster_file, + max_delta=self.max_delta, + input_handler=self.input_handler, + cache_data=self.cache_data, + cache_dir=self.cache_dir).data return data @classmethod diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index a9771906c..5fb5c94e9 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -48,22 +48,7 @@ class EraDownloader: VALID_VARIABLES: ClassVar[list] = [ 'u', 'v', 'pressure', 'temperature', 'relative_humidity', 'specific_humidity', 'total_precipitation', - 'convective_available_potential_energy' - ] - - KEEP_VARIABLES: ClassVar[list] = ['orog', 'cape'] - KEEP_VARIABLES += [f'{v}_' for v in VALID_VARIABLES] - - DEFAULT_RENAMED_VARS: ClassVar[list] = [ - 'z', 'zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m', - 'temperature', 'pressure', 'cape' - ] - DEFAULT_DOWNLOAD_VARS: ClassVar[list] = [ - '10m_u_component_of_wind', '10m_v_component_of_wind', - '100m_u_component_of_wind', '100m_v_component_of_wind', - 'u_component_of_wind', 'v_component_of_wind', '2m_temperature', - 'temperature', 'surface_pressure', 'relative_humidity', - 'total_precipitation', + 'convective_available_potential_energy', 'divergence' ] SFC_VARS: ClassVar[list] = [ @@ -74,7 +59,7 @@ class EraDownloader: ] LEVEL_VARS: ClassVar[list] = [ 'u_component_of_wind', 'v_component_of_wind', 'geopotential', - 'temperature', 'relative_humidity', 'specific_humidity', + 'temperature', 'relative_humidity', 'specific_humidity', 'divergence' ] NAME_MAP: ClassVar[dict] = { 'u10': 'u_10m', @@ -89,7 +74,8 @@ class EraDownloader: 'r': 'relative_humidity', 'q': 'specific_humidity', 'tp': 'total_precipitation', - 'cape': 'cape' + 'cape': 'cape', + 'd': 'divergence' } def __init__(self, @@ -387,16 +373,16 @@ def map_vars(self, old_ds, ds): ds : Dataset Dataset() object for new file with new variables written. """ - for old_name, new_name in self.NAME_MAP.items(): - if old_name in old_ds.variables: - _ = ds.createVariable(new_name, - np.float32, - dimensions=old_ds[old_name].dimensions, - ) - vals = old_ds.variables[old_name][:] - if 'temperature' in new_name: - vals -= 273.15 - ds.variables[new_name][:] = vals + for old_name in old_ds.variables: + new_name = self.NAME_MAP.get(old_name, old_name) + _ = ds.createVariable(new_name, + np.float32, + dimensions=old_ds[old_name].dimensions, + ) + vals = old_ds.variables[old_name][:] + if 'temperature' in new_name: + vals -= 273.15 + ds.variables[new_name][:] = vals return ds def convert_z(self, standard_name, long_name, old_ds, ds): @@ -542,7 +528,8 @@ def run_interpolation(self, max_workers=None, **kwargs): overwrite=self.overwrite, **kwargs) - def get_monthly_file(self, interp_workers=None, **interp_kwargs): + def get_monthly_file(self, interp_workers=None, keep_variables=None, + **interp_kwargs): """Download level and surface files, process variables, and combine processed files. Includes checks for shape and variables and option to interpolate.""" @@ -559,10 +546,10 @@ def get_monthly_file(self, interp_workers=None, **interp_kwargs): self.run_interpolation(max_workers=interp_workers, **interp_kwargs) if self.interp_file is not None and os.path.exists(self.interp_file): - if self.already_pruned(self.interp_file): + if self.already_pruned(self.interp_file, keep_variables): logger.info(f'{self.interp_file} pruned already.') else: - self.prune_output(self.interp_file) + self.prune_output(self.interp_file, keep_variables) @classmethod def all_months_exist(cls, year, file_pattern): @@ -587,21 +574,35 @@ def all_months_exist(cls, year, file_pattern): for month in range(1, 13)) @classmethod - def already_pruned(cls, infile): + def already_pruned(cls, infile, keep_variables): """Check if file has been pruned already.""" + if keep_variables is None: + logger.info('Received keep_variables=None. Skipping pruning.') + return + else: + logger.info( + f'Received keep_variables={keep_variables}. Skipping pruning.') + pruned = True with Dataset(infile, 'r') as ds: for var in ds.variables: - if not any(name in var for name in cls.KEEP_VARIABLES): + if not any(name in var for name in keep_variables): logger.info(f'Pruning {var} in {infile}.') pruned = False return pruned @classmethod - def prune_output(cls, infile): + def prune_output(cls, infile, keep_variables=None): """Prune output file to keep just single level variables""" + if keep_variables is None: + logger.info('Received keep_variables=None. Skipping pruning.') + return + else: + logger.info( + f'Received keep_variables={keep_variables}. Skipping pruning.') + logger.info(f'Pruning {infile}.') tmp_file = cls.get_tmp_file(infile) with Dataset(infile, 'r') as old_ds: @@ -609,7 +610,7 @@ def prune_output(cls, infile): new_ds = cls.init_dims(old_ds, new_ds, ('time', 'latitude', 'longitude')) for var in old_ds.variables: - if any(name in var for name in cls.KEEP_VARIABLES): + if any(name in var for name in keep_variables): old_var = old_ds[var] vals = old_var[:] _ = new_ds.createVariable( @@ -639,6 +640,7 @@ def run_month(cls, required_shape=None, interp_workers=None, variables=None, + keep_variables=None, **interp_kwargs): """Run routine for all months in the requested year. @@ -672,6 +674,9 @@ def run_month(cls, variables : list | None Variables to download. If None this defaults to just gepotential and wind components. + keep_variables : list | None + Variables to keep in final files. All other variables will be + pruned. **interp_kwargs : dict Keyword args for LogLinInterpolator.run() """ @@ -686,6 +691,7 @@ def run_month(cls, required_shape=required_shape, variables=variables) downloader.get_monthly_file(interp_workers=interp_workers, + keep_variables=keep_variables, **interp_kwargs) @classmethod @@ -703,6 +709,7 @@ def run_year(cls, max_workers=None, interp_workers=None, variables=None, + keep_variables=None, **interp_kwargs): """Run routine for all months in the requested year. @@ -741,6 +748,9 @@ def run_year(cls, variables : list | None Variables to download. If None this defaults to just gepotential and wind components. + keep_variables : list | None + Variables to keep in final files. All other variables will be + pruned. **interp_kwargs : dict Keyword args for LogLinInterpolator.run() """ @@ -757,6 +767,7 @@ def run_year(cls, required_shape=required_shape, interp_workers=interp_workers, variables=variables, + keep_variables=keep_variables, **interp_kwargs) else: futures = {} @@ -774,6 +785,7 @@ def run_year(cls, overwrite=overwrite, required_shape=required_shape, interp_workers=interp_workers, + keep_variables=keep_variables, variables=variables, **interp_kwargs) futures[future] = {'year': year, 'month': month} diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index 7a7f51910..993b1dcc1 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -37,6 +37,7 @@ class LogLinInterpolator: 'temperature': [10, 40, 80, 100, 120, 160, 200], 'pressure': [0, 100, 200], 'relative_humidity': [80, 100, 120], + 'divergence': [80, 100, 120] } def __init__( From 34ca76a4d043dde9c72145ff39130a7bbcba8c08 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 1 Nov 2023 09:11:14 -0600 Subject: [PATCH 25/44] modified methods to cache single time step for time independent exo data and apply repeat after loading cache. --- .../data_handling/exo_extraction.py | 17 ++++++++++------- sup3r/utilities/era_downloader.py | 6 ++---- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 5066d082c..428580ac7 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -201,7 +201,8 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, cache_fp : str Name of cache file """ - tsteps = (None if self.temporal_slice.start is None + tsteps = (None if self.temporal_slice is None + or self.temporal_slice.start is None or self.temporal_slice.stop is None else self.temporal_slice.stop - self.temporal_slice.start) fn = f'exo_{feature}_{self.target}_{self.shape},{tsteps}' @@ -343,6 +344,10 @@ def data(self): with open(tmp_fp, 'wb') as f: pickle.dump(data, f) shutil.move(tmp_fp, cache_fp) + + if data.shape[-2] == 1 and self.hr_shape[-1] > 1: + data = np.repeat(data[..., np.newaxis, :], self.hr_shape[-1], + axis=-2) return data def get_data(self): @@ -479,8 +484,7 @@ def source_data(self): """Get the 1D array of elevation data from the exo_source_h5""" with Resource(self._exo_source) as res: elev = res.get_meta_arr('elevation') - elev = np.repeat(elev[:, np.newaxis], self.hr_shape[-1], axis=-1) - return elev + return elev[:, np.newaxis] @property def source_lat_lon(self): @@ -506,11 +510,11 @@ def get_data(self): hr_data = [] for j in range(self._s_agg_factor): out = self.source_data[nn[:, j]] - out = out.reshape(self.hr_shape) + out = out.reshape((*self.hr_shape[:-1], -1)) hr_data.append(out[..., np.newaxis]) hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) logger.info('Finished mapping raster from {}'.format(self._exo_source)) - out = hr_data[..., np.newaxis] + return hr_data[..., np.newaxis] def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, t_agg_factor): @@ -577,8 +581,7 @@ def __init__(self, *args, **kwargs): def source_data(self): """Get the 1D array of elevation data from the exo_source_nc""" elev = self.source_handler.data[..., 0, 0].flatten() - elev = np.repeat(elev[..., np.newaxis], self.hr_shape[-1], axis=-1) - return elev + return elev[..., np.newaxis] @property def source_lat_lon(self): diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 5fb5c94e9..2d9fa0711 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -581,8 +581,7 @@ def already_pruned(cls, infile, keep_variables): logger.info('Received keep_variables=None. Skipping pruning.') return else: - logger.info( - f'Received keep_variables={keep_variables}. Skipping pruning.') + logger.info(f'Received keep_variables={keep_variables}.') pruned = True with Dataset(infile, 'r') as ds: @@ -600,8 +599,7 @@ def prune_output(cls, infile, keep_variables=None): logger.info('Received keep_variables=None. Skipping pruning.') return else: - logger.info( - f'Received keep_variables={keep_variables}. Skipping pruning.') + logger.info(f'Received keep_variables={keep_variables}.') logger.info(f'Pruning {infile}.') tmp_file = cls.get_tmp_file(infile) From 627e2a0c2cccc4f1fb2807f4f27acd534a6b1aa6 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 1 Nov 2023 11:30:22 -0600 Subject: [PATCH 26/44] normalization bug for dual handlers - assumed lr and hr handlers had same features originally but this isnt required. means/stds have to be indexed carefully. --- .../data_handling/dual_data_handling.py | 13 ++++++++----- sup3r/preprocessing/data_handling/mixin.py | 4 ++++ 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/sup3r/preprocessing/data_handling/dual_data_handling.py b/sup3r/preprocessing/data_handling/dual_data_handling.py index 49c6a3076..e22507b57 100644 --- a/sup3r/preprocessing/data_handling/dual_data_handling.py +++ b/sup3r/preprocessing/data_handling/dual_data_handling.py @@ -162,17 +162,20 @@ def normalize(self, means, stdevs): dimensions (features) array of means for all features with same ordering as data features """ - logger.info('Normalizing low resolution data.') + logger.info('Normalizing low resolution data features=' + f'{self.features}') self._normalize(data=self.lr_data, val_data=self.lr_val_data, means=means, stds=stdevs, max_workers=self.lr_dh.norm_workers) - logger.info('Normalizing high resolution data.') + logger.info('Normalizing high resolution data features=' + f'{self.output_features}') + indices = [self.features.index(f) for f in self.output_features] self._normalize(data=self.hr_data, val_data=self.hr_val_data, - means=means, - stds=stdevs, + means=means[indices], + stds=stdevs[indices], max_workers=self.hr_dh.norm_workers) @property @@ -219,7 +222,7 @@ def _run_pair_checks(self, hr_handler, lr_handler): assert hr_handler.val_split == 0 and lr_handler.val_split == 0, msg msg = ('Handlers have incompatible number of features. ' f'({hr_handler.features} vs {lr_handler.features})') - assert hr_handler.features == lr_handler.output_features, msg + assert hr_handler.features == self.output_features, msg hr_shape = hr_handler.sample_shape lr_shape = (hr_shape[0] // self.s_enhance, hr_shape[1] // self.s_enhance, diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index 79da9bbe8..f2f2c0494 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -1054,6 +1054,10 @@ def _normalize(self, data, val_data, means, stds, max_workers=None): max_workers : int | None Number of workers to use in thread pool for nomalization. """ + msg = f'Received {len(means)} means for {data.shape[-1]} features' + assert len(means) == data.shape[-1], msg + msg = f'Received {len(stds)} stds for {data.shape[-1]} features' + assert len(stds) == data.shape[-1], msg logger.info(f'Normalizing {data.shape[-1]} features.') if max_workers == 1: for i in range(data.shape[-1]): From 63b9227abe48b275656e7fd9544b262131850262 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 1 Nov 2023 11:30:22 -0600 Subject: [PATCH 27/44] normalization bug for dual handlers - assumed lr and hr handlers had same features originally but this isnt required. means/stds have to be indexed carefully. --- sup3r/preprocessing/data_handling/exo_extraction.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 428580ac7..9901b4b83 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -345,10 +345,9 @@ def data(self): pickle.dump(data, f) shutil.move(tmp_fp, cache_fp) - if data.shape[-2] == 1 and self.hr_shape[-1] > 1: - data = np.repeat(data[..., np.newaxis, :], self.hr_shape[-1], - axis=-2) - return data + if data.shape[-1] == 1 and self.hr_shape[-1] > 1: + data = np.repeat(data, self.hr_shape[-1], axis=-1) + return data[..., np.newaxis] def get_data(self): """Get a raster of source values corresponding to the @@ -363,7 +362,7 @@ def get_data(self): hr_data.append(out[..., np.newaxis]) hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) logger.info('Finished mapping raster from {}'.format(self._exo_source)) - return hr_data[..., np.newaxis] + return hr_data @classmethod def get_exo_raster(cls, @@ -514,7 +513,7 @@ def get_data(self): hr_data.append(out[..., np.newaxis]) hr_data = np.concatenate(hr_data, axis=-1).mean(axis=-1) logger.info('Finished mapping raster from {}'.format(self._exo_source)) - return hr_data[..., np.newaxis] + return hr_data def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, t_agg_factor): From 7b6c5c2dc74a54093c03b676534c74035a210a6f Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 3 Nov 2023 07:08:55 -0600 Subject: [PATCH 28/44] some edits to netcdf output for multistep forward-pass pipeline (run spatial fwp with netcdf output and then temporal fwp on that output) --- sup3r/postprocessing/file_handling.py | 2 +- .../data_handling/exo_extraction.py | 8 +++++++- .../data_handling/exogenous_data_handling.py | 10 ++++++++-- sup3r/utilities/era_downloader.py | 17 +++++++++-------- 4 files changed, 25 insertions(+), 12 deletions(-) diff --git a/sup3r/postprocessing/file_handling.py b/sup3r/postprocessing/file_handling.py index f86339f16..c1244366c 100644 --- a/sup3r/postprocessing/file_handling.py +++ b/sup3r/postprocessing/file_handling.py @@ -560,7 +560,7 @@ def _write_output(cls, data, features, lat_lon, times, out_file, List of coordinate indices used to label each lat lon pair and to help with spatial chunk data collection """ - coords = {'Times': (['Time'], times), + coords = {'Times': (['Time'], [str(t).encode('utf-8') for t in times]), 'XLAT': (['south_north', 'east_west'], lat_lon[..., 0]), 'XLONG': (['south_north', 'east_west'], lat_lon[..., 1])} diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 9901b4b83..8e3c60d07 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -42,7 +42,8 @@ def __init__(self, input_handler=None, cache_data=True, cache_dir='./exo_cache/', - ti_workers=1): + ti_workers=1, + res_kwargs=None): """ Parameters ---------- @@ -118,6 +119,9 @@ def __init__(self, parallel and then concatenated to get the full time index. If input files do not all have time indices or if there are few input files this should be set to one. + res_kwargs : dict | None + Dictionary of kwargs passed to lowest level resource handler. e.g. + xr.open_dataset(file_paths, **res_kwargs) """ logger.info(f'Initializing {self.__class__.__name__} utility.') @@ -138,6 +142,7 @@ def __init__(self, self.temporal_slice = temporal_slice self.target = target self.shape = shape + self.res_kwargs = res_kwargs if input_handler is None: in_type = get_source_type(file_paths) @@ -168,6 +173,7 @@ def __init__(self, raster_file=raster_file, max_delta=max_delta, worker_kwargs=dict(ti_workers=ti_workers), + res_kwargs=self.res_kwargs ) @property diff --git a/sup3r/preprocessing/data_handling/exogenous_data_handling.py b/sup3r/preprocessing/data_handling/exogenous_data_handling.py index f3d35fe73..3149887e4 100644 --- a/sup3r/preprocessing/data_handling/exogenous_data_handling.py +++ b/sup3r/preprocessing/data_handling/exogenous_data_handling.py @@ -206,7 +206,8 @@ def __init__(self, input_handler=None, exo_handler=None, cache_data=True, - cache_dir='./exo_cache'): + cache_dir='./exo_cache', + res_kwargs=None): """ Parameters ---------- @@ -279,6 +280,9 @@ def __init__(self, speed up forward passes with large temporal extents cache_dir : str Directory for storing cache data. Default is './exo_cache' + res_kwargs : dict | None + Dictionary of kwargs passed to lowest level resource handler. e.g. + xr.open_dataset(file_paths, **res_kwargs) """ self.feature = feature @@ -297,6 +301,7 @@ def __init__(self, self.cache_data = cache_data self.cache_dir = cache_dir self.data = {feature: {'steps': []}} + self.res_kwargs = res_kwargs self.input_check() agg_enhance = self._get_all_agg_and_enhancement() @@ -581,7 +586,8 @@ def get_exo_data(self, feature, s_enhance, t_enhance, s_agg_factor, max_delta=self.max_delta, input_handler=self.input_handler, cache_data=self.cache_data, - cache_dir=self.cache_dir).data + cache_dir=self.cache_dir, + res_kwargs=self.res_kwargs).data return data @classmethod diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 2d9fa0711..0a572c9ef 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -375,14 +375,15 @@ def map_vars(self, old_ds, ds): """ for old_name in old_ds.variables: new_name = self.NAME_MAP.get(old_name, old_name) - _ = ds.createVariable(new_name, - np.float32, - dimensions=old_ds[old_name].dimensions, - ) - vals = old_ds.variables[old_name][:] - if 'temperature' in new_name: - vals -= 273.15 - ds.variables[new_name][:] = vals + if new_name not in ds.variables: + _ = ds.createVariable(new_name, + np.float32, + dimensions=old_ds[old_name].dimensions, + ) + vals = old_ds.variables[old_name][:] + if 'temperature' in new_name: + vals -= 273.15 + ds.variables[new_name][:] = vals return ds def convert_z(self, standard_name, long_name, old_ds, ds): From d61a7925c2473adc22f374643aafa97d36d60cc9 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 3 Nov 2023 08:20:29 -0600 Subject: [PATCH 29/44] bc code doc strings --- sup3r/bias/bias_correct_means.py | 318 +++++++++++++++--- .../data_handling/exo_extraction.py | 14 +- .../data_handling/test_dual_data_handling.py | 8 +- 3 files changed, 286 insertions(+), 54 deletions(-) diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py index 3a4f19214..eb4f1f996 100644 --- a/sup3r/bias/bias_correct_means.py +++ b/sup3r/bias/bias_correct_means.py @@ -1,11 +1,13 @@ """Classes to compute means from vortex and era data and compute bias -correction factors.""" +correction factors. +""" import calendar import json import logging import os +from concurrent.futures import ThreadPoolExecutor, as_completed import h5py import numpy as np @@ -15,7 +17,6 @@ from rex import Resource from scipy.interpolate import interp1d from sklearn.neighbors import BallTree - from sup3r.postprocessing.file_handling import OutputHandler, RexOutputs from sup3r.preprocessing.feature_handling import Feature from sup3r.utilities import VERSION_RECORD @@ -26,11 +27,11 @@ class VortexMeanPrepper: """Class for converting monthly vortex tif files for each height to a single h5 files containing all monthly means for all requested output - heights""" + heights. + """ def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): - """ - Parameters + """Parameters ---------- path_pattern : str Pattern for input tif files. Needs to include {month} and {height} @@ -52,7 +53,7 @@ def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): @property def in_features(self): - """List of features corresponding to input heights""" + """List of features corresponding to input heights.""" return [f"windspeed_{h}m" for h in self.in_heights] @property @@ -92,7 +93,8 @@ def get_output_file(self, month): @property def output_files(self): """List of output monthly output files each with windspeed for all - input heights""" + input heights + """ files = [] for i in range(1, 13): month = calendar.month_name[i] @@ -101,7 +103,8 @@ def output_files(self): def convert_month_height_tif(self, month, height): """Get windspeed mean for the given month and hub height from the - corresponding input file and write this to a netcdf file.""" + corresponding input file and write this to a netcdf file. + """ infile = self.get_input_file(month, height) logger.info(f"Getting mean windspeed_{height}m for {month}.") outfile = infile.replace(".tif", ".nc") @@ -118,12 +121,12 @@ def convert_month_height_tif(self, month, height): return outfile def convert_month_tif(self, month): - """Write netcdf files for all heights for the given month""" + """Write netcdf files for all heights for the given month.""" for height in self.in_heights: self.convert_month_height_tif(month, height) def convert_all_tifs(self): - """Write netcdf files for all heights for all months""" + """Write netcdf files for all heights for all months.""" for i in range(1, 13): month = calendar.month_name[i] logger.info(f"Converting tif files to netcdf files for {month}") @@ -299,7 +302,23 @@ def run( ): """Read vortex tif files, convert these to monthly netcdf files for all input heights, interpolate this data to requested output heights, mask - fill values, and write all data to h5 file.""" + fill values, and write all data to h5 file. + + Parameters + ---------- + path_pattern : str + Pattern for input tif files. Needs to include {month} and {height} + format keys. + in_heights : list + List of heights for input files. + out_heights : list + List of output heights used for interpolation + fp_out : str + Name of final h5 output file to write with means. + overwrite : bool + Whether to overwrite intermediate netcdf files containing the + interpolated masked monthly means. + """ vprep = cls(path_pattern, in_heights, out_heights, overwrite=overwrite) vprep.convert_all_tifs() out = vprep.get_all_data() @@ -310,8 +329,7 @@ class EraMeanPrepper: """Class to compute monthly windspeed means from ERA data.""" def __init__(self, era_pattern, years, features): - """ - Parameters + """Parameters ---------- era_pattern : str Pattern pointing to era files with u/v wind components at the given @@ -438,7 +456,21 @@ def write_csv(self, out, out_file): def run(cls, era_pattern, years, features, out_pattern): """Compute monthly windspeed means for the given heights, using the given years of ERA data, and write the means to csv files for each - height.""" + height. + + Parameters + ---------- + era_pattern : str + Pattern pointing to era files with u/v wind components at the given + heights. Must have a {year} format key. + years : list + List of ERA years to use for calculating means. + features : list + List of features to compute means for. e.g. ['windspeed_10m'] + out_pattern : str + Pattern pointing to csv files to write means to. Must have a + {feature} format key. + """ em = cls(era_pattern=era_pattern, years=years, features=features) for height, feature in zip(em.heights, em.features): means = em.get_all_means(height) @@ -452,11 +484,23 @@ def run(cls, era_pattern, years, features, out_pattern): class BiasCorrectionFromMeans: """Class for getting bias correction factors from bias and base data files - with precomputed monthly means.""" + with precomputed monthly means. + """ MIN_DISTANCE = 1e-12 def __init__(self, bias_fp, base_fp, dset, leaf_size=4): + """Parameters + ---------- + bias_fp : str + Path to csv file containing means for biased data + base_fp : str + Path to csv file containing means for unbiased data + dset : str + Name of dataset to compute bias correction factor for + leaf_size : int + Leaf size for ball tree used to match bias and base grids + """ self.dset = dset self.bias_fp = bias_fp self.base_fp = base_fp @@ -488,7 +532,8 @@ def base_tree(self): @property def meta(self): """Get a meta data dictionary on how these bias factors were - calculated""" + calculated + """ meta = { "base_fp": self.base_fp, "bias_fp": self.bias_fp, @@ -585,12 +630,44 @@ def run( base_fp, dset, fp_out, + leaf_size=4, global_scalar=1.0, knn=1, out_shape=None, ): - """Run bias correction factor computation and write.""" - bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) + """Run bias correction factor computation and write. + + Parameters + ---------- + bias_fp : str + Path to csv file containing means for biased data + base_fp : str + Path to csv file containing means for unbiased data + dset : str + Name of dataset to compute bias correction factor for + fp_out : str + Name of output file containing bias correction factors + leaf_size : int + Leaf size for ball tree used to match bias and base grids + global_scalar : float + Optional global scalar to use for multiplying all bias correction + factors. This can be used to improve systemic bias against + observation data. This is just written to output files, not + included in the stored bias correction factor values. + knn : int + Number of nearest neighbors to use when matching bias and base + grids. This should be based on difference in resolution. e.g. if + bias grid is 30km and base grid is 3km then knn should be 100 to + aggregate 3km to 30km. + out_shape : tuple | None + Optional 2D shape for output. If this is provided then the bias + correction arrays will be reshaped to this shape. If not provided + the arrays will stay flattened. When using this to write bc files + that will be used in a forward-pass routine this shape should be + the same as the spatial shape of the forward-pass input data. + """ + bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset, + leaf_size=leaf_size) out = bc.get_corrections(global_scalar=global_scalar, knn=knn) if out_shape is not None: for k, v in out.items(): @@ -611,7 +688,37 @@ def run_uv( knn=1, out_shape=None, ): - """Run bias correction factor computation and write.""" + """Run bias correction factor computation and write. + + Parameters + ---------- + bias_fp : str + Path to csv file containing means for biased data + base_fp : str + Path to csv file containing means for unbiased data + dset : str + Name of dataset to compute bias correction factor for + fp_pattern : str + Pattern for output file. Should contain {feature} format key. + leaf_size : int + Leaf size for ball tree used to match bias and base grids + global_scalar : float + Optional global scalar to use for multiplying all bias correction + factors. This can be used to improve systemic bias against + observation data. This is just written to output files, not + included in the stored bias correction factor values. + knn : int + Number of nearest neighbors to use when matching bias and base + grids. This should be based on difference in resolution. e.g. if + bias grid is 30km and base grid is 3km then knn should be 100 to + aggregate 3km to 30km. + out_shape : tuple | None + Optional 2D shape for output. If this is provided then the bias + correction arrays will be reshaped to this shape. If not provided + the arrays will stay flattened. When using this to write bc files + that will be used in a forward-pass routine this shape should be + the same as the spatial shape of the forward-pass input data. + """ bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) out_u, out_v = bc.get_uv_corrections( global_scalar=global_scalar, knn=knn @@ -636,7 +743,26 @@ class BiasCorrectUpdate: @classmethod def get_bc_factors(cls, bc_file, dset, month, global_scalar=1): - """Get bias correction factors for the given dset and month""" + """Get bias correction factors for the given dset and month + + Parameters + ---------- + bc_file : str + Name of h5 file containing bias correction factors + dset : str + Name of dataset to apply bias correction factors for + month : int + Index of month to bias correct + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + + Returns + ------- + factors : ndarray + Array of bias correction factors for the given dset and month. + """ with Resource(bc_file) as res: logger.info( f"Getting {dset} bias correction factors for month {month}." @@ -650,9 +776,71 @@ def get_bc_factors(cls, bc_file, dset, month, global_scalar=1): return factors @classmethod - def update_file(cls, in_file, out_file, dset, bc_file, global_scalar=1): + def _correct_month( + cls, fh_in, month, out_file, dset, bc_file, global_scalar + ): + """Bias correct data for a given month. + + Parameters + ---------- + fh_in : Resource() + Resource handler for input file being corrected + month : int + Index of month to be corrected + out_file : str + Name of h5 file containing bias corrected data + dset : str + Name of dataset to bias correct + bc_file : str + Name of file containing bias correction factors for the given dset + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + """ + with RexOutputs(out_file, "a") as fh: + mask = fh.time_index.month == month + mask = np.arange(len(fh.time_index))[mask] + mask = slice(mask[0], mask[-1] + 1) + bc_factors = cls.get_bc_factors( + bc_file=bc_file, + dset=dset, + month=month, + global_scalar=global_scalar, + ) + logger.info(f"Applying bias correction factors for month {month}") + fh[dset, mask, :] = bc_factors * fh_in[dset, mask, :] + + @classmethod + def update_file( + cls, + in_file, + out_file, + dset, + bc_file, + global_scalar=1, + max_workers=None, + ): """Update the in_file with bias corrected values for the given dset - and write to out_file.""" + and write to out_file. + + Parameters + ---------- + in_file : str + Name of h5 file containing data to bias correct + out_file : str + Name of h5 file containing bias corrected data + dset : str + Name of dataset to bias correct + bc_file : str + Name of file containing bias correction factors for the given dset + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + max_workers : int | None + Number of workers to use for parallel processing. + """ tmp_file = out_file.replace(".h5", ".h5.tmp") logger.info(f"Bias correcting {dset} in {in_file} with {bc_file}.") with Resource(in_file) as fh_in: @@ -660,30 +848,54 @@ def update_file(cls, in_file, out_file, dset, bc_file, global_scalar=1): tmp_file, fh_in.time_index, fh_in.meta, fh_in.global_attrs ) OutputHandler._ensure_dset_in_output(tmp_file, dset) - for i in range(1, 13): - try: - with RexOutputs(tmp_file, "a") as fh: - mask = fh.time_index.month == i - mask = np.arange(len(fh.time_index))[mask] - mask = slice(mask[0], mask[-1] + 1) - bc_factors = cls.get_bc_factors( - bc_file=bc_file, + + if max_workers == 1: + for i in range(1, 13): + try: + cls._correct_month( + fh_in, + month=i, + out_file=tmp_file, dset=dset, + bc_file=bc_file, + global_scalar=global_scalar, + ) + except Exception as e: + raise RuntimeError( + f"Bias correction failed for month {i}." + ) from e + + logger.info( + f"Added {dset} for month {i} to output file " + f"{tmp_file}." + ) + else: + futures = {} + with ThreadPoolExecutor(max_workers=max_workers) as exe: + for i in range(1, 13): + future = exe.submit( + cls._correct_month, + fh_in=fh_in, month=i, + out_file=tmp_file, + dset=dset, + bc_file=bc_file, global_scalar=global_scalar, ) + futures[future] = i + logger.info( - f"Applying bias correction factors for month {i}" + f"Submitted bias correction for month {i} " + f"to {tmp_file}." ) - fh[dset, mask, :] = bc_factors * fh_in[dset, mask, :] - except Exception as e: - raise RuntimeError( - f"Bias correction failed for month {i}." - ) from e - logger.info( - f"Added {dset} for month {i} to output file {tmp_file}." - ) + for future in as_completed(futures): + _ = future.result() + i = futures[future] + logger.info( + f"Completed bias correction for month {i} " + f"to {tmp_file}." + ) os.replace(tmp_file, out_file) msg = f"Saved bias corrected {dset} to: {out_file}" @@ -698,11 +910,32 @@ def run( bc_file, overwrite=False, global_scalar=1, + max_workers=None ): - """Run bias correction update.""" + """Run bias correction update. + + Parameters + ---------- + in_file : str + Name of h5 file containing data to bias correct + out_file : str + Name of h5 file containing bias corrected data + dset : str + Name of dataset to bias correct + bc_file : str + Name of file containing bias correction factors for the given dset + overwrite : bool + Whether to overwrite the output file if it already exists. + global_scalar : float + Optional global scalar to multiply all bias correction + factors. This can be used to improve systemic bias against + observation data. + max_workers : int | None + Number of workers to use for parallel processing. + """ if os.path.exists(out_file) and not overwrite: logger.info( - f"{out_file} already exists and overwrite=False. " "Skipping." + f"{out_file} already exists and overwrite=False. Skipping." ) else: if os.path.exists(out_file) and overwrite: @@ -712,5 +945,6 @@ def run( ) os.remove(out_file) cls.update_file( - in_file, out_file, dset, bc_file, global_scalar=global_scalar + in_file, out_file, dset, bc_file, global_scalar=global_scalar, + max_workers=max_workers ) diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 8e3c60d07..81727400d 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -44,8 +44,7 @@ def __init__(self, cache_dir='./exo_cache/', ti_workers=1, res_kwargs=None): - """ - Parameters + """Parameters ---------- file_paths : str | list A single source h5 file to extract raster data from or a list @@ -123,7 +122,6 @@ def __init__(self, Dictionary of kwargs passed to lowest level resource handler. e.g. xr.open_dataset(file_paths, **res_kwargs) """ - logger.info(f'Initializing {self.__class__.__name__} utility.') self.ti_workers = ti_workers @@ -225,7 +223,8 @@ def get_cache_file(self, feature, s_enhance, t_enhance, s_agg_factor, @property def source_temporal_slice(self): """Get the temporal slice for the exo_source data corresponding to the - input file temporal slice""" + input file temporal slice + """ start_index = self.source_time_index.get_indexer( [self.input_handler.hr_time_index[0]], method='nearest')[0] end_index = self.source_time_index.get_indexer( @@ -353,6 +352,7 @@ def data(self): if data.shape[-1] == 1 and self.hr_shape[-1] > 1: data = np.repeat(data, self.hr_shape[-1], axis=-1) + return data[..., np.newaxis] def get_data(self): @@ -563,15 +563,13 @@ class TopoExtractNC(TopoExtractH5): """TopoExtract for netCDF files""" def __init__(self, *args, **kwargs): - """ - Parameters + """Parameters ---------- args : list Same positional arguments as TopoExtract kwargs : dict Same keyword arguments as TopoExtract """ - super().__init__(*args, **kwargs) logger.info('Getting topography for full domain from ' f'{self._exo_source}') @@ -611,4 +609,4 @@ def get_data(self): """ hr_data = self.source_data.reshape(self.hr_shape) logger.info('Finished computing SZA data') - return hr_data[..., np.newaxis] + return hr_data diff --git a/tests/data_handling/test_dual_data_handling.py b/tests/data_handling/test_dual_data_handling.py index 248954673..1a105c3c1 100644 --- a/tests/data_handling/test_dual_data_handling.py +++ b/tests/data_handling/test_dual_data_handling.py @@ -102,7 +102,7 @@ def test_regrid_caching(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', + cache_pattern=f'{td}/cache.pkl', ) # Load handlers again @@ -126,7 +126,7 @@ def test_regrid_caching(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', + cache_pattern=f'{td}/cache.pkl', ) assert np.array_equal(old_dh.lr_data, new_dh.lr_data) assert np.array_equal(old_dh.hr_data, new_dh.hr_data) @@ -161,7 +161,7 @@ def test_regrid_caching_in_steps(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl', + cache_pattern=f'{td}/cache.pkl', ) # Load handlers again with one cached feature and one noncached feature @@ -185,7 +185,7 @@ def test_regrid_caching_in_steps(log=False, s_enhance=2, t_enhance=1, val_split=0.1, - regrid_cache_pattern=f'{td}/cache.pkl') + cache_pattern=f'{td}/cache.pkl') assert np.array_equal(dh_step2.lr_data[..., 0:1], dh_step1.lr_data) assert np.array_equal(dh_step2.noncached_features, FEATURES[1:]) From f1a37b5cb06fbe70a46d4e421fc91f9e21ad0b51 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Fri, 3 Nov 2023 08:20:29 -0600 Subject: [PATCH 30/44] bc code doc strings --- sup3r/postprocessing/collection.py | 8 +- sup3r/preprocessing/feature_handling.py | 85 +++++++++------------- sup3r/utilities/interpolate_log_profile.py | 10 ++- 3 files changed, 44 insertions(+), 59 deletions(-) diff --git a/sup3r/postprocessing/collection.py b/sup3r/postprocessing/collection.py index 7a6977dca..3ac60dc36 100644 --- a/sup3r/postprocessing/collection.py +++ b/sup3r/postprocessing/collection.py @@ -26,8 +26,7 @@ class Collector(OutputMixIn): """Sup3r H5 file collection framework""" def __init__(self, file_paths): - """ - Parameters + """Parameters ---------- file_paths : list | str Explicit list of str file paths that will be sorted and collected @@ -51,7 +50,6 @@ def get_node_cmd(cls, config): sup3r collection config with all necessary args and kwargs to run data collection. """ - import_str = ( 'from sup3r.postprocessing.collection ' 'import Collector;\n' @@ -107,7 +105,6 @@ def get_slices( col_slice : slice final_meta[col_slice] = new_meta """ - final_index = final_meta.index new_index = new_meta.index row_loc = np.where(final_time_index.isin(new_time_index))[0] @@ -205,7 +202,6 @@ def get_data( col_slice : slice final_meta[col_slice] = new_meta """ - with RexOutputs(file_path, unscale=False, mode='r') as f: f_ti = f.time_index f_meta = f.meta @@ -698,7 +694,7 @@ def group_time_chunks(self, file_paths, n_writes=None): f'n_writes ({n_writes}) must be less than or equal ' f'to the number of temporal chunks ({len(file_chunks)}).' ) - assert n_writes < len(file_chunks), msg + assert n_writes <= len(file_chunks), msg return file_chunks def get_flist_chunks(self, file_paths, n_writes=None, join_times=False): diff --git a/sup3r/preprocessing/feature_handling.py b/sup3r/preprocessing/feature_handling.py index c5f712e34..3738ce03c 100644 --- a/sup3r/preprocessing/feature_handling.py +++ b/sup3r/preprocessing/feature_handling.py @@ -1,5 +1,4 @@ -""" -Sup3r feature handling module. +"""Sup3r feature handling module. @author: bbenton """ @@ -35,7 +34,8 @@ class DerivedFeature(ABC): """Abstract class for special features which need to be derived from raw - features""" + features + """ @classmethod @abstractmethod @@ -86,7 +86,6 @@ def compute(cls, data, height=None): Clearsky ratio, e.g. the all-sky ghi / the clearsky ghi. NaN where nighttime. """ - # need to use a nightime threshold of 1 W/m2 because cs_ghi is stored # in integer format and weird binning patterns happen in the clearsky # ratio and cloud mask between 0 and 1 W/m2 and sunrise/sunset @@ -104,7 +103,8 @@ def compute(cls, data, height=None): class ClearSkyRatioCC(DerivedFeature): """Clear Sky Ratio feature class for computing from climate change netcdf - data""" + data + """ @classmethod def inputs(cls, feature): @@ -142,7 +142,6 @@ def compute(cls, data, height=None): Clearsky ratio, e.g. the all-sky ghi / the clearsky ghi. This is assumed to be daily average data for climate change source data. """ - cs_ratio = data['rsds'] / data['clearsky_ghi'] cs_ratio = np.minimum(cs_ratio, 1) cs_ratio = np.maximum(cs_ratio, 0) @@ -189,7 +188,6 @@ def compute(cls, data, height=None): nighttime. Data is float32 so it can be normalized without any integer weirdness. """ - # need to use a nightime threshold of 1 W/m2 because cs_ghi is stored # in integer format and weird binning patterns happen in the clearsky # ratio and cloud mask between 0 and 1 W/m2 and sunrise/sunset @@ -208,7 +206,8 @@ def compute(cls, data, height=None): class PotentialTempNC(DerivedFeature): """Potential Temperature feature class for NETCDF data. Needed since T is - perturbation potential temperature.""" + perturbation potential temperature. + """ @classmethod def inputs(cls, feature): @@ -239,7 +238,8 @@ def compute(cls, data, height): class TempNC(DerivedFeature): """Temperature feature class for NETCDF data. Needed since T is potential - temperature not standard temp.""" + temperature not standard temp. + """ @classmethod def inputs(cls, feature): @@ -271,7 +271,8 @@ def compute(cls, data, height): class PressureNC(DerivedFeature): """Pressure feature class for NETCDF data. Needed since P is perturbation - pressure.""" + pressure. + """ @classmethod def inputs(cls, feature): @@ -302,7 +303,8 @@ def compute(cls, data, height): class BVFreqSquaredNC(DerivedFeature): """BVF Squared feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -355,7 +357,6 @@ def inputs(cls, feature): list List of required features for computing RMOL """ - assert feature == 'RMOL' features = ['UST', 'HFX'] return features @@ -431,7 +432,8 @@ def compute(cls, data, height): class BVFreqSquaredH5(DerivedFeature): """BVF Squared feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -495,7 +497,6 @@ def inputs(cls, feature): list List of required features for computing windspeed """ - height = Feature.get_height(feature) features = [f'U_{height}m', f'V_{height}m', 'lat_lon'] return features @@ -516,7 +517,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - ws, _ = invert_uv(data[f'U_{height}m'], data[f'V_{height}m'], data['lat_lon']) return ws @@ -539,7 +539,6 @@ def inputs(cls, feature): list List of required features for computing windspeed """ - height = Feature.get_height(feature) features = [f'U_{height}m', f'V_{height}m', 'lat_lon'] return features @@ -560,7 +559,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - _, wd = invert_uv(data[f'U_{height}m'], data[f'V_{height}m'], data['lat_lon']) return wd @@ -585,7 +583,6 @@ def inputs(cls, feature): list List of required features for computing REWS """ - rotor_center = Feature.get_height(feature) if rotor_center is None: heights = cls.HEIGHTS @@ -612,7 +609,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - if height is None: heights = cls.HEIGHTS else: @@ -643,7 +639,6 @@ def inputs(cls, feature): list List of required features for computing Veer """ - height = Feature.get_height(feature) heights = [int(height), int(height) + 20] features = [] @@ -667,7 +662,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - heights = [int(height), int(height) + 20] shear = np.cos(np.radians(data[f'winddirection_{int(height) + 20}m'])) shear -= np.cos(np.radians(data[f'winddirection_{int(height)}m'])) @@ -694,7 +688,6 @@ def inputs(cls, feature): list List of required features for computing REWS """ - rotor_center = Feature.get_height(feature) if rotor_center is None: heights = cls.HEIGHTS @@ -722,7 +715,6 @@ def compute(cls, data, height): ndarray Derived feature array """ - if height is None: heights = cls.HEIGHTS else: @@ -829,7 +821,8 @@ def compute(cls, data, height): class UWind(DerivedFeature): """U wind component feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -876,7 +869,8 @@ def compute(cls, data, height): class Vorticity(DerivedFeature): """Vorticity feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -919,7 +913,8 @@ def compute(cls, data, height): class VWind(DerivedFeature): """V wind component feature class with needed inputs method and compute - method""" + method + """ @classmethod def inputs(cls, feature): @@ -1047,14 +1042,16 @@ def compute(cls, data, height): class TasMin(Tas): """Daily min air temperature near surface variable from climate change nc - files""" + files + """ CC_FEATURE_NAME = 'tasmin' class TasMax(Tas): """Daily max air temperature near surface variable from climate change nc - files""" + files + """ CC_FEATURE_NAME = 'tasmax' @@ -1079,7 +1076,6 @@ def compute(file_paths, raster_index): lat lon array (spatial_1, spatial_2, 2) """ - fp = file_paths if isinstance(file_paths, str) else file_paths[0] handle = xr.open_dataset(fp) valid_vars = set(handle.variables) @@ -1166,7 +1162,8 @@ def compute(file_paths, raster_index): class Feature: """Class to simplify feature computations. Stores feature height, feature - basename, name of feature in handle""" + basename, name of feature in handle + """ def __init__(self, feature, handle): """Takes a feature (e.g. U_100m) and gets the height (100), basename @@ -1204,7 +1201,6 @@ def get_basename(feature): str feature basename """ - height = Feature.get_height(feature) pressure = Feature.get_pressure(feature) if height is not None or pressure is not None: @@ -1260,7 +1256,8 @@ def get_pressure(feature): class FeatureHandler: """Feature Handler with cache for previously loaded features used in other - calculations """ + calculations + """ FEATURE_REGISTRY: ClassVar[dict] = {} @@ -1280,7 +1277,6 @@ def valid_handle_features(cls, features, handle_features): bool Whether feature basename is in handle """ - if features is None: return False @@ -1304,7 +1300,6 @@ def valid_input_features(cls, features, handle_features): bool Whether feature basename is in handle """ - if features is None: return False @@ -1440,16 +1435,13 @@ def serial_extract(cls, file_paths, raster_index, time_chunks, keys for features. e.g. data[chunk_number][feature] = array. (spatial_1, spatial_2, temporal) """ - data = defaultdict(dict) for t, t_slice in enumerate(time_chunks): for f in input_features: data[t][f] = cls.extract_feature(file_paths, raster_index, f, t_slice, **kwargs) - interval = int(np.ceil(len(time_chunks) / 10)) - if t % interval == 0: - logger.debug(f'{t+1} out of {len(time_chunks)} feature ' - 'chunks extracted.') + logger.debug(f'{t+1} out of {len(time_chunks)} feature ' + 'chunks extracted.') return data @classmethod @@ -1511,7 +1503,6 @@ def parallel_extract(cls, f' time chunks of shape ({shape[0]}, {shape[1]}, ' f'{time_shape}) for {len(input_features)} features') - interval = int(np.ceil(len(futures) / 10)) for i, future in enumerate(as_completed(futures)): v = futures[future] try: @@ -1521,12 +1512,11 @@ def parallel_extract(cls, f' {v["feature"]}') logger.error(msg) raise RuntimeError(msg) from e - if i % interval == 0: - mem = psutil.virtual_memory() - logger.info(f'{i+1} out of {len(futures)} feature ' - 'chunks extracted. Current memory usage is ' - f'{mem.used / 1e9:.3f} GB out of ' - f'{mem.total / 1e9:.3f} GB total.') + mem = psutil.virtual_memory() + logger.info(f'{i+1} out of {len(futures)} feature ' + 'chunks extracted. Current memory usage is ' + f'{mem.used / 1e9:.3f} GB out of ' + f'{mem.total / 1e9:.3f} GB total.') return data @@ -1613,7 +1603,6 @@ def serial_compute(cls, data, file_paths, raster_index, time_chunks, e.g. data[chunk_number][feature] = array. (spatial_1, spatial_2, temporal) """ - if len(derived_features) == 0: return data @@ -1768,7 +1757,6 @@ def _exact_lookup(cls, feature): out : str Matching feature registry entry. """ - out = None for k, v in cls.FEATURE_REGISTRY.items(): if k.lower() == feature.lower(): @@ -1791,7 +1779,6 @@ def _pattern_lookup(cls, feature): out : str Matching feature registry entry. """ - out = None for k, v in cls.FEATURE_REGISTRY.items(): if re.match(k.lower(), feature.lower()): diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index 993b1dcc1..be6921f5e 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -29,7 +29,8 @@ class LogLinInterpolator: """Open ERA5 file, log interpolate wind components between 0 - max_log_height, linearly interpolate components above max_log_height - meters, and save to file""" + meters, and save to file + """ DEFAULT_OUTPUT_HEIGHTS: ClassVar[dict] = { 'u': [40, 80, 120, 160, 200], @@ -154,7 +155,8 @@ def load(self): def interpolate_vars(self, max_workers=None): """Interpolate u/v wind components below 100m using log profile. - Interpolate non wind data linearly.""" + Interpolate non wind data linearly. + """ for var, arrs in self.data_dict.items(): max_log_height = self.max_log_height if var not in ('u', 'v'): @@ -233,7 +235,8 @@ def init_dims(cls, old_ds, new_ds, dims): @classmethod def get_tmp_file(cls, file): """Get temp file for given file. Then only needed variables will be - written to the given file.""" + written to the given file. + """ tmp_file = file.replace('.nc', '_tmp.nc') return tmp_file @@ -582,7 +585,6 @@ def interp_single_ts(cls, out_array : ndarray Array of interpolated values. """ - # Interp each vertical column of height and var to requested levels zip_iter = zip(hgt_t, var_t, mask) out_array = [] From f56a471b3fe068998777c9a837e478eecec4a17b Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 4 Nov 2023 14:50:42 -0600 Subject: [PATCH 31/44] misc logging tweaks --- sup3r/preprocessing/batch_handling.py | 12 +++--- .../conditional_moment_batch_handling.py | 2 +- sup3r/preprocessing/data_handling/base.py | 8 ++-- sup3r/preprocessing/data_handling/mixin.py | 7 +++- sup3r/preprocessing/feature_handling.py | 18 ++++----- sup3r/utilities/era_downloader.py | 38 ++++++++++--------- tests/data_handling/test_data_handling_nc.py | 16 +++++--- 7 files changed, 55 insertions(+), 46 deletions(-) diff --git a/sup3r/preprocessing/batch_handling.py b/sup3r/preprocessing/batch_handling.py index 6390f71e7..8f5563119 100644 --- a/sup3r/preprocessing/batch_handling.py +++ b/sup3r/preprocessing/batch_handling.py @@ -499,14 +499,14 @@ def __init__(self, ] logger.info(f'Initializing BatchHandler with ' - f'{len(self.data_handlers)} data handlers with handler' + f'{len(self.data_handlers)} data handlers with handler ' f'weights={self.handler_weights}, smoothing={smoothing}. ' f'Using stats_workers={self.stats_workers}, ' f'norm_workers={self.norm_workers}, ' f'load_workers={self.load_workers}.') now = dt.now() - self.parallel_load() + self.load_handler_data() logger.debug(f'Finished loading data of shape {self.shape} ' f'for BatchHandler in {dt.now() - now}.') log_mem(logger, log_level='INFO') @@ -655,8 +655,8 @@ def parallel_normalization(self): logger.debug(f'{i+1} out of {len(futures)} data handlers' ' normalized.') - def parallel_load(self): - """Load data handler data in parallel""" + def load_handler_data(self): + """Load data handler data in parallel or serial""" logger.info(f'Loading {len(self.data_handlers)} data handlers') max_workers = self.load_workers if max_workers == 1: @@ -686,7 +686,7 @@ def parallel_load(self): logger.debug(f'{i+1} out of {len(futures)} handlers ' 'loaded.') - def parallel_stats(self): + def _get_stats(self): """Get standard deviations and means for training features in parallel.""" logger.info(f'Calculating stats for {len(self.training_features)} ' @@ -788,7 +788,7 @@ def get_stats(self): now = dt.now() logger.info('Calculating stdevs/means.') - self.parallel_stats() + self._get_stats() logger.info(f'Finished calculating stats in {dt.now() - now}.') self.cache_stats() diff --git a/sup3r/preprocessing/conditional_moment_batch_handling.py b/sup3r/preprocessing/conditional_moment_batch_handling.py index a1418b101..4d27a92da 100644 --- a/sup3r/preprocessing/conditional_moment_batch_handling.py +++ b/sup3r/preprocessing/conditional_moment_batch_handling.py @@ -902,7 +902,7 @@ def __init__( ) now = dt.now() - self.parallel_load() + self.load_handler_data() logger.debug( f'Finished loading data of shape {self.shape} ' f'for BatchHandler in {dt.now() - now}.' diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index 0fa62d0fd..dbae7b10e 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -1019,11 +1019,11 @@ def load_cached_data(self, with_split=True): logger.warning(msg) warnings.warn(msg) - logger.debug('Splitting data into training / validation sets ' - f'({1 - self.val_split}, {self.val_split}) ' - f'for {self.input_file_info}') + if with_split and self.val_split > 0: + logger.debug('Splitting data into training / validation sets ' + f'({1 - self.val_split}, {self.val_split}) ' + f'for {self.input_file_info}') - if with_split: self.data, self.val_data = self.split_data( val_split=self.val_split, shuffle_time=self.shuffle_time) diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index f2f2c0494..67e11cb5b 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -13,6 +13,7 @@ import numpy as np import pandas as pd +import psutil from scipy.stats import mode from sup3r.utilities.utilities import ( @@ -31,6 +32,7 @@ class CacheHandlingMixIn: """Collection of methods for handling data caching and loading""" def __init__(self): + """Initialize common attributes""" self._noncached_features = None self._cache_pattern = None self._cache_files = None @@ -285,7 +287,10 @@ def _load_single_cached_feature(self, fp, cache_files, features, msg = f'{features[idx].lower()} not found in {fp.lower()}.' assert features[idx].lower() in fp.lower(), msg fp = ignore_case_path_fetch(fp) - logger.info(f'Loading {features[idx]} from {fp}.') + mem = psutil.virtual_memory() + logger.info(f'Loading {features[idx]} from {fp}. Current memory ' + f'usage is {mem.used / 1e9:.3f} GB out of ' + f'{mem.total / 1e9:.3f} GB total.') out = None with open(fp, 'rb') as fh: diff --git a/sup3r/preprocessing/feature_handling.py b/sup3r/preprocessing/feature_handling.py index 3738ce03c..9a1ac6eb2 100644 --- a/sup3r/preprocessing/feature_handling.py +++ b/sup3r/preprocessing/feature_handling.py @@ -1617,10 +1617,8 @@ def serial_compute(cls, data, file_paths, raster_index, time_chunks, file_paths=file_paths, raster_index=raster_index) cls.pop_old_data(data, t, all_features) - interval = int(np.ceil(len(time_chunks) / 10)) - if t % interval == 0: - logger.debug(f'{t+1} out of {len(time_chunks)} feature ' - 'chunks computed.') + logger.debug(f'{t+1} out of {len(time_chunks)} feature ' + 'chunks computed.') return data @@ -1698,18 +1696,16 @@ def parallel_compute(cls, f' time chunks of shape ({shape[0]}, {shape[1]}, ' f'{time_shape}) for {len(derived_features)} features') - interval = int(np.ceil(len(futures) / 10)) for i, future in enumerate(as_completed(futures)): v = futures[future] chunk_idx = v['chunk'] data[chunk_idx] = data.get(chunk_idx, {}) data[chunk_idx][v['feature']] = future.result() - if i % interval == 0: - mem = psutil.virtual_memory() - logger.info(f'{i+1} out of {len(futures)} feature ' - 'chunks computed. Current memory usage is ' - f'{mem.used / 1e9:.3f} GB out of ' - f'{mem.total / 1e9:.3f} GB total.') + mem = psutil.virtual_memory() + logger.info(f'{i+1} out of {len(futures)} feature ' + 'chunks computed. Current memory usage is ' + f'{mem.used / 1e9:.3f} GB out of ' + f'{mem.total / 1e9:.3f} GB total.') return data diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 0a572c9ef..d627ca6c5 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -37,7 +37,8 @@ class EraDownloader: """Class to handle ERA5 downloading, variable renaming, file combination, - and interpolation.""" + and interpolation. + """ msg = ('To download ERA5 data you need to have a ~/.cdsapirc file ' 'with a valid url and api key. Follow the instructions here: ' @@ -223,7 +224,8 @@ def init_dims(cls, old_ds, new_ds, dims): @classmethod def get_tmp_file(cls, file): """Get temp file for given file. Then only needed variables will be - written to the given file.""" + written to the given file. + """ tmp_file = file.replace(".nc", "_tmp.nc") return tmp_file @@ -255,7 +257,8 @@ def check_good_vars(self, variables): def _prep_var_lists(self, variables): """Add all downloadable variables for the generic requested variables. - e.g. if variable = 'u' add all downloadable u variables to list.""" + e.g. if variable = 'u' add all downloadable u variables to list. + """ d_vars = [] vars = variables.copy() for i, v in enumerate(vars): @@ -269,7 +272,8 @@ def _prep_var_lists(self, variables): def prep_var_lists(self, variables): """Create surface and level variable lists based on requested - variables.""" + variables. + """ variables = self._prep_var_lists(variables) for var in variables: if var in self.SFC_VARS and var not in self.sfc_file_variables: @@ -344,7 +348,6 @@ def download_surface_file(self): def process_surface_file(self): """Rename variables and convert geopotential to geopotential height.""" - dims = ('time', 'latitude', 'longitude') tmp_file = self.get_tmp_file(self.surface_file) with Dataset(self.surface_file, "r") as old_ds: @@ -406,7 +409,6 @@ def convert_z(self, standard_name, long_name, old_ds, ds): ds : Dataset Dataset() object for new file with new height variable written. """ - _ = ds.createVariable(standard_name, np.float32, dimensions=old_ds['z'].dimensions) @@ -418,7 +420,6 @@ def convert_z(self, standard_name, long_name, old_ds, ds): def process_level_file(self): """Convert geopotential to geopotential height.""" - dims = ('time', 'level', 'latitude', 'longitude') tmp_file = self.get_tmp_file(self.level_file) with Dataset(self.level_file, "r") as old_ds: @@ -429,10 +430,12 @@ def process_level_file(self): ds = self.map_vars(old_ds, ds) - if 'pressure' in self.variables: + if ('pressure' in self.variables + and 'pressure' not in ds.variables): tmp = np.zeros(ds.variables['zg'].shape) for i in range(tmp.shape[1]): tmp[:, i, :, :] = ds.variables['level'][i] * 100 + _ = ds.createVariable('pressure', np.float32, dimensions=dims) @@ -446,7 +449,6 @@ def process_level_file(self): def process_and_combine(self): """Process variables and combine.""" - if not os.path.exists(self.combined_file) or self.overwrite: files = [] if os.path.exists(self.level_file): @@ -459,7 +461,7 @@ def process_and_combine(self): files.append(self.surface_file) logger.info(f'Combining {files} to {self.combined_file}.') - with xr.open_mfdataset(files) as ds: + with xr.open_mfdataset(files, compat='override') as ds: ds.to_netcdf(self.combined_file) logger.info(f'Finished writing {self.combined_file}') @@ -495,7 +497,8 @@ def good_file(self, file, required_shape): def check_existing_files(self): """If files exist already check them for good shape and required variables. Remove them if there was a problem so we can continue with - routine from scratch.""" + routine from scratch. + """ if os.path.exists(self.combined_file): try: check = self.good_file(self.combined_file, self.required_shape) @@ -521,7 +524,8 @@ def check_existing_files(self): def run_interpolation(self, max_workers=None, **kwargs): """Run interpolation to get final final. Runs log interpolation up to - max_log_height (usually 100m) and linear interpolation above this.""" + max_log_height (usually 100m) and linear interpolation above this. + """ LogLinInterpolator.run(infile=self.combined_file, outfile=self.interp_file, max_workers=max_workers, @@ -533,8 +537,8 @@ def get_monthly_file(self, interp_workers=None, keep_variables=None, **interp_kwargs): """Download level and surface files, process variables, and combine processed files. Includes checks for shape and variables and option to - interpolate.""" - + interpolate. + """ if os.path.exists(self.combined_file) and self.overwrite: os.remove(self.combined_file) @@ -577,7 +581,6 @@ def all_months_exist(cls, year, file_pattern): @classmethod def already_pruned(cls, infile, keep_variables): """Check if file has been pruned already.""" - if keep_variables is None: logger.info('Received keep_variables=None. Skipping pruning.') return @@ -586,7 +589,9 @@ def already_pruned(cls, infile, keep_variables): pruned = True with Dataset(infile, 'r') as ds: - for var in ds.variables: + variables = [var for var in ds.variables + if var not in ('time', 'latitude', 'longitude')] + for var in variables: if not any(name in var for name in keep_variables): logger.info(f'Pruning {var} in {infile}.') pruned = False @@ -595,7 +600,6 @@ def already_pruned(cls, infile, keep_variables): @classmethod def prune_output(cls, infile, keep_variables=None): """Prune output file to keep just single level variables""" - if keep_variables is None: logger.info('Received keep_variables=None. Skipping pruning.') return diff --git a/tests/data_handling/test_data_handling_nc.py b/tests/data_handling/test_data_handling_nc.py index 59ada2965..110c3016e 100644 --- a/tests/data_handling/test_data_handling_nc.py +++ b/tests/data_handling/test_data_handling_nc.py @@ -2,18 +2,21 @@ """pytests for data handling""" import os +import tempfile + +import matplotlib.pyplot as plt import numpy as np import pytest -import matplotlib.pyplot as plt -import tempfile import xarray as xr from sup3r import TEST_DATA_DIR +from sup3r.preprocessing.batch_handling import ( + BatchHandler, + SpatialBatchHandler, +) from sup3r.preprocessing.data_handling import DataHandlerNC as DataHandler -from sup3r.preprocessing.batch_handling import (BatchHandler, - SpatialBatchHandler) from sup3r.utilities.interpolation import Interpolator -from sup3r.utilities.pytest import make_fake_nc_files, make_fake_era_files +from sup3r.utilities.pytest import make_fake_era_files, make_fake_nc_files INPUT_FILE = os.path.join(TEST_DATA_DIR, 'test_wrf_2014-10-01_00_00_00') features = ['U_100m', 'V_100m', 'BVF_MO_200m'] @@ -166,7 +169,8 @@ def test_data_caching(): if os.path.exists(cache_pattern): os.system(f'rm {cache_pattern}') handler = DataHandler(INPUT_FILE, features, - cache_pattern=cache_pattern, **dh_kwargs) + cache_pattern=cache_pattern, **dh_kwargs, + val_split=0.1) assert handler.data is None handler.load_cached_data() assert handler.data.shape == (shape[0], shape[1], From d2216913bbbd1965f621f77d8f862f1edbe8f751 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sat, 4 Nov 2023 14:50:42 -0600 Subject: [PATCH 32/44] misc logging tweaks --- sup3r/bias/bias_correct_means.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py index eb4f1f996..4a878d725 100644 --- a/sup3r/bias/bias_correct_means.py +++ b/sup3r/bias/bias_correct_means.py @@ -17,6 +17,7 @@ from rex import Resource from scipy.interpolate import interp1d from sklearn.neighbors import BallTree + from sup3r.postprocessing.file_handling import OutputHandler, RexOutputs from sup3r.preprocessing.feature_handling import Feature from sup3r.utilities import VERSION_RECORD @@ -589,7 +590,7 @@ def get_corrections(self, global_scalar=1, knn=1): base_data = self.get_base_data(knn=knn) bias_data = self.get_bias_data() scaler = global_scalar * base_data / bias_data - adder = 0 + adder = np.zeros(scaler.shape) out = { "latitude": self.bias_meta["latitude"], @@ -670,13 +671,19 @@ def run( leaf_size=leaf_size) out = bc.get_corrections(global_scalar=global_scalar, knn=knn) if out_shape is not None: - for k, v in out.items(): - if k in ("latitude", "longitude"): - out[k] = np.array(v).reshape(out_shape) - elif not isinstance(v, float): - out[k] = np.array(v).reshape((*out_shape, 12)) + out = cls._reshape_output(out, out_shape) bc.write_output(fp_out, out) + @classmethod + def _reshape_output(cls, out, out_shape): + """Reshape output according to given output shape""" + for k, v in out.items(): + if k in ("latitude", "longitude"): + out[k] = np.array(v).reshape(out_shape) + elif not isinstance(v, (int, float)): + out[k] = np.array(v).reshape((*out_shape, 12)) + return out + @classmethod def run_uv( cls, @@ -724,16 +731,8 @@ def run_uv( global_scalar=global_scalar, knn=knn ) if out_shape is not None: - for k, v in out_u.items(): - if k in ("latitude", "longitude"): - out_u[k] = np.array(v).reshape(out_shape) - elif not isinstance(v, float): - out_u[k] = np.array(v).reshape((*out_shape, 12)) - for k, v in out_v.items(): - if k in ("latitude", "longitude"): - out_v[k] = np.array(v).reshape(out_shape) - elif not isinstance(v, float): - out_v[k] = np.array(v).reshape((*out_shape, 12)) + out_u = cls._reshape_output(out_u, out_shape) + out_v = cls._reshape_output(out_v, out_shape) bc.write_output(fp_pattern.format(feature=bc.u_name), out_u) bc.write_output(fp_pattern.format(feature=bc.v_name), out_v) From f2c8e68451d003cddbb1623e8429aa4e32d6c090 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 5 Nov 2023 09:42:57 -0700 Subject: [PATCH 33/44] nc collector for intermediate write during multi fwp pipeline --- sup3r/postprocessing/collection.py | 141 +++++++++++++++++++-- sup3r/postprocessing/data_collect_cli.py | 14 +- sup3r/preprocessing/data_handling/mixin.py | 10 +- sup3r/utilities/regridder.py | 2 + 4 files changed, 151 insertions(+), 16 deletions(-) diff --git a/sup3r/postprocessing/collection.py b/sup3r/postprocessing/collection.py index 3ac60dc36..b916915be 100644 --- a/sup3r/postprocessing/collection.py +++ b/sup3r/postprocessing/collection.py @@ -4,12 +4,14 @@ import logging import os import time +from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor, as_completed from warnings import warn import numpy as np import pandas as pd import psutil +import xarray as xr from gaps import Status from rex.utilities.fun_utils import get_fun_call_str from rex.utilities.loggers import init_logger @@ -22,17 +24,16 @@ logger = logging.getLogger(__name__) -class Collector(OutputMixIn): - """Sup3r H5 file collection framework""" +class BaseCollector(ABC, OutputMixIn): + """Base collector class for H5/NETCDF collection""" def __init__(self, file_paths): """Parameters ---------- file_paths : list | str Explicit list of str file paths that will be sorted and collected - or a single string with unix-style /search/patt*ern.h5. Files - should have non-overlapping time_index dataset and fully - overlapping meta dataset. + or a single string with unix-style /search/patt*ern.. Files + should have non-overlapping time_index and spatial domains. """ if not isinstance(file_paths, list): file_paths = glob.glob(file_paths) @@ -40,6 +41,11 @@ def __init__(self, file_paths): self.data = None self.file_attrs = {} + @classmethod + @abstractmethod + def collect(cls, *args, **kwargs): + """Collect data files from a dir to one output file.""" + @classmethod def get_node_cmd(cls, config): """Get a CLI call to collect data. @@ -52,7 +58,7 @@ def get_node_cmd(cls, config): """ import_str = ( 'from sup3r.postprocessing.collection ' - 'import Collector;\n' + f'import {cls.__class__.__name__};\n' 'from rex import init_logger;\n' 'import time;\n' 'from gaps import Status;\n' @@ -80,6 +86,123 @@ def get_node_cmd(cls, config): return cmd.replace('\\', '/') + +class CollectorNC(BaseCollector): + """Sup3r NETCDF file collection framework""" + + @classmethod + def collect( + cls, + file_paths, + out_file, + features, + max_workers=None, + log_level=None, + log_file=None, + write_status=False, + job_name=None, + join_times=False, + target_final_meta_file=None, + n_writes=None, + overwrite=True, + threshold=1e-4, + ): + """Collect data files from a dir to one output file. + + Filename requirements: + - Should end with ".nc" + + Parameters + ---------- + file_paths : list | str + Explicit list of str file paths that will be sorted and collected + or a single string with unix-style /search/patt*ern.nc. + out_file : str + File path of final output file. + features : list + List of dsets to collect + max_workers : int | None + Number of workers to use in parallel. 1 runs serial, + None will use all available workers. + log_level : str | None + Desired log level, None will not initialize logging. + log_file : str | None + Target log file. None logs to stdout. + write_status : bool + Flag to write status file once complete if running from pipeline. + job_name : str + Job name for status file if running from pipeline. + join_times : bool + Option to split full file list into chunks with each chunk having + the same temporal_chunk_index. The number of writes will then be + min(number of temporal chunks, n_writes). This ensures that each + write has all the spatial chunks for a given time index. Assumes + file_paths have a suffix format + _{temporal_chunk_index}_{spatial_chunk_index}.h5. This is required + if there are multiple writes and chunks have different time + indices. + target_final_meta_file : str + Path to target final meta containing coordinates to keep from the + full file list collected meta. This can be but is not necessarily a + subset of the full list of coordinates for all files in the file + list. This is used to remove coordinates from the full file list + which are not present in the target_final_meta. Either this full + meta or a subset, depending on which coordinates are present in + the data to be collected, will be the final meta for the collected + output files. + n_writes : int | None + Number of writes to split full file list into. Must be less than + or equal to the number of temporal chunks if chunks have different + time indices. + overwrite : bool + Whether to overwrite existing output file + threshold : float + Threshold distance for finding target coordinates within full meta + """ + t0 = time.time() + + logger.info( + f'Initializing collection for file_paths={file_paths}, ' + f'with max_workers={max_workers}.' + ) + + if log_level is not None: + init_logger( + 'sup3r.preprocessing', log_file=log_file, log_level=log_level + ) + + if not os.path.exists(os.path.dirname(out_file)): + os.makedirs(os.path.dirname(out_file), exist_ok=True) + + collector = cls(file_paths) + logger.info( + 'Collecting {} files to {}'.format(len(collector.flist), out_file) + ) + if overwrite and os.path.exists(out_file): + logger.info(f'overwrite=True, removing {out_file}.') + os.remove(out_file) + + out = xr.open_mfdataset(collector.flist) + out.to_netcdf(out_file) + + if write_status and job_name is not None: + status = { + 'out_dir': os.path.dirname(out_file), + 'fout': out_file, + 'flist': collector.flist, + 'job_status': 'successful', + 'runtime': (time.time() - t0) / 60, + } + Status.make_single_job_file( + os.path.dirname(out_file), 'collect', job_name, status + ) + + logger.info('Finished file collection.') + + +class CollectorH5(BaseCollector): + """Sup3r H5 file collection framework""" + @classmethod def get_slices( cls, final_time_index, final_meta, new_time_index, new_meta @@ -230,7 +353,7 @@ def get_data( warn(msg) else: - row_slice, col_slice = Collector.get_slices( + row_slice, col_slice = self.get_slices( time_index, meta, f_ti, f_meta ) @@ -494,13 +617,13 @@ def _write_flist_data( """ with RexOutputs(out_file, mode='r') as f: target_ti = f.time_index - y_write_slice, x_write_slice = Collector.get_slices( + y_write_slice, x_write_slice = self.get_slices( target_ti, target_masked_meta, time_index, subset_masked_meta, ) - Collector._ensure_dset_in_output(out_file, feature) + self._ensure_dset_in_output(out_file, feature) with RexOutputs(out_file, mode='a') as f: try: diff --git a/sup3r/postprocessing/data_collect_cli.py b/sup3r/postprocessing/data_collect_cli.py index ac148dd53..9d4dd1b2c 100644 --- a/sup3r/postprocessing/data_collect_cli.py +++ b/sup3r/postprocessing/data_collect_cli.py @@ -2,15 +2,16 @@ """ sup3r data collection CLI entry points. """ -import click -import logging import copy +import logging + +import click +from sup3r.postprocessing.collection import CollectorH5, CollectorNC from sup3r.utilities import ModuleName -from sup3r.version import __version__ -from sup3r.postprocessing.collection import Collector from sup3r.utilities.cli import AVAILABLE_HARDWARE_OPTIONS, BaseCLI - +from sup3r.utilities.utilities import get_source_type +from sup3r.version import __version__ logger = logging.getLogger(__name__) @@ -42,6 +43,9 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None): dset_split = config.get('dset_split', False) exec_kwargs = config.get('execution_control', {}) hardware_option = exec_kwargs.pop('option', 'local') + source_type = get_source_type(config['file_paths']) + collector_types = {'h5': CollectorH5, 'nc': CollectorNC} + Collector = collector_types[source_type] configs = [config] if dset_split: diff --git a/sup3r/preprocessing/data_handling/mixin.py b/sup3r/preprocessing/data_handling/mixin.py index 67e11cb5b..781cd5f67 100644 --- a/sup3r/preprocessing/data_handling/mixin.py +++ b/sup3r/preprocessing/data_handling/mixin.py @@ -914,6 +914,10 @@ class TrainingPrepMixIn: """Collection of training related methods. e.g. Training + Validation splitting, normalization""" + def __init__(self): + """Initialize common attributes""" + self.features = None + @classmethod def _split_data_indices(cls, data, @@ -1034,11 +1038,13 @@ def _normalize_data(self, data, val_data, feature_index, mean, std): val_data[..., feature_index] /= std data[..., feature_index] /= std else: - msg = ( - f'Standard Deviation is zero for feature #{feature_index + 1}') + msg = ('Standard Deviation is zero for ' + f'{self.features[feature_index]}') logger.warning(msg) warnings.warn(msg) + logger.info(f'Finished normalizing {self.features[feature_index]}.') + def _normalize(self, data, val_data, means, stds, max_workers=None): """Normalize all data features diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index d0af2abf6..21e04534d 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -344,6 +344,8 @@ def __call__(self, data): Flattened regridded spatiotemporal data (spatial, temporal) """ + msg = 'Input data must be 3D (spatial_1, spatial_2, temporal)' + assert len(data.shape) == 3, msg vals = [ data[:, :, i].flatten()[np.array(self.indices)][np.newaxis] for i in range(data.shape[-1]) From 2dac1bb6eb33da5fd2cf831b4bef0a803d1213d2 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Sun, 5 Nov 2023 18:22:49 -0700 Subject: [PATCH 34/44] cleaning up args --- sup3r/postprocessing/collection.py | 53 ++++++---------------- sup3r/postprocessing/data_collect_cli.py | 3 +- sup3r/utilities/interpolate_log_profile.py | 4 +- 3 files changed, 18 insertions(+), 42 deletions(-) diff --git a/sup3r/postprocessing/collection.py b/sup3r/postprocessing/collection.py index b916915be..10c4b9d10 100644 --- a/sup3r/postprocessing/collection.py +++ b/sup3r/postprocessing/collection.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -"""H5 file collection.""" +"""H5/NETCDF file collection.""" import glob import logging import os @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) -class BaseCollector(ABC, OutputMixIn): +class BaseCollector(OutputMixIn, ABC): """Base collector class for H5/NETCDF collection""" def __init__(self, file_paths): @@ -58,7 +58,7 @@ def get_node_cmd(cls, config): """ import_str = ( 'from sup3r.postprocessing.collection ' - f'import {cls.__class__.__name__};\n' + f'import {cls.__name__};\n' 'from rex import init_logger;\n' 'import time;\n' 'from gaps import Status;\n' @@ -96,16 +96,12 @@ def collect( file_paths, out_file, features, - max_workers=None, log_level=None, log_file=None, write_status=False, job_name=None, - join_times=False, - target_final_meta_file=None, - n_writes=None, overwrite=True, - threshold=1e-4, + res_kwargs=None ): """Collect data files from a dir to one output file. @@ -121,9 +117,6 @@ def collect( File path of final output file. features : list List of dsets to collect - max_workers : int | None - Number of workers to use in parallel. 1 runs serial, - None will use all available workers. log_level : str | None Desired log level, None will not initialize logging. log_file : str | None @@ -132,38 +125,15 @@ def collect( Flag to write status file once complete if running from pipeline. job_name : str Job name for status file if running from pipeline. - join_times : bool - Option to split full file list into chunks with each chunk having - the same temporal_chunk_index. The number of writes will then be - min(number of temporal chunks, n_writes). This ensures that each - write has all the spatial chunks for a given time index. Assumes - file_paths have a suffix format - _{temporal_chunk_index}_{spatial_chunk_index}.h5. This is required - if there are multiple writes and chunks have different time - indices. - target_final_meta_file : str - Path to target final meta containing coordinates to keep from the - full file list collected meta. This can be but is not necessarily a - subset of the full list of coordinates for all files in the file - list. This is used to remove coordinates from the full file list - which are not present in the target_final_meta. Either this full - meta or a subset, depending on which coordinates are present in - the data to be collected, will be the final meta for the collected - output files. - n_writes : int | None - Number of writes to split full file list into. Must be less than - or equal to the number of temporal chunks if chunks have different - time indices. overwrite : bool Whether to overwrite existing output file - threshold : float - Threshold distance for finding target coordinates within full meta + res_kwargs : dict | None + Dictionary of kwargs to pass to xarray.open_mfdataset. """ t0 = time.time() logger.info( - f'Initializing collection for file_paths={file_paths}, ' - f'with max_workers={max_workers}.' + f'Initializing collection for file_paths={file_paths}' ) if log_level is not None: @@ -182,8 +152,13 @@ def collect( logger.info(f'overwrite=True, removing {out_file}.') os.remove(out_file) - out = xr.open_mfdataset(collector.flist) - out.to_netcdf(out_file) + if not os.path.exists(out_file): + res_kwargs = (res_kwargs + or {"concat_dim": "Time", "combine": "nested"}) + out = xr.open_mfdataset(collector.flist, **res_kwargs) + features = [feat for feat in out if feat in features + or feat.lower() in features] + out[features].to_netcdf(out_file) if write_status and job_name is not None: status = { diff --git a/sup3r/postprocessing/data_collect_cli.py b/sup3r/postprocessing/data_collect_cli.py index 9d4dd1b2c..fc8621274 100644 --- a/sup3r/postprocessing/data_collect_cli.py +++ b/sup3r/postprocessing/data_collect_cli.py @@ -52,7 +52,8 @@ def from_config(ctx, config_file, verbose=False, pipeline_step=None): configs = [] for feature in config['features']: f_config = copy.deepcopy(config) - f_out_file = config['out_file'].replace('.h5', f'_{feature}.h5') + f_out_file = config['out_file'].replace( + f'.{source_type}', f'_{feature}.{source_type}') f_job_name = config['job_name'] + f'_{feature}' f_log_file = config.get('log_file', None) if f_log_file is not None: diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index be6921f5e..1a42b8ef7 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -33,8 +33,8 @@ class LogLinInterpolator: """ DEFAULT_OUTPUT_HEIGHTS: ClassVar[dict] = { - 'u': [40, 80, 120, 160, 200], - 'v': [40, 80, 120, 160, 200], + 'u': [10, 40, 80, 100, 120, 160, 200], + 'v': [10, 40, 80, 100, 120, 160, 200], 'temperature': [10, 40, 80, 100, 120, 160, 200], 'pressure': [0, 100, 200], 'relative_humidity': [80, 100, 120], From e796c4ac6b0aa902d3ae81874cc00472ad7c6410 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Mon, 6 Nov 2023 16:12:24 -0700 Subject: [PATCH 35/44] break at max_log_height --- sup3r/utilities/interpolate_log_profile.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index 1a42b8ef7..2ed8c39db 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -405,8 +405,8 @@ def ws_log_profile(z, a, b): good = True levels = np.array(levels) - lev_mask = (0 < levels) & (levels <= max_log_height) - var_mask = (0 < lev_array_samp) & (lev_array_samp <= max_log_height) + lev_mask = (levels > 0) & (levels <= max_log_height) + var_mask = (lev_array_samp > 0) & (lev_array_samp <= max_log_height) try: popt, _ = curve_fit(ws_log_profile, lev_array_samp[var_mask], @@ -480,8 +480,8 @@ def _interp_var_to_height(cls, max_log_height=max_log_height) if any(levels > max_log_height): - lev_mask = levels >= max_log_height - var_mask = lev_array >= max_log_height + lev_mask = levels > max_log_height + var_mask = lev_array > max_log_height if len(lev_array[var_mask]) > 1: lin_ws = interp1d(lev_array[var_mask], var_array[var_mask], From d4f637661431e8e0cb1a535154d6b2b1e6906f47 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 7 Nov 2023 06:00:43 -0700 Subject: [PATCH 36/44] collection try except --- sup3r/pipeline/forward_pass.py | 34 +++++++++++++++--------------- sup3r/postprocessing/collection.py | 15 +++++++++---- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 89b90291c..f6dc53b6a 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -383,9 +383,8 @@ def hr_crop_slices(self): if self._hr_crop_slices is None: self._hr_crop_slices = [] for t in self.t_hr_crop_slices: - node_slices = [] - for s in self.s_hr_crop_slices: - node_slices.append((s[0], s[1], t, slice(None))) + node_slices = [(s[0], s[1], t, slice(None)) + for s in self.s_hr_crop_slices] self._hr_crop_slices.append(node_slices) return self._hr_crop_slices @@ -805,7 +804,8 @@ def preflight(self): f'n_spatial_chunks={self.fwp_slicer.n_spatial_chunks}, ' f'n_temporal_chunks={self.fwp_slicer.n_temporal_chunks}, ' f'and n_total_chunks={self.chunks}. ' - f'{self.chunks / self.nodes} chunks per node on average.') + f'{self.chunks / self.nodes:.3f} chunks per node on ' + 'average.') logger.info(f'Using max_workers={self.max_workers}, ' f'pass_workers={self.pass_workers}, ' f'output_workers={self.output_workers}') @@ -847,7 +847,7 @@ def init_handler(self): out = self.input_handler_class(self.file_paths[0], [], target=self.target, shape=self.grid_shape, - worker_kwargs=dict(ti_workers=1)) + worker_kwargs={"ti_workers": 1}) self._init_handler = out return self._init_handler @@ -1165,17 +1165,17 @@ def update_input_handler_kwargs(self, strategy): data handler for the current forward pass chunk """ input_handler_kwargs = copy.deepcopy(strategy._input_handler_kwargs) - fwp_input_handler_kwargs = dict( - file_paths=self.file_paths, - features=self.features, - target=self.target, - shape=self.shape, - temporal_slice=self.temporal_pad_slice, - raster_file=self.raster_file, - cache_pattern=self.cache_pattern, - single_ts_files=self.single_ts_files, - handle_features=strategy.handle_features, - val_split=0.0) + fwp_input_handler_kwargs = { + "file_paths": self.file_paths, + "features": self.features, + "target": self.target, + "shape": self.shape, + "temporal_slice": self.temporal_pad_slice, + "raster_file": self.raster_file, + "cache_pattern": self.cache_pattern, + "single_ts_files": self.single_ts_files, + "handle_features": strategy.handle_features, + "val_split": 0.0} input_handler_kwargs.update(fwp_input_handler_kwargs) return input_handler_kwargs @@ -1900,7 +1900,7 @@ def _run_parallel(cls, strategy, node_index): futures = {} start = dt.now() - pool_kws = dict(max_workers=strategy.pass_workers, loggers=['sup3r']) + pool_kws = {"max_workers": strategy.pass_workers, "loggers": ['sup3r']} with SpawnProcessPool(**pool_kws) as exe: now = dt.now() for _i, chunk_index in enumerate(strategy.node_chunks[node_index]): diff --git a/sup3r/postprocessing/collection.py b/sup3r/postprocessing/collection.py index 10c4b9d10..f9f584d9b 100644 --- a/sup3r/postprocessing/collection.py +++ b/sup3r/postprocessing/collection.py @@ -340,7 +340,16 @@ def get_data( f_data = np.round(f_data) f_data = f_data.astype(dtype) - self.data[row_slice, col_slice] = f_data + + try: + self.data[row_slice, col_slice] = f_data + except Exception as e: + msg = (f'Failed to add data to self.data[{row_slice}, ' + f'{col_slice}] for feature={feature}, ' + f'file_path={file_path}, time_index={time_index}, ' + f'meta={meta}. {e}') + logger.error(msg) + raise OSError(msg) from e def _get_file_attrs(self, file): """Get meta data and time index for a single file""" @@ -778,9 +787,7 @@ def group_time_chunks(self, file_paths, n_writes=None): for file in file_paths: t_chunk = file.split('_')[-2] file_split[t_chunk] = [*file_split.get(t_chunk, []), file] - file_chunks = [] - for files in file_split.values(): - file_chunks.append(files) + file_chunks = list(file_split.values()) logger.debug( f'Split file list into {len(file_chunks)} chunks ' From a44078746489faa1f15024cd51401d5982c250cf Mon Sep 17 00:00:00 2001 From: bnb32 Date: Tue, 7 Nov 2023 10:53:19 -0700 Subject: [PATCH 37/44] random strings for tmp names to avoid collisions --- sup3r/models/abstract.py | 10 ++++++ sup3r/models/multi_step.py | 4 +-- sup3r/pipeline/forward_pass.py | 4 +-- .../data_handling/exo_extraction.py | 4 +-- sup3r/utilities/era_downloader.py | 33 +++++++++---------- sup3r/utilities/interpolate_log_profile.py | 2 ++ sup3r/utilities/utilities.py | 10 ++++++ 7 files changed, 44 insertions(+), 23 deletions(-) diff --git a/sup3r/models/abstract.py b/sup3r/models/abstract.py index 9eb6e9ce7..68d0fabff 100644 --- a/sup3r/models/abstract.py +++ b/sup3r/models/abstract.py @@ -96,6 +96,16 @@ def input_dims(self): else: return 5 + @property + def is_5d(self): + """Check if model expects spatiotemporal input""" + return self.input_dims == 5 + + @property + def is_4d(self): + """Check if model expects spatial only input""" + return self.input_dims == 4 + # pylint: disable=E1101 def get_s_enhance_from_layers(self): """Compute factor by which model will enhance spatial resolution from diff --git a/sup3r/models/multi_step.py b/sup3r/models/multi_step.py index 28c0ffdcd..01cb40cfe 100644 --- a/sup3r/models/multi_step.py +++ b/sup3r/models/multi_step.py @@ -137,10 +137,10 @@ def _transpose_model_input(self, model, hi_res): Synthetically generated high-resolution data transposed according to the number of model input dimensions """ - if model.input_dims == 5 and len(hi_res.shape) == 4: + if model.is_5d and len(hi_res.shape) == 4: hi_res = np.transpose( hi_res, axes=(1, 2, 0, 3))[np.newaxis] - elif model.input_dims == 4 and len(hi_res.shape) == 5: + elif model.is_4d and len(hi_res.shape) == 5: msg = ('Recieved 5D input data with shape ' f'({hi_res.shape}) to a 4D model.') assert hi_res.shape[0] == 1, msg diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index f6dc53b6a..722239ee2 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -1705,13 +1705,13 @@ def _reshape_data_chunk(model, data_chunk, exo_data): 'exceeds the number of model steps') assert entry['model'] < len(models), msg current_model = models[entry['model']] - if current_model.input_dims == 4: + if current_model.is_4d: out = np.transpose(entry['data'], axes=(2, 0, 1, 3)) else: out = np.expand_dims(entry['data'], axis=0) exo_data[feature]['steps'][i]['data'] = out - if model.input_dims == 4: + if model.is_4d: i_lr_t = 0 i_lr_s = 1 data_chunk = np.transpose(data_chunk, axes=(2, 0, 1, 3)) diff --git a/sup3r/preprocessing/data_handling/exo_extraction.py b/sup3r/preprocessing/data_handling/exo_extraction.py index 81727400d..581bd8d0a 100644 --- a/sup3r/preprocessing/data_handling/exo_extraction.py +++ b/sup3r/preprocessing/data_handling/exo_extraction.py @@ -15,7 +15,7 @@ from sup3r.postprocessing.file_handling import OutputHandler from sup3r.preprocessing.data_handling.h5_data_handling import DataHandlerH5 from sup3r.preprocessing.data_handling.nc_data_handling import DataHandlerNC -from sup3r.utilities.utilities import get_source_type +from sup3r.utilities.utilities import generate_random_string, get_source_type logger = logging.getLogger(__name__) @@ -337,7 +337,7 @@ def data(self): t_enhance=self._t_enhance, s_agg_factor=self._s_agg_factor, t_agg_factor=self._t_agg_factor) - tmp_fp = cache_fp + '.tmp' + tmp_fp = cache_fp + f'.{generate_random_string(10)}.tmp' if os.path.exists(cache_fp): with open(cache_fp, 'rb') as f: data = pickle.load(f) diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index d627ca6c5..e2f607deb 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -46,21 +46,20 @@ class EraDownloader: req_file = os.path.join(os.path.expanduser('~'), '.cdsapirc') assert os.path.exists(req_file), msg - VALID_VARIABLES: ClassVar[list] = [ - 'u', 'v', 'pressure', 'temperature', 'relative_humidity', - 'specific_humidity', 'total_precipitation', - 'convective_available_potential_energy', 'divergence' - ] + VALID_VARIABLES: ClassVar[list] = ['u', 'v'] SFC_VARS: ClassVar[list] = [ '10m_u_component_of_wind', '10m_v_component_of_wind', '100m_u_component_of_wind', '100m_v_component_of_wind', 'surface_pressure', '2m_temperature', 'geopotential', - 'total_precipitation', "convective_available_potential_energy" + 'total_precipitation', "convective_available_potential_energy", + "2m_dewpoint_temperature", "convective_inhibition", + "surface_latent_heat_flux", "instantaneous_moisture_flux" ] LEVEL_VARS: ClassVar[list] = [ 'u_component_of_wind', 'v_component_of_wind', 'geopotential', - 'temperature', 'relative_humidity', 'specific_humidity', 'divergence' + 'temperature', 'relative_humidity', 'specific_humidity', 'divergence', + 'vertical_velocity', 'pressure', 'potential_vorticity' ] NAME_MAP: ClassVar[dict] = { 'u10': 'u_10m', @@ -69,14 +68,12 @@ class EraDownloader: 'v100': 'v_100m', 't': 'temperature', 't2m': 'temperature_2m', - 'u': 'u', - 'v': 'v', 'sp': 'pressure_0m', 'r': 'relative_humidity', 'q': 'specific_humidity', 'tp': 'total_precipitation', - 'cape': 'cape', - 'd': 'divergence' + 'd': 'divergence', + '2d': 'surface_dewpoint', } def __init__(self, @@ -164,11 +161,11 @@ def days(self): @property def interp_file(self): """Get name of file with interpolated variables""" - if self._interp_file is None: - if self.interp_out_pattern is not None and self.run_interp: - self._interp_file = self.interp_out_pattern.format( - year=self.year, month=str(self.month).zfill(2)) - os.makedirs(os.path.dirname(self._interp_file), exist_ok=True) + if (self._interp_file is None and self.interp_out_pattern is not None + and self.run_interp): + self._interp_file = self.interp_out_pattern.format( + year=self.year, month=str(self.month).zfill(2)) + os.makedirs(os.path.dirname(self._interp_file), exist_ok=True) return self._interp_file @property @@ -248,7 +245,9 @@ def check_good_vars(self, variables): variables : list List of variables to download. Can be any of VALID_VARIABLES """ - good = all(var in self.VALID_VARIABLES for var in variables) + valid_vars = (self.VALID_VARIABLES + list(self.LEVEL_VARS) + + list(self.SFC_VARS)) + good = all(var in valid_vars for var in variables) if not good: msg = (f'Received variables {variables} not in valid variables ' f'list {self.VALID_VARIABLES}') diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index 2ed8c39db..252061eaf 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -35,6 +35,8 @@ class LogLinInterpolator: DEFAULT_OUTPUT_HEIGHTS: ClassVar[dict] = { 'u': [10, 40, 80, 100, 120, 160, 200], 'v': [10, 40, 80, 100, 120, 160, 200], + 'w': [10, 40, 80, 100, 120, 160, 200], + 'pv': [10, 40, 80, 100, 120, 160, 200], 'temperature': [10, 40, 80, 100, 120, 160, 200], 'pressure': [0, 100, 200], 'relative_humidity': [80, 100, 120], diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index d6875aa4a..17a85e4fa 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -7,7 +7,9 @@ import glob import logging import os +import random import re +import string from fnmatch import fnmatch from warnings import warn @@ -26,6 +28,14 @@ logger = logging.getLogger(__name__) +def generate_random_string(length): + """Generate random string with given length. Used for naming temporary + files to avoid collisions.""" + letters = string.ascii_letters + random_string = ''.join(random.choice(letters) for i in range(length)) + return random_string + + def windspeed_log_law(z, a, b, c): """Windspeed log profile. From 101428ab1f3ea6efeeaae2d27818dc37491ed8e8 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 8 Nov 2023 12:54:44 -0700 Subject: [PATCH 38/44] some era downloader clean up. --- sup3r/utilities/era_downloader.py | 236 ++++++++++++++++-------------- 1 file changed, 123 insertions(+), 113 deletions(-) diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index e2f607deb..1e165f567 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -46,21 +46,24 @@ class EraDownloader: req_file = os.path.join(os.path.expanduser('~'), '.cdsapirc') assert os.path.exists(req_file), msg - VALID_VARIABLES: ClassVar[list] = ['u', 'v'] - + # variables available on a single level (e.g. surface) SFC_VARS: ClassVar[list] = [ '10m_u_component_of_wind', '10m_v_component_of_wind', '100m_u_component_of_wind', '100m_v_component_of_wind', 'surface_pressure', '2m_temperature', 'geopotential', 'total_precipitation', "convective_available_potential_energy", "2m_dewpoint_temperature", "convective_inhibition", - "surface_latent_heat_flux", "instantaneous_moisture_flux" + "surface_latent_heat_flux", "instantaneous_moisture_flux", + "mean_total_precipitation_rate" ] + + # variables available on multiple pressure levels LEVEL_VARS: ClassVar[list] = [ 'u_component_of_wind', 'v_component_of_wind', 'geopotential', 'temperature', 'relative_humidity', 'specific_humidity', 'divergence', 'vertical_velocity', 'pressure', 'potential_vorticity' ] + NAME_MAP: ClassVar[dict] = { 'u10': 'u_10m', 'v10': 'v_10m', @@ -71,9 +74,20 @@ class EraDownloader: 'sp': 'pressure_0m', 'r': 'relative_humidity', 'q': 'specific_humidity', - 'tp': 'total_precipitation', 'd': 'divergence', - '2d': 'surface_dewpoint', + } + + SHORT_NAME_MAP: ClassVar[dict] = { + 'convective_inhibition': 'cin', + '2m_dewpoint_temperature': 'd2m', + 'potential_vorticity': 'pv', + 'vertical_velocity': 'w', + 'surface_latent_heat_flux': 'slhf', + 'instantaneous_moisture_flux': 'ie', + 'divergence': 'd', + 'total_precipitation': 'tp', + 'convective_available_potential_energy': 'cape', + 'mean_total_precipitation_rate': 'mtpr' } def __init__(self, @@ -85,8 +99,8 @@ def __init__(self, interp_out_pattern=None, run_interp=True, overwrite=False, - required_shape=None, - variables=None): + variables=None, + check_files=False): """Initialize the class. Parameters @@ -110,13 +124,11 @@ def __init__(self, Whether to run interpolation after downloading and combining files. overwrite : bool Whether to overwrite existing files. - required_shape : tuple | None - Required shape of data to download. Used to check downloaded data. - Should be (n_levels, n_lats, n_lons). If None, no check is - performed. variables : list | None Variables to download. If None this defaults to just gepotential and wind components. + check_files : bool + Check existing files. Remove and redownload if checks fail. """ self.year = year self.month = month @@ -126,15 +138,13 @@ def __init__(self, self.overwrite = overwrite self.combined_out_pattern = combined_out_pattern self.interp_out_pattern = interp_out_pattern + self.check_files = check_files self._interp_file = None self._combined_file = None self._variables = variables self.hours = [str(n).zfill(2) + ":00" for n in range(0, 24)] self.sfc_file_variables = ['geopotential'] self.level_file_variables = ['geopotential'] - - self.shape_check(required_shape, levels) - self.check_good_vars(self.variables) self.prep_var_lists(self.variables) msg = ('Initialized EraDownloader with: ' @@ -146,7 +156,7 @@ def __init__(self, def variables(self): """Get list of requested variables""" if self._variables is None: - self._variables = self.VALID_VARIABLES + raise OSError('Received empty variable list.') return self._variables @property @@ -226,34 +236,6 @@ def get_tmp_file(cls, file): tmp_file = file.replace(".nc", "_tmp.nc") return tmp_file - def shape_check(self, required_shape, levels): - """Check given required shape""" - if required_shape is None or len(required_shape) == 3: - self.required_shape = required_shape - elif len(required_shape) == 2 and len(levels) != required_shape[0]: - self.required_shape = (len(levels), *required_shape) - else: - msg = f'Received weird required_shape: {required_shape}.' - logger.error(msg) - raise OSError(msg) - - def check_good_vars(self, variables): - """Make sure requested variables are valid. - - Parameters - ---------- - variables : list - List of variables to download. Can be any of VALID_VARIABLES - """ - valid_vars = (self.VALID_VARIABLES + list(self.LEVEL_VARS) - + list(self.SFC_VARS)) - good = all(var in valid_vars for var in variables) - if not good: - msg = (f'Received variables {variables} not in valid variables ' - f'list {self.VALID_VARIABLES}') - logger.error(msg) - raise OSError(msg) - def _prep_var_lists(self, variables): """Add all downloadable variables for the generic requested variables. e.g. if variable = 'u' add all downloadable u variables to list. @@ -280,7 +262,7 @@ def prep_var_lists(self, variables): elif (var in self.LEVEL_VARS and var not in self.level_file_variables): self.level_file_variables.append(var) - else: + elif var not in self.SFC_VARS and var not in self.LEVEL_VARS: msg = f'Requested {var} is not available for download.' logger.warning(msg) warn(msg) @@ -297,53 +279,62 @@ def download_process_combine(self): if self.levels is None: logger.warning(msg) warn(msg) + + time_dict = {'year': self.year, 'month': self.month, 'day': self.days, + 'time': self.hours} if sfc_check: - self.download_surface_file() + self.download_file(self.sfc_file_variables, time_dict=time_dict, + area=self.area, out_file=self.surface_file, + level_type='single', overwrite=self.overwrite) if level_check: - self.download_levels_file() + self.download_file(self.level_file_variables, time_dict=time_dict, + area=self.area, out_file=self.level_file, + level_type='pressure', levels=self.levels, + overwrite=self.overwrite) if sfc_check or level_check: self.process_and_combine() - def download_levels_file(self): - """Download file with requested pressure levels""" - if not os.path.exists(self.level_file) or self.overwrite: - msg = (f'Downloading {self.level_file_variables} to ' - f'{self.level_file} with levels = {self.levels}.') - logger.info(msg) - CDS_API_CLIENT.retrieve( - 'reanalysis-era5-pressure-levels', { - 'product_type': 'reanalysis', - 'format': 'netcdf', - 'variable': self.level_file_variables, - 'pressure_level': self.levels, - 'year': self.year, - 'month': self.month, - 'day': self.days, - 'time': self.hours, - 'area': self.area, - }, self.level_file) - else: - logger.info(f'File already exists: {self.level_file}.') + @classmethod + def download_file(cls, variables, time_dict, area, out_file, level_type, + levels=None, overwrite=False): + """Download either single-level or pressure-level file - def download_surface_file(self): - """Download surface file""" - if not os.path.exists(self.surface_file) or self.overwrite: - msg = (f'Downloading {self.sfc_file_variables} to ' - f'{self.surface_file}.') + Parameters + ---------- + variables : list + List of variables to download + time_dict : dict + Dictionary with year, month, day, time entries. + area : list + List of bounding box coordinates. + e.g. [max_lat, min_lon, min_lat, max_lon] + out_file : str + Name of output file + level_type : str + Either 'single' or 'pressure' + levels : list + List of pressure levels to download, if level_type == 'pressure' + overwrite : bool + Whether to overwrite existing file + """ + if not os.path.exists(out_file) or overwrite: + msg = (f'Downloading {variables} to ' + f'{out_file} with levels = {levels}.') logger.info(msg) + entry = { + 'product_type': 'reanalysis', + 'format': 'netcdf', + 'variable': variables, + 'area': area} + entry.update(time_dict) + if level_type == 'pressure': + entry['pressure_level'] = levels + logger.info(f'Calling CDS-API with {entry}.') CDS_API_CLIENT.retrieve( - 'reanalysis-era5-single-levels', { - 'product_type': 'reanalysis', - 'format': 'netcdf', - 'variable': self.sfc_file_variables, - 'year': self.year, - 'month': self.month, - 'day': self.days, - 'time': self.hours, - 'area': self.area, - }, self.surface_file) + f'reanalysis-era5-{level_type}-levels', + entry, out_file) else: - logger.info(f'File already exists: {self.surface_file}.') + logger.info(f'File already exists: {out_file}.') def process_surface_file(self): """Rename variables and convert geopotential to geopotential height.""" @@ -469,15 +460,17 @@ def process_and_combine(self): if os.path.exists(self.surface_file): os.remove(self.surface_file) - def good_file(self, file, required_shape): + def good_file(self, file, required_shape=None): """Check if file has the required shape and variables. Parameters ---------- file : str Name of file to check for required variables and shape - required_shape : tuple - Required shape for data. Should be (n_levels, n_lats, n_lons). + required_shape : tuple | None + Required shape of data to download. Used to check downloaded data. + Should be (n_levels, n_lats, n_lons). If None, no check is + performed. Returns ------- @@ -489,18 +482,44 @@ def good_file(self, file, required_shape): check_nans=False, check_heights=False, required_shape=required_shape) - good_vars, good_shape, _, _ = out - check = good_vars and good_shape - return check + good_vars, good_shape, good_hgts, _ = out + return bool(good_vars and good_shape and good_hgts) + + def shape_check(self, required_shape, levels): + """Check given required shape""" + if required_shape is None or len(required_shape) == 3: + self.required_shape = required_shape + elif len(required_shape) == 2 and len(levels) != required_shape[0]: + self.required_shape = (len(levels), *required_shape) + else: + msg = f'Received weird required_shape: {required_shape}.' + logger.error(msg) + raise OSError(msg) + + def check_good_vars(self, variables): + """Make sure requested variables are valid. + + Parameters + ---------- + variables : list + List of variables to download. Can be any of VALID_VARIABLES + """ + valid_vars = list(self.LEVEL_VARS) + list(self.SFC_VARS) + good = all(var in valid_vars for var in variables) + if not good: + msg = (f'Received variables {variables} not in valid variables ' + f'list {valid_vars}') + logger.error(msg) + raise OSError(msg) - def check_existing_files(self): + def check_existing_files(self, required_shape=None): """If files exist already check them for good shape and required variables. Remove them if there was a problem so we can continue with routine from scratch. """ if os.path.exists(self.combined_file): try: - check = self.good_file(self.combined_file, self.required_shape) + check = self.good_file(self.combined_file, required_shape) if not check: msg = f'Bad file: {self.combined_file}' logger.error(msg) @@ -541,7 +560,8 @@ def get_monthly_file(self, interp_workers=None, keep_variables=None, if os.path.exists(self.combined_file) and self.overwrite: os.remove(self.combined_file) - self.check_existing_files() + if self.check_files: + self.check_existing_files() if not os.path.exists(self.combined_file): self.download_process_combine() @@ -639,10 +659,10 @@ def run_month(cls, interp_out_pattern=None, run_interp=True, overwrite=False, - required_shape=None, interp_workers=None, variables=None, keep_variables=None, + check_files=False, **interp_kwargs): """Run routine for all months in the requested year. @@ -667,10 +687,6 @@ def run_month(cls, Whether to run interpolation after downloading and combining files. overwrite : bool Whether to overwrite existing files. - required_shape : tuple | None - Required shape of data to download. Used to check downloaded data. - Should be (n_levels, n_lats, n_lons). If None, no check is - performed. interp_workers : int | None Max number of workers to use for interpolation. variables : list | None @@ -679,6 +695,8 @@ def run_month(cls, keep_variables : list | None Variables to keep in final files. All other variables will be pruned. + check_files : bool + Check existing files. Remove and redownload if checks fail. **interp_kwargs : dict Keyword args for LogLinInterpolator.run() """ @@ -690,8 +708,8 @@ def run_month(cls, interp_out_pattern=interp_out_pattern, run_interp=run_interp, overwrite=overwrite, - required_shape=required_shape, - variables=variables) + variables=variables, + check_files=check_files) downloader.get_monthly_file(interp_workers=interp_workers, keep_variables=keep_variables, **interp_kwargs) @@ -707,11 +725,11 @@ def run_year(cls, interp_yearly_file=None, run_interp=True, overwrite=False, - required_shape=None, max_workers=None, interp_workers=None, variables=None, keep_variables=None, + check_files=False, **interp_kwargs): """Run routine for all months in the requested year. @@ -738,10 +756,6 @@ def run_year(cls, Whether to run interpolation after downloading and combining files. overwrite : bool Whether to overwrite existing files. - required_shape : tuple | None - Required shape of data to download. Used to check downloaded data. - Should be (n_levels, n_lats, n_lons). If None, no check is - performed. max_workers : int Max number of workers to use for downloading and processing monthly files. @@ -753,6 +767,8 @@ def run_year(cls, keep_variables : list | None Variables to keep in final files. All other variables will be pruned. + check_files : bool + Check existing files. Remove and redownload if checks fail. **interp_kwargs : dict Keyword args for LogLinInterpolator.run() """ @@ -766,10 +782,10 @@ def run_year(cls, interp_out_pattern=interp_out_pattern, run_interp=run_interp, overwrite=overwrite, - required_shape=required_shape, interp_workers=interp_workers, variables=variables, keep_variables=keep_variables, + check_files=check_files, **interp_kwargs) else: futures = {} @@ -785,10 +801,10 @@ def run_year(cls, interp_out_pattern=interp_out_pattern, run_interp=run_interp, overwrite=overwrite, - required_shape=required_shape, interp_workers=interp_workers, keep_variables=keep_variables, variables=variables, + check_files=check_files, **interp_kwargs) futures[future] = {'year': year, 'month': month} logger.info(f'Submitted future for year {year} and month ' @@ -1082,7 +1098,6 @@ def check_single_file(cls, def run_files_checks(cls, file_pattern, var_list=None, - required_shape=None, check_nans=True, check_heights=True, max_interp_height=200, @@ -1098,9 +1113,6 @@ def run_files_checks(cls, var_list : list | None List of variables to check. If None: ['zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m'] - required_shape : None | tuple - Required shape for data. Should include (n_levels, n_lats, n_lons). - If None the shape check will be skipped. check_nans : bool Whether to check data for NaNs. check_heights : bool @@ -1137,8 +1149,7 @@ def run_files_checks(cls, check_nans=check_nans, check_heights=check_heights, max_interp_height=max_interp_height, - max_workers=height_check_workers, - required_shape=required_shape) + max_workers=height_check_workers) df.loc[i, df.columns[1:]] = out logger.info(f'Finished checking {file}.') else: @@ -1151,8 +1162,7 @@ def run_files_checks(cls, check_nans=check_nans, check_heights=check_heights, max_interp_height=max_interp_height, - max_workers=height_check_workers, - required_shape=required_shape) + max_workers=height_check_workers) msg = (f'Submitted file check future for {file}. Future ' f'{i + 1} of {len(files)}.') logger.info(msg) From 64f29ca26e9356fc8dee92eea20fc4c106927691 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Wed, 8 Nov 2023 16:48:56 -0700 Subject: [PATCH 39/44] bc case sensitivity fix --- sup3r/preprocessing/data_handling/base.py | 5 +++-- sup3r/utilities/era_downloader.py | 11 ++++++++++- sup3r/utilities/interpolate_log_profile.py | 3 ++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index dbae7b10e..24ff15e91 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -1261,8 +1261,9 @@ def lin_bc(self, bc_files, threshold=0.1): dset_scalar = f'{feature}_scalar' dset_adder = f'{feature}_adder' with Resource(fp) as res: - check = (dset_scalar in res.dsets - and dset_adder in res.dsets) + dsets = [dset.lower() for dset in res.dsets] + check = (dset_scalar.lower() in dsets + and dset_adder.lower() in dsets) if feature not in completed and check: scalar, adder = get_spatial_bc_factors( lat_lon=self.lat_lon, diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 1e165f567..26178aa6d 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -544,10 +544,19 @@ def run_interpolation(self, max_workers=None, **kwargs): """Run interpolation to get final final. Runs log interpolation up to max_log_height (usually 100m) and linear interpolation above this. """ + variables = [var for var in self.variables if var in self.LEVEL_VARS] + for var in self.variables: + if var in self.NAME_MAP: + variables.append(self.NAME_MAP[var]) + elif (var in self.SHORT_NAME_MAP + and var not in self.NAME_MAP.values()): + variables.append(self.SHORT_NAME_MAP[var]) + else: + variables.append(var) LogLinInterpolator.run(infile=self.combined_file, outfile=self.interp_file, max_workers=max_workers, - variables=self.variables, + variables=variables, overwrite=self.overwrite, **kwargs) diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py index 252061eaf..52bfb5f12 100644 --- a/sup3r/utilities/interpolate_log_profile.py +++ b/sup3r/utilities/interpolate_log_profile.py @@ -148,7 +148,8 @@ def _load_single_var(self, variable): def load(self): """Load ERA5 data and create data arrays""" self.data_dict = {} - for var in self.variables: + vars = [var for var in self.variables if var in self.new_heights] + for var in vars: self.data_dict[var] = {} out = self._load_single_var(var) self.data_dict[var]['heights'] = out[0] From f874437c55deb3ef1d8781a9505e2fe6daa43c34 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 9 Nov 2023 06:54:05 -0700 Subject: [PATCH 40/44] pr changes --- sup3r/bias/bias_correct_means.py | 8 +++++++- sup3r/utilities/era_downloader.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py index 4a878d725..89f83ece1 100644 --- a/sup3r/bias/bias_correct_means.py +++ b/sup3r/bias/bias_correct_means.py @@ -1,5 +1,8 @@ """Classes to compute means from vortex and era data and compute bias correction factors. + +Vortex mean files can be downloaded from IRENA. +https://globalatlas.irena.org/workspace """ @@ -32,7 +35,8 @@ class VortexMeanPrepper: """ def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): - """Parameters + """ + Parameters ---------- path_pattern : str Pattern for input tif files. Needs to include {month} and {height} @@ -45,6 +49,8 @@ def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): Whether to overwrite intermediate netcdf files containing the interpolated masked monthly means. """ + msg = 'path_pattern needs to have {month} and {height} format keys' + assert '{month}' in path_pattern and '{year}' in path_pattern, msg self.path_pattern = path_pattern self.in_heights = in_heights self.out_heights = out_heights diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 26178aa6d..fee3b8827 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -139,6 +139,7 @@ def __init__(self, self.combined_out_pattern = combined_out_pattern self.interp_out_pattern = interp_out_pattern self.check_files = check_files + self.required_shape = None self._interp_file = None self._combined_file = None self._variables = variables @@ -1083,7 +1084,6 @@ def check_single_file(cls, good_shape = None good_vars = None good_hgts = None - var_list = (var_list if var_list is not None else cls.VALID_VARIABLES) try: res = xr.open_dataset(file) except Exception as e: From db62a8fc946683c7d1216a15b513fd056b2ddc68 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 9 Nov 2023 06:54:05 -0700 Subject: [PATCH 41/44] pr changes --- sup3r/bias/bias_correct_means.py | 8 +++++++- sup3r/utilities/era_downloader.py | 2 +- tests/output/test_output_handling.py | 8 ++++---- tests/utilities/test_utilities.py | 16 ++++++++-------- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py index 4a878d725..89f83ece1 100644 --- a/sup3r/bias/bias_correct_means.py +++ b/sup3r/bias/bias_correct_means.py @@ -1,5 +1,8 @@ """Classes to compute means from vortex and era data and compute bias correction factors. + +Vortex mean files can be downloaded from IRENA. +https://globalatlas.irena.org/workspace """ @@ -32,7 +35,8 @@ class VortexMeanPrepper: """ def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): - """Parameters + """ + Parameters ---------- path_pattern : str Pattern for input tif files. Needs to include {month} and {height} @@ -45,6 +49,8 @@ def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): Whether to overwrite intermediate netcdf files containing the interpolated masked monthly means. """ + msg = 'path_pattern needs to have {month} and {height} format keys' + assert '{month}' in path_pattern and '{year}' in path_pattern, msg self.path_pattern = path_pattern self.in_heights = in_heights self.out_heights = out_heights diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py index 26178aa6d..fee3b8827 100644 --- a/sup3r/utilities/era_downloader.py +++ b/sup3r/utilities/era_downloader.py @@ -139,6 +139,7 @@ def __init__(self, self.combined_out_pattern = combined_out_pattern self.interp_out_pattern = interp_out_pattern self.check_files = check_files + self.required_shape = None self._interp_file = None self._combined_file = None self._variables = variables @@ -1083,7 +1084,6 @@ def check_single_file(cls, good_shape = None good_vars = None good_hgts = None - var_list = (var_list if var_list is not None else cls.VALID_VARIABLES) try: res = xr.open_dataset(file) except Exception as e: diff --git a/tests/output/test_output_handling.py b/tests/output/test_output_handling.py index 6f85c9e99..a4cc9629f 100644 --- a/tests/output/test_output_handling.py +++ b/tests/output/test_output_handling.py @@ -9,7 +9,7 @@ from rex import ResourceX, init_logger from sup3r import __version__ -from sup3r.postprocessing.collection import Collector +from sup3r.postprocessing.collection import CollectorH5 from sup3r.postprocessing.file_handling import OutputHandlerH5, OutputHandlerNC from sup3r.utilities.pytest import make_fake_h5_chunks from sup3r.utilities.utilities import invert_uv, transform_rotate_wind @@ -140,7 +140,7 @@ def test_h5_out_and_collect(): low_res_times, ) = out - Collector.collect(out_files, fp_out, features=features) + CollectorH5.collect(out_files, fp_out, features=features) with ResourceX(fp_out) as fh: full_ti = fh.time_index combined_ti = [] @@ -207,7 +207,7 @@ def test_h5_collect_mask(log=False): out = make_fake_h5_chunks(td) (out_files, data, _, _, features, _, _, _, _, _, _) = out - Collector.collect(out_files, fp_out, features=features) + CollectorH5.collect(out_files, fp_out, features=features) indices = np.arange(np.product(data.shape[:2])) indices = indices[slice(-len(indices) // 2, None)] removed = [] @@ -220,7 +220,7 @@ def test_h5_collect_mask(log=False): mask_meta['gid'][:] = np.arange(len(mask_meta)) mask_meta.to_csv(mask_file, index=False) - Collector.collect( + CollectorH5.collect( out_files, fp_out_mask, features=features, diff --git a/tests/utilities/test_utilities.py b/tests/utilities/test_utilities.py index 93ae8be82..63f4a7c37 100644 --- a/tests/utilities/test_utilities.py +++ b/tests/utilities/test_utilities.py @@ -11,7 +11,7 @@ from scipy.interpolate import interp1d from sup3r import TEST_DATA_DIR -from sup3r.postprocessing.collection import Collector +from sup3r.postprocessing.collection import CollectorH5 from sup3r.postprocessing.file_handling import OutputHandler from sup3r.utilities.interpolate_log_profile import LogLinInterpolator from sup3r.utilities.regridder import RegridOutput @@ -100,13 +100,13 @@ def test_regridding(log=False): for node_index in range(regrid_output.nodes): regrid_output.run(node_index=node_index) - Collector.collect(regrid_output.out_files, - collect_file, - regrid_output.output_features, - target_final_meta_file=meta_path, - join_times=False, - n_writes=2, - max_workers=1) + CollectorH5.collect(regrid_output.out_files, + collect_file, + regrid_output.output_features, + target_final_meta_file=meta_path, + join_times=False, + n_writes=2, + max_workers=1) with Resource(collect_file) as out_res: for height in heights: ws_name = f'windspeed_{height}m' From 44e40e8af9c32bcce67e3b39cc979abd95d101cd Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 9 Nov 2023 10:11:23 -0700 Subject: [PATCH 42/44] removed era mean prepper and changed vortex mean prepper output to conform to format expected by classes in bias_calc.py. Added scalar only class to bias_calc.py --- sup3r/bias/bias_calc.py | 66 ++++++- sup3r/bias/bias_correct_means.py | 205 ++++------------------ sup3r/preprocessing/data_handling/base.py | 2 + 3 files changed, 94 insertions(+), 179 deletions(-) diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index d909e45b4..2b7b91125 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -1,11 +1,11 @@ """Utilities to calculate the bias correction factors for biased data that is going to be fed into the sup3r downscaling models. This is typically used to bias correct GCM data vs. some historical record like the WTK or NSRDB.""" -from abc import abstractmethod import copy import json import logging import os +from abc import abstractmethod from concurrent.futures import ProcessPoolExecutor, as_completed from glob import glob @@ -875,12 +875,7 @@ def _run_single(cls, decimals=decimals) base_arr = np.full(cls.NT, np.nan, dtype=np.float32) - out = {f'bias_{bias_feature}_mean': base_arr.copy(), - f'bias_{bias_feature}_std': base_arr.copy(), - f'base_{base_dset}_mean': base_arr.copy(), - f'base_{base_dset}_std': base_arr.copy(), - f'{bias_feature}_scalar': base_arr.copy(), - f'{bias_feature}_adder': base_arr.copy()} + out = {} for month in range(1, 13): bias_mask = bias_ti.month == month @@ -892,11 +887,68 @@ def _run_single(cls, bias_feature, base_dset) for k, v in mout.items(): + if k not in out: + out[k] = base_arr.copy() out[k][month - 1] = v return out +class MonthlyLinearCorrectionScalarOnly(MonthlyLinearCorrection): + """Calculate linear correction *scalar factors to bias correct data. This + typically used when base data is just monthly means and standard deviations + cannot be computed. This is case for vortex data, for example. Thus, just + scalar factors are computed as mean(base_data) / mean(bias_data). Adder + factors are still written but are exactly zero. + + This calculation operates on single bias sites on a montly basis + """ + + @staticmethod + def get_linear_correction(bias_data, base_data, bias_feature, base_dset): + """Get the linear correction factors based on 1D bias and base datasets + + Parameters + ---------- + bias_data : np.ndarray + 1D array of biased data observations. + base_data : np.ndarray + 1D array of base data observations. + bias_feature : str + This is the biased feature from bias_fps to retrieve. This should + be a single feature name corresponding to base_dset + base_dset : str + A single dataset from the base_fps to retrieve. In the case of wind + components, this can be U_100m or V_100m which will retrieve + windspeed and winddirection and derive the U/V component. + + Returns + ------- + out : dict + Dictionary of values defining the mean/std of the bias + base + data and the scalar + adder factors to correct the biased data + like: bias_data * scalar + adder + """ + + bias_std = np.nanstd(bias_data) + if bias_std == 0: + bias_std = np.nanstd(base_data) + + scalar = np.nanmean(base_data) / np.nanmean(bias_data) + adder = np.zeros(scalar.shape) + + out = { + f'bias_{bias_feature}_mean': np.nanmean(bias_data), + f'bias_{bias_feature}_std': bias_std, + f'base_{base_dset}_mean': np.nanmean(base_data), + f'base_{base_dset}_std': np.nanstd(base_data), + f'{bias_feature}_scalar': scalar, + f'{bias_feature}_adder': adder, + } + + return out + + class SkillAssessment(MonthlyLinearCorrection): """Calculate historical skill of one dataset compared to another.""" diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py index 89f83ece1..e4b6627e8 100644 --- a/sup3r/bias/bias_correct_means.py +++ b/sup3r/bias/bias_correct_means.py @@ -50,13 +50,14 @@ def __init__(self, path_pattern, in_heights, out_heights, overwrite=False): interpolated masked monthly means. """ msg = 'path_pattern needs to have {month} and {height} format keys' - assert '{month}' in path_pattern and '{year}' in path_pattern, msg + assert '{month}' in path_pattern and '{height}' in path_pattern, msg self.path_pattern = path_pattern self.in_heights = in_heights self.out_heights = out_heights self.out_dir = os.path.dirname(path_pattern) self.overwrite = overwrite self._mask = None + self._meta = None @property def in_features(self): @@ -246,6 +247,23 @@ def get_lat_lon(self): ) return np.array(lats), np.array(lons) + @property + def meta(self): + """Get meta with latitude/longitude""" + if self._meta is None: + lats, lons = self.get_lat_lon() + self._meta = pd.DataFrame() + self._meta["latitude"] = lats.flatten()[self.mask] + self._meta["longitude"] = lons.flatten()[self.mask] + return self._meta + + @property + def time_index(self): + """Get time index so output conforms to standard format""" + times = [f'2000-{str(i).zfill(2)}' for i in range(1, 13)] + time_index = pd.DatetimeIndex(times) + return time_index + def get_all_data(self): """Get interpolated monthly means for all out heights as a dictionary to use for h5 writing. @@ -258,10 +276,7 @@ def get_all_data(self): flattened """ data_dict = {} - lats, lons = self.get_lat_lon() - data_dict["latitude"] = lats.flatten()[self.mask] - data_dict["longitude"] = lons.flatten()[self.mask] - s_num = len(data_dict["longitude"]) + s_num = len(self.meta) for i in range(1, 13): month = calendar.month_name[i] out = self.get_month(month) @@ -273,14 +288,14 @@ def get_all_data(self): return data_dict @property - def meta(self): - """Get a meta data dictionary on how this data is prepared""" - meta = { + def global_attrs(self): + """Get dictionary on how this data is prepared""" + attrs = { "input_files": self.input_files, "class": str(self.__class__), - "version_record": VERSION_RECORD, + "version_record": str(VERSION_RECORD), } - return meta + return attrs def write_data(self, fp_out, out): """Write monthly means for all heights to h5 file""" @@ -289,14 +304,17 @@ def write_data(self, fp_out, out): os.makedirs(os.path.dirname(fp_out), exist_ok=True) if not os.path.exists(fp_out) or self.overwrite: - with h5py.File(fp_out, "w") as f: + + OutputHandler._init_h5( + fp_out, self.time_index, self.meta, self.global_attrs + ) + with RexOutputs(fp_out, "a") as f: + for dset, data in out.items(): - f.create_dataset(dset, data=data) + OutputHandler._ensure_dset_in_output(fp_out, dset) + f[dset] = data.T logger.info(f"Added {dset} to {fp_out}.") - for k, v in self.meta.items(): - f.attrs[k] = json.dumps(v) - logger.info( f"Wrote monthly means for all out heights: {fp_out}" ) @@ -332,163 +350,6 @@ def run( vprep.write_data(fp_out, out) -class EraMeanPrepper: - """Class to compute monthly windspeed means from ERA data.""" - - def __init__(self, era_pattern, years, features): - """Parameters - ---------- - era_pattern : str - Pattern pointing to era files with u/v wind components at the given - heights. Must have a {year} format key. - years : list - List of ERA years to use for calculating means. - features : list - List of features to compute means for. e.g. ['windspeed_10m'] - """ - self.era_pattern = era_pattern - self.years = years - self.features = features - self.lats, self.lons = self.get_lat_lon() - - @property - def shape(self): - """Get shape of spatial dimensions (lats, lons)""" - return self.lats.shape - - @property - def heights(self): - """List of feature heights""" - heights = [Feature.get_height(feature) for feature in self.features] - return heights - - @property - def input_files(self): - """List of ERA input files to use for calculating means.""" - return [self.era_pattern.format(year=year) for year in self.years] - - def get_lat_lon(self): - """Get arrays of latitude and longitude for ERA domain""" - with xr.open_dataset(self.input_files[0]) as res: - lons, lats = np.meshgrid( - res["longitude"].values, res["latitude"].values - ) - return lats, lons - - def get_windspeed(self, data, height): - """Compute windspeed from u/v wind components from given data. - - Parameters - ---------- - data : xarray.Dataset - xarray dataset object for a year of ERA data. Must include u/v - components for the given height. e.g. u_{height}m, v_{height}m. - height : int - Height to compute windspeed for. - """ - return np.hypot( - data[f"u_{height}m"].values, data[f"v_{height}m"].values - ) - - def get_month_mean(self, data, height, month): - """Get windspeed_{height}m mean for the given month. - - Parameters - ---------- - data : xarray.Dataset - xarray dataset object for a year of ERA data. Must include u/v - components for the given height. e.g. u_{height}m, v_{height}m. - height : int - Height to compute windspeed for. - month : int - Index of month to get mean for. e.g. 1 = Jan, 2 = Feb, etc. - - Returns - ------- - out : np.ndarray - Array of time averaged windspeed data for the given month. - - """ - mask = pd.to_datetime(data["time"]).month == month - ws = self.get_windspeed(data, height)[mask] - return ws.mean(axis=0) - - def get_all_means(self, height): - """Get monthly means for all months across all given years for the - given height. - - Parameters - ---------- - height : int - Height to compute windspeed for. - - Returns - ------- - means : dict - Dictionary of windspeed_{height}m means for each month - """ - feature = self.features[self.heights.index(height)] - means = {i: [] for i in range(1, 13)} - for i, year in enumerate(self.years): - logger.info(f"Getting means for year={year}, feature={feature}.") - data = xr.open_dataset(self.input_files[i]) - for m in range(1, 13): - means[m].append(self.get_month_mean(data, height, month=m)) - means = {m: np.dstack(arr).mean(axis=-1) for m, arr in means.items()} - return means - - def write_csv(self, out, out_file): - """Write monthly means to a csv file. - - Parameters - ---------- - out : dict - Dictionary of windspeed_{height}m means for each month - out_file : str - Name of csv output file. - """ - logger.info(f"Writing means to {out_file}.") - out = { - f"{str(calendar.month_name[m])[:3]}_mean": v.flatten() - for m, v in out.items() - } - df = pd.DataFrame.from_dict(out) - df["latitude"] = self.lats.flatten() - df["longitude"] = self.lons.flatten() - df["gid"] = np.arange(len(df["latitude"])) - df.to_csv(out_file) - logger.info(f"Finished writing means for {out_file}.") - - @classmethod - def run(cls, era_pattern, years, features, out_pattern): - """Compute monthly windspeed means for the given heights, using the - given years of ERA data, and write the means to csv files for each - height. - - Parameters - ---------- - era_pattern : str - Pattern pointing to era files with u/v wind components at the given - heights. Must have a {year} format key. - years : list - List of ERA years to use for calculating means. - features : list - List of features to compute means for. e.g. ['windspeed_10m'] - out_pattern : str - Pattern pointing to csv files to write means to. Must have a - {feature} format key. - """ - em = cls(era_pattern=era_pattern, years=years, features=features) - for height, feature in zip(em.heights, em.features): - means = em.get_all_means(height) - out_file = out_pattern.format(feature=feature) - em.write_csv(means, out_file=out_file) - logger.info( - f"Finished writing means for years={years} and " - f"heights={em.heights}." - ) - - class BiasCorrectionFromMeans: """Class for getting bias correction factors from bias and base data files with precomputed monthly means. diff --git a/sup3r/preprocessing/data_handling/base.py b/sup3r/preprocessing/data_handling/base.py index 24ff15e91..1ffb48b2d 100644 --- a/sup3r/preprocessing/data_handling/base.py +++ b/sup3r/preprocessing/data_handling/base.py @@ -335,6 +335,8 @@ def _run_data_init_if_needed(self): extraction""" if any(self.features): self.data = self.run_all_data_init() + mask = np.isinf(self.data) + self.data[mask] = np.nan nan_perc = 100 * np.isnan(self.data).sum() / self.data.size if nan_perc > 0: msg = 'Data has {:.3f}% NaN values!'.format(nan_perc) From 85aa1dd30da9511a1088525de68dbea6fd3ef44a Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 9 Nov 2023 12:39:27 -0700 Subject: [PATCH 43/44] removed some bc from means code --- sup3r/bias/bias_calc.py | 2 +- sup3r/bias/bias_correct_means.py | 258 ------------------------------- 2 files changed, 1 insertion(+), 259 deletions(-) diff --git a/sup3r/bias/bias_calc.py b/sup3r/bias/bias_calc.py index 2b7b91125..2ac31677f 100644 --- a/sup3r/bias/bias_calc.py +++ b/sup3r/bias/bias_calc.py @@ -894,7 +894,7 @@ def _run_single(cls, return out -class MonthlyLinearCorrectionScalarOnly(MonthlyLinearCorrection): +class MonthlyScalarCorrection(MonthlyLinearCorrection): """Calculate linear correction *scalar factors to bias correct data. This typically used when base data is just monthly means and standard deviations cannot be computed. This is case for vortex data, for example. Thus, just diff --git a/sup3r/bias/bias_correct_means.py b/sup3r/bias/bias_correct_means.py index e4b6627e8..7e63858cf 100644 --- a/sup3r/bias/bias_correct_means.py +++ b/sup3r/bias/bias_correct_means.py @@ -7,22 +7,18 @@ import calendar -import json import logging import os from concurrent.futures import ThreadPoolExecutor, as_completed -import h5py import numpy as np import pandas as pd import rioxarray import xarray as xr from rex import Resource from scipy.interpolate import interp1d -from sklearn.neighbors import BallTree from sup3r.postprocessing.file_handling import OutputHandler, RexOutputs -from sup3r.preprocessing.feature_handling import Feature from sup3r.utilities import VERSION_RECORD logger = logging.getLogger(__name__) @@ -350,260 +346,6 @@ def run( vprep.write_data(fp_out, out) -class BiasCorrectionFromMeans: - """Class for getting bias correction factors from bias and base data files - with precomputed monthly means. - """ - - MIN_DISTANCE = 1e-12 - - def __init__(self, bias_fp, base_fp, dset, leaf_size=4): - """Parameters - ---------- - bias_fp : str - Path to csv file containing means for biased data - base_fp : str - Path to csv file containing means for unbiased data - dset : str - Name of dataset to compute bias correction factor for - leaf_size : int - Leaf size for ball tree used to match bias and base grids - """ - self.dset = dset - self.bias_fp = bias_fp - self.base_fp = base_fp - self.leaf_size = leaf_size - self.bias_means = pd.read_csv(bias_fp) - self.base_means = Resource(base_fp) - self.bias_meta = self.bias_means[["latitude", "longitude"]] - self.base_meta = pd.DataFrame(columns=["latitude", "longitude"]) - self.base_meta["latitude"] = self.base_means["latitude"] - self.base_meta["longitude"] = self.base_means["longitude"] - self._base_tree = None - logger.info( - "Finished initializing BiasCorrectionFromMeans for " - f"bias_fp={bias_fp}, base_fp={base_fp}, dset={dset}." - ) - - @property - def base_tree(self): - """Build ball tree from source_meta""" - if self._base_tree is None: - logger.info("Building ball tree for regridding.") - self._base_tree = BallTree( - np.deg2rad(self.base_meta), - leaf_size=self.leaf_size, - metric="haversine", - ) - return self._base_tree - - @property - def meta(self): - """Get a meta data dictionary on how these bias factors were - calculated - """ - meta = { - "base_fp": self.base_fp, - "bias_fp": self.bias_fp, - "dset": self.dset, - "class": str(self.__class__), - "version_record": VERSION_RECORD, - "NOTES": ("scalar factors computed from base_data / bias_data."), - } - return meta - - @property - def height(self): - """Get feature height""" - return Feature.get_height(self.dset) - - @property - def u_name(self): - """Get corresponding u component for given height""" - return f"u_{self.height}m" - - @property - def v_name(self): - """Get corresponding v component for given height""" - return f"v_{self.height}m" - - def get_base_data(self, knn=1): - """Get means for baseline data.""" - logger.info(f"Getting base data for {self.dset}.") - dists, gids = self.base_tree.query(np.deg2rad(self.bias_meta), k=knn) - mask = dists < self.MIN_DISTANCE - if mask.sum() > 0: - logger.info( - f"{np.sum(mask)} of {np.product(mask.shape)} " - "distances are zero." - ) - dists[mask] = self.MIN_DISTANCE - weights = 1 / dists - norm = np.sum(weights, axis=-1) - out = self.base_means[self.dset, gids] - out = np.einsum("ijk,ij->ik", out, weights) / norm[:, np.newaxis] - return out - - def get_bias_data(self): - """Get means for biased data.""" - logger.info(f"Getting bias data for {self.dset}.") - cols = [col for col in self.bias_means.columns if "mean" in col] - bias_data = self.bias_means[cols].to_numpy() - return bias_data - - def get_corrections(self, global_scalar=1, knn=1): - """Get bias correction factors.""" - logger.info(f"Getting correction factors for {self.dset}.") - base_data = self.get_base_data(knn=knn) - bias_data = self.get_bias_data() - scaler = global_scalar * base_data / bias_data - adder = np.zeros(scaler.shape) - - out = { - "latitude": self.bias_meta["latitude"], - "longitude": self.bias_meta["longitude"], - f"base_{self.dset}_mean": base_data, - f"bias_{self.dset}_mean": bias_data, - f"{self.dset}_adder": adder, - f"{self.dset}_scalar": scaler, - f"{self.dset}_global_scalar": global_scalar, - } - return out - - def get_uv_corrections(self, global_scalar=1, knn=1): - """Write windspeed bias correction factors for u/v components""" - u_out = self.get_corrections(global_scalar=global_scalar, knn=knn) - v_out = u_out.copy() - u_out[f"{self.u_name}_scalar"] = u_out[f"{self.dset}_scalar"] - v_out[f"{self.v_name}_scalar"] = v_out[f"{self.dset}_scalar"] - u_out[f"{self.u_name}_adder"] = u_out[f"{self.dset}_adder"] - v_out[f"{self.v_name}_adder"] = v_out[f"{self.dset}_adder"] - return u_out, v_out - - def write_output(self, fp_out, out): - """Write bias correction factors to h5 file.""" - logger.info(f"Writing correction factors to file: {fp_out}.") - with h5py.File(fp_out, "w") as f: - for dset, data in out.items(): - f.create_dataset(dset, data=data) - logger.info(f"Added {dset} to {fp_out}.") - for k, v in self.meta.items(): - f.attrs[k] = json.dumps(v) - logger.info(f"Finished writing output to {fp_out}.") - - @classmethod - def run( - cls, - bias_fp, - base_fp, - dset, - fp_out, - leaf_size=4, - global_scalar=1.0, - knn=1, - out_shape=None, - ): - """Run bias correction factor computation and write. - - Parameters - ---------- - bias_fp : str - Path to csv file containing means for biased data - base_fp : str - Path to csv file containing means for unbiased data - dset : str - Name of dataset to compute bias correction factor for - fp_out : str - Name of output file containing bias correction factors - leaf_size : int - Leaf size for ball tree used to match bias and base grids - global_scalar : float - Optional global scalar to use for multiplying all bias correction - factors. This can be used to improve systemic bias against - observation data. This is just written to output files, not - included in the stored bias correction factor values. - knn : int - Number of nearest neighbors to use when matching bias and base - grids. This should be based on difference in resolution. e.g. if - bias grid is 30km and base grid is 3km then knn should be 100 to - aggregate 3km to 30km. - out_shape : tuple | None - Optional 2D shape for output. If this is provided then the bias - correction arrays will be reshaped to this shape. If not provided - the arrays will stay flattened. When using this to write bc files - that will be used in a forward-pass routine this shape should be - the same as the spatial shape of the forward-pass input data. - """ - bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset, - leaf_size=leaf_size) - out = bc.get_corrections(global_scalar=global_scalar, knn=knn) - if out_shape is not None: - out = cls._reshape_output(out, out_shape) - bc.write_output(fp_out, out) - - @classmethod - def _reshape_output(cls, out, out_shape): - """Reshape output according to given output shape""" - for k, v in out.items(): - if k in ("latitude", "longitude"): - out[k] = np.array(v).reshape(out_shape) - elif not isinstance(v, (int, float)): - out[k] = np.array(v).reshape((*out_shape, 12)) - return out - - @classmethod - def run_uv( - cls, - bias_fp, - base_fp, - dset, - fp_pattern, - global_scalar=1.0, - knn=1, - out_shape=None, - ): - """Run bias correction factor computation and write. - - Parameters - ---------- - bias_fp : str - Path to csv file containing means for biased data - base_fp : str - Path to csv file containing means for unbiased data - dset : str - Name of dataset to compute bias correction factor for - fp_pattern : str - Pattern for output file. Should contain {feature} format key. - leaf_size : int - Leaf size for ball tree used to match bias and base grids - global_scalar : float - Optional global scalar to use for multiplying all bias correction - factors. This can be used to improve systemic bias against - observation data. This is just written to output files, not - included in the stored bias correction factor values. - knn : int - Number of nearest neighbors to use when matching bias and base - grids. This should be based on difference in resolution. e.g. if - bias grid is 30km and base grid is 3km then knn should be 100 to - aggregate 3km to 30km. - out_shape : tuple | None - Optional 2D shape for output. If this is provided then the bias - correction arrays will be reshaped to this shape. If not provided - the arrays will stay flattened. When using this to write bc files - that will be used in a forward-pass routine this shape should be - the same as the spatial shape of the forward-pass input data. - """ - bc = cls(bias_fp=bias_fp, base_fp=base_fp, dset=dset) - out_u, out_v = bc.get_uv_corrections( - global_scalar=global_scalar, knn=knn - ) - if out_shape is not None: - out_u = cls._reshape_output(out_u, out_shape) - out_v = cls._reshape_output(out_v, out_shape) - bc.write_output(fp_pattern.format(feature=bc.u_name), out_u) - bc.write_output(fp_pattern.format(feature=bc.v_name), out_v) - - class BiasCorrectUpdate: """Class for bias correcting existing files and writing corrected files.""" From 0d81d793fd489109d9bbfcbb311c121c3caf24b4 Mon Sep 17 00:00:00 2001 From: bnb32 Date: Thu, 9 Nov 2023 12:57:04 -0700 Subject: [PATCH 44/44] allow 2d input to regridder --- sup3r/utilities/regridder.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sup3r/utilities/regridder.py b/sup3r/utilities/regridder.py index 21e04534d..3e015b3dc 100644 --- a/sup3r/utilities/regridder.py +++ b/sup3r/utilities/regridder.py @@ -344,10 +344,12 @@ def __call__(self, data): Flattened regridded spatiotemporal data (spatial, temporal) """ - msg = 'Input data must be 3D (spatial_1, spatial_2, temporal)' - assert len(data.shape) == 3, msg + if len(data.shape) == 3: + data = data.reshape((data.shape[0] * data.shape[1], -1)) + msg = 'Input data must be 2D (spatial, temporal)' + assert len(data.shape) == 2, msg vals = [ - data[:, :, i].flatten()[np.array(self.indices)][np.newaxis] + data[np.array(self.indices), i][np.newaxis] for i in range(data.shape[-1]) ] vals = np.concatenate(vals, axis=0)