Skip to content

Commit

Permalink
Try to make Spectrum Collection concatenation more efficient
Browse files Browse the repository at this point in the history
The .from_spectra() constructor has a lot of overhead as it checks the
types and numerical agreement of axes; if we are comparing Collections
this should only be checked once.

If this tweak works well it should move upstream to euphonic. Here we
refer to a private method on the 1D collection from the 2D collection:
really this can live on a parent class.
  • Loading branch information
ajjackson committed May 24, 2024
1 parent b3e6c90 commit c2e4ae9
Showing 1 changed file with 50 additions and 4 deletions.
54 changes: 50 additions & 4 deletions scripts/abins/sdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def check_thresholds(
return warning_cases


XTickLabels = Sequence[Tuple[int, str]]
LineData = Sequence[Dict[str, Union[str, int]]]
Metadata = Dict[str, Union[str, int, LineData]]


class AbinsSpectrum1DCollection(Spectrum1DCollection):
"""Minor patch to euphonic Spectrum1DCollection, to be moved upstream"""

Expand Down Expand Up @@ -123,10 +128,41 @@ def _get_line_data_vals(self, *line_data_keys: str) -> List[tuple]:
line_data_vals.append(tuple([all_metadata[key] for key in line_data_keys]))
return line_data_vals

def __add__(self, other: Self) -> Self:
"""
Appends the y_data of 2 Spectrum1DCollection objects,
creating a single Spectrum1DCollection that contains
the spectra from both objects. The two objects must
have equal x_data axes, and their y_data must
have compatible units and the same number of y_data
entries
Any metadata key/value pairs that are common to both
spectra are retained in the top level dictionary, any
others are put in the individual 'line_data' entries
"""
assert np.allclose(self.x_data.magnitude, other.x_data.magnitude)
assert self.x_data_unit == other.x_data_unit

XTickLabels = Sequence[Tuple[int, str]]
LineData = Sequence[Dict[str, Union[str, int]]]
Metadata = Dict[str, Union[str, int, LineData]]
return type(self)(
x_data=self.x_data,
y_data=np.concatenate((self.y_data, other.y_data)),
metadata=self._concatenate_metadata(self.metadata, other.metadata),
)

@staticmethod
def _concatenate_metadata(a: Metadata, b: Metadata) -> Metadata:
"""
Common top-level key-value pairs are retained at top level, while
differing top-level key-value pairs are added to line_data
"""
common_items = {key: value for (key, value) in a.items() if key != "line_data" and b.get(key) == value}
a_only_items = {key: a[key] for key in a if key != "line_data" and key not in common_items}
b_only_items = {key: b[key] for key in b if key != "line_data" and key not in common_items}

line_data = [entry | a_only_items for entry in a["line_data"]] + [entry | b_only_items for entry in b["line_data"]]

return common_items | {"line_data": line_data}


class AbinsSpectrum2DCollection(collections.abc.Sequence, Spectrum):
Expand Down Expand Up @@ -238,7 +274,17 @@ def __add__(self, other: Self) -> Self:
retained in the top level dictionary, any others are put in the
individual 'line_data' entries
"""
return type(self).from_spectra([*self, *other])
assert np.allclose(self.x_data.magnitude, other.x_data.magnitude)
assert np.allclose(self.y_data.magnitude, other.y_data.magnitude)
assert self.x_data_unit == other.x_data_unit
assert self.y_data_unit == other.y_data_unit

return type(self)(
x_data=self.x_data,
y_data=self.y_data,
z_data=np.concatenate((self.z_data, other.z_data)),
metadata=AbinsSpectrum1DCollection._concatenate_metadata(self.metadata, other.metadata),
)

def __len__(self):
return self.z_data.shape[0]
Expand Down

0 comments on commit c2e4ae9

Please sign in to comment.