diff --git a/dkist/conftest.py b/dkist/conftest.py index 6d5485d5..71641ef7 100644 --- a/dkist/conftest.py +++ b/dkist/conftest.py @@ -94,8 +94,7 @@ def dataset(array, identity_gwcs): assert ds.data is array assert ds.wcs is identity_gwcs - ds._array_container = DaskFITSArrayContainer([ExternalArrayReference('test1.fits', 0, 'float', (10, 10)), - ExternalArrayReference('test2.fits', 0, 'float', (10, 10))], + ds._array_container = DaskFITSArrayContainer([ExternalArrayReference('test1.fits', 0, 'float', array.shape)], loader=AstropyFITSLoader) return ds diff --git a/dkist/io/array_containers.py b/dkist/io/array_containers.py index b6d1ec05..db523f13 100644 --- a/dkist/io/array_containers.py +++ b/dkist/io/array_containers.py @@ -45,7 +45,10 @@ def __init__(self, reference_array, *, loader, **kwargs): if reference_shape[0] == 1: reference_shape = reference_shape[1:] - self.shape = tuple(list(reference_array.shape) + list(reference_shape)) + if len(reference_array) == 1: + self.shape = reference_shape + else: + self.shape = tuple(list(reference_array.shape) + list(reference_shape)) loader_array = np.empty_like(reference_array, dtype=object) for i, ele in enumerate(reference_array.flat):