diff --git a/src/scippnexus/nxtransformations.py b/src/scippnexus/nxtransformations.py index 5a4c9e23..0b18b93e 100644 --- a/src/scippnexus/nxtransformations.py +++ b/src/scippnexus/nxtransformations.py @@ -69,7 +69,12 @@ class Transform: value: sc.Variable | sc.DataArray | sc.DataGroup vector: sc.Variable depends_on: DependsOn - offset: sc.Variable | None + offset: sc.Variable | None = None + + @property + def sizes(self) -> dict[str, int]: + """Convenience property to access sizes of the value.""" + return self.value.sizes def __post_init__(self): if self.transformation_type not in ['translation', 'rotation']: diff --git a/tests/transform_test.py b/tests/transform_test.py new file mode 100644 index 00000000..dabb0007 --- /dev/null +++ b/tests/transform_test.py @@ -0,0 +1,38 @@ +import pytest +import scipp as sc + +from scippnexus import DependsOn +from scippnexus.nxtransformations import Transform, TransformationError + + +@pytest.fixture() +def depends_on() -> DependsOn: + return DependsOn(parent='/', value='.') + + +@pytest.fixture() +def z_vector() -> sc.Variable: + return sc.vector(value=[0, 0, 1], unit='m') + + +def test_init_raises_if_transformation_type_is_invalid(depends_on, z_vector) -> None: + with pytest.raises(TransformationError, match='transformation_type'): + Transform( + name='t1', + transformation_type='trans', + value=sc.ones(dims=['x', 'y', 'z'], shape=(2, 3, 4)), + vector=z_vector, + depends_on=depends_on, + ) + + +def test_sizes_returns_sizes_of_value(depends_on, z_vector) -> None: + value = sc.ones(dims=['x', 'y', 'z'], shape=(2, 3, 4)) + transform = Transform( + name='t1', + transformation_type='translation', + value=value, + vector=z_vector, + depends_on=depends_on, + ) + assert transform.sizes == {'x': 2, 'y': 3, 'z': 4}