Skip to content

Commit

Permalink
Be compatible with strax >= 2 and straxen >= 3 (#96)
Browse files Browse the repository at this point in the history
* Update dependencies

* Be compatible with strax >= 2 and straxen >= 3

* Use master branch of straxen
  • Loading branch information
dachengx authored Dec 28, 2024
1 parent 2999264 commit e3ba229
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 27 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ jobs:
python -m pip install --upgrade pip
python -m pip install pytest coverage coveralls
pip install git+https://github.com/XENONnT/[email protected]_wimp_unblind --force-reinstall
pip install git+https://github.com/AxFoundation/strax.git@master --force-reinstall
pip install git+https://github.com/XENONnT/straxen.git@master --force-reinstall
- name: Start MongoDB
uses: supercharge/[email protected]
Expand Down
18 changes: 11 additions & 7 deletions axidence/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import strax
from strax import LoopPlugin, CutPlugin, CutList
import straxen
from straxen import EventBasics, EventInfoDouble
from straxen import EventBasicsSOM, EventInfoDouble

from axidence import RunMeta, EventsSalting, PeaksSalted
from axidence import (
Expand All @@ -16,7 +16,7 @@
)
from axidence import (
EventsSalted,
EventBasicsSalted,
EventBasicsSOMSalted,
EventShadowSalted,
EventAmbienceSalted,
EventNearestTriggeringSalted,
Expand All @@ -43,7 +43,7 @@


default_assign_attributes = {
EventBasics: ["peak_properties", "posrec_save"],
EventBasicsSOM: ["peak_properties", "posrec_save"],
EventInfoDouble: ["input_dtype"],
}

Expand Down Expand Up @@ -203,15 +203,19 @@ def infer_dtype(self):

if not issubclass(plugin, LoopPlugin):

def _fix_output(self, result, start, end, _dtype=None):
def _fix_output(self, result, start, end, superrun, subruns, _dtype=None):
if self.multi_output and _dtype is None:
result = keys_attach_suffix(result, self.suffix)
return {
d: super(plugin, self)._fix_output(result[d], start, end, _dtype=d)
d: super(plugin, self)._fix_output(
result[d], start, end, superrun, subruns, _dtype=d
)
for d in self.provides
}
else:
return super()._fix_output(result, start, end, _dtype=_dtype)
return super()._fix_output(
result, start, end, superrun, subruns, _dtype=_dtype
)

def do_compute(self, chunk_i=None, **kwargs):
return super().do_compute(
Expand Down Expand Up @@ -300,7 +304,7 @@ def _salt_to_context(self):
PeakNearestTriggeringSalted,
PeakSEScoreSalted,
EventsSalted,
EventBasicsSalted,
EventBasicsSOMSalted,
EventShadowSalted,
EventAmbienceSalted,
EventNearestTriggeringSalted,
Expand Down
9 changes: 3 additions & 6 deletions axidence/dtypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import strax
import straxen
from straxen.plugins.peaklets.peaklet_classification_som import som_additional_fields

from straxen.misc import kind_colors

Expand All @@ -21,9 +22,7 @@


def peak_positions_dtype():
st = strax.Context(
config=straxen.contexts.xnt_common_config, **straxen.contexts.xnt_common_opts
)
st = strax.Context(config=straxen.contexts.common_config, **straxen.contexts.common_opts)
data_name = "peak_positions"
PeakPositionsNT0 = st._get_plugins((data_name,), "0")[data_name]
return strax.unpack_dtype(PeakPositionsNT0.dtype)
Expand Down Expand Up @@ -68,9 +67,7 @@ def positioned_peak_dtype(n_channels=straxen.n_tpc_pmts):
f"{direction}_area",
]

peak_misc_fields = [
"center_time",
"area_fraction_top",
peak_misc_fields = list(strax.to_numpy_dtype(som_additional_fields).names) + [
"n_competing",
"n_competing_left",
]
Expand Down
4 changes: 2 additions & 2 deletions axidence/plugins/pairing/events_paired.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import strax
from strax import OverlapWindowPlugin
import straxen
from straxen import Events, EventBasics
from straxen import Events

from ...utils import copy_dtype

Expand Down Expand Up @@ -133,7 +133,7 @@ def compute(self, events_paired, peaks_paired):
result = np.zeros(len(events_paired), dtype=self.dtype)

# assign the additional fields
EventBasics.set_nan_defaults(result)
strax.set_nan_defaults(result)

# assign the features already in EventInfo
for q in self.deps["event_info_paired"].dtype_for("event_info_paired").names:
Expand Down
8 changes: 4 additions & 4 deletions axidence/plugins/salting/event_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import strax
from strax import ExhaustPlugin
import straxen
from straxen import Events, EventBasics
from straxen import Events, EventBasicsSOM

from ...utils import needed_dtype, merge_salted_real

Expand Down Expand Up @@ -139,7 +139,7 @@ def compute(self, peaks_salted, peaks, start, end):
return result


class EventBasicsSalted(EventBasics, ExhaustPlugin):
class EventBasicsSOMSalted(EventBasicsSOM, ExhaustPlugin):
__version__ = "0.1.0"
child_plugin = True
depends_on: Tuple[str, ...] = (
Expand Down Expand Up @@ -191,14 +191,14 @@ def compute(self, events_salted, peaks_salted, peaks):
_, index, counts = np.unique(events_salted["time"], return_index=True, return_counts=True)

_result = np.zeros(len(index), dtype=self.dtype)
self.set_nan_defaults(_result)
strax.set_nan_defaults(_result)

split_peaks = strax.split_by_containment(_peaks, events_salted[index])

_result["time"] = events_salted["time"][index]
_result["endtime"] = events_salted["endtime"][index]

self.fill_events(_result, events_salted[index], split_peaks)
self.fill_events(_result, split_peaks)

for i in [1, 2]:
if np.all(_result[f"s{i}_salt_number"] < 0):
Expand Down
4 changes: 2 additions & 2 deletions axidence/plugins/salting/events_salting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import strax
from strax import ExhaustPlugin, DownChunkingPlugin
import straxen
from straxen import units, EventBasics, EventPositions
from straxen import units, EventBasicsSOM, EventPositions

from ...utils import copy_dtype
from ...samplers import SAMPLERS


class EventsSalting(ExhaustPlugin, DownChunkingPlugin, EventPositions, EventBasics):
class EventsSalting(ExhaustPlugin, DownChunkingPlugin, EventPositions, EventBasicsSOM):
__version__ = "0.0.2"
child_plugin = True
depends_on = "run_meta"
Expand Down
4 changes: 2 additions & 2 deletions axidence/plugins/salting/peaks_salted.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import numpy as np
import strax
import straxen
from straxen import PeakBasics
from straxen import PeakBasicsSOM

from ...utils import copy_dtype


class PeaksSalted(PeakBasics):
class PeaksSalted(PeakBasicsSOM):
__version__ = "0.0.1"
child_plugin = True
depends_on = "events_salting"
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ documentation = "https://readthedocs.org/projects/axidence/"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
strax = ">=1.6.3,<=1.6.5"
straxen = ">=2.2.6,<=2.2.7"
strax = ">=2.0.2"
straxen = { git = "https://github.com/XENONnT/straxen.git@master"}
GOFevaluation = ">=0.1.5"

[build-system]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pairing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_pairing(self):
subrun_ids = [self.run_id]
data_type = "event_basics"
self.st.make(self.run_id, data_type, save=data_type)
meta = self.st.get_meta(self.run_id, data_type)
meta = self.st.get_metadata(self.run_id, data_type)
self.st.storage[0] = strax.DataDirectory(self.st.storage[0].path, provide_run_metadata=True)
_write_run_doc(
self.st,
Expand All @@ -42,7 +42,7 @@ def test_pairing(self):
meta["end"],
)
self.st.define_run(hyperrun_name, subrun_ids)
self.st.check_hyperrun()
self.st.check_superrun()
plugins = [
"peaks_paired",
"event_info_paired",
Expand Down

0 comments on commit e3ba229

Please sign in to comment.