Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combine variables in coherent dataset #130

Merged
merged 14 commits into from
Aug 24, 2023
2 changes: 2 additions & 0 deletions optim_esm_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
from .analyze.cmip_handler import read_ds
from .analyze.io import load_glob
from .plotting.map_maker import MapMaker
from .utils import print_versions
from .config import get_logger
1 change: 1 addition & 0 deletions optim_esm_tools/analyze/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from . import region_finding
from . import time_statistics
from . import concise_dataframe
from . import combine_variables
9 changes: 8 additions & 1 deletion optim_esm_tools/analyze/cmip_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def read_ds(
_ma_window: ty.Optional[int] = None,
_cache: bool = True,
_file_name: str = None,
_skip_folder_info: bool = False,
**kwargs,
) -> xr.Dataset:
"""Read a dataset from a folder called "base".
Expand All @@ -105,6 +106,8 @@ def read_ds(
_cache (bool, optional): cache the dataset with it's extra fields to alow faster
(re)loading. Defaults to True.
_file_name (str, optional): name to match. Defaults to configs settings.
_skip_folder_info (bool, optional): if set to True, do not infer the properties from the
(synda) path of the file

kwargs:
any kwargs are passed onto transform_ds.
Expand Down Expand Up @@ -173,7 +176,11 @@ def read_ds(
folders = base.split(os.sep)

# start with -1 (for i==0)
metadata = {k: folders[-i - 1] for i, k in enumerate(_FOLDER_FMT[::-1])}
metadata = (
{}
if _skip_folder_info
else {k: folders[-i - 1] for i, k in enumerate(_FOLDER_FMT[::-1])}
)
metadata.update(dict(path=base, file=res_file, running_mean_period=_ma_window))

data_set.attrs.update(metadata)
Expand Down
226 changes: 226 additions & 0 deletions optim_esm_tools/analyze/combine_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import os
import typing as ty
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

import optim_esm_tools as oet
from optim_esm_tools.analyze.time_statistics import default_thresholds


class VariableMerger:
full_paths = None
source_files: ty.Mapping
common_mask: xr.DataArray

def __init__(self, paths, other_paths=None, merge_method='logical_or'):
self.mask_paths = paths
self.other_paths = other_paths or []
self.merge_method = merge_method
source_files, common_mask = self.process_masks()
self.source_files = source_files
self.common_mask = common_mask

def squash_sources(self) -> xr.Dataset:
common_mask = (
self.common_mask
if self.common_mask.dims == ('lat', 'lon')
else oet.analyze.xarray_tools.reverse_name_mask_coords(self.common_mask)
)

new_ds = defaultdict(dict)
new_ds['data_vars']['global_mask'] = common_mask
for var, path in self.source_files.items():
_ds = oet.load_glob(path)
new_ds['data_vars'][var] = (
_ds[var]
.where(common_mask)
.mean(oet.config.config['analyze']['lon_lat_dim'].split(','))
)
new_ds['data_vars'][var].attrs = _ds[var].attrs

# Make one copy - just use the last dataset
new_ds['data_vars']['cell_area'] = _ds['cell_area']
keys = sorted(list(self.source_files.keys()))
new_ds['attrs'] = dict(
variables=keys,
source_files=[self.source_files[k] for k in keys],
mask_files=sorted(self.mask_paths),
)
try:
new_ds = xr.Dataset(**new_ds)
except TypeError as e: # pragma: no cover
oet.get_logger.warning(f'Ran into {e} fallback method because of cftime')
# Stupid cftime can't compare it's own formats
data_vars = new_ds.pop('data_vars')
new_ds = xr.Dataset(**new_ds)

# But xarray can fudge something along the way!
for k, v in data_vars.items():
new_ds[k] = v
return new_ds

def make_fig(self, new_ds=None, fig_kw=None):
new_ds = new_ds or self.squash_sources()
variables = list(new_ds.attrs['variables'])
mapping = {str(i): v for i, v in enumerate(variables)}
keys = list(mapping.keys()) + ['t']

fig_kw = fig_kw or dict(
mosaic=''.join(f'{k}.\n' for k in keys),
figsize=(17, 4 * ((1 + len(keys)) / 3)),
gridspec_kw=dict(width_ratios=[1, 1], wspace=0.1, hspace=0.05),
)

_, axes = plt.subplot_mosaic(**fig_kw)

if len(keys) > 1:
for k in keys[1:]:
axes[k].sharex(axes[keys[0]])

for key, var in mapping.items():
plt.sca(axes[key])
plot_kw = dict(label=var)
oet.plotting.map_maker.plot_simple(new_ds, var, **plot_kw)
plt.legend(loc='center left')

ax = plt.gcf().add_subplot(
1, 2, 2, projection=oet.plotting.plot.get_cartopy_projection()
)
oet.plotting.map_maker.overlay_area_mask(
new_ds.where(new_ds['global_mask']).copy(), ax=ax
)
res_f, tips = result_table(new_ds)
add_table(res_f=res_f, tips=tips, ax=axes['t'])

def process_masks(self) -> ty.Tuple[dict, xr.DataArray]:
source_files = {}
common_mask = None
for path in self.mask_paths:
ds = oet.load_glob(path)
# Source files may be non-unique!
source_files[ds.attrs['variable_id']] = ds.attrs['file']
common_mask = self.combine_masks(common_mask, ds)
for other_path in self.other_paths:
if other_path == '': # pragma: no cover
continue
ds = oet.load_glob(other_path)
# Source files may be non-unique!
var = ds.attrs['variable_id']
if var not in source_files:
source_files[var] = ds.attrs['file']
return source_files, common_mask

def combine_masks(
self,
common_mask: ty.Optional[xr.DataArray],
other_dataset: xr.Dataset,
field: ty.Optional[str] = None,
) -> xr.DataArray:
field = field or (
'global_mask' if 'global_mask' in other_dataset else 'cell_area'
)
is_the_first_instance = common_mask is None
if is_the_first_instance:
return other_dataset[field].astype(np.bool_)
if self.merge_method == 'logical_or':
return common_mask | other_dataset[field].astype(np.bool_)

raise NotImplementedError(f'No such method as {self.merge_method}')


def change_plt_table_height():
"""Increase the height of rows in plt.table

Unfortunately, the options that you can pass to plt.table are insufficient to render a table
that has rows with sufficient heights that work with a font that is not the default. From the
plt.table implementation, I figured I could change these (rather patchy) lines in the source
code:
https://github.com/matplotlib/matplotlib/blob/b7dfdc5c97510733770429f38870a623426d0cdc/lib/matplotlib/table.py#L391

Matplotlib version matplotlib==3.7.2
"""
import matplotlib

print('Change default plt.table row height')

def _approx_text_height(self):
return 1.5 * (
self.FONTSIZE / 72.0 * self.figure.dpi / self._axes.bbox.height * 1.2
)

matplotlib.table.Table._approx_text_height = _approx_text_height


def add_table(res_f, tips, ax=None, fontsize=16, pass_color=(0.75, 1, 0.75)):
ax = ax or plt.gcf().add_subplot(2, 2, 4)
ax.axis('off')
ax.axis('tight')

table = ax.table(
cellText=res_f.values,
rowLabels=res_f.index,
colLabels=res_f.columns,
cellColours=[
[(pass_color if v else [1, 1, 1]) for v in row] for row in tips.values
],
loc='bottom',
colLoc='center',
rowLoc='center',
cellLoc='center',
)
table.set_fontsize(fontsize)


def result_table(ds, formats=None):
res = {
field: summarize_stats(ds, field, path)
for field, path in zip(ds.attrs['variables'], ds.attrs['source_files'])
}
thrs = default_thresholds()
is_tip = pd.DataFrame(
{
k: {
t: (thrs[t][0](v, thrs[t][1]) if v is not None else False)
for t, v in d.items()
}
for k, d in res.items()
}
).T

formats = formats or dict(
n_breaks='.0f',
p_symmetry='.3f',
p_dip='.3f',
max_jump='.1f',
n_std_global='.1f',
)
res_f = pd.DataFrame(res).T
for k, f in formats.items():
res_f[k] = res_f[k].map(f'{{:,{f}}}'.format)

order = list(formats.keys())
return res_f[order], is_tip[order]


def summarize_stats(ds, field, path):
return {
'n_breaks': oet.analyze.time_statistics.calculate_n_breaks(ds, field=field),
'p_symmetry': oet.analyze.time_statistics.calculate_symmetry_test(
ds, field=field
),
'p_dip': oet.analyze.time_statistics.calculate_dip_test(ds, field=field),
'n_std_global': oet.analyze.time_statistics.n_times_global_std(
ds=oet.load_glob(path).where(ds['global_mask'])
),
'max_jump': oet.analyze.time_statistics.calculate_max_jump_in_std_history(
ds=oet.load_glob(path).where(ds['global_mask']), mask=ds['global_mask']
),
}


if __name__ == '__main__':
change_plt_table_height()
33 changes: 33 additions & 0 deletions optim_esm_tools/analyze/time_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing as ty
from functools import partial
import os
import operator


class TimeStatistics:
Expand Down Expand Up @@ -77,6 +78,38 @@ def calculate_statistics(self) -> ty.Dict[str, ty.Optional[float]]:
}


def default_thresholds(
max_jump=None,
p_dip=None,
p_symmetry=None,
n_breaks=None,
n_std_global=None,
):
return dict(
max_jump=(
operator.ge,
max_jump or float(oet.config.config['tipping_thresholds']['max_jump']),
),
p_dip=(
operator.le,
p_dip or float(oet.config.config['tipping_thresholds']['p_dip']),
),
p_symmetry=(
operator.le,
p_symmetry or float(oet.config.config['tipping_thresholds']['p_symmetry']),
),
n_breaks=(
operator.ge,
n_breaks or float(oet.config.config['tipping_thresholds']['n_breaks']),
),
n_std_global=(
operator.ge,
n_std_global
or float(oet.config.config['tipping_thresholds']['n_std_global']),
),
)


def _get_ds_global(ds, **read_kw):
path = ds.attrs['file']
if os.path.exists(path):
Expand Down
7 changes: 6 additions & 1 deletion optim_esm_tools/optim_esm_conf.ini
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ excluded =
; # Projection fails
; DKRZ MPI-ESM1-2-LR ssp119 r1i1p1f1 siconc * * v20210901


[tipping_thresholds]
max_jump=4
p_dip=0.01
p_symmetry=0.001
n_breaks=1
n_std_global=3

[log]
logging_level = WARNING
Expand Down
64 changes: 64 additions & 0 deletions test/test_combine_variables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import tempfile
import optim_esm_tools as oet
import pandas as pd
import os
from unittest import TestCase
import numpy as np


class TestCombineVariables(TestCase):
def test_merge_two(self, nx=5, ny=20, is_match=(True, True)):
with tempfile.TemporaryDirectory() as temp_dir:
# temp_dir = '/home/aangevaare/software/paper_oet/notebooks/'
kw = dict(len_x=nx, len_y=ny, len_time=20, add_nans=False)
names = list('abcefg')[: len(is_match)]
paths = [os.path.join(temp_dir, f'{x}.nc') for x in names]
post_path = []
for name, path in zip(names, paths):
ds = oet._test_utils.minimal_xr_ds(**kw)
ds = ds.rename(dict(var=name))
import cftime

ds['time'] = [
cftime.datetime(y + 2000, 1, 1) for y in range(len(ds['time']))
]
ds['lat'].attrs.update(
{
'standard_name': 'latitude',
'long_name': 'Latitude',
'units': 'degrees_north',
'axis': 'Y',
}
)
ds['lon'].attrs.update(
{
'standard_name': 'longitude',
'long_name': 'Longitude',
'units': 'degrees_east',
'axis': 'X',
}
)
assert name in ds

ds.attrs.update(dict(file=path, variable_id=name))
ds.to_netcdf(path)
head, tail = os.path.split(path)
post_ds = oet.read_ds(head, _file_name=tail, _skip_folder_info=True)
post_path.append(post_ds.attrs['file'])

merger = oet.analyze.combine_variables.VariableMerger(
paths=[p for p, m in zip(post_path, is_match) if m],
other_paths=[p for p, m in zip(post_path, is_match) if not m],
merge_method='logical_or',
)
merged = merger.squash_sources()
for n, m in zip(names, is_match):
if m:
assert n in merged.data_vars
oet.analyze.combine_variables.change_plt_table_height()
merger.make_fig(merged)
return merger

def test_merge_three(self):
merger = self.test_merge_two(is_match=(True, True, False))
assert merger.other_paths