Skip to content

Commit

Permalink
Add dask_delayed nc write calls (#728)
Browse files Browse the repository at this point in the history
* Use dask delayed_write API with xarray to_netcdf to write files in pp write_dataset
add more preprocessor logging messages
clean up var_id references

* remove commented-out line

* add timing calls and original to_netcdf call to write_dataset comments
  • Loading branch information
wrongkindofdoctor authored Jan 2, 2025
1 parent b2e4427 commit a1a14f8
Showing 1 changed file with 32 additions and 9 deletions.
41 changes: 32 additions & 9 deletions src/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import xarray as xr
import collections
import re
import time

# TODO: Make the following lines a unit test
# import sys
Expand All @@ -28,7 +29,7 @@
import logging

_log = logging.getLogger(__name__)

write_times = []

def copy_as_alternate(old_v, **kwargs):
"""Wrapper for :py:func:`dataclasses.replace` that creates a copy of an
Expand Down Expand Up @@ -1046,8 +1047,8 @@ def query_catalog(self,

# define initial query dictionary with variable settings requirements that do not change if
# the variable is translated
case_d.set_query(var, path_regex)
case_d.set_query(var, path_regex)

# change realm key name if necessary
if cat.df.get('modeling_realm', None) is not None:
case_d.query['modeling_realm'] = case_d.query.pop('realm')
Expand Down Expand Up @@ -1105,9 +1106,9 @@ def query_catalog(self,
cat_subset.esmcat._df = self.check_group_daterange(cat_subset.df, date_range, var.log)
if cat_subset.df.empty:
raise util.DataRequestError(
f"check_group_daterange returned empty data frame for {var_id}"
f"check_group_daterange returned empty data frame for {var.name}"
f" case {case_name} in {data_catalog}, indicating issues with data continuity")
# v.log.debug("Read %d mb for %s.", cat_subset.esmcat._df.dtypes.nbytes / (1024 * 1024), v.full_name)
var.log.info(f"Converting {var.name} catalog subset to dataset dictionary")
# convert subset catalog to an xarray dataset dict
# and concatenate the result with the final dict
cat_subset_dict = cat_subset.to_dataset_dict(
Expand Down Expand Up @@ -1170,9 +1171,10 @@ def query_catalog(self,
# check that the trimmed variable data in the merged dataset matches the desired date range
if not var.is_static:
try:
var.log.info(f'Calling check_time_bounds for {var.name}')
self.check_time_bounds(cat_dict[case_name], var.translation, var.T.frequency)
except LookupError:
var.log.error(f'Time bounds in trimmed dataset for {var_id} in case {case_name} do not match'
var.log.error(f'Time bounds in trimmed dataset for {var.name} in case {case_name} do not match'
f'requested date_range.')
raise SystemExit("Terminating program")
return cat_dict
Expand Down Expand Up @@ -1401,13 +1403,32 @@ def write_dataset(self, var, ds):
unlimited_dims = []
else:
unlimited_dims = [var.T.name]
var_ds.to_netcdf(

# The following block is retained for time comparison with dask delayed write procedure
#var_ds.to_netcdf(
# path=var.dest_path,
# mode='w',
# **self.save_dataset_kwargs,
# unlimited_dims=unlimited_dims
#)
#ds.close()

# Uncomment the timing lines and log calls if desired
#start_time = time.monotonic()
delayed_write = var_ds.to_netcdf(
path=var.dest_path,
mode='w',
**self.save_dataset_kwargs,
unlimited_dims=unlimited_dims
unlimited_dims=unlimited_dims,
compute=False
)
ds.close()
delayed_write.compute()
delayed_write.close()
#end_time = time.monotonic()
#var.log.info(f'Time to write file {var.dest_path}: {str(datetime.timedelta(seconds=end_time - start_time))}')
#dt = datetime.timedelta(seconds=end_time - start_time)
#write_times.append(dt.total_seconds())
#var.log.info(f'Total write time: {str(sum(write_times))} s')

def write_ds(self, case_list: dict,
catalog_subset: collections.OrderedDict,
Expand Down Expand Up @@ -1480,13 +1501,15 @@ def process(self,
for v in case_list[case_name].varlist.iter_vars():
tv_name = v.translation.name
# todo: maybe skip this if no standard_name attribute for v in case_xr_dataset
v.log.info(f'Calling parse_ds for {v.name}')
var_xr_dataset = self.parse_ds(v, case_xr_dataset)
varlist_ex = [v_l.translation.name for v_l in case_list[case_name].varlist.iter_vars()]
if tv_name in varlist_ex:
varlist_ex.remove(tv_name)
for v_d in var_xr_dataset.variables:
if v_d not in varlist_ex:
cat_subset[case_name].update({v_d: var_xr_dataset[v_d]})
v.log.info(f'Calling preprocessing functions for {v.name}')
pp_func_dataset = self.execute_pp_functions(v,
cat_subset[case_name],
work_dir=model_work_dir[case_name],
Expand Down

0 comments on commit a1a14f8

Please sign in to comment.