Skip to content

Commit

Permalink
Merge pull request #167 from scipp/save-variances
Browse files Browse the repository at this point in the history
Save errors in create_field
  • Loading branch information
jl-wynen authored Oct 10, 2023
2 parents 54642fe + 55d6ff3 commit d007479
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 15 deletions.
17 changes: 17 additions & 0 deletions docs/about/release-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,23 @@ Release Notes
Deprecations
~~~~~~~~~~~~
vrelease
--------

Features
~~~~~~~~

Breaking changes
~~~~~~~~~~~~~~~~

Bugfixes
~~~~~~~~

* Save errors when writing variables using ``create_field`` or ``Group.__setitem__`` `#167 <https://github.com/scipp/scippnexus/pull/167>`_.

Deprecations
~~~~~~~~~~~~

v23.08.0
--------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __write_to_nexus_group__(self, group: H5Group):
)

coord.attrs[self._variances] = 'Q_errors'
create_field(group, 'Q_errors', sc.stddevs(da.coords['Q']))
# The errors are written automatically by create_field.


class _SASdata(NXdata):
Expand Down
58 changes: 44 additions & 14 deletions src/scippnexus/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _make_child(obj: Union[H5Dataset, H5Group]) -> Union[Field, Group]:
if (
isinstance(values, Field)
and isinstance(errors, Field)
and values.unit == errors.unit
and (values.unit == errors.unit or errors.unit is None)
and values.dataset.shape == errors.dataset.shape
):
values.errors = errors.dataset
Expand Down Expand Up @@ -390,7 +390,14 @@ def create_field(self, key: str, value: sc.Variable) -> H5Dataset:
"""Create a child dataset with given name and value.
Note that due to the caching mechanisms in this class, reading the group
or its children may not reflect the changes made by this method."""
or its children may not reflect the changes made by this method.
Returns
-------
:
The created dataset of the values.
If errors are written to the file, their dataset is not returned.
"""
return create_field(self._group, key, value)

def create_class(self, name: str, class_name: str) -> Group:
Expand Down Expand Up @@ -424,23 +431,46 @@ def shape(self) -> Tuple[int, ...]:
return tuple(self.sizes.values())


def _create_field_params_numpy(data: np.ndarray):
return data, None, {}


def _create_field_params_string(data: sc.Variable):
return np.array(data.values, dtype=object), None, {}


def _create_field_params_datetime(data: sc.Variable):
start = sc.epoch(unit=data.unit)
return (data - start).values, None, {'start': str(start.value)}


def _create_field_params_number(data: sc.Variable):
errors = sc.stddevs(data).values if data.variances is not None else None
return data.values, errors, {}


def create_field(
group: H5Group, name: str, data: Union[np.ndarray, sc.Variable], **kwargs
) -> H5Dataset:
if not isinstance(data, sc.Variable):
return group.create_dataset(name, data=data, **kwargs)
values = data.values
if data.dtype == sc.DType.string:
values = np.array(data.values, dtype=object)
values, errors, attrs = _create_field_params_numpy(data)
elif data.dtype == sc.DType.string:
values, errors, attrs = _create_field_params_string(data)
elif data.dtype == sc.DType.datetime64:
start = sc.epoch(unit=data.unit)
values = (data - start).values
dataset = group.create_dataset(name, data=values, **kwargs)
if data.unit is not None:
dataset.attrs['units'] = str(data.unit)
if data.dtype == sc.DType.datetime64:
dataset.attrs['start'] = str(start.value)
return dataset
values, errors, attrs = _create_field_params_datetime(data)
else:
values, errors, attrs = _create_field_params_number(data)

if isinstance(data, sc.Variable) and data.unit:
attrs['units'] = str(data.unit)

values_dataset = group.create_dataset(name, data=values, **kwargs)
values_dataset.attrs.update(attrs)
if errors is not None:
errors_dataset = group.create_dataset(name + '_errors', data=errors, **kwargs)
errors_dataset.attrs.update(attrs)

return values_dataset


def create_class(group: H5Group, name: str, nx_class: Union[str, type]) -> H5Group:
Expand Down
31 changes: 31 additions & 0 deletions tests/nexus_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,24 @@ def test_errors_read_as_variances(h5root):
assert np.array_equal(dg['time'].variances, np.arange(5.0) ** 2)


def test_does_not_require_unit_of_errors(h5root):
entry = h5root.create_group('entry')
data = entry.create_group('data')
data['signal'] = np.arange(4.0)
data['signal'].attrs['units'] = 'm'
data['signal_errors'] = np.arange(4.0)
# no units on signal_errors
data['time'] = np.arange(5.0)
data['time'].attrs['units'] = 's'
data['time_errors'] = np.arange(5.0)
# no units on time_errors
obj = snx.Group(data)
assert set(obj._children.keys()) == {'signal', 'time'}
dg = obj[()]
assert dg['signal'].unit == 'm'
assert dg['time'].unit == 's'


def test_read_field(h5root):
entry = h5root.create_group('entry')
data = entry.create_group('data')
Expand Down Expand Up @@ -520,3 +538,16 @@ def test_nxdata_with_bin_edges_positional_indexing_returns_correct_slice(h5root)
obj = snx.Group(data, definitions=snx.base_definitions())
da = obj['temperature', 0:2]
assert sc.identical(da, ref['temperature', 0:2])


def test_create_field_saves_errors(nxroot):
entry = nxroot['entry']
data = sc.array(
dims=['d0'], values=[1.2, 3.4, 5.6], variances=[0.9, 0.8, 0.7], unit='cm'
)
entry.create_field('signal', data)

loaded = entry['signal'][()]
# Use allclose instead of identical because the variances are stored as stddevs
# which loses precision.
assert sc.allclose(loaded, data.rename_dims(d0='dim_0'))

0 comments on commit d007479

Please sign in to comment.