Skip to content

Commit

Permalink
FEAT: implement quantum problem set filter (#287)
Browse files Browse the repository at this point in the history
* FEAT: implement `dict_set_intersection()`
* FEAT: implement `filter_quantum_number_problem_set()`
  • Loading branch information
grayson-helmholz authored Nov 6, 2024
1 parent c156ea3 commit 846d4ad
Show file tree
Hide file tree
Showing 6 changed files with 303 additions and 3 deletions.
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def pick_newtype_attrs(some_type: type) -> list:
add_module_names = False
api_github_repo = f"{ORGANIZATION}/{REPO_NAME}"
api_target_substitutions: dict[str, str | tuple[str, str]] = {
"EdgeQuantumNumberTypes": ("obj", "qrules.quantum_numbers.EdgeQuantumNumberTypes"),
"EdgeType": "typing.TypeVar",
"GraphEdgePropertyMap": ("obj", "qrules.argument_handling.GraphEdgePropertyMap"),
"GraphElementProperties": ("obj", "qrules.solving.GraphElementProperties"),
Expand All @@ -56,11 +57,13 @@ def pick_newtype_attrs(some_type: type) -> list:
"NewEdgeType": "typing.TypeVar",
"NewNodeType": "typing.TypeVar",
"NodeQuantumNumber": ("obj", "qrules.quantum_numbers.NodeQuantumNumber"),
"NodeQuantumNumberTypes": ("obj", "qrules.quantum_numbers.NodeQuantumNumberTypes"),
"NodeType": "typing.TypeVar",
"ParticleWithSpin": ("obj", "qrules.particle.ParticleWithSpin"),
"Path": "pathlib.Path",
"qrules.topology.EdgeType": "typing.TypeVar",
"qrules.topology.NodeType": "typing.TypeVar",
"Rule": ("obj", "qrules.argument_handling.Rule"),
"SpinFormalism": ("obj", "qrules.transition.SpinFormalism"),
"StateDefinition": ("obj", "qrules.combinatorics.StateDefinition"),
"StateTransition": ("obj", "qrules.transition.StateTransition"),
Expand Down
119 changes: 118 additions & 1 deletion docs/usage/visualize.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
":::{warning}\n",
"Currently the main user-interface is the ```StateTransitionManager```. There is work in progress to remove it and split its functionality into several functions/classes to separate concerns\n",
"and to facilitate the modification of intermediate results like the filtering of ```QNProblemSet```s, setting allowed interaction types, etc. (see below)\n",
":::"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -103,7 +113,18 @@
"from IPython.display import display\n",
"\n",
"import qrules\n",
"from qrules.conservation_rules import (\n",
" parity_conservation,\n",
" spin_magnitude_conservation,\n",
" spin_validity,\n",
")\n",
"from qrules.particle import Spin\n",
"from qrules.quantum_numbers import EdgeQuantumNumbers, NodeQuantumNumbers\n",
"from qrules.solving import (\n",
" CSPSolver,\n",
" dict_set_intersection,\n",
" filter_quantum_number_problem_set,\n",
")\n",
"from qrules.topology import create_isobar_topologies, create_n_body_topology\n",
"from qrules.transition import State"
]
Expand Down Expand Up @@ -315,6 +336,102 @@
"graphviz.Source(dot)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Filtering quantum number problem sets"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Sometimes, only a certain subset of quantum numbers and conservation rules are relevant, or the number of solutions the {class}`.StateTransitionManager` gives by default is too large for the follow-up analysis.\n",
"The {func}`.filter_quantum_number_problem_set` function can be used to produce a {class}`.QNProblemSet` where only the desired quantum numbers and conservation rules are considered when fed back to the solver."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"desired_edge_properties = {EdgeQuantumNumbers.spin_magnitude, EdgeQuantumNumbers.parity}\n",
"desired_node_properties = {\n",
" NodeQuantumNumbers.l_magnitude,\n",
" NodeQuantumNumbers.s_magnitude,\n",
"} # has to be reused in the CSPSolver-constructor\n",
"filtered_qn_problem_set = filter_quantum_number_problem_set(\n",
" qn_problem_set,\n",
" edge_rules={spin_validity},\n",
" node_rules={spin_magnitude_conservation, parity_conservation},\n",
" edge_properties=desired_edge_properties,\n",
" node_properties=desired_node_properties,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [
"hide-output"
]
},
"outputs": [],
"source": [
"dot = qrules.io.asdot(filtered_qn_problem_set, render_node=True)\n",
"graphviz.Source(dot)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
":::{warning}\n",
"The next cell will use some (currently) internal functionality. As statet at the top, a workflow similar to this will be used in future versions of ```qrules```. Manual setup of the {obj}`.CSPSolver` like in here will then also not be necessary.\n",
":::"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"solver = CSPSolver([\n",
" dict_set_intersection(\n",
" qrules.system_control.create_edge_properties(part),\n",
" desired_edge_properties,\n",
" )\n",
" for part in qrules.particle.load_pdg()\n",
"])\n",
"\n",
"filtered_qn_solutions = solver.find_solutions(filtered_qn_problem_set)\n",
"filtered_qn_result = filtered_qn_solutions.solutions[6]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"source_hidden": true
},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"dot = qrules.io.asdot(filtered_qn_result, render_node=True)\n",
"graphviz.Source(dot)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -672,7 +789,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.9.20"
}
},
"nbformat": 4,
Expand Down
1 change: 1 addition & 0 deletions src/qrules/argument_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Scalar = Union[int, float]

Rule = Union[GraphElementRule, EdgeQNConservationRule, ConservationRule]
"""Any type of rule"""

_ElementType = TypeVar("_ElementType")

Expand Down
5 changes: 3 additions & 2 deletions src/qrules/quantum_numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ class EdgeQuantumNumbers:
edge_qn_type.__module__ = __name__


# for static typing
EdgeQuantumNumber = Union[
EdgeQuantumNumbers.pid,
EdgeQuantumNumbers.mass,
Expand All @@ -126,8 +125,8 @@ class EdgeQuantumNumbers:
EdgeQuantumNumbers.c_parity,
EdgeQuantumNumbers.g_parity,
]
"""Type hint for quantum numbers of edges"""

# for accessing the keys of the dicts in EdgeSettings
EdgeQuantumNumberTypes = Union[
type[EdgeQuantumNumbers.pid],
type[EdgeQuantumNumbers.mass],
Expand All @@ -149,6 +148,7 @@ class EdgeQuantumNumbers:
type[EdgeQuantumNumbers.c_parity],
type[EdgeQuantumNumbers.g_parity],
]
"""Type-Union for accessing the keys of the dicts in `.EdgeSettings`"""


@frozen(init=False)
Expand Down Expand Up @@ -186,6 +186,7 @@ class NodeQuantumNumbers:
type[NodeQuantumNumbers.s_projection],
type[NodeQuantumNumbers.parity_prefactor],
]
"""Type-Union for accessing the keys of the dicts in `.NodeSettings`"""


def _to_optional_float(optional_float: float | None) -> float | None:
Expand Down
87 changes: 87 additions & 0 deletions src/qrules/solving.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,86 @@ def topology(self) -> Topology:
return self.initial_facts.topology


def filter_quantum_number_problem_set(
quantum_number_problem_set: QNProblemSet,
edge_rules: set[GraphElementRule],
node_rules: set[Rule],
edge_properties: Iterable[EdgeQuantumNumberTypes],
node_properties: Iterable[NodeQuantumNumberTypes],
) -> QNProblemSet:
"""Filter `QNProblemSet` for desired conservation rules, settings and domains.
Currently it is the responsibility of the user to provide fitting properties
and domains for the correspinding conservation rules.
Args:
quantum_number_problem_set: `QNProblemSet` as generated by `CSPSolver`.
edge_rules: Conservation rules regarding the edges.
node_rules: Conservation rules regarding the nodes.
edge_properties: Edge settings, properties and domains.
node_properties: Node settings, properties and domains.
"""
old_edge_settings = quantum_number_problem_set.solving_settings.states
old_node_settings = quantum_number_problem_set.solving_settings.interactions
old_edge_properties = quantum_number_problem_set.initial_facts.states
old_node_properties = quantum_number_problem_set.initial_facts.interactions
new_edge_settings = {
edge_id: EdgeSettings(
conservation_rules=edge_rules,
rule_priorities=edge_setting.rule_priorities,
qn_domains=({
key: val
for key, val in edge_setting.qn_domains.items()
if key in set(edge_properties)
}),
)
for edge_id, edge_setting in old_edge_settings.items()
}
new_node_settings = {
node_id: NodeSettings(
conservation_rules=node_rules,
rule_priorities=node_setting.rule_priorities,
qn_domains=({
key: val
for key, val in node_setting.qn_domains.items()
if key in set(node_properties)
}),
)
for node_id, node_setting in old_node_settings.items()
}
new_combined_settings = MutableTransition(
topology=quantum_number_problem_set.solving_settings.topology,
states=new_edge_settings,
interactions=new_node_settings,
)
new_edge_properties = {
edge_id: {
edge_quantum_number: scalar
for edge_quantum_number, scalar in graph_edge_property_map.items()
if edge_quantum_number in edge_properties
}
for edge_id, graph_edge_property_map in old_edge_properties.items()
}
new_node_properties = {
node_id: {
node_quantum_number: scalar
for node_quantum_number, scalar in graph_node_property_map.items()
if node_quantum_number in node_properties
}
for node_id, graph_node_property_map in old_node_properties.items()
}
new_combined_properties = MutableTransition(
topology=quantum_number_problem_set.initial_facts.topology,
states=new_edge_properties,
interactions=new_node_properties,
)
return attrs.evolve(
quantum_number_problem_set,
solving_settings=new_combined_settings,
initial_facts=new_combined_properties,
)


QuantumNumberSolution = MutableTransition[GraphEdgePropertyMap, GraphNodePropertyMap]


Expand Down Expand Up @@ -747,6 +827,13 @@ def __convert_solution_keys(
return converted_solutions


def dict_set_intersection(
base_dict: dict[Any, Any],
set_of_keys: set[Any],
) -> dict[Any, Any]:
return {key: value for key, value in base_dict.items() if key in set_of_keys}


class Scoresheet:
def __init__(self) -> None:
self.__rule_calls: dict[tuple[int, Rule], int] = {}
Expand Down
Loading

0 comments on commit 846d4ad

Please sign in to comment.