Skip to content

Commit

Permalink
Refactor nv plugins (#744)
Browse files Browse the repository at this point in the history
* Remove unneeded properties.

* Added simplified summed waveform

* Add flatpart for non zero bins

* Remove outdated test

* Add test

* Support empty events

* Add new line

* Add assert
  • Loading branch information
WenzDaniel authored Aug 27, 2024
1 parent 8e189a7 commit 6d96bb1
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 112 deletions.
21 changes: 0 additions & 21 deletions strax/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,31 +124,10 @@ def hitlet_dtype():
(("Total hit area in pe", "area"), np.float32),
(("Maximum of the PMT pulse in pe/sample", "amplitude"), np.float32),
(('Position of the Amplitude in ns (minus "time")', "time_amplitude"), np.int16),
(("Hit entropy", "entropy"), np.float32),
(("Width (in ns) of the central 50% area of the hitlet", "range_50p_area"), np.float32),
(("Width (in ns) of the central 80% area of the hitlet", "range_80p_area"), np.float32),
(("Position of the 25% area decile [ns]", "left_area"), np.float32),
(("Position of the 10% area decile [ns]", "low_left_area"), np.float32),
(
(
"Width (in ns) of the highest density region covering a 50% area of the hitlet",
"range_hdr_50p_area",
),
np.float32,
),
(
(
"Width (in ns) of the highest density region covering a 80% area of the hitlet",
"range_hdr_80p_area",
),
np.float32,
),
(("Left edge of the 50% highest density region [ns]", "left_hdr"), np.float32),
(("Left edge of the 80% highest density region [ns]", "low_left_hdr"), np.float32),
(("FWHM of the PMT pulse [ns]", "fwhm"), np.float32),
(('Left edge of the FWHM [ns] (minus "time")', "left"), np.float32),
(("FWTM of the PMT pulse [ns]", "fwtm"), np.float32),
(('Left edge of the FWTM [ns] (minus "time")', "low_left"), np.float32),
]
return dtype

Expand Down
26 changes: 0 additions & 26 deletions strax/processing/hitlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,19 +297,6 @@ def hitlet_properties(hitlets):
h["amplitude"] = height
h["time_amplitude"] = amp_time

# Computing FWHM:
left_edge, right_edge = get_fwxm(h, 0.5)
width = right_edge - left_edge

# Computing FWTM:
left_edge_low, right_edge = get_fwxm(h, 0.1)
width_low = right_edge - left_edge_low

h["fwhm"] = width
h["left"] = left_edge
h["low_left"] = left_edge_low
h["fwtm"] = width_low

# Compute area deciles & width:
if not h["area"] == 0:
# Due to noise total area can sum up to zero
Expand All @@ -323,19 +310,6 @@ def hitlet_properties(hitlets):
h["range_50p_area"] = res[2] - res[1]
h["range_80p_area"] = res[3] - res[0]

# Compute width based on HDR:
resh = highest_density_region_width(
data,
fractions_desired=np.array([0.5, 0.8]),
dt=h["dt"],
fractionl_edges=True,
)

h["left_hdr"] = resh[0, 0]
h["low_left_hdr"] = resh[1, 0]
h["range_hdr_50p_area"] = resh[0, 1] - resh[0, 0]
h["range_hdr_80p_area"] = resh[1, 1] - resh[1, 0]


@export
@numba.njit(cache=True, nogil=True)
Expand Down
50 changes: 50 additions & 0 deletions strax/processing/peak_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,56 @@ def store_downsampled_waveform(
p["data"][: p["length"]] = wv_buffer[: p["length"]]


@export
def simple_summed_waveform(records, containers, to_pe):
"""Computes simple (downsampled) summed waveform based on raw data touching a certain container.
:param container: Things for which summed waveform should be
computed. Must contain data field of desired length.
:param records: Record infromation which should be used to compute
summed waveform.
Note: To keep this function simple the floating part of the baseline
is only added if the data field in records is not zero.
This will lead to a biased representation of the summed waveform!
However, this bias is small for shape estimates, but the total
charge of the signal should be estimated in an unbiased way.
"""
if not len(containers):
return

assert np.all(records["dt"] != 0), "Records dt is not allowed to be zero"
assert np.all(containers["dt"] != 0), "Containers dt is not allowed to be zero"

touching_windows = strax.touching_windows(records, containers)
_simple_summed_waveform(records, containers, touching_windows, to_pe)


@numba.njit
def _simple_summed_waveform(records, containers, touching_windows, to_pe):
summed_wf_buffer = np.zeros(2 * containers["length"].max(), np.float32)
for (tw_s, tw_e), container in zip(touching_windows, containers):
records_in_wf = records[tw_s:tw_e]

for r in records_in_wf:
(r_start, r_end), (c_start, c_end) = strax.overlap_indices(
r["time"] // r["dt"],
r["length"],
container["time"] // container["dt"],
container["length"],
)
bl_fpart = r["baseline"] % 1
_is_not_zero = r["data"][r_start:r_end] != 0
bl_fpart = _is_not_zero.astype(np.float32) * bl_fpart
summed_wf_buffer[c_start:c_end] += (r["data"][r_start:r_end] + bl_fpart) * to_pe[
r["channel"]
]

strax.store_downsampled_waveform(container, summed_wf_buffer)
summed_wf_buffer[:] = 0


@export
@numba.jit(nopython=True, nogil=True, cache=True)
def sum_waveform(
Expand Down
64 changes: 0 additions & 64 deletions tests/test_hitlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,70 +363,6 @@ def test_hitlet_properties(hits_n_data):
assert pos_max == h["time_amplitude"], "Wrong amplitude position found!"
assert d[pos_max] == h["amplitude"], "Wrong amplitude value found!"

# Checking FHWM and FWTM:
fractions = [0.1, 0.5]
for f in fractions:
# Get field names for the correct test:
if f == 0.5:
left = "left"
fwxm = "fwhm"
else:
left = "low_left"
fwxm = "fwtm"

amplitude = np.max(d)
if np.all(d[0] == d) or np.all(d > amplitude * f):
# If all samples are either the same or greater than
# required height FWXM is not defined:
mes = "All samples are the same or larger than require height."
assert np.isnan(h[left]), mes + f" Left edge for {f} should have been np.nan."
assert np.isnan(h[left]), mes + f" FWXM for X={f} should have been np.nan."
else:
le = np.argwhere(d[:pos_max] <= amplitude * f)
if len(le):
le = le[-1, 0]
m = d[le + 1] - d[le]
le = le + 0.5 + (amplitude * f - d[le]) / m
else:
le = 0

re = np.argwhere(d[pos_max:] <= amplitude * f)

if len(re) and re[0, 0] != 0:
re = re[0, 0] + pos_max
m = d[re] - d[re - 1]
re = re + 0.5 + (amplitude * f - d[re]) / m
else:
re = len(d)

assert math.isclose(
le, h[left], rel_tol=10**-4, abs_tol=10**-4
), f"Left edge does not match for fraction {f}"
assert math.isclose(
re - le, h[fwxm], rel_tol=10**-4, abs_tol=10**-4
), f"FWHM does not match for {f}"


def test_not_defined_get_fhwm():
# This is a specific unity test for some edge-cases in which the full
# width half maximum is not defined.
odd_hitlets = np.zeros(4, dtype=strax.hitlet_with_data_dtype(10))
odd_hitlets[0]["data"][:5] = [2, 2, 3, 2, 2]
odd_hitlets[0]["length"] = 5
odd_hitlets[1]["data"][:2] = [5, 5]
odd_hitlets[1]["length"] = 2
odd_hitlets[2]["length"] = 3
odd_hitlets[3]["data"][:3] = [-1, -2, 0]
odd_hitlets[3]["length"] = 3

for oh in odd_hitlets:
res = strax.get_fwxm(oh)
mes = (
f'get_fxhm returned {res} for {oh["data"][:oh["length"]]}!'
"However, the FWHM is not defined and the return should be nan!"
)
assert np.all(np.isnan(res)), mes


# ------------------------
# Entropy test
Expand Down
40 changes: 39 additions & 1 deletion tests/test_peak_processing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from strax.testutils import fake_hits, several_fake_records
from strax.testutils import fake_hits, several_fake_records, sorted_intervals
import numpy as np
from hypothesis import given, settings, example
import hypothesis.strategies as st
Expand Down Expand Up @@ -308,3 +308,41 @@ def retrun_1(x):
assert len(peaklets) <= len(r)
# Integer overflow will manifest itself here again:
assert np.all(peaklets["dt"] > 0)


@settings(deadline=None)
@given(sorted_intervals)
def test_simple_summed_waveform(pulses):
fake_event_dtype = strax.time_dt_fields + [
("data", np.float32, 200),
("data_top", np.float32, 200),
]

records = np.zeros(len(pulses), dtype=strax.record_dtype())
records["time"] = pulses["time"]
records["length"] = pulses["length"]
records["dt"] = pulses["dt"]
records["data"] = 1

if len(pulses):
fake_event = np.zeros(1, dtype=fake_event_dtype)
fake_event["time"] = records[0]["time"]
fake_event["length"] = records[-1]["time"] + records[-1]["length"] - records[0]["time"]
fake_event["dt"] = records["dt"][0]
else:
fake_event = np.zeros(0, dtype=fake_event_dtype)

strax.simple_summed_waveform(records, fake_event, np.ones(2000))
assert fake_event["data"].sum() == records["length"].sum(), "Event has wrong total area."
msg = "Summed waveform has incorrect shape."
assert _test_simple_summed_waveform_has_correct_pattern, msg # type: ignore


def _test_simple_summed_waveform_has_correct_pattern(fake_event, fake_records):
"""Test if summed wavefrom has correct structure."""
buffer = np.zeros(len(fake_event["data"]))

for r in fake_records:
indicies = np.arange(r["time"], r["time"] + r["length"], np.int64)
buffer[indicies] += 1
return np.all(buffer == fake_event["data"])

0 comments on commit 6d96bb1

Please sign in to comment.