Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache flag-related objects and refactor extraction of flags #520

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 93 additions & 83 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,41 +1114,6 @@ def __getattr__(self, attr):
)


def create_flag_dict(da) -> Mapping[Hashable, FlagParam]:
"""
Return possible flag meanings and associated bitmask/values.

The mapping values are a tuple containing a bitmask and a value. Either
can be None.
If only a bitmask: Independent flags.
If only a value: Mutually exclusive flags.
If both: Mix of independent and mutually exclusive flags.
"""
if not da.cf.is_flag_variable:
raise ValueError(
"Comparisons are only supported for DataArrays that represent "
"CF flag variables. .attrs must contain 'flag_meanings' and "
"'flag_values' or 'flag_masks'."
)

flag_meanings = da.attrs["flag_meanings"].split(" ")
n_flag = len(flag_meanings)

flag_values = da.attrs.get("flag_values", [None] * n_flag)
flag_masks = da.attrs.get("flag_masks", [None] * n_flag)

if not (n_flag == len(flag_values) == len(flag_masks)):
raise ValueError(
"Not as many flag meanings as values or masks. "
"Please check the flag_meanings, flag_values, flag_masks attributes "
)

flag_params = tuple(
FlagParam(mask, value) for mask, value in zip(flag_masks, flag_values)
)
return dict(zip(flag_meanings, flag_params))


class CFAccessor:
"""
Common Dataset and DataArray accessor functionality.
Expand All @@ -1157,23 +1122,64 @@ class CFAccessor:
def __init__(self, obj):
self._obj = obj
self._all_cell_measures = None
self._flag_dict: Mapping[Hashable, FlagParam] | None = None

def __setstate__(self, d):
self.__dict__ = d

def _assert_valid_other_comparison(self, other):
# TODO cache this property
flag_dict = create_flag_dict(self._obj)
if other not in flag_dict:
@property
def flag_dict(self) -> Mapping[Hashable, FlagParam]:
"""
Return possible flag meanings and associated bitmask/values.

The mapping values are a tuple containing a bitmask and a value. Either
can be None.
If only a bitmask: Independent flags.
If only a value: Mutually exclusive flags.
If both: Mix of independent and mutually exclusive flags.
"""
if self._flag_dict is not None:
return self._flag_dict
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that you can change the underlying values in the DataArray and then the flags will be wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot that if for DataArray.data most operations are not done in place (you get a new DataArray), changing DataArray.attrs is easily done in place. I only see either checking if any of the attributes have changed, or urging to use DataArray.assign_attrs() in the documentation...


da = self._obj

if not da.cf.is_flag_variable:
raise ValueError(
f"Did not find flag value meaning [{other}] in known flag meanings: [{flag_dict.keys()!r}]"
"Comparisons are only supported for DataArrays that represent "
"CF flag variables. .attrs must contain 'flag_meanings' and "
"'flag_values' or 'flag_masks'."
)
if flag_dict[other].flag_mask is not None:

flag_meanings = da.attrs["flag_meanings"].split(" ")
n_flag = len(flag_meanings)

flag_values = da.attrs.get("flag_values", [None] * n_flag)
flag_masks = da.attrs.get("flag_masks", [None] * n_flag)

if not (n_flag == len(flag_values) == len(flag_masks)):
raise ValueError(
"Not as many flag meanings as values or masks. "
"Please check the flag_meanings, flag_values, flag_masks attributes "
)

flag_params = tuple(
FlagParam(mask, value) for mask, value in zip(flag_masks, flag_values)
)
return dict(zip(flag_meanings, flag_params))

def _assert_valid_other_comparison(
self, other: Hashable
) -> Mapping[Hashable, FlagParam]:
if other not in self.flag_dict:
raise ValueError(
f"Did not find flag value meaning [{other}] in known flag meanings: [{self.flag_dict.keys()!r}]"
)
if self.flag_dict[other].flag_mask is not None:
raise NotImplementedError(
"Only equals and not-equals comparisons with flag masks are supported."
" Please open an issue."
)
return flag_dict
return self.flag_dict

def __eq__(self, other) -> DataArray: # type: ignore[override]
"""
Expand Down Expand Up @@ -1320,15 +1326,13 @@ def isin(self, test_elements) -> DataArray:
raise ValueError(
".cf.isin is only supported on DataArrays that contain CF flag attributes."
)
# TODO cache this property
flag_dict = create_flag_dict(self._obj)
mapped_test_elements = []
for elem in test_elements:
if elem not in flag_dict:
if elem not in self.flag_dict:
raise ValueError(
f"Did not find flag value meaning [{elem}] in known flag meanings: [{flag_dict.keys()!r}]"
f"Did not find flag value meaning [{elem}] in known flag meanings: [{self.flag_dict.keys()!r}]"
)
mapped_test_elements.append(flag_dict[elem].flag_value)
mapped_test_elements.append(self.flag_dict[elem].flag_value)
return self._obj.isin(mapped_test_elements)

def _drop_missing_variables(self, variables: list[Hashable]) -> list[Hashable]:
Expand Down Expand Up @@ -2828,6 +2832,11 @@ def decode_vertical_coords(self, *, outnames=None, prefix=None):

@xr.register_dataarray_accessor("cf")
class CFDataArrayAccessor(CFAccessor):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._flags: Dataset | None = None

@property
def formula_terms(self) -> dict[str, str]: # numpydoc ignore=SS06
"""
Expand Down Expand Up @@ -2973,6 +2982,8 @@ def flags(self) -> Dataset:
"""
Dataset containing boolean masks of available flags.
"""
if self._flags is not None:
return self._flags
return self._extract_flags()

def _extract_flags(self, flags: Sequence[Hashable] | None = None) -> Dataset:
Expand All @@ -2982,48 +2993,47 @@ def _extract_flags(self, flags: Sequence[Hashable] | None = None) -> Dataset:
Parameters
----------
flags: Sequence[str]
Flags to extract. If empty (string or list), return all flags in
`flag_meanings`.
Flags to extract. If None, return all flags in `flag_meanings`.
"""
# TODO cache this property
flag_dict = create_flag_dict(self._obj)

flag_dict = self.flag_dict
if flags is None:
flags = tuple(flag_dict.keys())
flags = list(self.flag_dict.keys())
else:
for flag in flags:
if flag not in self.flag_dict:
raise ValueError(
f"Did not find flag meaning [{flag}] in known flag meanings:"
f" [{self.flag_dict.keys()!r}]"
)
flag_dict = {f: flag_dict[f] for f in flags}

# Check if we are in simplified cases
all_mutually_exclusive = any(f.flag_mask is None for f in flag_dict.values())
all_indep = any(f.flag_value is None for f in flag_dict.values())

out = {} # Output arrays

masks = [] # Bitmasks and values for asked flags
values = []
flags_reduced = [] # Flags left after removing mutually excl. flags
for flag in flags:
if flag not in flag_dict:
raise ValueError(
f"Did not find flag value meaning [{flag}] in known flag meanings:"
f" [{flag_dict.keys()!r}]"
)
mask, value = flag_dict[flag]
if mask is None:
out[flag] = self._obj == value
if all_mutually_exclusive:
for flag, params in flag_dict.items():
out[flag] = self._obj == params.flag_value
return Dataset(out)

# We cast both masks and flag variable as integers to make the
# bitwise comparison.
# TODO We could probably restrict the integer size
bit_mask = DataArray(
[f.flag_mask for f in flag_dict.values()], dims=["_mask"]
).astype("i")
x = self._obj.astype("i")

bit_comp = x & bit_mask

for i, (flag, params) in enumerate(flag_dict.items()):
bit = bit_comp.isel(_mask=i)
if all_indep:
out[flag] = bit.astype(bool)
else:
masks.append(mask)
values.append(value)
flags_reduced.append(flag)

if len(masks) > 0: # If independant masks are left
# We cast both masks and flag variable as integers to make the
# bitwise comparison. We could probably restrict the integer size
# but it's difficult to make it safely for mixed type flags.
bit_mask = DataArray(masks, dims=["_mask"]).astype("i")
x = self._obj.astype("i")
bit_comp = x & bit_mask

for i, (flag, value) in enumerate(zip(flags_reduced, values)):
bit = bit_comp.isel(_mask=i)
if value is not None:
out[flag] = bit == value
else:
out[flag] = bit.astype(bool)
out[flag] = bit == params.flag_value

return Dataset(out)

Expand Down
4 changes: 1 addition & 3 deletions cf_xarray/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,8 @@ def find_set_bits(mask, value, repeated_masks, bit_length):


def _format_flags(accessor, rich):
from .accessor import create_flag_dict

try:
flag_dict = create_flag_dict(accessor._obj)
flag_dict = accessor.flag_dict
except ValueError:
return _print_rows(
"Flag Meanings", ["Invalid Mapping. Check attributes."], rich
Expand Down
Loading