-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add compute_mapped
#170
Add compute_mapped
#170
Changes from 12 commits
5348538
6680c5c
2041c3f
07b5e07
c7f95e2
e2c2ff2
eaefb27
194beb5
12442de
0df6300
303bfb7
1997df4
4ae04f7
9d4bd36
151fd4c
c88585b
f6d8058
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,5 @@ | |
graphviz | ||
jsonschema | ||
numpy | ||
pandas | ||
pytest |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you consider how you could integrate this functionality into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something like that, yes. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,10 @@ | |
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) | ||
from __future__ import annotations | ||
|
||
from collections.abc import Callable, Iterable | ||
from collections.abc import Callable, Hashable, Iterable, Sequence | ||
from itertools import chain | ||
from types import UnionType | ||
from typing import Any, TypeVar, get_args, get_type_hints, overload | ||
from typing import TYPE_CHECKING, Any, TypeVar, get_args, get_type_hints, overload | ||
|
||
from ._provider import Provider, ToProvider | ||
from .data_graph import DataGraph, to_task_graph | ||
|
@@ -15,6 +15,11 @@ | |
from .task_graph import TaskGraph | ||
from .typing import Key | ||
|
||
if TYPE_CHECKING: | ||
import graphviz | ||
import pandas | ||
|
||
|
||
T = TypeVar('T') | ||
KeyType = TypeVar('KeyType', bound=Key) | ||
|
||
|
@@ -84,7 +89,7 @@ def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any: | |
""" | ||
return self.get(tp, **kwargs).compute() | ||
|
||
def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821 | ||
def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digraph: | ||
""" | ||
Return a graphviz Digraph object representing the graph for the given keys. | ||
|
||
|
@@ -194,3 +199,93 @@ def bind_and_call( | |
def _repr_html_(self) -> str: | ||
nodes = ((key, data) for key, data in self._graph.nodes.items()) | ||
return pipeline_html_repr(nodes) | ||
|
||
|
||
def get_mapped_node_names( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 'name' or 'key'? What is the difference? |
||
graph: Pipeline, key: type, index_names: Sequence[Hashable] | None = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it a graph or a pipeline? We should be consistent with names. |
||
) -> pandas.Series: | ||
""" | ||
Given a graph with key depending on mapped nodes, return a series corresponding | ||
mapped keys. | ||
|
||
This is meant to be used in combination with :py:func:`Pipeline.map`. | ||
If the key depends on multiple indices, the series will be a multi-index series. | ||
SimonHeybrock marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Note that Pandas is not a dependency of Sciline and must be installed separately. | ||
|
||
Parameters | ||
---------- | ||
graph: | ||
The pipeline to get the mapped key names from. | ||
key: | ||
The key to get the mapped key names for. This key must depend on mapped nodes. | ||
index_names: | ||
Specifies the names of the indices of the mapped node. If not given this is | ||
inferred from the graph, but the argument may be required to disambiguate | ||
multiple mapped nodes with the same name. | ||
|
||
Returns | ||
------- | ||
: | ||
The series of mapped key names. | ||
""" | ||
import pandas as pd | ||
from cyclebane.graph import IndexValues, MappedNode, NodeName | ||
|
||
candidates = [ | ||
node | ||
for node in graph._cbgraph.graph.nodes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like that this uses a protected attribute of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is definitely something that we need to consider for a lot of other functionality (in particular around planner graph operations) that we intend to implement. Either we need to make those properties accessible on the public interface, or potentially add a lot more methods to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know what operations you have in mind that need direct access to the graph. But I would say that we should expose a minimal set of primitive operations that can be composed in, e.g., methods of workflow classes. |
||
if isinstance(node, MappedNode) and node.name == key | ||
] | ||
if index_names is not None: | ||
candidates = [ | ||
node for node in candidates if set(node.indices) == set(index_names) | ||
] | ||
if len(candidates) == 0: | ||
raise ValueError(f"'{key}' is not a mapped node.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this check be before filtering by index names? It seems that |
||
if len(candidates) > 1: | ||
raise ValueError(f"Multiple mapped nodes with name '{key}' found: {candidates}") | ||
# Drops unrelated indices | ||
graph = graph[candidates[0]] # type: ignore[index] | ||
indices = graph._cbgraph.indices | ||
if index_names is not None: | ||
indices = {name: indices[name] for name in indices if name in index_names} | ||
index_names = tuple(indices) | ||
|
||
index = pd.MultiIndex.from_product(indices.values(), names=index_names) | ||
keys = tuple(NodeName(key, IndexValues(index_names, idx)) for idx in index) | ||
if index.nlevels == 1: # Avoid more complicated MultiIndex if unnecessary | ||
index = index.get_level_values(0) | ||
return pd.Series(keys, index=index, name=key) | ||
|
||
|
||
def compute_mapped( | ||
graph: Pipeline, key: type, index_names: Sequence[Hashable] | None = None | ||
) -> pandas.Series: | ||
""" | ||
Given a graph with key depending on mapped nodes, compute a series for the key. | ||
|
||
This is meant to be used in combination with :py:func:`Pipeline.map`. | ||
If the key depends on multiple indices, the series will be a multi-index series. | ||
|
||
Note that Pandas is not a dependency of Sciline and must be installed separately. | ||
|
||
Parameters | ||
---------- | ||
graph: | ||
The pipeline to compute the series from. | ||
key: | ||
The key to compute the series for. This key must depend on mapped nodes. | ||
index_names: | ||
Specifies the names of the indices of the mapped node. If not given this is | ||
inferred from the graph, but the argument may be required to disambiguate | ||
multiple mapped nodes with the same name. | ||
|
||
Returns | ||
------- | ||
: | ||
The computed series. | ||
""" | ||
key_series = get_mapped_node_names(graph=graph, key=key, index_names=index_names) | ||
results = graph.compute(key_series) | ||
return key_series.apply(lambda x: results[x]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This link and the link below should be relative links.