diff --git a/src/scippnexus/base.py b/src/scippnexus/base.py index b3dd218a..ae7142f7 100644 --- a/src/scippnexus/base.py +++ b/src/scippnexus/base.py @@ -41,45 +41,12 @@ def is_dataset(obj: H5Group | H5Dataset) -> bool: return hasattr(obj, 'shape') -_scipp_dtype = { - np.dtype('int8'): sc.DType.int32, - np.dtype('int16'): sc.DType.int32, - np.dtype('uint8'): sc.DType.int32, - np.dtype('uint16'): sc.DType.int32, - np.dtype('uint32'): sc.DType.int32, - np.dtype('uint64'): sc.DType.int64, - np.dtype('int32'): sc.DType.int32, - np.dtype('int64'): sc.DType.int64, - np.dtype('float32'): sc.DType.float32, - np.dtype('float64'): sc.DType.float64, - np.dtype('bool'): sc.DType.bool, -} - - -def _dtype_fromdataset(dataset: H5Dataset) -> sc.DType: - return _scipp_dtype.get(dataset.dtype, sc.DType.string) - - -def _squeezed_field_sizes(dataset: H5Dataset) -> dict[str, int]: - if (shape := dataset.shape) == (1,): - return {} - return {f'dim_{i}': size for i, size in enumerate(shape)} - - class NXobject: - def _init_field(self, field: Field): - if field.sizes is None: - field.sizes = _squeezed_field_sizes(field.dataset) - field.dtype = _dtype_fromdataset(field.dataset) - def __init__(self, attrs: dict[str, Any], children: dict[str, Field | Group]): """Subclasses should call this in their __init__ method, or ensure that they initialize the fields in `children` with the correct sizes and dtypes.""" self._attrs = attrs self._children = children - for field in children.values(): - if isinstance(field, Field): - self._init_field(field) @property def unit(self) -> None | sc.Unit: @@ -222,7 +189,7 @@ def nx_class(self) -> type | None: return NXroot @cached_property - def attrs(self) -> dict[str, Any]: + def attrs(self) -> MappingProxyType[str, Any]: """The attributes of the group. Cannot be used for writing attributes, since they are cached for performance.""" @@ -479,6 +446,16 @@ def dims(self) -> tuple[str, ...]: def shape(self) -> tuple[int, ...]: return tuple(self.sizes.values()) + @property + def definitions(self) -> MappingProxyType[str, str | type] | None: + return ( + None if self._definitions is None else MappingProxyType(self._definitions) + ) + + @property + def underlying(self) -> H5Group: + return self._group + def _create_field_params_numpy(data: np.ndarray): return data, None, {} diff --git a/src/scippnexus/field.py b/src/scippnexus/field.py index 395d2280..d0e65d6d 100644 --- a/src/scippnexus/field.py +++ b/src/scippnexus/field.py @@ -81,6 +81,31 @@ def _as_datetime(obj: Any): return None +_scipp_dtype = { + np.dtype('int8'): sc.DType.int32, + np.dtype('int16'): sc.DType.int32, + np.dtype('uint8'): sc.DType.int32, + np.dtype('uint16'): sc.DType.int32, + np.dtype('uint32'): sc.DType.int32, + np.dtype('uint64'): sc.DType.int64, + np.dtype('int32'): sc.DType.int32, + np.dtype('int64'): sc.DType.int64, + np.dtype('float32'): sc.DType.float32, + np.dtype('float64'): sc.DType.float64, + np.dtype('bool'): sc.DType.bool, +} + + +def _dtype_fromdataset(dataset: H5Dataset) -> sc.DType: + return _scipp_dtype.get(dataset.dtype, sc.DType.string) + + +def _squeezed_field_sizes(dataset: H5Dataset) -> dict[str, int]: + if (shape := dataset.shape) == (1,): + return {} + return {f'dim_{i}': size for i, size in enumerate(shape)} + + @dataclass class Field: """NeXus field. @@ -93,6 +118,12 @@ class Field: dtype: sc.DType | None = None errors: H5Dataset | None = None + def __post_init__(self) -> None: + if self.sizes is None: + self.sizes = _squeezed_field_sizes(self.dataset) + if self.dtype is None: + self.dtype = _dtype_fromdataset(self.dataset) + @cached_property def attrs(self) -> dict[str, Any]: """The attributes of the dataset. diff --git a/src/scippnexus/nxdata.py b/src/scippnexus/nxdata.py index 37352483..5d97e621 100644 --- a/src/scippnexus/nxdata.py +++ b/src/scippnexus/nxdata.py @@ -481,7 +481,6 @@ def __init__(self, attrs: dict[str, Any], children: dict[str, Field | Group]): for k in list(children): if k.startswith(name): field = children.pop(k) - self._init_field(field) field.sizes = { 'time' if i == 0 else f'dim_{i}': size for i, size in enumerate(field.dataset.shape) diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index ac18deb4..ff0cc6a4 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -35,15 +35,24 @@ from __future__ import annotations +import posixpath import warnings +from collections.abc import Mapping from dataclasses import dataclass, field, replace from typing import Literal +import h5py import numpy as np import scipp as sc from scipp.scipy import interpolate -from .base import Group, NexusStructureError, NXobject, base_definitions_dict +from .base import ( + Group, + NexusStructureError, + NXobject, + base_definitions_dict, + is_dataset, +) from .field import DependsOn, Field @@ -266,17 +275,45 @@ def compute(self) -> sc.Variable | sc.DataArray: return transform +def _locate_depends_on_target( + file: h5py.File, + depends_on: DependsOn, + definitions: Mapping[str, type] | None, +) -> tuple[Field | Group, str]: + """Find the target of a depends_on link. + + The returned object is equivalent to calling ``parent[depends_on]`` + in the context of transformations. + This function does not work in general because it does not process any attributes + of parents which is required to fully load some groups. + """ + target_path = depends_on.absolute_path() + target = file[target_path] + + if is_dataset(target): + res = Field( + target, + parent=Group(target.parent, definitions=definitions), + ) + else: + res = Group(target, definitions=definitions) + return res, posixpath.dirname(target_path) + + def parse_depends_on_chain( parent: Field | Group, depends_on: DependsOn ) -> TransformationChain | None: """Follow a depends_on chain and return the transformations.""" chain = TransformationChain(depends_on.parent, depends_on.value) - depends_on = depends_on.value + # Use raw h5py objects to follow the chain because that avoids constructing + # expensive intermediate snx.Group objects. + file = parent.underlying.file try: - while depends_on != '.': - transform = parent[depends_on] - parent = transform.parent - depends_on = transform.attrs['depends_on'] + while depends_on.value != '.': + transform, base = _locate_depends_on_target( + file, depends_on, parent.definitions + ) + depends_on = DependsOn(parent=base, value=transform.attrs['depends_on']) chain.transformations[transform.name] = transform[()] except KeyError as e: warnings.warn( diff --git a/src/scippnexus/typing.py b/src/scippnexus/typing.py index 339071c7..75286cab 100644 --- a/src/scippnexus/typing.py +++ b/src/scippnexus/typing.py @@ -17,7 +17,7 @@ def name(self) -> str: """Name of dataset or group""" @property - def file(self) -> list[int]: + def file(self) -> Any: """File of dataset or group""" @property