diff --git a/src/preprocessor.py b/src/preprocessor.py index 53dff151d..cdce8e22e 100644 --- a/src/preprocessor.py +++ b/src/preprocessor.py @@ -713,7 +713,9 @@ def cast_to_cftime(self, dt: datetime.datetime, calendar): ('tm_year', 'tm_mon', 'tm_mday', 'tm_hour', 'tm_min', 'tm_sec')) return cftime.datetime(*tt, calendar=calendar) - def check_time_bounds(self, ds, var: translation.TranslatedVarlistEntry, freq: str): + def check_time_bounds(self, ds: xr.Dataset, + var: translation.TranslatedVarlistEntry, + freq: str): """Parse quantities related to the calendar for time-dependent data and truncate the date range of model dataset *ds*. @@ -746,20 +748,20 @@ def check_time_bounds(self, ds, var: translation.TranslatedVarlistEntry, freq: s # do not begin at hour zero if dt_range.start.lower.hour != t_start.hour: var.log.info("Variable %s data starts at hour %s", var.full_name, t_start.hour) - dt_start_upper_new = datetime.datetime(dt_range.start.upper.year, - dt_range.start.upper.month, - dt_range.start.upper.day, + dt_start_lower_new = datetime.datetime(t_start.year, + t_start.month, + t_start.day, t_start.hour, t_start.minute, t_start.second) - dt_start_upper = self.cast_to_cftime(dt_start_upper_new, cal) + dt_start_lower = self.cast_to_cftime(dt_start_lower_new, cal) else: - dt_start_upper = self.cast_to_cftime(dt_range.start.upper, cal) + dt_start_lower = self.cast_to_cftime(dt_range.start.lower, cal) if dt_range.end.lower.hour != t_end.hour: var.log.info("Variable %s data ends at hour %s", var.full_name, t_end.hour) - dt_end_lower_new = datetime.datetime(dt_range.end.lower.year, - dt_range.end.lower.month, - dt_range.end.lower.day, + dt_end_lower_new = datetime.datetime(t_end.year, + t_end.month, + t_end.day, t_end.hour, t_end.minute, t_end.second) @@ -769,10 +771,10 @@ def check_time_bounds(self, ds, var: translation.TranslatedVarlistEntry, freq: s # only check that up to monthly precision for monthly or longer data if freq in ['mon', 'year']: - if t_start.year > dt_start_upper.year or \ - t_start.year == dt_start_upper.year and t_start.month > dt_start_upper.month: + if t_start.year > dt_start_lower.year or \ + t_start.year == dt_start_lower.year and t_start.month > dt_start_lower.month: err_str = (f"Error: dataset start ({t_start}) is after " - f"requested date range start ({dt_start_upper}).") + f"requested date range start ({dt_start_lower}).") var.log.error(err_str) raise IndexError(err_str) if t_end.year < dt_end_lower.year or \ @@ -782,9 +784,9 @@ def check_time_bounds(self, ds, var: translation.TranslatedVarlistEntry, freq: s var.log.error(err_str) raise IndexError(err_str) else: - if t_start > dt_start_upper: + if t_start > dt_start_lower: err_str = (f"Error: dataset start ({t_start}) is after " - f"requested date range start ({dt_start_upper}).") + f"requested date range start ({dt_start_lower}).") var.log.error(err_str) raise IndexError(err_str) if t_end < dt_end_lower: @@ -816,75 +818,126 @@ def check_multichunk(self, group_df: pd.DataFrame, case_dr, log) -> pd.DataFrame case_dr: requested daterange of POD log: log file """ - if 'chunk_freq' in group_df: - chunks = group_df['chunk_freq'].unique() - if len(chunks) > 1: - for i, c in enumerate(chunks): - chunks[i] = int(c.replace('yr', '')) - chunks = -np.sort(-chunks) - case_dt = int(str(case_dr.end)[:4]) - int(str(case_dr.start)[:4]) + 1 - for c in chunks: - if case_dt % c == 0: - grabbed_chunk = str(c) + 'yr' - log.warning("Multiple values for 'chunk_freq' found in dataset " - "only grabbing data with 'chunk_freq': %s", grabbed_chunk) - break - group_df = group_df[group_df['chunk_freq'] == grabbed_chunk] + chunks = group_df['chunk_freq'].unique() + if len(chunks) > 1: + for i, c in enumerate(chunks): + chunks[i] = int(c.replace('yr', '')) + chunks = -np.sort(-chunks) + case_dt = int(str(case_dr.end)[:4]) - int(str(case_dr.start)[:4]) + 1 + for c in chunks: + if case_dt % c == 0: + grabbed_chunk = str(c) + 'yr' + log.warning("Multiple values for 'chunk_freq' found in dataset " + "only grabbing data with 'chunk_freq': %s", grabbed_chunk) + break + group_df = group_df[group_df['chunk_freq'] == grabbed_chunk] return pd.DataFrame.from_dict(group_df).reset_index() - def check_group_daterange(self, group_df: pd.DataFrame, case_dr, + def crop_date_range(self, case_date_range: util.DateRange, xr_ds, time_coord) -> xr.Dataset: + xr_ds = xr.decode_cf(xr_ds, + decode_coords=True, # parse coords attr + decode_times=True, + use_cftime=True # use cftime instead of np.datetime6 + ) + cal = xr_ds[time_coord.name].attrs.get('calendar', 'noleap') + + ds_date_time = xr_ds[time_coord.name].values + ds_start_time = ds_date_time[0] + ds_end_time = ds_date_time[-1] + # force hours in dataset to match date range if frequency is daily, monthly, annual + if ds_start_time.hour != case_date_range.start_datetime.hour and case_date_range.precision < 4: + dt_start_new = datetime.datetime(ds_start_time.year, + ds_start_time.month, + ds_start_time.day, + ds_start_time.hour, + ds_start_time.minute, + ds_start_time.second) + ds_start = self.cast_to_cftime(dt_start_new, cal) + else: + ds_start = self.cast_to_cftime(ds_start_time, cal) + if ds_end_time.hour != case_date_range.end_datetime.hour and case_date_range.precision < 4: + dt_end_new = datetime.datetime(ds_end_time.year, + ds_end_time.month, + ds_end_time.day, + ds_end_time.hour, + ds_end_time.minute, + ds_end_time.second) + ds_end = self.cast_to_cftime(dt_end_new, cal) + else: + ds_end = self.cast_to_cftime(ds_end_time, cal) + date_range_cf_start = self.cast_to_cftime(case_date_range.start.lower, cal) + date_range_cf_end = self.cast_to_cftime(case_date_range.end.lower, cal) + + if ds_start < date_range_cf_start and ds_end < date_range_cf_start or \ + ds_end > date_range_cf_end and ds_start > date_range_cf_end: + new_xr_ds = None + # dataset falls entirely within user-specified date range + elif ds_start >= date_range_cf_start and ds_end <= date_range_cf_end: + new_xr_ds = xr_ds.sel({time_coord.name: slice(ds_start, ds_end)}) + # dataset overlaps user-specified date range start + elif date_range_cf_start < ds_start and \ + date_range_cf_start <= ds_end <= date_range_cf_end: + new_xr_ds = xr_ds.sel({time_coord.name: slice(date_range_cf_start, ds_end)}) + # dataset overlaps user-specified date range end + elif date_range_cf_start < ds_start <= date_range_cf_end <= ds_end: + new_xr_ds = xr_ds.sel({time_coord.name: slice(ds_start, date_range_cf_end)}) + # dataset contains all of requested date range + elif date_range_cf_start>=ds_start and date_range_cf_end<=ds_end: + new_xr_ds = xr_ds.sel({time_coord.name: slice(date_range_cf_start, date_range_cf_end)}) + + return new_xr_ds + + def check_group_daterange(self, df: pd.DataFrame, date_range: util.DateRange, log=_log) -> pd.DataFrame: """Sort the files found for each experiment by date, verify that the date ranges contained in the files are contiguous in time and that the date range of the files spans the query date range. Args: - group_df (Pandas Dataframe): - case_dr: requested daterange of POD + df (Pandas Dataframe): + date_range: requested daterange of POD log: log file """ date_col = "date_range" - delimiters = ",.!?/&-:;@_'\\s+" - if hasattr(group_df, 'time_range'): + if hasattr(df, 'time_range'): start_times = [] end_times = [] - for tr in group_df['time_range'].values: + for tr in df['time_range'].values: tr = tr.replace(' ', '').replace('-', '').replace(':', '') start_times.append(tr[0:len(tr)//2]) end_times.append(tr[len(tr)//2:]) - group_df['start_time'] = pd.Series(start_times) - group_df['end_time'] = pd.Series(end_times) + df['start_time'] = pd.Series(start_times) + df['end_time'] = pd.Series(end_times) else: raise AttributeError('Data catalog is missing the attribute `time_range`;' ' this is a required entry.') try: - start_time_vals = self.normalize_group_time_vals(group_df['start_time'].values.astype(str)) - end_time_vals = self.normalize_group_time_vals(group_df['end_time'].values.astype(str)) + start_time_vals = self.normalize_group_time_vals(df['start_time'].values.astype(str)) + end_time_vals = self.normalize_group_time_vals(df['end_time'].values.astype(str)) if not isinstance(start_time_vals[0], datetime.date): date_format = dl.date_fmt(start_time_vals[0]) # convert start_times to date_format for all files in query - group_df['start_time'] = start_time_vals - group_df['start_time'] = group_df['start_time'].apply(lambda x: + df['start_time'] = start_time_vals + df['start_time'] = df['start_time'].apply(lambda x: datetime.datetime.strptime(x, date_format)) # convert end_times to date_format for all files in query - group_df['end_time'] = end_time_vals - group_df['end_time'] = group_df['end_time'].apply(lambda x: - datetime.datetime.strptime(x, date_format)) + df['end_time'] = end_time_vals + df['end_time'] = df['end_time'].apply(lambda x: + datetime.datetime.strptime(x, date_format)) # method throws ValueError if ranges aren't contiguous - dates_df = group_df.loc[:, ['start_time', 'end_time']] + dates_df = df.loc[:, ['start_time', 'end_time']] date_range_vals = [] - for idx, x in enumerate(group_df.values): + for idx, x in enumerate(df.values): st = dates_df.at[idx, 'start_time'] en = dates_df.at[idx, 'end_time'] date_range_vals.append(util.DateRange(st, en)) - group_df = group_df.assign(date_range=date_range_vals) + group_df = df.assign(date_range=date_range_vals) sorted_df = group_df.sort_values(by=date_col) files_date_range = util.DateRange.from_contiguous_span( *(sorted_df[date_col].to_list()) ) # throws AssertionError if we don't span the query range - # TODO: define self.attrs.DateRange from runtime config info # assert files_date_range.contains(self.attrs.date_range) # throw out df entries not in date_range return_df = [] @@ -893,12 +946,13 @@ def check_group_daterange(self, group_df: pd.DataFrame, case_dr, if pd.isnull(cat_row['start_time']): continue else: - st = dl.dt_to_str(cat_row['start_time']) - et = dl.dt_to_str(cat_row['end_time']) - stin = dl.Date(st) in case_dr - etin = dl.Date(et) in case_dr - if stin and etin: - return_df.append(cat_row.to_dict()) + ds_st = cat_row['start_time'] + ds_et = cat_row['end_time'] + # date range includes entire or part of dataset + if ds_st>=date_range.start.lower and ds_et=date_range.start.lower or \ + ds_st <= date_range.end.lower < ds_et: + return_df.append(cat_row) return pd.DataFrame.from_dict(return_df) except ValueError: @@ -1030,8 +1084,9 @@ def query_catalog(self, # Get files in specified date range # https://intake-esm.readthedocs.io/en/stable/how-to/modify-catalog.html if not var.is_static: - cat_subset.esmcat._df = self.check_multichunk(cat_subset.df, date_range, var.log) - cat_subset.esmcat._df = self.check_group_daterange(cat_subset.df, date_range) + if "chunk_freq" in cat_subset.df: + cat_subset.esmcat._df = self.check_multichunk(cat_subset.df, date_range, var.log) + cat_subset.esmcat._df = self.check_group_daterange(cat_subset.df, date_range, var.log) if cat_subset.df.empty: raise util.DataRequestError( f"check_group_daterange returned empty data frame for {var_id}" @@ -1039,9 +1094,10 @@ def query_catalog(self, # v.log.debug("Read %d mb for %s.", cat_subset.esmcat._df.dtypes.nbytes / (1024 * 1024), v.full_name) # convert subset catalog to an xarray dataset dict # and concatenate the result with the final dict - cat_subset_df = cat_subset.to_dataset_dict( + cat_subset_dict = cat_subset.to_dataset_dict( progressbar=False, - xarray_open_kwargs=self.open_dataset_kwargs + xarray_open_kwargs=self.open_dataset_kwargs, + aggregate=False ) # NOTE: The time_range of each file in cat_subset_df must be in a specific # order in order for xr.concat() to work correctly. In the current implementation, @@ -1051,27 +1107,33 @@ def query_catalog(self, # tl;dr hic sunt dracones var_xr = [] if not var.is_static: - time_sort_dict = {f: cat_subset_df[f].time.values[0] - for f in list(cat_subset_df)} + time_sort_dict = {f: cat_subset_dict[f].time.values[0] + for f in list(cat_subset_dict)} time_sort_dict = dict(sorted(time_sort_dict.items(), key=lambda item: item[1])) for k in list(time_sort_dict): - if not var_xr: - var_xr = cat_subset_df[k] + cat_subset_dict[k] = self.crop_date_range(date_range, + cat_subset_dict[k], + var.T) + if cat_subset_dict[k] is None: + continue else: - var_xr = xr.concat([var_xr, cat_subset_df[k]], "time") + if not var_xr: + var_xr = cat_subset_dict[k] + else: + var_xr = xr.concat([var_xr, cat_subset_dict[k]], var.T.name) else: # get xarray dataset for static variable - cat_index = [k for k in cat_subset_df.keys()][0] + cat_index = [k for k in cat_subset_dict.keys()][0] if not var_xr: - var_xr = cat_subset_df[cat_index] + var_xr = cat_subset_dict[cat_index] else: if var.Y is not None: - var_xr = xr.concat([var_xr, cat_subset_df[cat_index]], var.Y.name) + var_xr = xr.concat([var_xr, cat_subset_dict[cat_index]], var.Y.name) elif var.X is not None: - var_xr = xr.concat([var_xr, cat_subset_df[cat_index]], var.X.name) + var_xr = xr.concat([var_xr, cat_subset_dict[cat_index]], var.X.name) else: - var_xr = xr.concat([var_xr, cat_subset_df.values[cat_index]], var.N.name) + var_xr = xr.concat([var_xr, cat_subset_dict.values[cat_index]], var.N.name) for att in drop_atts: if var_xr.get(att, None) is not None: var_xr = var_xr.drop_vars(att) @@ -1090,14 +1152,14 @@ def query_catalog(self, cat_dict[case_name] = var_xr else: cat_dict[case_name] = xr.merge([cat_dict[case_name], var_xr], compat='no_conflicts') - # check that start and end times include runtime startdate and enddate + + # check that the trimmed variable data in the merged dataset matches the desired date range if not var.is_static: - var_obj = var.translation try: - self.check_time_bounds(cat_dict[case_name], var_obj, freq) + self.check_time_bounds(cat_dict[case_name], var.translation, freq) except LookupError: - var.log.error(f'Data not found in catalog query for {var_id}' - f' for requested date_range.') + var.log.error(f'Time bounds in trimmed dataset for {var_id} in case {case_name} do not match' + f'requested date_range.') raise SystemExit("Terminating program") return cat_dict @@ -1153,6 +1215,7 @@ def open_dataset_kwargs(self): "decode_times": False, "use_cftime": False, "chunks": "auto" + } @property @@ -1435,6 +1498,7 @@ def write_pp_catalog(self, # each key is a case for case_name, case_dict in cases.items(): ds_match = input_catalog_ds[case_name] + ds_match.time.values.sort() for var in case_dict.varlist.iter_vars(): var_name = var.translation.name ds_var = ds_match.data_vars.get(var_name, None) @@ -1446,13 +1510,11 @@ def write_pp_catalog(self, for c in columns: if key.split('intake_esm_attrs:')[1] == c: d[c] = val - if var.translation.convention == 'no_translation': - d.update({'project_id': var.convention}) - else: - d.update({'project_id': var.translation.convention}) + + d.update({'project_id': var.translation.convention}) d.update({'path': var.dest_path}) - d.update({'start_time': util.cftime_to_str(ds_match.time.values[0])}) - d.update({'end_time': util.cftime_to_str(ds_match.time.values[-1])}) + d.update({'time_range': f'{util.cftime_to_str(ds_match.time.values[0]).replace('-',':')}-' + f'{util.cftime_to_str(ds_match.time.values[-1]).replace('-',':')}'}) d.update({'standard_name': ds_match[var.name].attrs['standard_name']}) cat_entries.append(d) diff --git a/src/util/catalog.py b/src/util/catalog.py index 268f75b3c..9618b3a75 100644 --- a/src/util/catalog.py +++ b/src/util/catalog.py @@ -69,7 +69,7 @@ def define_pp_catalog_assets(config, cat_file_name: str) -> dict: ) # add columns required for GFDL/CESM institutions and MDTF-diagnostics functionality - append_atts = ['chunk_freq', 'path', 'standard_name', 'start_time', 'end_time'] + append_atts = ['chunk_freq', 'path', 'standard_name', "time_range"] for att in append_atts: cat_dict["attributes"].append( dict(column_name=att) diff --git a/src/util/datelabel.py b/src/util/datelabel.py index 9d816c77d..b70473835 100644 --- a/src/util/datelabel.py +++ b/src/util/datelabel.py @@ -625,11 +625,15 @@ def __init__(self, start, end=None, precision=None, log=_log): # start: split_str[start index of 0: nelem_half elements total], end[start index at nelem_half, (start, end) = ''.join(split_str[:nelem_half]), ''.join(split_str[nelem_half:]) + elif len(start) == 2: (start, end) = start else: raise ValueError('Bad input ({},{})'.format(start, end)) - + if isinstance(start, str): + start = start.replace(':','') + if isinstance(end, str): + end = end.replace(':','') dt0, prec0 = self._coerce_to_datetime(start, is_lower=True) dt1, prec1 = self._coerce_to_datetime(end, is_lower=False) if not (dt0 < dt1):