Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compute_mapped #170

Merged
merged 17 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/api-reference/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
HandleAsComputeTimeException
```

## Top-level functions

```{eval-rst}
.. autosummary::
:toctree: ../generated/functions
:recursive:

compute_series
get_mapped_node_names
```

## Exceptions

```{eval-rst}
Expand Down
27 changes: 20 additions & 7 deletions docs/user-guide/parameter-tables.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Then we can compute `Result` for each index in the parameter table.\n",
"Currently there is no convenient way of accessing these, instead we manually define the target nodes to compute:"
"We can use the [compute_series](https://scipp.github.io/sciline/generated/functions/sciline.compute_series.html) function to compute `Result` for each index in the parameter table:"
]
},
{
Expand All @@ -156,10 +155,8 @@
"metadata": {},
"outputs": [],
"source": [
"from cyclebane.graph import NodeName, IndexValues\n",
"\n",
"targets = [NodeName(Result, IndexValues(('run_id',), (i,))) for i in run_ids]\n",
"pipeline.compute(targets)"
"results = sciline.compute_series(pipeline, Result)\n",
"pd.DataFrame(results) # DataFrame for HTML rendering"
]
},
{
Expand All @@ -168,7 +165,22 @@
"source": [
"Note the use of the `run_id` index.\n",
"If the index axis of the DataFrame has no name then a default of `dim_0`, `dim_1`, etc. is used.\n",
"We can also visualize the task graph for computing the series of `Result` values:"
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"**Note**\n",
"\n",
"[compute_series](https://scipp.github.io/sciline/generated/functions/sciline.compute_series.html) depends on Pandas, which is not a dependency of Sciline and must be installed separately, e.g., using pip:\n",
"\n",
"```bash\n",
"pip install pandas\n",
"```\n",
"\n",
"</div>\n",
"\n",
"We can also visualize the task graph for computing the series of `Result` values.\n",
"For this, we need to get all the node names derived from `Result` via the `map` operation.\n",
"The [get_mapped_node_names](https://scipp.github.io/sciline/generated/functions/sciline.get_mapped_node_names.html) function can be used to get a `pandas.Series` of these node names, which we can then visualize:"
]
},
{
Expand All @@ -177,6 +189,7 @@
"metadata": {},
"outputs": [],
"source": [
"targets = sciline.get_mapped_node_names(pipeline, Result)\n",
"pipeline.visualize(targets)"
]
},
Expand Down
2 changes: 1 addition & 1 deletion requirements/base.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# will not be touched by ``make_base.py``
# --- END OF CUSTOM SECTION ---
# The following was generated by 'tox -e deps', DO NOT EDIT MANUALLY!
cyclebane >= 24.06.0
cyclebane>=24.06.0
2 changes: 1 addition & 1 deletion requirements/base.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SHA1:44b03b700447e95874de7cd4d8b11558c0c972e5
# SHA1:1b4246f703135629f3fb69e65829cfb99abca695
#
# This file is autogenerated by pip-compile-multi
# To update, run:
Expand Down
1 change: 1 addition & 0 deletions requirements/basetest.in
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
graphviz
jsonschema
numpy
pandas
pytest
18 changes: 15 additions & 3 deletions requirements/basetest.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SHA1:f7d11f6aab1600c37d922ffb857a414368b800cd
# SHA1:fa9ebd4f58fe57db20baa224ad46fac634f3f046
#
# This file is autogenerated by pip-compile-multi
# To update, run:
Expand All @@ -19,14 +19,22 @@ jsonschema==4.22.0
# via -r basetest.in
jsonschema-specifications==2023.12.1
# via jsonschema
numpy==1.26.4
# via -r basetest.in
numpy==2.0.0
# via
# -r basetest.in
# pandas
packaging==24.1
# via pytest
pandas==2.2.2
# via -r basetest.in
pluggy==1.5.0
# via pytest
pytest==8.2.2
# via -r basetest.in
python-dateutil==2.9.0.post0
# via pandas
pytz==2024.1
# via pandas
referencing==0.35.1
# via
# jsonschema
Expand All @@ -35,5 +43,9 @@ rpds-py==0.18.1
# via
# jsonschema
# referencing
six==1.16.0
# via python-dateutil
tomli==2.0.1
# via pytest
tzdata==2024.1
# via pandas
4 changes: 2 additions & 2 deletions requirements/ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ colorama==0.4.6
# via tox
distlib==0.3.8
# via virtualenv
filelock==3.14.0
filelock==3.15.1
# via
# tox
# virtualenv
Expand Down Expand Up @@ -50,7 +50,7 @@ tomli==2.0.1
# tox
tox==4.15.1
# via -r ci.in
urllib3==2.2.1
urllib3==2.2.2
# via requests
virtualenv==20.26.2
# via tox
2 changes: 1 addition & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ prometheus-client==0.20.0
# via jupyter-server
pycparser==2.22
# via cffi
pydantic==2.7.3
pydantic==2.7.4
# via copier
pydantic-core==2.18.4
# via pydantic
Expand Down
6 changes: 3 additions & 3 deletions requirements/docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ exceptiongroup==1.2.1
# via ipython
executing==2.0.1
# via stack-data
fastjsonschema==2.19.1
fastjsonschema==2.20.0
# via nbformat
graphviz==0.20.3
# via -r docs.in
Expand Down Expand Up @@ -120,7 +120,7 @@ nbsphinx==0.9.4
# via -r docs.in
nest-asyncio==1.6.0
# via ipykernel
numpy==1.26.4
numpy==2.0.0
# via pandas
packaging==24.1
# via
Expand Down Expand Up @@ -241,7 +241,7 @@ typing-extensions==4.12.2
# pydata-sphinx-theme
tzdata==2024.1
# via pandas
urllib3==2.2.1
urllib3==2.2.2
# via requests
wcwidth==0.2.13
# via prompt-toolkit
Expand Down
2 changes: 1 addition & 1 deletion requirements/static.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ cfgv==3.4.0
# via pre-commit
distlib==0.3.8
# via virtualenv
filelock==3.14.0
filelock==3.15.1
# via virtualenv
identify==2.5.36
# via pre-commit
Expand Down
2 changes: 1 addition & 1 deletion requirements/test-dask.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ click==8.1.7
# via dask
cloudpickle==3.0.0
# via dask
dask==2024.5.2
dask==2024.6.0
# via -r test-dask.in
fsspec==2024.6.0
# via dask
Expand Down
4 changes: 3 additions & 1 deletion src/sciline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
HandleAsComputeTimeException,
UnsatisfiedRequirement,
)
from .pipeline import Pipeline
from .pipeline import Pipeline, compute_series, get_mapped_node_names
from .task_graph import TaskGraph

__all__ = [
Expand All @@ -30,6 +30,8 @@
"UnsatisfiedRequirement",
"HandleAsBuildTimeException",
"HandleAsComputeTimeException",
"compute_series",
"get_mapped_node_names",
]

del importlib
72 changes: 70 additions & 2 deletions src/sciline/pipeline.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you consider how you could integrate this functionality into compute? E.g., check whether a given key is mapped and then do the equivalent of compute_series?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like that, yes.

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Callable, Iterable
from itertools import chain
from types import UnionType
from typing import Any, TypeVar, get_args, get_type_hints, overload
from typing import TYPE_CHECKING, Any, TypeVar, get_args, get_type_hints, overload

from ._provider import Provider, ToProvider
from .data_graph import DataGraph, to_task_graph
Expand All @@ -15,6 +15,11 @@
from .task_graph import TaskGraph
from .typing import Key

if TYPE_CHECKING:
import graphviz
import pandas


T = TypeVar('T')
KeyType = TypeVar('KeyType', bound=Key)

Expand Down Expand Up @@ -84,7 +89,7 @@ def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any:
"""
return self.get(tp, **kwargs).compute()

def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821
def visualize(self, tp: type | Iterable[type], **kwargs: Any) -> graphviz.Digraph:
"""
Return a graphviz Digraph object representing the graph for the given keys.

Expand Down Expand Up @@ -194,3 +199,66 @@ def bind_and_call(
def _repr_html_(self) -> str:
nodes = ((key, data) for key, data in self._graph.nodes.items())
return pipeline_html_repr(nodes)


def get_mapped_node_names(graph: Pipeline, key: type) -> pandas.Series:
"""
Given a graph with key depending on mapped nodes, return a series corresponding
mapped keys.

This is meant to be used in combination with :py:func:`Pipeline.map`.
If the key depends on multiple indices, the series will be a multi-index series.
SimonHeybrock marked this conversation as resolved.
Show resolved Hide resolved

Note that Pandas is not a dependency of Sciline and must be installed separately.

Parameters
----------
graph:
The pipeline to get the mapped key names from.
key:
The key to get the mapped key names for. This key must depend on mapped nodes.

Returns
-------
:
The series of mapped key names.
"""
import pandas as pd
from cyclebane.graph import IndexValues, NodeName

graph = graph[key] # Drops unrelated indices
indices = graph._cbgraph.indices
if len(indices) == 0:
raise ValueError(f"'{key}' does not depend on any mapped nodes.")
index_names = tuple(indices)
index = pd.MultiIndex.from_product(indices.values(), names=index_names)
keys = tuple(NodeName(key, IndexValues(index_names, idx)) for idx in index)
if index.nlevels == 1: # Avoid more complicate MultiIndex if unnecessary
index = index.get_level_values(0)
return pd.Series(keys, index=index, name=key)


def compute_series(graph: Pipeline, key: type) -> pandas.Series:
"""
Given a graph with key depending on mapped nodes, compute a series for the key.

This is meant to be used in combination with :py:func:`Pipeline.map`.
If the key depends on multiple indices, the series will be a multi-index series.

Note that Pandas is not a dependency of Sciline and must be installed separately.

Parameters
----------
graph:
The pipeline to compute the series from.
key:
The key to compute the series for. This key must depend on mapped nodes.

Returns
-------
:
The computed series.
"""
key_series = get_mapped_node_names(graph, key)
results = graph.compute(key_series)
return key_series.apply(lambda x: results[x])
Loading
Loading