-
Notifications
You must be signed in to change notification settings - Fork 129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Nanoevents + jax looks for _mass2_kernel in the wrong spot #874
Comments
This looks like a behavior dispatch issue in awkward. Can you try boiling this down to not include coffea? |
I have another example in scikit-hep/awkward#2591 which was (partially) fixed by scikit-hep/awkward#2592 (now only hits Is this qualitatively similar? If so I guess I'd have to interpolate between that and the nanoevents behavior by simplifying it in coffea until it works? |
I'd import some things with uproot directly, build the array you want, and then apply the behaviors from coffea by hand. You don't have to use the full machinery of nanoevents to replicate this, for sure. |
Here's another reproducer that only applies import awkward as ak
from coffea.nanoevents.methods import candidate
import numpy as np
import uproot
ak.jax.register_and_check()
ak.behavior.update(candidate.behavior)
ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
"raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"
with uproot.open(ttbar_file) as f:
arr = f["Events"].arrays(["Electron_pt", "Electron_eta", "Electron_phi",
"Electron_mass", "Electron_charge"])
px = arr.Electron_pt * np.cos(arr.Electron_phi)
py = arr.Electron_pt * np.sin(arr.Electron_phi)
pz = arr.Electron_pt * np.sinh(arr.Electron_eta)
E = np.sqrt(arr.Electron_mass**2 + px**2 + py**2 + pz**2)
evtfilter = ak.num(arr["Electron_pt"]) >= 2
els = ak.zip({"pt": arr.Electron_pt, "eta": arr.Electron_eta, "phi": arr.Electron_phi,
"energy": E, "charge": arr.Electron_charge}, with_name="PtEtaPhiECandidate")[evtfilter]
els = ak.to_backend(els, "jax")
(els[:, 0] + els[:, 1]).mass which results in the same problem. Sounds like an awkward issue then perhaps? I'll open one there. This does work with |
I believe the underlying issue is the clash of numba and jax as described in scikit-hep/awkward#2603 (comment). The setup in the previous comment can be patched with this snippet: from coffea.nanoevents.methods import candidate, vector
def _mass2_kernel(t, x, y, z):
return t * t - x * x - y * y - z * z
class PatchedLorentzVector(vector.LorentzVector):
@property
def mass2(self):
"""Squared `mass`"""
return _mass2_kernel(self.t, self.x, self.y, self.z)
candidate.Candidate.__bases__ = (PatchedLorentzVector,) It is not a good idea to remove numba for everyone just to make this work. Most of the time people are presumably not using jax, so this is not really something that can be merged. |
I tried running the code sample on the Edit: this should be solved by scikit-hep/awkward#3025. |
This can be closed now. The issue has been resolved on the main branch of awkward: In [1]: import awkward as ak
...: from coffea.nanoevents.methods import candidate
...: import numpy as np
...: import uproot
...:
...: ak.jax.register_and_check()
...: ak.behavior.update(candidate.behavior)
...:
...: ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
...: "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"
...:
...: with uproot.open(ttbar_file) as f:
...: arr = f["Events"].arrays(["Electron_pt", "Electron_eta", "Electron_phi",
...: "Electron_mass", "Electron_charge"])
...:
...: px = arr.Electron_pt * np.cos(arr.Electron_phi)
...: py = arr.Electron_pt * np.sin(arr.Electron_phi)
...: pz = arr.Electron_pt * np.sinh(arr.Electron_eta)
...: E = np.sqrt(arr.Electron_mass**2 + px**2 + py**2 + pz**2)
...:
...:
...: evtfilter = ak.num(arr["Electron_pt"]) >= 2
...:
...: els = ak.zip({"pt": arr.Electron_pt, "eta": arr.Electron_eta, "phi": arr.Electron_phi,
...: "energy": E, "charge": arr.Electron_charge}, with_name="PtEtaPhiECandidate")[ev
...: tfilter]
...: els = ak.to_backend(els, "jax")
...:
...: (els[:, 0] + els[:, 1]).mass
Out[1]: <Array [86.903534, 97.60412, ..., 62.408997, 50.49058] type='5 * float32'> |
@Saransh-cpp lemme know the version of awkward this will correspond to and we will close it with a pin adjustment to coffea. |
Great! I can create a PR once the fix is out in a release. |
@alexander-held can you test this with coffea main to see if your issue is resolved? |
I'm trying to remind myself of the setup here. I can confirm that the example in #874 (comment) now works with latest coffea + awkward releases. The example in the original issue at the top runs into a |
I think this is happening because In [1]: import awkward as ak
...: from coffea.nanoevents import NanoEventsFactory, NanoAODSchema
...:
...: ak.jax.register_and_check()
...: NanoAODSchema.warn_missing_crossrefs = False # silences warnings about
...: branches we will not use here
...:
...: ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
...: "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.
...: root"
...:
...: events = NanoEventsFactory.from_root({ttbar_file: "Events"}, schemaclas
...: s=NanoAODSchema).events()
...: events = ak.to_backend(events.compute(), "jax") # compute is required for switching backends
...:
...: evtfilter = ak.to_backend(ak.num(events.Jet.pt) >= 2, "jax") # backend call is needed here!
...: jets = events.Jet[evtfilter]
...:
...: (jets[:, 0] + jets[:, 1]).mass
Out[1]: <Array [157.21956, 81.92088, ..., 32.363174, 223.94753] type='140 * float32'> |
I see, mixing Dask + Jax is something that we don't support at the moment as far as I'm aware so that makes sense that it would not work. Then from my side we can close this as fixed, thank you! |
@lgray a gentle bump on closing this 🙂 |
Describe the bug
When combining nanoevents + jax, invariant mass calculations no longer work.
It is not very clear to me if this is rather an issue in
awkward
or elsewhere, but_mass2_kernel
comes fromcoffea
so I figured I'd start here.To Reproduce
A full example is at https://github.com/alexander-held/agc-autodiff/blob/9bcad94a689063b130829bd33fff12e17dd43c36/nanoevents_plus_jax.ipynb. This is using
coffea== 2023.7.0rc0
.Expected behavior
Mass calculation succeeds (just like when not using
jax
).Output
Desktop (please complete the following information):
n/a
Additional context
n/a
The text was updated successfully, but these errors were encountered: