Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise depends_on chain resolution by using h5py #253

Merged
merged 4 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 11 additions & 34 deletions src/scippnexus/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,45 +41,12 @@ def is_dataset(obj: H5Group | H5Dataset) -> bool:
return hasattr(obj, 'shape')


_scipp_dtype = {
np.dtype('int8'): sc.DType.int32,
np.dtype('int16'): sc.DType.int32,
np.dtype('uint8'): sc.DType.int32,
np.dtype('uint16'): sc.DType.int32,
np.dtype('uint32'): sc.DType.int32,
np.dtype('uint64'): sc.DType.int64,
np.dtype('int32'): sc.DType.int32,
np.dtype('int64'): sc.DType.int64,
np.dtype('float32'): sc.DType.float32,
np.dtype('float64'): sc.DType.float64,
np.dtype('bool'): sc.DType.bool,
}


def _dtype_fromdataset(dataset: H5Dataset) -> sc.DType:
return _scipp_dtype.get(dataset.dtype, sc.DType.string)


def _squeezed_field_sizes(dataset: H5Dataset) -> dict[str, int]:
if (shape := dataset.shape) == (1,):
return {}
return {f'dim_{i}': size for i, size in enumerate(shape)}


class NXobject:
def _init_field(self, field: Field):
if field.sizes is None:
field.sizes = _squeezed_field_sizes(field.dataset)
field.dtype = _dtype_fromdataset(field.dataset)

def __init__(self, attrs: dict[str, Any], children: dict[str, Field | Group]):
"""Subclasses should call this in their __init__ method, or ensure that they
initialize the fields in `children` with the correct sizes and dtypes."""
Comment on lines 46 to 47
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring is outdated now, at least in parts?

self._attrs = attrs
self._children = children
for field in children.values():
if isinstance(field, Field):
self._init_field(field)

@property
def unit(self) -> None | sc.Unit:
Expand Down Expand Up @@ -222,7 +189,7 @@ def nx_class(self) -> type | None:
return NXroot

@cached_property
def attrs(self) -> dict[str, Any]:
def attrs(self) -> MappingProxyType[str, Any]:
"""The attributes of the group.

Cannot be used for writing attributes, since they are cached for performance."""
Expand Down Expand Up @@ -479,6 +446,16 @@ def dims(self) -> tuple[str, ...]:
def shape(self) -> tuple[int, ...]:
return tuple(self.sizes.values())

@property
def definitions(self) -> MappingProxyType[str, str | type] | None:
return (
None if self._definitions is None else MappingProxyType(self._definitions)
)

@property
def underlying(self) -> H5Group:
return self._group


def _create_field_params_numpy(data: np.ndarray):
return data, None, {}
Expand Down
31 changes: 31 additions & 0 deletions src/scippnexus/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,31 @@ def _as_datetime(obj: Any):
return None


_scipp_dtype = {
np.dtype('int8'): sc.DType.int32,
np.dtype('int16'): sc.DType.int32,
np.dtype('uint8'): sc.DType.int32,
np.dtype('uint16'): sc.DType.int32,
np.dtype('uint32'): sc.DType.int32,
np.dtype('uint64'): sc.DType.int64,
np.dtype('int32'): sc.DType.int32,
np.dtype('int64'): sc.DType.int64,
np.dtype('float32'): sc.DType.float32,
np.dtype('float64'): sc.DType.float64,
np.dtype('bool'): sc.DType.bool,
}


def _dtype_fromdataset(dataset: H5Dataset) -> sc.DType:
return _scipp_dtype.get(dataset.dtype, sc.DType.string)


def _squeezed_field_sizes(dataset: H5Dataset) -> dict[str, int]:
if (shape := dataset.shape) == (1,):
return {}
return {f'dim_{i}': size for i, size in enumerate(shape)}


@dataclass
class Field:
"""NeXus field.
Expand All @@ -93,6 +118,12 @@ class Field:
dtype: sc.DType | None = None
errors: H5Dataset | None = None

def __post_init__(self) -> None:
if self.sizes is None:
self.sizes = _squeezed_field_sizes(self.dataset)
if self.dtype is None:
self.dtype = _dtype_fromdataset(self.dataset)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check that this is sound!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feared that there might be some code that relies on field.sizes being None (e.g., to do custom init logic only if dims are not set yet) before calling NXobject.__init__, but I could not find any (looking mainly at nxdata.py). So I think it should be ok?

@cached_property
def attrs(self) -> dict[str, Any]:
"""The attributes of the dataset.
Expand Down
1 change: 0 additions & 1 deletion src/scippnexus/nxdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,6 @@ def __init__(self, attrs: dict[str, Any], children: dict[str, Field | Group]):
for k in list(children):
if k.startswith(name):
field = children.pop(k)
self._init_field(field)
field.sizes = {
'time' if i == 0 else f'dim_{i}': size
for i, size in enumerate(field.dataset.shape)
Expand Down
49 changes: 43 additions & 6 deletions src/scippnexus/nxtransformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,24 @@

from __future__ import annotations

import posixpath
import warnings
from collections.abc import Mapping
from dataclasses import dataclass, field, replace
from typing import Literal

import h5py
import numpy as np
import scipp as sc
from scipp.scipy import interpolate

from .base import Group, NexusStructureError, NXobject, base_definitions_dict
from .base import (
Group,
NexusStructureError,
NXobject,
base_definitions_dict,
is_dataset,
)
from .field import DependsOn, Field


Expand Down Expand Up @@ -266,17 +275,45 @@ def compute(self) -> sc.Variable | sc.DataArray:
return transform


def _locate_depends_on_target(
file: h5py.File,
depends_on: DependsOn,
definitions: Mapping[str, type] | None,
) -> tuple[Field | Group, str]:
"""Find the target of a depends_on link.

The returned object is equivalent to calling ``parent[depends_on]``
in the context of transformations.
This function does not work in general because it does not process any attributes
of parents which is required to fully load some groups.
"""
target_path = depends_on.absolute_path()
target = file[target_path]

if is_dataset(target):
res = Field(
target,
parent=Group(target.parent, definitions=definitions),
)
else:
res = Group(target, definitions=definitions)
return res, posixpath.dirname(target_path)


def parse_depends_on_chain(
parent: Field | Group, depends_on: DependsOn
) -> TransformationChain | None:
"""Follow a depends_on chain and return the transformations."""
chain = TransformationChain(depends_on.parent, depends_on.value)
depends_on = depends_on.value
# Use raw h5py objects to follow the chain because that avoids constructing
# expensive intermediate snx.Group objects.
file = parent.underlying.file
try:
while depends_on != '.':
transform = parent[depends_on]
parent = transform.parent
depends_on = transform.attrs['depends_on']
while depends_on.value != '.':
transform, base = _locate_depends_on_target(
file, depends_on, parent.definitions
)
depends_on = DependsOn(parent=base, value=transform.attrs['depends_on'])
chain.transformations[transform.name] = transform[()]
except KeyError as e:
warnings.warn(
Expand Down
2 changes: 1 addition & 1 deletion src/scippnexus/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def name(self) -> str:
"""Name of dataset or group"""

@property
def file(self) -> list[int]:
def file(self) -> Any:
"""File of dataset or group"""

@property
Expand Down