Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: support QNProblemSet in asdot() #313

Merged
merged 19 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
"nrows",
"nsimplify",
"pandoc",
"pbar",
"permalinks",
"phsp",
"pids",
Expand Down
7 changes: 6 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@
"python.analysis.autoImportCompletions": false,
"python.analysis.inlayHints.pytestParameters": true,
"python.terminal.activateEnvironment": false,
"python.testing.pytestArgs": ["--color=no", "--no-cov"],
"python.testing.pytestArgs": [
"--color=no",
"--no-cov",
"--verbose",
"--verbose"
],
"python.testing.pytestEnabled": true,
"python.testing.unittestEnabled": false,
"redhat.telemetry.enabled": false,
Expand Down
9 changes: 6 additions & 3 deletions docs/usage/visualize.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [
"hide-cell"
]
Expand Down Expand Up @@ -379,7 +382,7 @@
"source_hidden": true
},
"tags": [
"hide-output"
"hide-input"
]
},
"outputs": [],
Expand All @@ -393,7 +396,7 @@
"metadata": {},
"source": [
":::{warning}\n",
"The next cell will use some (currently) internal functionality. As statet at the top, a workflow similar to this will be used in future versions of ```qrules```. Manual setup of the {obj}`.CSPSolver` like in here will then also not be necessary.\n",
"The next cell will use some (currently) internal functionality. As stated at the top, a workflow similar to this will be used in future versions of {mod}`qrules`, see e.g. [ComPWA/qrules#305](https://github.com/ComPWA/qrules/issues/305). Manual setup of the {obj}`.CSPSolver` like in here will then also not be necessary.\n",
":::"
]
},
Expand Down Expand Up @@ -789,7 +792,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.20"
"version": "3.12.8"
}
},
"nbformat": 4,
Expand Down
47 changes: 39 additions & 8 deletions src/qrules/io/_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from fractions import Fraction
from functools import singledispatch
from inspect import isfunction
from types import NoneType
from typing import TYPE_CHECKING, Any, cast

import attrs
Expand Down Expand Up @@ -303,7 +304,12 @@ def as_string(obj: Any) -> str:
return str(obj)


as_string.register(str, lambda _: _) # avoid warning for str type
@as_string.register(NoneType)
@as_string.register(int)
@as_string.register(float)
@as_string.register(str)
def _(obj: Any) -> str:
return str(obj)


@as_string.register(dict)
Expand All @@ -314,13 +320,21 @@ def _(obj: dict) -> str:
key_repr = key.__name__
else:
key_repr = key
if value != 0 or any(s in key_repr for s in ["magnitude", "projection"]):
pm = not any(s in key_repr for s in ["pid", "mass", "width", "magnitude"])
value_repr = _render_fraction(value, pm)
lines.append(f"{key_repr} = {value_repr}")
if not value and not key_repr.endswith(("magnitude", "projection")):
continue
value_repr = __render_key_and_value(key_repr, value)
lines.append(f"{key_repr} = {value_repr}")
return "\n".join(lines)


def __render_key_and_value(key: str, value: Any) -> str:
if isinstance(value, (Fraction, int)):
fraction = Fraction(value)
no_pm = key.endswith("magnitude") or key == "pid"
return _render_fraction(fraction, plusminus=not no_pm)
return as_string(value)


@as_string.register(InteractionProperties)
def _(obj: InteractionProperties) -> str:
lines = []
Expand Down Expand Up @@ -358,7 +372,8 @@ def _(settings: EdgeSettings | NodeSettings) -> str:
if output:
output += "\n"
domains = sorted(
f"{qn.__name__} ∊ {domain}" for qn, domain in settings.qn_domains.items()
f"{qn.__name__} ∊ {__render_domain(domain, key=qn.__name__)}"
for qn, domain in settings.qn_domains.items()
)
output += "DOMAINS\n"
output += "\n".join(domains)
Expand Down Expand Up @@ -388,6 +403,22 @@ def __extract_priority(description: str) -> str:
return matches[1]


def __render_domain(domain: list[Any], key: str) -> str:
"""Render a domain as a `str`.

>>> half = Fraction(0.5)
>>> __render_domain([-half, +half], key="spin_projection")
'[-1/2, +1/2]'
>>> __render_domain([0, 1], key="l_magnitude")
'[0, 1]'
>>> __render_domain([None, +1, -1], key="parity")
'[-1, +1, None]'
"""
domain = sorted(domain, key=lambda x: +9999 if x is None else x)
domain_str = [__render_key_and_value(key, x) for x in domain]
return "[" + ", ".join(domain_str) + "]"


@as_string.register(Particle)
def _(particle: Particle) -> str:
return particle.name
Expand All @@ -410,10 +441,10 @@ def _state_to_str(state: State) -> str:
@as_string.register(tuple)
def _(obj: tuple) -> str:
if len(obj) == 2:
if isinstance(obj[0], Particle) and isinstance(obj[1], (float, int)):
if isinstance(obj[0], Particle) and isinstance(obj[1], (Fraction, float, int)):
state = State(*obj)
return _state_to_str(state)
if all(isinstance(o, (float, int)) for o in obj):
if all(isinstance(o, (Fraction, float, int)) for o in obj):
spin = Spin(*obj)
return _spin_to_str(spin)
return "\n".join(map(as_string, obj))
Expand Down
168 changes: 165 additions & 3 deletions tests/unit/io/test_dot.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
from fractions import Fraction
from textwrap import dedent

import pydot
import pytest

import qrules
from qrules import io
from qrules.io._dot import _collapse_graphs, _get_particle_graphs, _strip_projections
from qrules.io._dot import (
_collapse_graphs,
_get_particle_graphs,
_strip_projections,
as_string,
)
from qrules.particle import Particle, ParticleCollection
from qrules.settings import InteractionType
from qrules.solving import QNProblemSet, QNResult
from qrules.topology import (
Edge,
Topology,
create_isobar_topologies,
create_n_body_topology,
)
from qrules.transition import ReactionInfo, SpinFormalism
from qrules.transition import (
ProblemSet,
ReactionInfo,
SpinFormalism,
StateTransitionManager,
)


def test_asdot(reaction: ReactionInfo):
Expand Down Expand Up @@ -85,6 +100,12 @@ def test_asdot_graphviz_attrs(reaction: ReactionInfo):
assert "bgcolor=none" not in dot_data


def test_asdot_qn_problem_set(qn_problem_and_result: tuple[QNProblemSet, QNResult]):
qn_problem_set, _ = qn_problem_and_result
dot_data = qrules.io.asdot(qn_problem_set, render_node=True)
assert pydot.graph_from_dot_data(dot_data) is not None


def test_asdot_with_styled_edges_and_nodes(reaction: ReactionInfo, output_dir):
transition = reaction.transitions[0]
dot = io.asdot(
Expand Down Expand Up @@ -117,7 +138,7 @@ def test_asdot_no_label_overwriting(reaction: ReactionInfo):
["canonical", "canonical-helicity", "helicity"],
)
def test_asdot_problemset(formalism: SpinFormalism):
stm = qrules.StateTransitionManager(
stm = StateTransitionManager(
initial_state=[("J/psi(1S)", [+1])],
final_state=["gamma", "pi0", "pi0"],
formalism=formalism,
Expand Down Expand Up @@ -149,6 +170,120 @@ def test_asdot_topology():
assert pydot.graph_from_dot_data(dot_data) is not None


def test_as_string_dict(
problem_sets: dict[float, list[ProblemSet]],
qn_problem_and_result: tuple[QNProblemSet, QNResult],
):
_, qn_result = qn_problem_and_result
problem_set = problem_sets[3600.0][0]
interaction = qn_result.solutions[1].interactions[1]
intermediate_state, *_ = qn_result.solutions[0].intermediate_states.values()
node_setting = problem_set.solving_settings.interactions[0]
intermediate_setting, *_ = problem_set.solving_settings.intermediate_states.values()

dot = as_string(intermediate_setting).strip()
expected_dot = dedent("""
RULES
spin_validity - 62
isospin_validity - 61
gellmann_nishijima - 50
DOMAINS
baryon_number ∊ [-1, +1]
bottomness ∊ [0]
c_parity ∊ [None]
charge ∊ [-1, 0, +1]
charmness ∊ [0]
electron_lepton_number ∊ [0]
g_parity ∊ [None]
isospin_magnitude ∊ [1]
isospin_projection ∊ [-1, 0, +1]
muon_lepton_number ∊ [0]
parity ∊ [-1, +1]
spin_magnitude ∊ [1/2]
spin_projection ∊ [-4, -7/2, -3, -5/2, -2, -3/2, -1, -1/2, 0, +1/2, +1, +3/2, +2, +5/2, +3, +7/2, +4]
strangeness ∊ [-1, +1]
tau_lepton_number ∊ [0]
topness ∊ [0]
""").strip()
assert dot == expected_dot

dot = as_string(node_setting).strip()
expected_dot = dedent("""
RULES
clebsch_gordan_helicity_to_canonical - NA
BaryonNumberConservation - 90
ls_spin_validity - 89
spin_magnitude_conservation - 8
CharmConservation - 70
helicity_conservation - 7
StrangenessConservation - 69
BottomnessConservation - 68
isospin_conservation - 60
parity_conservation - 6
c_parity_conservation - 5
ElectronLNConservation - 45
MuonLNConservation - 44
TauLNConservation - 43
parity_conservation_helicity - 4
g_parity_conservation - 3
identical_particle_symmetrization - 2
ChargeConservation - 100
MassConservation - 10
DOMAINS
l_magnitude ∊ [0, 1]
l_projection ∊ [0]
parity_prefactor ∊ [-1, +1]
s_magnitude ∊ [0, 1/2, 1, 3/2, 2]
s_projection ∊ [-2, -3/2, -1, -1/2, 0, +1/2, +1, +3/2, +2]
""").strip()
assert dot == expected_dot

dot = as_string(interaction).strip()
expected_dot = dedent("""
l_magnitude = 0
s_magnitude = 1/2
l_projection = 0
s_projection = -1/2
parity_prefactor = +1
""").strip()
assert dot == expected_dot

dot = as_string(intermediate_state).strip()
expected_dot = dedent("""
spin_magnitude = 1/2
spin_projection = +1/2
parity = +1
isospin_magnitude = 1
isospin_projection = -1
baryon_number = -1
charge = -1
strangeness = +1
pid = -23222
mass = 1.75
width = 0.15
""").strip()
assert dot == expected_dot


def test_as_string_spin_tuple(particle_database: ParticleCollection):
# non-spin
src = as_string(("a", "b", "c"))
assert src == "a\nb\nc"
src = as_string(("a", "b"))
assert src == "a\nb"

# spin
src = as_string((2, 1))
assert src == "|2,+1⟩"

# particle with spin projection
pion = particle_database["J/psi(1S)"]
src = as_string((pion, 1))
assert src == "J/psi(1S)[+1]"
src = as_string((pion, Fraction(-1)))
assert src == "J/psi(1S)[-1]"


class TestWrite:
def test_write_topology(self, output_dir):
output_file = output_dir + "two_body_decay_topology.gv"
Expand Down Expand Up @@ -257,3 +392,30 @@ def test_strip_projections(skh_particle_version: str):
assert stripped_transition.interactions[0].l_projection is None
assert stripped_transition.interactions[1].s_projection is None
assert stripped_transition.interactions[1].l_projection is None


@pytest.fixture
def stm() -> StateTransitionManager:
stm = StateTransitionManager(
initial_state=[("J/psi(1S)", [+1])],
final_state=["K0", ("Sigma+", [+0.5]), ("p~", [+0.5])],
allowed_intermediate_particles=["Sigma(1750)"],
formalism="canonical-helicity",
)
stm.set_allowed_interaction_types([InteractionType.STRONG, InteractionType.EM])
return stm


@pytest.fixture
def problem_sets(stm: StateTransitionManager) -> dict[float, list[ProblemSet]]:
return stm.create_problem_sets()


@pytest.fixture
def qn_problem_and_result(
stm: StateTransitionManager,
problem_sets: dict[float, list[ProblemSet]],
) -> tuple[QNProblemSet, QNResult]:
qn_solutions = stm.find_quantum_number_transitions(problem_sets)
strong_qn_solutions = qn_solutions[3600.0]
return strong_qn_solutions[1]