Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compute_mapped #170

Merged
merged 17 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/api-reference/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
HandleAsComputeTimeException
```

## Top-level functions

```{eval-rst}
.. autosummary::
:toctree: ../generated/functions
:recursive:

compute_mapped
get_mapped_node_names
```

## Exceptions

```{eval-rst}
Expand Down
27 changes: 20 additions & 7 deletions docs/user-guide/parameter-tables.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we can compute `Result` for each index in the parameter table.\n",
"Currently there is no convenient way of accessing these, instead we manually define the target nodes to compute:"
"We can use the [compute_mapped](https://scipp.github.io/sciline/generated/functions/sciline.compute_mapped.html) function to compute `Result` for each index in the parameter table:"
]
},
{
Expand All @@ -156,10 +155,8 @@
"metadata": {},
"outputs": [],
"source": [
"from cyclebane.graph import NodeName, IndexValues\n",
"\n",
"targets = [NodeName(Result, IndexValues(('run_id',), (i,))) for i in run_ids]\n",
"pipeline.compute(targets)"
"results = sciline.compute_mapped(pipeline, Result)\n",
"pd.DataFrame(results) # DataFrame for HTML rendering"
]
},
{
Expand All @@ -168,7 +165,22 @@
"source": [
"Note the use of the `run_id` index.\n",
"If the index axis of the DataFrame has no name then a default of `dim_0`, `dim_1`, etc. is used.\n",
"We can also visualize the task graph for computing the series of `Result` values:"
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"**Note**\n",
"\n",
"[compute_mapped](https://scipp.github.io/sciline/generated/functions/sciline.compute_mapped.html) depends on Pandas, which is not a dependency of Sciline and must be installed separately, e.g., using pip:\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This link and the link below should be relative links.

"\n",
"```bash\n",
"pip install pandas\n",
"```\n",
"\n",
"</div>\n",
"\n",
"We can also visualize the task graph for computing the series of `Result` values.\n",
"For this, we need to get all the node names derived from `Result` via the `map` operation.\n",
"The [get_mapped_node_names](https://scipp.github.io/sciline/generated/functions/sciline.get_mapped_node_names.html) function can be used to get a `pandas.Series` of these node names, which we can then visualize:"
]
},
{
Expand All @@ -177,6 +189,7 @@
"metadata": {},
"outputs": [],
"source": [
"targets = sciline.get_mapped_node_names(pipeline, Result)\n",
"pipeline.visualize(targets)"
]
},
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# will not be touched by ``make_base.py``
# --- END OF CUSTOM SECTION ---
# The following was generated by 'tox -e deps', DO NOT EDIT MANUALLY!
cyclebane >= 24.06.0
cyclebane>=24.06.0
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SHA1:44b03b700447e95874de7cd4d8b11558c0c972e5
# SHA1:1b4246f703135629f3fb69e65829cfb99abca695
#
# This file is autogenerated by pip-compile-multi
# To update, run:
Expand Down
1 change: 1 addition & 0 deletions requirements/basetest.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
graphviz
jsonschema
numpy
pandas
pytest
18 changes: 15 additions & 3 deletions requirements/basetest.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SHA1:f7d11f6aab1600c37d922ffb857a414368b800cd
# SHA1:fa9ebd4f58fe57db20baa224ad46fac634f3f046
#
# This file is autogenerated by pip-compile-multi
# To update, run:
Expand All @@ -19,14 +19,22 @@ jsonschema==4.22.0
# via -r basetest.in
jsonschema-specifications==2023.12.1
# via jsonschema
numpy==1.26.4
# via -r basetest.in
numpy==2.0.0
# via
# -r basetest.in
# pandas
packaging==24.1
# via pytest
pandas==2.2.2
# via -r basetest.in
pluggy==1.5.0
# via pytest
pytest==8.2.2
# via -r basetest.in
python-dateutil==2.9.0.post0
# via pandas
pytz==2024.1
# via pandas
referencing==0.35.1
# via
# jsonschema
Expand All @@ -35,5 +43,9 @@ rpds-py==0.18.1
# via
# jsonschema
# referencing
six==1.16.0
# via python-dateutil
tomli==2.0.1
# via pytest
tzdata==2024.1
# via pandas
4 changes: 2 additions & 2 deletions requirements/ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ colorama==0.4.6
# via tox
distlib==0.3.8
# via virtualenv
filelock==3.14.0
filelock==3.15.1
# via
# tox
# virtualenv
Expand Down Expand Up @@ -50,7 +50,7 @@ tomli==2.0.1
# tox
tox==4.15.1
# via -r ci.in
urllib3==2.2.1
urllib3==2.2.2
# via requests
virtualenv==20.26.2
# via tox
2 changes: 1 addition & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ prometheus-client==0.20.0
# via jupyter-server
pycparser==2.22
# via cffi
pydantic==2.7.3
pydantic==2.7.4
# via copier
pydantic-core==2.18.4
# via pydantic
Expand Down
6 changes: 3 additions & 3 deletions requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ exceptiongroup==1.2.1
# via ipython
executing==2.0.1
# via stack-data
fastjsonschema==2.19.1
fastjsonschema==2.20.0
# via nbformat
graphviz==0.20.3
# via -r docs.in
Expand Down Expand Up @@ -120,7 +120,7 @@ nbsphinx==0.9.4
# via -r docs.in
nest-asyncio==1.6.0
# via ipykernel
numpy==1.26.4
numpy==2.0.0
# via pandas
packaging==24.1
# via
Expand Down Expand Up @@ -241,7 +241,7 @@ typing-extensions==4.12.2
# pydata-sphinx-theme
tzdata==2024.1
# via pandas
urllib3==2.2.1
urllib3==2.2.2
# via requests
wcwidth==0.2.13
# via prompt-toolkit
Expand Down
2 changes: 1 addition & 1 deletion requirements/static.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ cfgv==3.4.0
# via pre-commit
distlib==0.3.8
# via virtualenv
filelock==3.14.0
filelock==3.15.1
# via virtualenv
identify==2.5.36
# via pre-commit
Expand Down
2 changes: 1 addition & 1 deletion requirements/test-dask.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ click==8.1.7
# via dask
cloudpickle==3.0.0
# via dask
dask==2024.5.2
dask==2024.6.0
# via -r test-dask.in
fsspec==2024.6.0
# via dask
Expand Down
4 changes: 3 additions & 1 deletion src/sciline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
HandleAsComputeTimeException,
UnsatisfiedRequirement,
)
from .pipeline import Pipeline
from .pipeline import Pipeline, compute_mapped, get_mapped_node_names
from .task_graph import TaskGraph

__all__ = [
Expand All @@ -30,6 +30,8 @@
"UnsatisfiedRequirement",
"HandleAsBuildTimeException",
"HandleAsComputeTimeException",
"compute_mapped",
"get_mapped_node_names",
]

del importlib
101 changes: 98 additions & 3 deletions src/sciline/pipeline.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you consider how you could integrate this functionality into compute? E.g., check whether a given key is mapped and then do the equivalent of compute_series?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like that, yes.

Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from __future__ import annotations

from collections.abc import Callable, Iterable
from collections.abc import Callable, Hashable, Iterable, Sequence
from itertools import chain
from types import UnionType
from typing import Any, TypeVar, get_args, get_type_hints, overload
from typing import TYPE_CHECKING, Any, TypeVar, get_args, get_type_hints, overload

from ._provider import Provider, ToProvider
from .data_graph import DataGraph, to_task_graph
Expand All @@ -15,6 +15,11 @@
from .task_graph import TaskGraph
from .typing import Key

if TYPE_CHECKING:
import graphviz
import pandas


T = TypeVar('T')
KeyType = TypeVar('KeyType', bound=Key)

Expand Down Expand Up @@ -84,7 +89,7 @@ def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any:
"""
return self.get(tp, **kwargs).compute()

def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821
def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digraph:
"""
Return a graphviz Digraph object representing the graph for the given keys.

Expand Down Expand Up @@ -194,3 +199,93 @@ def bind_and_call(
def _repr_html_(self) -> str:
nodes = ((key, data) for key, data in self._graph.nodes.items())
return pipeline_html_repr(nodes)


def get_mapped_node_names(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'name' or 'key'? What is the difference?

graph: Pipeline, key: type, index_names: Sequence[Hashable] | None = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a graph or a pipeline? We should be consistent with names.

) -> pandas.Series:
"""
Given a graph with key depending on mapped nodes, return a series corresponding
mapped keys.

This is meant to be used in combination with :py:func:`Pipeline.map`.
If the key depends on multiple indices, the series will be a multi-index series.
SimonHeybrock marked this conversation as resolved.
Show resolved Hide resolved

Note that Pandas is not a dependency of Sciline and must be installed separately.

Parameters
----------
graph:
The pipeline to get the mapped key names from.
key:
The key to get the mapped key names for. This key must depend on mapped nodes.
index_names:
Specifies the names of the indices of the mapped node. If not given this is
inferred from the graph, but the argument may be required to disambiguate
multiple mapped nodes with the same name.

Returns
-------
:
The series of mapped key names.
"""
import pandas as pd
from cyclebane.graph import IndexValues, MappedNode, NodeName

candidates = [
node
for node in graph._cbgraph.graph.nodes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that this uses a protected attribute of graph. Should get_mapped_node_names be a method of Pipeline?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely something that we need to consider for a lot of other functionality (in particular around planner graph operations) that we intend to implement. Either we need to make those properties accessible on the public interface, or potentially add a lot more methods to Pipeline. Right now I cannot say which is better? So unless you can clearly say which solution should be chosen, I'd like to keep this as it is for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what operations you have in mind that need direct access to the graph. But I would say that we should expose a minimal set of primitive operations that can be composed in, e.g., methods of workflow classes.

if isinstance(node, MappedNode) and node.name == key
]
if index_names is not None:
candidates = [
node for node in candidates if set(node.indices) == set(index_names)
]
if len(candidates) == 0:
raise ValueError(f"'{key}' is not a mapped node.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this check be before filtering by index names? It seems that key refers to a mapped node but that is filtered out by the index names.

if len(candidates) > 1:
raise ValueError(f"Multiple mapped nodes with name '{key}' found: {candidates}")
# Drops unrelated indices
graph = graph[candidates[0]] # type: ignore[index]
indices = graph._cbgraph.indices
if index_names is not None:
indices = {name: indices[name] for name in indices if name in index_names}
index_names = tuple(indices)

index = pd.MultiIndex.from_product(indices.values(), names=index_names)
keys = tuple(NodeName(key, IndexValues(index_names, idx)) for idx in index)
if index.nlevels == 1: # Avoid more complicated MultiIndex if unnecessary
index = index.get_level_values(0)
return pd.Series(keys, index=index, name=key)


def compute_mapped(
graph: Pipeline, key: type, index_names: Sequence[Hashable] | None = None
) -> pandas.Series:
"""
Given a graph with key depending on mapped nodes, compute a series for the key.

This is meant to be used in combination with :py:func:`Pipeline.map`.
If the key depends on multiple indices, the series will be a multi-index series.

Note that Pandas is not a dependency of Sciline and must be installed separately.

Parameters
----------
graph:
The pipeline to compute the series from.
key:
The key to compute the series for. This key must depend on mapped nodes.
index_names:
Specifies the names of the indices of the mapped node. If not given this is
inferred from the graph, but the argument may be required to disambiguate
multiple mapped nodes with the same name.

Returns
-------
:
The computed series.
"""
key_series = get_mapped_node_names(graph=graph, key=key, index_names=index_names)
results = graph.compute(key_series)
return key_series.apply(lambda x: results[x])
Loading