\n",
+ "\n",
+ "**Note**\n",
"\n",
"We use a peculiar-looking syntax for defining \"generic type aliases\".\n",
"We would love to use [typing.NewType](https://docs.python.org/3/library/typing.html#typing.NewType) for this, but it does not allow for definition of generic aliases.\n",
@@ -360,9 +370,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "### Generic providers and parameter tables\n",
+ "### Generic providers and map operations\n",
"\n",
- "As a more complex example of where generic providers are useful, we may add a parameter table, so we can process multiple samples:"
+ "As a more complex example of where generic providers are useful, we may add `map` operation, so we can process multiple samples:"
]
},
{
@@ -371,33 +381,17 @@
"metadata": {},
"outputs": [],
"source": [
- "RunID = NewType('RunID', int)\n",
"run_ids = [102, 103, 104, 105]\n",
"filenames = [f'file{i}.txt' for i in run_ids]\n",
- "param_table = sciline.ParamTable(RunID, {Filename[Sample]: filenames}, index=run_ids)\n",
- "param_table"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We can now create a parametrized pipeline:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
"params = {\n",
" ScaleFactor: 2.0,\n",
" Filename[Background]: 'background.txt',\n",
"}\n",
"pipeline = sciline.Pipeline(providers, params=params)\n",
- "pipeline.set_param_table(param_table)\n",
- "graph = pipeline.get(sciline.Series[RunID, Result])\n",
+ "pipeline = pipeline.map({Filename[Sample]: filenames})\n",
+ "\n",
+ "# We can collect the results into a list for simplicity in this example\n",
+ "graph = pipeline.reduce(func=lambda *x: list[x], name='collected').get('collected')\n",
"graph.visualize()"
]
},
@@ -409,15 +403,6 @@
"source": [
"graph.compute()"
]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "With the generics mechanism, we could have added a line for the background to the parameter table.\n",
- "However, the `subtrack_background` function would then have to be modified to accept a `Series[RunID, CleanedData]`.\n",
- "More importantly, this would have resulted in a synchronization point in the computation graph, preventing efficient scheduling of the subsequent computation, with potentially disastrous effects on memory consumption."
- ]
}
],
"metadata": {
diff --git a/docs/user-guide/parameter-tables.ipynb b/docs/user-guide/parameter-tables.ipynb
index aeae0ba8..bb7c70b3 100644
--- a/docs/user-guide/parameter-tables.ipynb
+++ b/docs/user-guide/parameter-tables.ipynb
@@ -9,9 +9,10 @@
"\n",
"## Overview\n",
"\n",
- "Parameter tables provide a mechanism for repeating parts of or all of a computation with different values for one or more parameters.\n",
+ "Sciline supports a mechanism for repeating parts of or all of a computation with different values for one or more parameters.\n",
"This allows for a variety of use cases, similar to *map*, *reduce*, and *groupby* operations in other systems.\n",
- "We illustrate each of these in the follow three chapters."
+ "We illustrate each of these in the follow three chapters.\n",
+ "Sciline's implementation is based on [Cyclebane](scipp.github.io/cyclebane)."
]
},
{
@@ -20,9 +21,10 @@
"source": [
"## Computing results for series of parameters\n",
"\n",
- "This chapter illustrates how to implement *map* operations with Sciline.\n",
+ "This chapter illustrates how to perform *map* operations with Sciline.\n",
"\n",
- "Starting with the model workflow introduced in [Getting Started](getting-started.ipynb), we can replace the fixed `Filename` parameter with a series of filenames listed in a [ParamTable](../generated/classes/sciline.ParamTable.rst):"
+ "Starting with the model workflow introduced in [Getting Started](getting-started.ipynb), we would like to replace the fixed `Filename` parameter with a series of filenames listed in a \"parameter table\".\n",
+ "We begin by defining the base pipeline:"
]
},
{
@@ -74,23 +76,17 @@
"\n",
"# 3. Create pipeline\n",
"\n",
- "# 3.a Providers and normal parameters\n",
"providers = [load, clean, process]\n",
"params = {ScaleFactor: 2.0}\n",
- "\n",
- "# 3.b Parameter table\n",
- "RunID = NewType('RunID', int)\n",
- "run_ids = [102, 103, 104, 105]\n",
- "filenames = [f'file{i}.txt' for i in run_ids]\n",
- "param_table = sciline.ParamTable(RunID, {Filename: filenames}, index=run_ids)"
+ "base = sciline.Pipeline(providers, params=params)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Note how steps 1.) and 2.) are identical to those from the example without parameter table.\n",
- "Above we have created the following parameter table:"
+ "Aside from not having defined a value for the `Filename` parameter, this is identical to the example in [Getting Started](getting-started.ipynb).\n",
+ "The task-graph visualization indicates this missing parameter:"
]
},
{
@@ -99,14 +95,14 @@
"metadata": {},
"outputs": [],
"source": [
- "param_table"
+ "base.visualize(Result, graph_attr={'rankdir': 'LR'})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "We can now create the pipeline and set the parameter table:"
+ "We now define a \"parameter table\" listing the filenames we would like to process:"
]
},
{
@@ -115,16 +111,26 @@
"metadata": {},
"outputs": [],
"source": [
- "# 3.c Setup pipeline\n",
- "pipeline = sciline.Pipeline(providers, params=params)\n",
- "pipeline.set_param_table(param_table)"
+ "import pandas as pd\n",
+ "\n",
+ "run_ids = [102, 103, 104, 105]\n",
+ "filenames = [f'file{i}.txt' for i in run_ids]\n",
+ "param_table = pd.DataFrame({Filename: filenames}, index=run_ids).rename_axis(\n",
+ " index='run_id'\n",
+ ")\n",
+ "param_table"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Then we can compute `Result` for each index in the parameter table:"
+ "Note how we used a node name of the pipeline as the column name in the parameter table.\n",
+ "For convenience we used a `pandas.DataFrame` to represent the table above, but the use of Pandas is entirely optional.\n",
+ "Equivalently the table could be represented as a `dict`, where each key corresponds to a column header and each value is a list of values for that column, i.e., `{Filename: filenames}`.\n",
+ "Specifying an index is currently not possible in this case, and it will default to a range index.\n",
+ "\n",
+ "We can now use [Pipeline.map](https://scipp.github.io/sciline/generated/classes/sciline.Pipeline.html#sciline.Pipeline.map) to create a modified pipeline that processes each row in the parameter table:"
]
},
{
@@ -133,18 +139,35 @@
"metadata": {},
"outputs": [],
"source": [
- "pipeline.compute(sciline.Series[RunID, Result])"
+ "pipeline = base.map(param_table)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
+ "Then we can compute `Result` for each index in the parameter table.\n",
+ "Currently there is no convenient way of accessing these, instead we manually define the target nodes to compute:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from cyclebane.graph import NodeName, IndexValues\n",
"\n",
- "`sciline.Series` is a special `dict`-like type that signals to Sciline that the values of the series are based on values from one or more columns of a parameter table.\n",
- "The parameter table is identified using the first argument to `Series`, in this case `RunID`.\n",
- "The second argument specifies the result to be computed.\n",
- "\n",
+ "targets = [NodeName(Result, IndexValues(('run_id',), (i,))) for i in run_ids]\n",
+ "pipeline.compute(targets)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Note the use of the `run_id` index.\n",
+ "If the index axis of the DataFrame has no name then a default of `dim_0`, `dim_1`, etc. is used.\n",
"We can also visualize the task graph for computing the series of `Result` values:"
]
},
@@ -154,15 +177,14 @@
"metadata": {},
"outputs": [],
"source": [
- "pipeline.visualize(sciline.Series[RunID, Result])"
+ "pipeline.visualize(targets)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Nodes that depend on values from a parameter table are drawn with the parameter index name (the row dimension of the parameter table) and value given in parenthesis.\n",
- "The dashed arrow indicates and internal transformation that gathers result from each branch and combines them into a single output, here `Series[RunID, Result]`.\n",
+ "Nodes that depend on values from a parameter table are drawn with the parameter index name (the row dimension of the parameter table) and index value (defaulting to a range index starting at 0 if no index if given) shown in parenthesis.\n",
"\n",
"
\n",
"\n",
@@ -185,8 +207,7 @@
"\n",
"This chapter illustrates how to implement *reduce* operations with Sciline.\n",
"\n",
- "Instead of requesting a series of results as above, we can also build pipelines with providers that depend on such series.\n",
- "We can create a new pipeline, or extend the existing one by inserting a new provider:\n"
+ "Instead of requesting a series of results as above, we use the [Pipeline.reduce](https://scipp.github.io/sciline/generated/classes/sciline.Pipeline.html#sciline.Pipeline.reduce) method and pass a function that combines the results from each parameter into a single result:"
]
},
{
@@ -195,23 +216,31 @@
"metadata": {},
"outputs": [],
"source": [
- "MergedResult = NewType('MergedResult', float)\n",
- "\n",
+ "graph = pipeline.reduce(func=lambda *result: sum(result), name='merged').get('merged')\n",
+ "graph.visualize()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
\n",
"\n",
- "def merge_runs(runs: sciline.Series[RunID, Result]) -> MergedResult:\n",
- " return MergedResult(sum(runs.values()))\n",
+ "**Note**\n",
"\n",
+ "The `func` passed to `reduce` is *not* making use of Sciline's mechanism of assembling a graph based on type hints.\n",
+ "In particular, the input type may be identical to the output type.\n",
+ "The [Pipeline.reduce](https://scipp.github.io/sciline/generated/classes/sciline.Pipeline.html#sciline.Pipeline.reduce) method adds a *new* node, attached at a unique (but mapped) sink node of the graph.\n",
+ "[Pipeline.__getitem__](https://scipp.github.io/sciline/generated/classes/sciline.Pipeline.html#sciline.Pipeline.__getitem__) and [Pipeline.__setitem__](https://scipp.github.io/sciline/generated/classes/sciline.Pipeline.html#sciline.Pipeline.__setitem__) can be used to compose more complex graphs where the reduction is not the final step.\n",
"\n",
- "pipeline.insert(merge_runs)\n",
- "graph = pipeline.get(MergedResult)\n",
- "graph.visualize()"
+ "
"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Note that this is identical to the example in the previous section, except for the last two nodes in the graph.\n",
+ "Note that the graph shown above is identical to the example in the previous section, except for the last two nodes in the graph.\n",
"The computation now returns a single result:"
]
},
@@ -231,12 +260,29 @@
"This is useful if we need to continue computation after gathering results without setting up a second pipeline."
]
},
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
\n",
+ "\n",
+ "**Note**\n",
+ "\n",
+ "For the `reduce` operation, all inputs to the reduction function have to be kept in memory simultaneously.\n",
+ "This can be very memory intensive.\n",
+ "We intend to support, e.g., hierarchical reduction operations in the future, where intermediate results are combined in a tree-like fashion to avoid excessive memory consumption..\n",
+ "\n",
+ "
"
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Grouping intermediate results based on secondary parameters\n",
"\n",
+ "**Cyclebane and Sciline do not support `groupby` yet, this is work in progress so this example is not functional yet.**\n",
+ "\n",
"This chapter illustrates how to implement *groupby* operations with Sciline.\n",
"\n",
"Continuing from the examples for *map* and *reduce*, we can introduce a secondary parameter in the table, such as the material of the sample:"
@@ -250,123 +296,35 @@
"source": [
"Material = NewType('Material', str)\n",
"\n",
- "# 3.a Providers and normal parameters\n",
- "providers = [load, clean, process, merge_runs]\n",
- "params = {ScaleFactor: 2.0}\n",
- "\n",
- "# 3.b Parameter table\n",
"run_ids = [102, 103, 104, 105]\n",
"sample = ['diamond', 'graphite', 'graphite', 'graphite']\n",
"filenames = [f'file{i}.txt' for i in run_ids]\n",
- "param_table = sciline.ParamTable(\n",
- " RunID, {Filename: filenames, Material: sample}, index=run_ids\n",
- ")\n",
+ "param_table = pd.DataFrame(\n",
+ " {Filename: filenames, Material: sample}, index=run_ids\n",
+ ").rename_axis(index='run_id')\n",
"param_table"
]
},
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "# 3.c Setup pipeline\n",
- "pipeline = sciline.Pipeline(providers, params=params)\n",
- "pipeline.set_param_table(param_table)"
- ]
- },
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "We can now compute `MergedResult` for a series of \"materials\":"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "pipeline.compute(sciline.Series[Material, MergedResult])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "The computation looks as show below.\n",
- "Note how the initial steps of the computation depend on the `RunID` parameter, while later steps depend on `Material`:\n",
- "The files for each run ID have been grouped by their material and then merged:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "pipeline.visualize(sciline.Series[Material, MergedResult])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## More examples"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Using tables for series of parameters\n",
- "\n",
- "Sometimes the parameter of interest is the index of a parameter table itself.\n",
- "If there are no further parameters, the param table may have no columns (aside from the index).\n",
- "In this case we can bypass the manual creation of a parameter table and use the [Pipeline.set_param_series](../generated/classes/sciline.Pipeline.rst#sciline.Pipeline.set_param_series) function instead:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from typing import NewType\n",
- "import sciline as sl\n",
- "\n",
- "Param = NewType(\"Param\", int)\n",
- "Sum = NewType(\"Sum\", float)\n",
- "\n",
- "\n",
- "def compute(x: Param) -> float:\n",
- " return 0.5 * x\n",
- "\n",
+ "Future releases of Sciline will support a `groupby` operation, roughly as follows:\n",
"\n",
- "def gather(x: sl.Series[Param, float]) -> Sum:\n",
- " return Sum(sum(x.values()))\n",
+ "```python\n",
+ "pipeline = base.map(param_table).groupby(Material).reduce(func=merge)\n",
+ "```\n",
"\n",
- "\n",
- "pl = sl.Pipeline([gather, compute])\n",
- "pl.set_param_series(Param, [1, 4, 9])\n",
- "pl.visualize(Sum)"
+ "We can then compute the merged result, grouped by the value of `Material`.\n",
+ "Note how the initial steps of the computation depend on the `run_id` index name, while later steps depend on `Material`, a new index name defined by the `groupby` operation.\n",
+ "The files for each run ID have been grouped by their material and then merged."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Note that `pl.set_param_series(Param, [1, 4, 9])` above is equivalent to `pl.set_param_table(sl.ParamTable(Param, columns={}, index=[1, 4, 9]))`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "pl.compute(Sum)"
+ "## More examples"
]
},
{
@@ -387,21 +345,19 @@
"Sum = NewType(\"Sum\", float)\n",
"Param1 = NewType(\"Param1\", int)\n",
"Param2 = NewType(\"Param2\", int)\n",
- "Row = NewType(\"Run\", int)\n",
"\n",
"\n",
- "def gather(\n",
- " x: sl.Series[Row, float],\n",
- ") -> Sum:\n",
- " return Sum(sum(x.values()))\n",
+ "def gather(*x: float) -> Sum:\n",
+ " return Sum(sum(x))\n",
"\n",
"\n",
"def product(x: Param1, y: Param2) -> float:\n",
" return x / y\n",
"\n",
"\n",
- "pl = sl.Pipeline([gather, product])\n",
- "pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]}))\n",
+ "params = pd.DataFrame({Param1: [1, 4, 9], Param2: [1, 2, 3]})\n",
+ "pl = sl.Pipeline([product])\n",
+ "pl = pl.map(params).reduce(func=gather, name=Sum)\n",
"\n",
"pl.visualize(Sum)"
]
@@ -432,11 +388,10 @@
"Param = NewType(\"Param\", int)\n",
"Param1 = NewType(\"Param1\", int)\n",
"Param2 = NewType(\"Param2\", int)\n",
- "Row = NewType(\"Run\", int)\n",
"\n",
"\n",
- "def gather(x: sl.Series[Row, float]) -> Sum:\n",
- " return Sum(sum(x.values()))\n",
+ "def gather(*x: float) -> float:\n",
+ " return sum(x)\n",
"\n",
"\n",
"def to_param1(x: Param) -> Param1:\n",
@@ -451,8 +406,9 @@
" return x * y\n",
"\n",
"\n",
- "pl = sl.Pipeline([gather, product, to_param1, to_param2])\n",
- "pl.set_param_table(sl.ParamTable(Row, {Param: [1, 2, 3]}))\n",
+ "pl = sl.Pipeline([product, to_param1, to_param2])\n",
+ "params = pd.DataFrame({Param: [1, 2, 3]})\n",
+ "pl = pl.map(params).reduce(func=gather, name=Sum)\n",
"pl.visualize(Sum)"
]
},
@@ -469,40 +425,37 @@
"metadata": {},
"outputs": [],
"source": [
+ "from typing import Any\n",
"import sciline as sl\n",
"\n",
- "List1 = NewType(\"List1\", float)\n",
- "List2 = NewType(\"List2\", float)\n",
"Param1 = NewType(\"Param1\", int)\n",
"Param2 = NewType(\"Param2\", int)\n",
- "Row1 = NewType(\"Row1\", int)\n",
- "Row2 = NewType(\"Row2\", int)\n",
- "\n",
"\n",
- "def gather1(x: sl.Series[Row1, float]) -> List1:\n",
- " return List1(list(x.values()))\n",
"\n",
- "\n",
- "def gather2(x: sl.Series[Row2, List1]) -> List2:\n",
- " return List2(list(x.values()))\n",
+ "def gather(*x: Any) -> list[Any]:\n",
+ " return list(x)\n",
"\n",
"\n",
"def product(x: Param1, y: Param2) -> float:\n",
" return x * y\n",
"\n",
"\n",
- "pl = sl.Pipeline([gather1, gather2, product])\n",
- "pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 4, 9]}))\n",
- "pl.set_param_table(sl.ParamTable(Row2, {Param2: [1, 2]}))\n",
+ "base = sl.Pipeline([product])\n",
+ "pl = (\n",
+ " base.map({Param1: [1, 4, 9]})\n",
+ " .map({Param2: [1, 2]})\n",
+ " .reduce(func=gather, name='reduce_1', index='dim_1')\n",
+ " .reduce(func=gather, name='reduce_0')\n",
+ ")\n",
"\n",
- "pl.visualize(List2)"
+ "pl.visualize('reduce_0')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
- "Note how intermediates such as `float(Row1, Row2)` depend on two parameters, i.e., we are dealing with a 2-D array of branches in the graph."
+ "Note how intermediates such as `float(dim_1, dim_0)` depend on two parameters, i.e., we are dealing with a 2-D array of branches in the graph."
]
},
{
@@ -511,7 +464,29 @@
"metadata": {},
"outputs": [],
"source": [
- "pl.compute(List2)"
+ "pl.compute('reduce_0')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "It is also possible to reduce multiple axes at once.\n",
+ "For example, `reduce` will reduce all axes if no `index` or `axis` is specified:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pl = (\n",
+ " base.map({Param1: [1, 4, 9]})\n",
+ " .map({Param2: [1, 2]})\n",
+ " .reduce(func=gather, name='reduce_both')\n",
+ ")\n",
+ "pl.visualize('reduce_both')"
]
}
],
diff --git a/pyproject.toml b/pyproject.toml
index 13ab13fd..d4be62be 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,6 +30,7 @@ requires-python = ">=3.10"
# Run 'tox -e deps' after making changes here. This will update requirement files.
# Make sure to list one dependency per line.
dependencies = [
+ "cyclebane",
]
dynamic = ["version"]
diff --git a/requirements/base.in b/requirements/base.in
index b801db0e..3baebf5e 100644
--- a/requirements/base.in
+++ b/requirements/base.in
@@ -2,4 +2,4 @@
# will not be touched by ``make_base.py``
# --- END OF CUSTOM SECTION ---
# The following was generated by 'tox -e deps', DO NOT EDIT MANUALLY!
-
+cyclebane
diff --git a/requirements/base.txt b/requirements/base.txt
index c24fd27f..eaed248d 100644
--- a/requirements/base.txt
+++ b/requirements/base.txt
@@ -1,8 +1,11 @@
-# SHA1:da39a3ee5e6b4b0d3255bfef95601890afd80709
+# SHA1:dd62fc84fce7fde1639783710b14ec7a54d8cdc5
#
# This file is autogenerated by pip-compile-multi
# To update, run:
#
# pip-compile-multi
#
-
+cyclebane==24.5.0
+ # via -r base.in
+networkx==3.3
+ # via cyclebane
diff --git a/requirements/basetest.txt b/requirements/basetest.txt
index b36e420b..b5c71f98 100644
--- a/requirements/basetest.txt
+++ b/requirements/basetest.txt
@@ -9,13 +9,13 @@ attrs==23.2.0
# via
# jsonschema
# referencing
-exceptiongroup==1.2.0
+exceptiongroup==1.2.1
# via pytest
graphviz==0.20.3
# via -r basetest.in
iniconfig==2.0.0
# via pytest
-jsonschema==4.21.1
+jsonschema==4.22.0
# via -r basetest.in
jsonschema-specifications==2023.12.1
# via jsonschema
@@ -23,15 +23,15 @@ numpy==1.26.4
# via -r basetest.in
packaging==24.0
# via pytest
-pluggy==1.4.0
+pluggy==1.5.0
# via pytest
-pytest==8.1.1
+pytest==8.2.1
# via -r basetest.in
-referencing==0.34.0
+referencing==0.35.1
# via
# jsonschema
# jsonschema-specifications
-rpds-py==0.18.0
+rpds-py==0.18.1
# via
# jsonschema
# referencing
diff --git a/requirements/ci.txt b/requirements/ci.txt
index 88d49ee6..ab682a51 100644
--- a/requirements/ci.txt
+++ b/requirements/ci.txt
@@ -17,7 +17,7 @@ colorama==0.4.6
# via tox
distlib==0.3.8
# via virtualenv
-filelock==3.13.4
+filelock==3.14.0
# via
# tox
# virtualenv
@@ -32,15 +32,15 @@ packaging==24.0
# -r ci.in
# pyproject-api
# tox
-platformdirs==4.2.0
+platformdirs==4.2.2
# via
# tox
# virtualenv
-pluggy==1.4.0
+pluggy==1.5.0
# via tox
pyproject-api==1.6.1
# via tox
-requests==2.31.0
+requests==2.32.3
# via -r ci.in
smmap==5.0.1
# via gitdb
@@ -48,9 +48,9 @@ tomli==2.0.1
# via
# pyproject-api
# tox
-tox==4.14.2
+tox==4.15.0
# via -r ci.in
urllib3==2.2.1
# via requests
-virtualenv==20.25.1
+virtualenv==20.26.2
# via tox
diff --git a/requirements/dev.txt b/requirements/dev.txt
index adba5090..ab785ebb 100644
--- a/requirements/dev.txt
+++ b/requirements/dev.txt
@@ -12,9 +12,9 @@
-r static.txt
-r test.txt
-r wheels.txt
-annotated-types==0.6.0
+annotated-types==0.7.0
# via pydantic
-anyio==4.3.0
+anyio==4.4.0
# via
# httpx
# jupyter-server
@@ -34,7 +34,7 @@ click==8.1.7
# pip-tools
copier==9.2.0
# via -r dev.in
-dunamai==1.20.0
+dunamai==1.21.1
# via copier
fqdn==1.5.1
# via jsonschema
@@ -54,7 +54,7 @@ json5==0.9.25
# via jupyterlab-server
jsonpointer==2.4
# via jsonschema
-jsonschema[format-nongpl]==4.21.1
+jsonschema[format-nongpl]==4.22.0
# via
# -r basetest.in
# jupyter-events
@@ -72,9 +72,9 @@ jupyter-server==2.14.0
# notebook-shim
jupyter-server-terminals==0.5.3
# via jupyter-server
-jupyterlab==4.1.6
+jupyterlab==4.2.1
# via -r dev.in
-jupyterlab-server==2.26.0
+jupyterlab-server==2.27.2
# via jupyterlab
notebook-shim==0.2.4
# via jupyterlab
@@ -86,15 +86,15 @@ pip-compile-multi==2.6.3
# via -r dev.in
pip-tools==7.4.1
# via pip-compile-multi
-plumbum==1.8.2
+plumbum==1.8.3
# via copier
prometheus-client==0.20.0
# via jupyter-server
pycparser==2.22
# via cffi
-pydantic==2.7.0
+pydantic==2.7.2
# via copier
-pydantic-core==2.18.1
+pydantic-core==2.18.3
# via pydantic
python-json-logger==2.0.7
# via jupyter-events
@@ -128,7 +128,7 @@ uri-template==1.3.0
# via jsonschema
webcolors==1.13
# via jsonschema
-websocket-client==1.7.0
+websocket-client==1.8.0
# via jupyter-server
wheel==0.43.0
# via pip-tools
diff --git a/requirements/docs.in b/requirements/docs.in
index 8a2366c9..f512d9c6 100644
--- a/requirements/docs.in
+++ b/requirements/docs.in
@@ -4,6 +4,7 @@ ipykernel
ipython!=8.7.0 # Breaks syntax highlighting in Jupyter code cells.
myst-parser
nbsphinx
+pandas
pydata-sphinx-theme>=0.14
sphinx
sphinx-autodoc-typehints
diff --git a/requirements/docs.txt b/requirements/docs.txt
index 6ee43684..1d0313b1 100644
--- a/requirements/docs.txt
+++ b/requirements/docs.txt
@@ -1,4 +1,4 @@
-# SHA1:69ee98dd11fd5e6fe2f8e3b546660b033b8839c4
+# SHA1:d646fa4c965681f36669a3ad404405afcd35ce97
#
# This file is autogenerated by pip-compile-multi
# To update, run:
@@ -6,7 +6,7 @@
# pip-compile-multi
#
-r base.txt
-accessible-pygments==0.0.4
+accessible-pygments==0.0.5
# via pydata-sphinx-theme
alabaster==0.7.16
# via sphinx
@@ -16,7 +16,7 @@ attrs==23.2.0
# via
# jsonschema
# referencing
-babel==2.14.0
+babel==2.15.0
# via
# pydata-sphinx-theme
# sphinx
@@ -38,13 +38,13 @@ decorator==5.1.1
# via ipython
defusedxml==0.7.1
# via nbconvert
-docutils==0.20.1
+docutils==0.21.2
# via
# myst-parser
# nbsphinx
# pydata-sphinx-theme
# sphinx
-exceptiongroup==1.2.0
+exceptiongroup==1.2.1
# via ipython
executing==2.0.1
# via stack-data
@@ -58,23 +58,23 @@ imagesize==1.4.1
# via sphinx
ipykernel==6.29.4
# via -r docs.in
-ipython==8.23.0
+ipython==8.24.0
# via
# -r docs.in
# ipykernel
jedi==0.19.1
# via ipython
-jinja2==3.1.3
+jinja2==3.1.4
# via
# myst-parser
# nbconvert
# nbsphinx
# sphinx
-jsonschema==4.21.1
+jsonschema==4.22.0
# via nbformat
jsonschema-specifications==2023.12.1
# via jsonschema
-jupyter-client==8.6.1
+jupyter-client==8.6.2
# via
# ipykernel
# nbclient
@@ -95,46 +95,50 @@ markupsafe==2.1.5
# via
# jinja2
# nbconvert
-matplotlib-inline==0.1.6
+matplotlib-inline==0.1.7
# via
# ipykernel
# ipython
-mdit-py-plugins==0.4.0
+mdit-py-plugins==0.4.1
# via myst-parser
mdurl==0.1.2
# via markdown-it-py
mistune==3.0.2
# via nbconvert
-myst-parser==2.0.0
+myst-parser==3.0.1
# via -r docs.in
nbclient==0.10.0
# via nbconvert
-nbconvert==7.16.3
+nbconvert==7.16.4
# via nbsphinx
nbformat==5.10.4
# via
# nbclient
# nbconvert
# nbsphinx
-nbsphinx==0.9.3
+nbsphinx==0.9.4
# via -r docs.in
nest-asyncio==1.6.0
# via ipykernel
+numpy==1.26.4
+ # via pandas
packaging==24.0
# via
# ipykernel
# nbconvert
# pydata-sphinx-theme
# sphinx
+pandas==2.2.2
+ # via -r docs.in
pandocfilters==1.5.1
# via nbconvert
parso==0.8.4
# via jedi
pexpect==4.9.0
# via ipython
-platformdirs==4.2.0
+platformdirs==4.2.2
# via jupyter-core
-prompt-toolkit==3.0.43
+prompt-toolkit==3.0.45
# via ipython
psutil==5.9.8
# via ipykernel
@@ -142,9 +146,9 @@ ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.2
# via stack-data
-pydata-sphinx-theme==0.15.2
+pydata-sphinx-theme==0.15.3
# via -r docs.in
-pygments==2.17.2
+pygments==2.18.0
# via
# accessible-pygments
# ipython
@@ -152,20 +156,24 @@ pygments==2.17.2
# pydata-sphinx-theme
# sphinx
python-dateutil==2.9.0.post0
- # via jupyter-client
+ # via
+ # jupyter-client
+ # pandas
+pytz==2024.1
+ # via pandas
pyyaml==6.0.1
# via myst-parser
-pyzmq==26.0.0
+pyzmq==26.0.3
# via
# ipykernel
# jupyter-client
-referencing==0.34.0
+referencing==0.35.1
# via
# jsonschema
# jsonschema-specifications
-requests==2.31.0
+requests==2.32.3
# via sphinx
-rpds-py==0.18.0
+rpds-py==0.18.1
# via
# jsonschema
# referencing
@@ -178,7 +186,7 @@ snowballstemmer==2.2.0
# via sphinx
soupsieve==2.5
# via beautifulsoup4
-sphinx==7.2.6
+sphinx==7.3.7
# via
# -r docs.in
# myst-parser
@@ -187,11 +195,11 @@ sphinx==7.2.6
# sphinx-autodoc-typehints
# sphinx-copybutton
# sphinx-design
-sphinx-autodoc-typehints==2.0.1
+sphinx-autodoc-typehints==2.1.0
# via -r docs.in
sphinx-copybutton==0.5.2
# via -r docs.in
-sphinx-design==0.5.0
+sphinx-design==0.6.0
# via -r docs.in
sphinxcontrib-applehelp==1.0.8
# via sphinx
@@ -207,13 +215,15 @@ sphinxcontrib-serializinghtml==1.1.10
# via sphinx
stack-data==0.6.3
# via ipython
-tinycss2==1.2.1
+tinycss2==1.3.0
# via nbconvert
+tomli==2.0.1
+ # via sphinx
tornado==6.4
# via
# ipykernel
# jupyter-client
-traitlets==5.14.2
+traitlets==5.14.3
# via
# comm
# ipykernel
@@ -225,10 +235,12 @@ traitlets==5.14.2
# nbconvert
# nbformat
# nbsphinx
-typing-extensions==4.11.0
+typing-extensions==4.12.0
# via
# ipython
# pydata-sphinx-theme
+tzdata==2024.1
+ # via pandas
urllib3==2.2.1
# via requests
wcwidth==0.2.13
diff --git a/requirements/mypy.txt b/requirements/mypy.txt
index 8d293074..f0e7b3bf 100644
--- a/requirements/mypy.txt
+++ b/requirements/mypy.txt
@@ -6,9 +6,9 @@
# pip-compile-multi
#
-r test.txt
-mypy==1.9.0
+mypy==1.10.0
# via -r mypy.in
mypy-extensions==1.0.0
# via mypy
-typing-extensions==4.11.0
+typing-extensions==4.12.0
# via mypy
diff --git a/requirements/nightly.in b/requirements/nightly.in
index 6b1ebcc2..55a65bd8 100644
--- a/requirements/nightly.in
+++ b/requirements/nightly.in
@@ -1,4 +1,4 @@
-r basetest.in
# --- END OF CUSTOM SECTION ---
# The following was generated by 'tox -e deps', DO NOT EDIT MANUALLY!
-
+cyclebane
diff --git a/requirements/nightly.txt b/requirements/nightly.txt
index a98564b2..e697c13b 100644
--- a/requirements/nightly.txt
+++ b/requirements/nightly.txt
@@ -1,4 +1,4 @@
-# SHA1:e8b11c1210855f07eaedfbcfb3ecd1aec3595dee
+# SHA1:6e913d4c1e64cd8d5433e7de17be332a12fefe11
#
# This file is autogenerated by pip-compile-multi
# To update, run:
@@ -6,3 +6,7 @@
# pip-compile-multi
#
-r basetest.txt
+cyclebane==24.5.0
+ # via -r nightly.in
+networkx==3.3
+ # via cyclebane
diff --git a/requirements/static.txt b/requirements/static.txt
index 99028a0d..ea619e41 100644
--- a/requirements/static.txt
+++ b/requirements/static.txt
@@ -9,20 +9,17 @@ cfgv==3.4.0
# via pre-commit
distlib==0.3.8
# via virtualenv
-filelock==3.13.4
+filelock==3.14.0
# via virtualenv
-identify==2.5.35
+identify==2.5.36
# via pre-commit
-nodeenv==1.8.0
+nodeenv==1.9.0
# via pre-commit
-platformdirs==4.2.0
+platformdirs==4.2.2
# via virtualenv
-pre-commit==3.7.0
+pre-commit==3.7.1
# via -r static.in
pyyaml==6.0.1
# via pre-commit
-virtualenv==20.25.1
+virtualenv==20.26.2
# via pre-commit
-
-# The following packages are considered to be unsafe in a requirements file:
-# setuptools
diff --git a/requirements/test-dask.txt b/requirements/test-dask.txt
index 6a4f2eeb..1362167d 100644
--- a/requirements/test-dask.txt
+++ b/requirements/test-dask.txt
@@ -10,15 +10,15 @@ click==8.1.7
# via dask
cloudpickle==3.0.0
# via dask
-dask==2024.4.1
+dask==2024.5.1
# via -r test-dask.in
-fsspec==2024.3.1
+fsspec==2024.5.0
# via dask
importlib-metadata==7.1.0
# via dask
locket==1.0.0
# via partd
-partd==1.4.1
+partd==1.4.2
# via dask
pyyaml==6.0.1
# via dask
@@ -26,5 +26,5 @@ toolz==0.12.1
# via
# dask
# partd
-zipp==3.18.1
+zipp==3.19.0
# via importlib-metadata
diff --git a/requirements/wheels.txt b/requirements/wheels.txt
index ff60a184..d1a95de4 100644
--- a/requirements/wheels.txt
+++ b/requirements/wheels.txt
@@ -9,9 +9,7 @@ build==1.2.1
# via -r wheels.in
packaging==24.0
# via build
-pyproject-hooks==1.0.0
+pyproject-hooks==1.1.0
# via build
tomli==2.0.1
- # via
- # build
- # pyproject-hooks
+ # via build
diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py
index 5a87ae80..8e2ec6c6 100644
--- a/src/sciline/__init__.py
+++ b/src/sciline/__init__.py
@@ -17,17 +17,12 @@
HandleAsComputeTimeException,
UnsatisfiedRequirement,
)
-from .param_table import ParamTable
-from .pipeline import AmbiguousProvider, Pipeline
-from .series import Series
+from .pipeline import Pipeline
from .task_graph import TaskGraph
__all__ = [
- "AmbiguousProvider",
- "ParamTable",
"Pipeline",
"scheduler",
- "Series",
"Scope",
"ScopeTwoParams",
'TaskGraph',
diff --git a/src/sciline/_provider.py b/src/sciline/_provider.py
index d21e8907..e999d474 100644
--- a/src/sciline/_provider.py
+++ b/src/sciline/_provider.py
@@ -11,6 +11,7 @@
Any,
Callable,
Generator,
+ Hashable,
Literal,
Optional,
TypeVar,
@@ -138,14 +139,6 @@ def bind_type_vars(self, bound: dict[TypeVar, Key]) -> Provider:
kind=self._kind,
)
- def map_arg_keys(self, transform: Callable[[Key], Key]) -> Provider:
- """Return a new provider with transformed argument keys."""
- return Provider(
- func=self._func,
- arg_spec=self._arg_spec.map_keys(transform),
- kind=self._kind,
- )
-
def __str__(self) -> str:
return f"Provider('{self.location.name}')"
@@ -155,7 +148,7 @@ def __repr__(self) -> str:
f"func={self._func})"
)
- def call(self, values: dict[Key, Any]) -> Any:
+ def call(self, values: dict[Hashable, Any]) -> Any:
"""Call the provider with arguments extracted from ``values``."""
return self._func(
*(values[arg] for arg in self._arg_spec.args),
@@ -170,10 +163,19 @@ def __init__(
self, *, args: dict[str, Key], kwargs: dict[str, Key], return_: Optional[Key]
) -> None:
"""Build from components, use dedicated creation functions instead."""
+ # Duplicate type hints could be allowed in principle, but it makes structure
+ # analysis and checks more difficult and error prone. As there is likely
+ # little utility in supporting this, we disallow it.
+ if len(set(args.values()) | set(kwargs.values())) != len(args) + len(kwargs):
+ raise ValueError("Duplicate type hints found in args and/or kwargs")
self._args = args
self._kwargs = kwargs
self._return = return_
+ def __len__(self) -> int:
+ """Number of args and kwargs, not counting return value."""
+ return len(self._args) + len(self._kwargs)
+
@classmethod
def from_function(cls, provider: ToProvider) -> ArgSpec:
"""Parse the argument spec of a provider."""
@@ -227,7 +229,7 @@ def map_keys(self, transform: Callable[[Key], Key]) -> ArgSpec:
return ArgSpec(
args={name: transform(arg) for name, arg in self._args.items()},
kwargs={name: transform(arg) for name, arg in self._kwargs.items()},
- return_=self._return,
+ return_=self._return if self._return is None else transform(self._return),
)
diff --git a/src/sciline/_utils.py b/src/sciline/_utils.py
index 024edb57..c8c1ca0f 100644
--- a/src/sciline/_utils.py
+++ b/src/sciline/_utils.py
@@ -5,7 +5,7 @@
from typing import Any, Callable, DefaultDict, Iterable, TypeVar, Union, get_args
from ._provider import Provider
-from .typing import Item, Key
+from .typing import Key
T = TypeVar('T')
G = TypeVar('G')
@@ -31,11 +31,6 @@ def full_qualname(obj: Any) -> str:
def key_name(key: Union[Key, TypeVar]) -> str:
- if isinstance(key, Item):
- parameters = ", ".join(
- f'{key_name(label.tp)}:{label.index}' for label in key.label
- )
- return f'{key_name(key.tp)}({parameters})'
args = get_args(key)
if len(args):
parameters = ', '.join(map(key_name, args))
@@ -43,15 +38,12 @@ def key_name(key: Union[Key, TypeVar]) -> str:
return f'{getattr(key, "__name__", "
")}[{parameters}]'
if isinstance(key, TypeVar):
return str(key)
- return key.__name__
+ if hasattr(key, '__name__'):
+ return key.__name__
+ return str(key)
def key_full_qualname(key: Union[Key, TypeVar]) -> str:
- if isinstance(key, Item):
- parameters = ", ".join(
- f'{key_full_qualname(label.tp)}:{label.index}' for label in key.label
- )
- return f'{key_full_qualname(key.tp)}({parameters})'
args = get_args(key)
if len(args):
origin = key.__origin__ # type: ignore[union-attr] # key is a TypeVar
diff --git a/src/sciline/data_graph.py b/src/sciline/data_graph.py
new file mode 100644
index 00000000..b2c778e5
--- /dev/null
+++ b/src/sciline/data_graph.py
@@ -0,0 +1,205 @@
+# SPDX-License-Identifier: BSD-3-Clause
+# Copyright (c) 2024 Scipp contributors (https://github.com/scipp)
+from __future__ import annotations
+
+import itertools
+from collections.abc import Iterable
+from types import NoneType
+from typing import Any, Callable, Generator, TypeVar, Union, get_args
+
+import cyclebane as cb
+import networkx as nx
+
+from ._provider import ArgSpec, Provider, ToProvider, _bind_free_typevars
+from .handler import ErrorHandler, HandleAsBuildTimeException
+from .typing import Graph, Key
+
+
+def _as_graph(key: Key, value: Any) -> cb.Graph:
+ """Create a cyclebane.Graph with a single value."""
+ graph = nx.DiGraph()
+ graph.add_node(key, value=value)
+ return cb.Graph(graph)
+
+
+def _find_all_typevars(t: type | TypeVar) -> set[TypeVar]:
+ """Returns the set of all TypeVars in a type expression."""
+ if isinstance(t, TypeVar):
+ return {t}
+ return set(itertools.chain(*map(_find_all_typevars, get_args(t))))
+
+
+def _get_typevar_constraints(t: TypeVar) -> set[type]:
+ """Returns the set of constraints of a TypeVar."""
+ return set(t.__constraints__)
+
+
+def _mapping_to_constrained(
+ type_vars: set[TypeVar],
+) -> Generator[dict[TypeVar, type], None, None]:
+ constraints = [_get_typevar_constraints(t) for t in type_vars]
+ if any(len(c) == 0 for c in constraints):
+ raise ValueError('Typevars must have constraints')
+ for combination in itertools.product(*constraints):
+ yield dict(zip(type_vars, combination, strict=True))
+
+
+T = TypeVar('T', bound='DataGraph')
+
+
+class DataGraph:
+ def __init__(self, providers: None | Iterable[ToProvider | Provider]) -> None:
+ self._cbgraph = cb.Graph(nx.DiGraph())
+ for provider in providers or []:
+ self.insert(provider)
+
+ @classmethod
+ def _from_cyclebane(cls: type[T], graph: cb.Graph) -> T:
+ out = cls([])
+ out._cbgraph = graph
+ return out
+
+ def copy(self: T) -> T:
+ return self._from_cyclebane(self._cbgraph.copy())
+
+ def __copy__(self: T) -> T:
+ return self.copy()
+
+ @property
+ def _graph(self) -> nx.DiGraph:
+ return self._cbgraph.graph
+
+ def _get_clean_node(self, key: Key) -> Any:
+ """Return node ready for setting value or provider."""
+ if key is NoneType:
+ raise ValueError('Key must not be None')
+ if key in self._graph:
+ self._graph.remove_edges_from(list(self._graph.in_edges(key)))
+ self._graph.nodes[key].pop('value', None)
+ self._graph.nodes[key].pop('provider', None)
+ self._graph.nodes[key].pop('reduce', None)
+ else:
+ self._graph.add_node(key)
+ return self._graph.nodes[key]
+
+ def insert(self, provider: Union[ToProvider, Provider], /) -> None:
+ """
+ Insert a callable into the graph that provides its return value.
+
+ Parameters
+ ----------
+ provider:
+ Either a callable that provides its return value. Its arguments
+ and return value must be annotated with type hints.
+ Or a ``Provider`` object that has been constructed from such a callable.
+ """
+ if not isinstance(provider, Provider):
+ provider = Provider.from_function(provider)
+ return_type = provider.deduce_key()
+ if typevars := _find_all_typevars(return_type):
+ for bound in _mapping_to_constrained(typevars):
+ self.insert(provider.bind_type_vars(bound))
+ return
+ # Trigger UnboundTypeVar error if any input typevars are not bound
+ provider = provider.bind_type_vars({})
+ self._get_clean_node(return_type)['provider'] = provider
+ for dep in provider.arg_spec.keys():
+ self._graph.add_edge(dep, return_type, key=dep)
+
+ def __setitem__(self, key: Key, value: DataGraph | Any) -> None:
+ """
+ Provide a concrete value for a type.
+
+ Parameters
+ ----------
+ key:
+ Type to provide a value for.
+ param:
+ Concrete value to provide.
+ """
+ # This is a questionable approach: Using MyGeneric[T] as a key will actually
+ # not pass mypy [valid-type] checks. What we do on our side is ok, but the
+ # calling code is not.
+ if typevars := _find_all_typevars(key):
+ for bound in _mapping_to_constrained(typevars):
+ self[_bind_free_typevars(key, bound)] = value
+ return
+
+ # TODO If key is generic, should we support multi-sink case and update all?
+ # Would imply that we need the same for __getitem__.
+ self._cbgraph[key] = (
+ value._cbgraph if isinstance(value, DataGraph) else _as_graph(key, value)
+ )
+
+ def __getitem__(self: T, key: Key) -> T:
+ return self._from_cyclebane(self._cbgraph[key])
+
+ def map(self: T, node_values: dict[Key, Any]) -> T:
+ return self._from_cyclebane(self._cbgraph.map(node_values))
+
+ def reduce(self: T, *, func: Callable[..., Any], **kwargs: Any) -> T:
+ # Note that the type hints of `func` are not checked here. As we are explicit
+ # about the modification, this is in line with __setitem__ which does not
+ # perform such checks and allows for using generic reduction functions.
+ return self._from_cyclebane(
+ self._cbgraph.reduce(attrs={'reduce': func}, **kwargs)
+ )
+
+ def to_networkx(self) -> nx.DiGraph:
+ return self._cbgraph.to_networkx()
+
+ def visualize_data_graph(
+ self, **kwargs: Any
+ ) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821
+ import graphviz
+
+ dot = graphviz.Digraph(strict=True, **kwargs)
+ for node in self._graph.nodes:
+ dot.node(str(node), label=str(node), shape='box')
+ attrs = self._graph.nodes[node]
+ attrs = '\n'.join(f'{k}={v}' for k, v in attrs.items())
+ dot.node(str(node), label=f'{node}\n{attrs}', shape='box')
+ for edge in self._graph.edges:
+ key = self._graph.edges[edge].get('key')
+ label = str(key) if key is not None else ''
+ dot.edge(str(edge[0]), str(edge[1]), label=label)
+ return dot
+
+
+_no_value = object()
+
+
+def to_task_graph(
+ data_graph: DataGraph, targets: tuple[Key, ...], handler: ErrorHandler | None = None
+) -> Graph:
+ graph = data_graph.to_networkx()
+ handler = handler or HandleAsBuildTimeException()
+ ancestors = list(targets)
+ for target in targets:
+ if target not in graph:
+ handler.handle_unsatisfied_requirement(target)
+ ancestors.extend(nx.ancestors(graph, target))
+ graph = graph.subgraph(set(ancestors))
+ out = {}
+
+ for key in graph.nodes:
+ node = graph.nodes[key]
+ input_nodes = list(graph.predecessors(key))
+ input_edges = list(graph.in_edges(key, data=True))
+ orig_keys = [edge[2].get('key', None) for edge in input_edges]
+ if (value := node.get('value', _no_value)) is not _no_value:
+ out[key] = Provider.parameter(value)
+ elif (provider := node.get('provider')) is not None:
+ new_key = dict(zip(orig_keys, input_nodes))
+ spec = provider.arg_spec.map_keys(new_key.get)
+ if len(spec) != len(input_nodes):
+ # This should be caught by __setitem__, but we check here to be safe.
+ raise ValueError("Corrupted graph")
+ # TODO also kwargs
+ out[key] = Provider(func=provider.func, arg_spec=spec, kind='function')
+ elif (func := node.get('reduce')) is not None:
+ spec = ArgSpec.from_args(*input_nodes)
+ out[key] = Provider(func=func, arg_spec=spec, kind='function')
+ else:
+ out[key] = handler.handle_unsatisfied_requirement(key)
+ return out
diff --git a/src/sciline/display.py b/src/sciline/display.py
index 4a066e6b..770a982a 100644
--- a/src/sciline/display.py
+++ b/src/sciline/display.py
@@ -1,11 +1,10 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from html import escape
-from typing import Iterable, List, Tuple, TypeVar, Union
+from typing import Any, Iterable, Tuple
-from ._provider import Provider
-from ._utils import groupby, key_name
-from .typing import Item, Key
+from ._utils import key_name
+from .typing import Key
def _details(summary: str, body: str) -> str:
@@ -17,27 +16,13 @@ def _details(summary: str, body: str) -> str:
'''
-def _provider_name(
- p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]]
-) -> str:
- key, args, _ = p
- if args:
- # This is always the case, but mypy complains
- if hasattr(key, '__getitem__'):
- return escape(key_name(key[args]))
- return escape(key_name(key))
+def _provider_name(p: Tuple[Key, dict[str, Any]]) -> str:
+ return escape(key_name(p[0]))
-def _provider_source(
- p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]]
-) -> str:
- key, _, (v, *rest) = p
- if v.kind == 'table_cell':
- # This is always the case, but mypy complains
- if isinstance(key, Item):
- return escape(
- f'ParamTable({key_name(key.label[0].tp)}, length={len((v, *rest))})'
- )
+def _provider_source(data: dict[str, Any]) -> str:
+ if (v := data.get('provider', None)) is None:
+ return ''
if v.kind == 'function':
return _details(
escape(v.location.name),
@@ -46,46 +31,23 @@ def _provider_source(
return ''
-def _provider_value(
- p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]]
-) -> str:
- _, _, (v, *_) = p
- if v.kind == 'parameter':
- html = escape(str(v.call({}))).strip()
- return _details(f'{html[:30]}...', html) if len(html) > 30 else html
- return ''
-
+def _provider_value(data: dict[str, Any]) -> str:
+ if (value := data.get('value', None)) is None:
+ return ''
+ html = escape(str(value)).strip()
+ return _details(f'{html[:30]}...', html) if len(html) > 30 else html
-def pipeline_html_repr(
- providers: Iterable[Tuple[Key, Tuple[Union[Key, TypeVar], ...], Provider]]
-) -> str:
- def associate_table_values(
- p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], Provider]
- ) -> Tuple[Key, Union[type, Tuple[Union[Key, TypeVar], ...]]]:
- key, args, v = p
- if isinstance(key, Item):
- return (key.label[0].tp, key.tp)
- return (key, args)
- providers_collected = (
- (key, args, [value, *(v for _, _, v in rest)])
- for ((key, args, value), *rest) in groupby(
- associate_table_values,
- providers,
- ).values()
- )
+def pipeline_html_repr(nodes: Iterable[Tuple[Key, dict[str, Any]]]) -> str:
provider_rows = '\n'.join(
(
f'''
- {_provider_name(p)} |
- {_provider_value(p)} |
- {_provider_source(p)}
+ | {_provider_name(item)} |
+ {_provider_value(item[1])} |
+ {_provider_source(item[1])}
|
'''
- for p in sorted(
- providers_collected,
- key=_provider_name,
- )
+ for item in sorted(nodes, key=_provider_name)
)
)
return f'''
diff --git a/src/sciline/param_table.py b/src/sciline/param_table.py
deleted file mode 100644
index 96e76f5f..00000000
--- a/src/sciline/param_table.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# SPDX-License-Identifier: BSD-3-Clause
-# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
-from __future__ import annotations
-
-from typing import Any, Collection, Dict, Mapping, Optional
-
-
-class ParamTable(Mapping[type, Collection[Any]]):
- """A table of parameters with a row index and named row dimension."""
-
- def __init__(
- self,
- row_dim: type,
- columns: Dict[type, Collection[Any]],
- *,
- index: Optional[Collection[Any]] = None,
- ):
- """
- Create a new param table.
-
- Parameters
- ----------
- row_dim:
- The row dimension. This must be a type or a type-alias (not an instance),
- and is used by :py:class:`sciline.Pipeline` to identify each parameter
- table.
- columns:
- The columns of the table. The keys (column names) must be types or type-
- aliases matching the values in the respective columns.
- index:
- The row index of the table. If not given, a default index will be
- generated, as the integer range of the column length.
- """
- sizes = set(len(v) for v in columns.values())
- if len(sizes) > 1:
- raise ValueError(
- f"Columns in param table must all have same size, got {sizes}"
- )
- if sizes:
- size = sizes.pop()
- elif index is None:
- raise ValueError("Cannot create param table with zero columns and no index")
- else:
- size = len(index)
- if index is not None:
- if len(index) != size:
- raise ValueError(
- f"Index length not matching columns, got {len(index)} and {size}"
- )
- if len(set(index)) != len(index):
- raise ValueError(f"Index must be unique, got {index}")
- self._row_dim = row_dim
- self._columns = columns
- self._index = index or list(range(size))
-
- @property
- def row_dim(self) -> type:
- """The row dimension of the table."""
- return self._row_dim
-
- @property
- def index(self) -> Collection[Any]:
- """The row index of the table."""
- return self._index
-
- def __contains__(self, key: Any) -> bool:
- return self._columns.__contains__(key)
-
- def __getitem__(self, key: Any) -> Any:
- return self._columns.__getitem__(key)
-
- def __iter__(self) -> Any:
- return self._columns.__iter__()
-
- def __len__(self) -> int:
- return self._columns.__len__()
-
- def __repr__(self) -> str:
- return f"ParamTable(row_dim={self.row_dim.__name__}, columns={self._columns})"
-
- def _repr_html_(self) -> str:
- return (
- f"{self.row_dim.__name__} | "
- + "".join(
- f"{getattr(k, '__name__', str(k).split('.')[-1])} | "
- for k in self._columns.keys()
- )
- + "
"
- + "".join(
- f"{idx} | " + "".join(f"{v} | " for v in row) + "
"
- for idx, row in zip(self.index, zip(*self._columns.values()))
- )
- + "
"
- )
diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py
index 1ae2b270..3e20cd59 100644
--- a/src/sciline/pipeline.py
+++ b/src/sciline/pipeline.py
@@ -2,155 +2,36 @@
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from __future__ import annotations
-from collections import defaultdict
from collections.abc import Iterable
from itertools import chain
from types import UnionType
from typing import (
Any,
Callable,
- Collection,
Dict,
- Generic,
- List,
- Mapping,
Optional,
- Set,
Tuple,
Type,
TypeVar,
Union,
get_args,
- get_origin,
get_type_hints,
overload,
)
-from sciline.task_graph import TaskGraph
-
-from ._provider import ArgSpec, Provider, ProviderLocation, ToProvider
-from ._utils import key_name
+from ._provider import Provider, ToProvider
+from .data_graph import DataGraph, to_task_graph
from .display import pipeline_html_repr
-from .handler import (
- ErrorHandler,
- HandleAsBuildTimeException,
- HandleAsComputeTimeException,
- UnsatisfiedRequirement,
-)
-from .param_table import ParamTable
+from .handler import ErrorHandler, HandleAsComputeTimeException
from .scheduler import Scheduler
-from .series import Series
-from .typing import Graph, Item, Key, Label
+from .task_graph import TaskGraph
+from .typing import Key
T = TypeVar('T')
KeyType = TypeVar('KeyType', bound=Key)
-ValueType = TypeVar('ValueType', bound=Key)
-IndexType = TypeVar('IndexType', bound=Key)
-LabelType = TypeVar('LabelType', bound=Key)
-
-
-class AmbiguousProvider(Exception):
- """Raised when multiple providers are found for a type."""
-
-
-def _extract_typevars_from_generic_type(t: type) -> Tuple[TypeVar, ...]:
- """Returns the typevars that were used in the definition of a Generic type."""
- if not hasattr(t, '__orig_bases__'):
- return ()
- return tuple(
- chain(*(get_args(b) for b in t.__orig_bases__ if get_origin(b) == Generic))
- )
-
-
-def _find_all_typevars(t: Union[type, TypeVar]) -> Set[TypeVar]:
- """Returns the set of all TypeVars in a type expression."""
- if isinstance(t, TypeVar):
- return {t}
- return set(chain(*map(_find_all_typevars, get_args(t))))
-
-
-def _find_bounds_to_make_compatible_type(
- requested: Key,
- provided: Key | TypeVar,
-) -> Optional[Dict[TypeVar, Key]]:
- """
- Check if a type is compatible to a provided type.
- If the types are compatible, return a mapping from typevars to concrete types
- that makes the provided type equal to the requested type.
- """
- if provided == requested:
- ret: Dict[TypeVar, Key] = {}
- return ret
- if isinstance(provided, TypeVar):
- # If the type var has no constraints, accept anything
- if not provided.__constraints__:
- return {provided: requested}
- for c in provided.__constraints__:
- if _find_bounds_to_make_compatible_type(requested, c) is not None:
- return {provided: requested}
- if get_origin(provided) is not None:
- if get_origin(provided) == get_origin(requested):
- return _find_bounds_to_make_compatible_type_tuple(
- get_args(requested), get_args(provided)
- )
- return None
-
-
-def _find_bounds_to_make_compatible_type_tuple(
- requested: tuple[Key, ...],
- provided: tuple[Key | TypeVar, ...],
-) -> Optional[Dict[TypeVar, Key]]:
- """
- Check if a tuple of requested types is compatible with a tuple of provided types
- and return a mapping from type vars to concrete types that makes all provided
- types equal to their corresponding requested type.
- If any of the types is not compatible, return None.
- """
- union: Dict[TypeVar, Key] = {}
- for bound in map(_find_bounds_to_make_compatible_type, requested, provided):
- # If no mapping from the type-var to a concrete type was found,
- # or if the mapping is inconsistent,
- # interrupt the search and report that no compatible types were found.
- if bound is None or any(k in union and union[k] != bound[k] for k in bound):
- return None
- union.update(bound)
- return union
-
-
-def _find_all_paths(
- dependencies: Mapping[Key, Collection[Key]], start: Key, end: Key
-) -> List[List[Key]]:
- """Find all paths from start to end in a DAG."""
- if start == end:
- return [[start]]
- if start not in dependencies:
- return []
- paths = []
- for node in dependencies[start]:
- if start == node:
- continue
- for path in _find_all_paths(dependencies, node, end):
- paths.append([start] + path)
- return paths
-
-
-def _find_nodes_in_paths(graph: Graph, end: Key) -> List[Key]:
- """
- Helper for Pipeline. Finds all nodes that need to be duplicated since they depend
- on a value from a param table.
- """
- start = next(iter(graph))
- dependencies = {k: tuple(p.arg_spec.keys()) for k, p in graph.items()}
- paths = _find_all_paths(dependencies, start, end)
- nodes = set()
- for path in paths:
- nodes.update(path)
- return list(nodes)
-def _is_multiple_keys(
- keys: type | Iterable[type] | Item[T] | object,
-) -> bool:
+def _is_multiple_keys(keys: type | Iterable[type] | UnionType) -> bool:
# Cannot simply use isinstance(keys, Iterable) because that is True for
# generic aliases of iterable types, e.g.,
#
@@ -159,196 +40,14 @@ def _is_multiple_keys(
#
# And isinstance(keys, type) does not work on its own because
# it is False for the above type.
+ if isinstance(keys, str):
+ return False
return (
not isinstance(keys, type) and not get_args(keys) and isinstance(keys, Iterable)
)
-class ReplicatorBase(Generic[IndexType]):
- def __init__(self, index_name: type, index: Iterable[IndexType], path: List[Key]):
- if len(path) == 0:
- raise UnsatisfiedRequirement(
- 'Could not find path to param in param table. This is likely caused '
- 'by requesting a Series that does not depend directly or transitively '
- 'on any param from a table.'
- )
- self._index_name = index_name
- self.index = index
- self._path = path
-
- def __contains__(self, key: Key) -> bool:
- return key in self._path
-
- def replicate(
- self,
- key: Key,
- provider: Provider,
- get_provider: Callable[..., Tuple[Provider, Dict[TypeVar, Key]]],
- ) -> Graph:
- graph: Graph = {}
- for idx in self.index:
- subkey = self.key(idx, key)
- if isinstance(provider, _ParamSentinel):
- graph[subkey] = get_provider(subkey)[0]
- else:
- graph[subkey] = self._copy_node(key, provider, idx)
- return graph
-
- def _copy_node(
- self,
- key: Key,
- provider: Union[Provider, SeriesProvider[IndexType]],
- idx: IndexType,
- ) -> Provider:
- return provider.map_arg_keys(
- lambda arg: self.key(idx, arg) if arg in self else arg
- )
-
- def key(self, i: IndexType, value_name: Union[Type[T], Item[T]]) -> Item[T]:
- label = Label(self._index_name, i)
- if isinstance(value_name, Item):
- return Item(value_name.label + (label,), value_name.tp)
- else:
- return Item((label,), value_name)
-
-
-class Replicator(ReplicatorBase[IndexType]):
- r"""
- Helper for rewriting the graph to map over a given index.
-
- See Pipeline._build_series for context. Given a graph template, this makes a
- transformation as follows:
-
- S P1[0] P1[1] P1[2]
- | | | |
- A -> A[0] A[1] A[2]
- | | | |
- B B[0] B[1] B[2]
-
- Where S is a sentinel value, P1 are parameters from a parameter table, 0,1,2
- are indices of the param table rows, and A and B are arbitrary nodes in the graph.
- """
-
- def __init__(self, param_table: ParamTable, graph_template: Graph) -> None:
- index_name = param_table.row_dim
- super().__init__(
- index_name=index_name,
- index=param_table.index,
- path=_find_nodes_in_paths(graph_template, index_name),
- )
-
-
-class GroupingReplicator(ReplicatorBase[LabelType], Generic[IndexType, LabelType]):
- r"""
- Helper for rewriting the graph to group by a given index.
-
- See Pipeline._build_series for context. Given a graph template, this makes a
- transformation as follows:
-
- P1[0] P1[1] P1[2] P1[0] P1[1] P1[2]
- | | | | | |
- A[0] A[1] A[2] A[0] A[1] A[2]
- | | | | | |
- B[0] B[1] B[2] -> B[0] B[1] B[2]
- \______|______/ \______/ |
- | | |
- SB SB[x] SB[y]
- | | |
- C C[x] C[y]
-
- Where SB is Series[Idx,B]. Here, the upper half of the graph originates from a
- prior transformation of a graph template using `Replicator`. The output of this
- combined with further nodes is the graph template passed to this class. x and y
- are the labels used in a grouping operation, based on the values of a ParamTable
- column P2.
- """
-
- def __init__(
- self, param_table: ParamTable, graph_template: Graph, label_name: type
- ) -> None:
- self._label_name = label_name
- self._group_node = self._find_grouping_node(param_table.row_dim, graph_template)
- self._groups: Dict[LabelType, List[IndexType]] = defaultdict(list)
- for idx, label in zip(param_table.index, param_table[label_name]):
- self._groups[label].append(idx)
- super().__init__(
- index_name=label_name,
- index=self._groups,
- path=_find_nodes_in_paths(graph_template, self._group_node),
- )
-
- def _copy_node(
- self,
- key: Key,
- provider: Union[Provider, SeriesProvider[IndexType]],
- idx: LabelType,
- ) -> Provider:
- if (not isinstance(provider, SeriesProvider)) or key != self._group_node:
- return super()._copy_node(key, provider, idx)
- labels = self._groups[idx]
- if set(labels) - set(provider.labels):
- raise ValueError(f'{labels} is not a subset of {provider.labels}')
- if tuple(provider.arg_spec.kwargs):
- raise RuntimeError(
- 'A Series was provided with keyword arguments. This should not happen '
- 'and is an internal error of Sciline.'
- )
- selected = {
- label: arg
- for label, arg in zip(provider.labels, provider.arg_spec.args)
- if label in labels
- }
- return SeriesProvider(selected.keys(), provider.row_dim, args=selected.values())
-
- def _find_grouping_node(self, index_name: Key, subgraph: Graph) -> type:
- ends: List[type] = []
- for key in subgraph:
- if get_origin(key) == Series and get_args(key)[0] == index_name:
- # Because of the succeeded get_origin we know it is a type
- ends.append(key) # type: ignore[arg-type]
- if len(ends) == 1:
- return ends[0]
- raise ValueError(f"Could not find unique grouping node, found {ends}")
-
-
-class SeriesProvider(Generic[KeyType], Provider):
- """
- Internal provider for combining results obtained based on different rows in a
- param table into a single object.
- """
-
- def __init__(
- self,
- labels: Iterable[KeyType],
- row_dim: Type[KeyType],
- *,
- args: Optional[Iterable[Key]] = None,
- ) -> None:
- super().__init__(
- func=self._call,
- arg_spec=ArgSpec.from_args(*(args if args is not None else labels)),
- kind='series',
- )
- self.labels = labels
- self.row_dim = row_dim
-
- def _call(self, *vals: ValueType) -> Series[KeyType, ValueType]:
- return Series(self.row_dim, dict(zip(self.labels, vals)))
-
-
-class _ParamSentinel(Provider):
- def __init__(self, key: Key) -> None:
- super().__init__(
- func=lambda: None,
- arg_spec=ArgSpec.from_args(key),
- kind='sentinel',
- location=ProviderLocation(
- name=f'param_sentinel({type(key).__name__})', module='sciline'
- ),
- )
-
-
-class Pipeline:
+class Pipeline(DataGraph):
"""A container for providers that can be assembled into a task graph."""
def __init__(
@@ -368,348 +67,10 @@ def __init__(
params:
Dictionary of concrete values to provide for types.
"""
- self._providers: Dict[Key, Provider] = {}
- self._subproviders: Dict[type, Dict[Tuple[Key | TypeVar, ...], Provider]] = {}
- self._param_tables: Dict[Key, ParamTable] = {}
- self._param_name_to_table_key: Dict[Key, Key] = {}
- for provider in providers or []:
- self.insert(provider)
+ super().__init__(providers)
for tp, param in (params or {}).items():
self[tp] = param
- def insert(self, provider: Union[ToProvider, Provider], /) -> None:
- """
- Add a callable that provides its return value to the pipeline.
-
- Parameters
- ----------
- provider:
- Either a callable that provides its return value. Its arguments
- and return value must be annotated with type hints.
- Or a ``Provider`` object that has been constructed from such a callable.
- """
- if not isinstance(provider, Provider):
- provider = Provider.from_function(provider)
- self._set_provider(provider.deduce_key(), provider)
-
- def __setitem__(self, key: Type[T], param: T) -> None:
- """
- Provide a concrete value for a type.
-
- Parameters
- ----------
- key:
- Type to provide a value for.
- param:
- Concrete value to provide.
- """
- self._set_provider(key, Provider.parameter(param))
-
- def set_param_table(self, params: ParamTable) -> None:
- """
- Set a parameter table for a row dimension.
-
- Values in the parameter table provide concrete values for a type given by the
- respective column header.
-
- A pipeline can have multiple parameter tables, but only one per row dimension.
- Column names must be unique across all parameter tables.
-
- Parameters
- ----------
- params:
- Parameter table to set.
- """
- for param_name in params:
- if (existing := self._param_name_to_table_key.get(param_name)) is not None:
- if (
- existing == params.row_dim
- and param_name in self._param_tables[existing]
- ):
- # Column will be removed by del_param_table below, clash is ok
- continue
- raise ValueError(f'Parameter {param_name} already set')
- if params.row_dim in self._param_tables:
- self.del_param_table(params.row_dim)
- self._param_tables[params.row_dim] = params
- for param_name in params:
- self._param_name_to_table_key[param_name] = params.row_dim
- for param_name, values in params.items():
- for index, label in zip(params.index, values):
- self._set_provider(
- Item((Label(tp=params.row_dim, index=index),), param_name),
- Provider.table_cell(label),
- )
- for index, label in zip(params.index, params.index):
- self._set_provider(
- Item((Label(tp=params.row_dim, index=index),), params.row_dim),
- Provider.table_cell(label),
- )
-
- def del_param_table(self, row_dim: type) -> None:
- """
- Remove a parameter table.
-
- Parameters
- ----------
- row_dim:
- Row dimension of the parameter table to remove.
- """
- # 1. Remove providers pointing to table cells
- params = self._param_tables[row_dim]
- for index in params.index:
- label = (Label(tp=row_dim, index=index),)
- for param_name in params:
- del self._providers[Item(label, param_name)]
- del self._providers[Item(label, row_dim)]
- # 2. Remove column to table mapping
- for param_name in list(self._param_name_to_table_key):
- if self._param_name_to_table_key[param_name] == row_dim:
- del self._param_name_to_table_key[param_name]
- # 3. Remove table
- del self._param_tables[row_dim]
-
- def set_param_series(self, row_dim: type, index: Collection[Any]) -> None:
- """
- Set a series of parameters.
-
- This is a convenience method for creating and setting a parameter table with
- no columns and an index given by `index`.
-
- Parameters
- ----------
- row_dim:
- Row dimension of the parameter table to set.
- index:
- Index of the parameter table to set.
- """
- self.set_param_table(ParamTable(row_dim, columns={}, index=index))
-
- def _set_provider(
- self,
- key: Key,
- provider: Provider,
- ) -> None:
- # isinstance does not work here and types.NoneType available only in 3.10+
- if key == type(None): # noqa: E721
- raise ValueError(f'Provider {provider} returning `None` is not allowed')
- if get_origin(key) == Series:
- raise ValueError(
- f'Provider {provider} returning a sciline.Series is not allowed. '
- 'Series is a special container reserved for use in conjunction with '
- 'sciline.ParamTable and must not be provided directly.'
- )
- if (origin := get_origin(key)) is not None:
- subproviders = self._subproviders.setdefault(origin, {})
- args = get_args(key)
- subproviders[args] = provider
- else:
- self._providers[key] = provider
-
- def _get_provider(
- self, tp: Union[Type[T], Item[T]], handler: Optional[ErrorHandler] = None
- ) -> Tuple[Provider, Dict[TypeVar, Key]]:
- handler = handler or HandleAsBuildTimeException()
- explanation: List[str] = []
- if (provider := self._providers.get(tp)) is not None:
- return provider, {}
- elif (origin := get_origin(tp)) is not None and (
- subproviders := self._subproviders.get(origin)
- ) is not None:
- requested = get_args(tp)
- matches = [
- (subprovider, bound)
- for args, subprovider in subproviders.items()
- if (
- bound := _find_bounds_to_make_compatible_type_tuple(requested, args)
- )
- is not None
- ]
- typevar_counts = [len(bound) for _, bound in matches]
- min_typevar_count = min(typevar_counts, default=0)
- matches = [
- m
- for count, m in zip(typevar_counts, matches)
- if count == min_typevar_count
- ]
-
- if len(matches) == 1:
- provider, bound = matches[0]
- return provider, bound
- elif len(matches) > 1:
- matching_providers = [provider.location.name for provider, _ in matches]
- raise AmbiguousProvider(
- f"Multiple providers found for type {tp}."
- f" Matching providers are: {matching_providers}."
- )
- else:
- typevars_in_expression = _extract_typevars_from_generic_type(origin)
- if typevars_in_expression:
- explanation = [
- ''.join(
- map(
- str,
- (
- 'Note that ',
- key_name(origin[typevars_in_expression]),
- ' has constraints ',
- (
- {
- key_name(tv): tuple(
- map(key_name, tv.__constraints__)
- )
- for tv in typevars_in_expression
- }
- ),
- ),
- )
- )
- ]
- return handler.handle_unsatisfied_requirement(tp, *explanation), {}
-
- def build(
- self,
- tp: Union[Type[T], Item[T]],
- /,
- *,
- handler: ErrorHandler,
- search_param_tables: bool = False,
- ) -> Graph:
- """
- Return a dict of providers required for building the requested type `tp`.
-
- This is mainly for internal and low-level use. Prefer using :py:meth:`get`.
-
- The values are tuples containing the provider and the dict of arguments for
- the provider. The values in the latter dict reference other keys in the returned
- graph.
-
- Parameters
- ----------
- tp:
- Type to build the graph for.
- search_param_tables:
- Whether to search parameter tables for concrete keys.
- """
- graph: Graph = {}
- stack: List[Union[Type[T], Item[T]]] = [tp]
- while stack:
- tp = stack.pop()
- # First look in column labels of param tables
- if search_param_tables and tp in self._param_name_to_table_key:
- graph[tp] = _ParamSentinel(self._param_name_to_table_key[tp])
- continue
- # Then also indices of param tables. This comes second because we need to
- # prefer column labels over indices for multi-level grouping.
- if search_param_tables and tp in self._param_tables:
- graph[tp] = _ParamSentinel(tp)
- continue
- if get_origin(tp) == Series:
- sub = self._build_series(tp, handler=handler) # type: ignore[arg-type]
- graph.update(sub)
- continue
- provider, bound = self._get_provider(tp, handler=handler)
- provider = provider.bind_type_vars(bound)
- graph[tp] = provider
- stack.extend(provider.arg_spec.keys() - graph.keys())
- return graph
-
- def _build_series(
- self, tp: Type[Series[KeyType, ValueType]], handler: ErrorHandler
- ) -> Graph:
- """
- Build (sub)graph for a Series type implementing ParamTable-based functionality.
-
- We illustrate this with an example. Given a ParamTable with row_dim 'Idx':
-
- Idx | P1 | P2
- 0 | a | x
- 1 | b | x
- 2 | c | y
-
- and providers for A depending on P1 and B depending on A. Calling
- build(Series[Idx,B]) will call _build_series(Series[Idx,B]). This results in
- the following procedure here:
-
- 1. Call build(P1), resulting, e.g., in a graph S->A->B, where S is a sentinel.
- The sentinel is used because build() cannot find a unique P1, since it is
- not a single value but a column in a table.
- 2. Instantiation of `Replicator`, which will be used to replicate the
- relevant parts of the graph (see illustration there).
- 3. Insert a special `SeriesProvider` node, which will gather the duplicates of
- the 'B' node and provides the requested Series[Idx,B].
- 4. Replicate the graph. Nodes that do not directly or indirectly depend on P1
- are not replicated.
-
- Conceptually, the final result will be {
- 0: B(A(a)),
- 1: B(A(b)),
- 2: B(A(c))
- }.
-
- In more complex cases, we may be dealing with multiple levels of Series,
- which is used for grouping operations. Consider the above example, but with
- and additional provider for C depending on Series[Idx,B]. Calling
- build(Series[P2,C]) will call _build_series(Series[P2,C]). This results in
- the following procedure here:
-
- a. Call build(C), which results in the procedure above, i.e., a nested call
- to _build_series(Series[Idx,B]) and the resulting graph as explained above.
- b. Instantiation of `GroupingReplicator`, which will be used to replicate the
- relevant parts of the graph (see illustration there).
- c. Insert a special `SeriesProvider` node, which will gather the duplicates of
- the 'C' node and providers the requested Series[P2,C].
- c. Replicate the graph. Nodes that do not directly or indirectly depend on
- the special `SeriesProvider` node (from step 3.) are not replicated.
-
- Conceptually, the final result will be {
- x: C({
- 0: B(A(a)),
- 1: B(A(b))
- }),
- y: C({
- 2: B(A(c))
- })
- }.
- """
- index_name: Type[KeyType]
- value_type: Type[ValueType]
- index_name, value_type = get_args(tp)
-
- subgraph = self.build(value_type, search_param_tables=True, handler=handler)
-
- replicator: ReplicatorBase[KeyType]
- if (
- # For multi-level grouping a type is an index as well as a column label.
- # In this case we do not want to replicate the graph, but group it (elif).
- index_name not in self._param_name_to_table_key
- and (params := self._param_tables.get(index_name)) is not None
- ):
- replicator = Replicator(param_table=params, graph_template=subgraph)
- elif (table_key := self._param_name_to_table_key.get(index_name)) is not None:
- replicator = GroupingReplicator(
- param_table=self._param_tables[table_key],
- graph_template=subgraph,
- label_name=index_name,
- )
- else:
- raise KeyError(f'No parameter table found for label {index_name}')
-
- graph: Graph = {
- tp: SeriesProvider(
- list(replicator.index),
- index_name,
- args=(replicator.key(idx, value_type) for idx in replicator.index),
- )
- }
-
- for key, provider in subgraph.items():
- if key in replicator:
- graph.update(replicator.replicate(key, provider, self._get_provider))
- else:
- graph[key] = provider
- return graph
-
@overload
def compute(self, tp: Type[T], **kwargs: Any) -> T:
...
@@ -718,17 +79,11 @@ def compute(self, tp: Type[T], **kwargs: Any) -> T:
def compute(self, tp: Iterable[Type[T]], **kwargs: Any) -> Dict[Type[T], T]:
...
- @overload
- def compute(self, tp: Item[T], **kwargs: Any) -> T:
- ...
-
@overload
def compute(self, tp: UnionType, **kwargs: Any) -> Any:
...
- def compute(
- self, tp: type | Iterable[type] | Item[T] | UnionType, **kwargs: Any
- ) -> Any:
+ def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any:
"""
Compute result for the given keys.
@@ -764,7 +119,7 @@ def visualize(
def get(
self,
- keys: type | Iterable[type] | Item[T] | object,
+ keys: type | Iterable[type] | UnionType,
*,
scheduler: Optional[Scheduler] = None,
handler: Optional[ErrorHandler] = None,
@@ -788,16 +143,15 @@ def get(
raises an exception only when the graph is computed. This can be achieved
by passing :py:class:`HandleAsComputeTimeException` as the handler.
"""
- handler = handler or HandleAsBuildTimeException()
- if _is_multiple_keys(keys):
- keys = tuple(keys) # type: ignore[arg-type]
- graph: Graph = {}
- for t in keys: # type: ignore[var-annotated]
- graph.update(self.build(t, handler=handler))
+ if multi := _is_multiple_keys(keys):
+ targets = tuple(keys) # type: ignore[arg-type]
else:
- graph = self.build(keys, handler=handler) # type: ignore[arg-type]
+ targets = (keys,) # type: ignore[assignment]
+ graph = to_task_graph(self, targets=targets, handler=handler)
return TaskGraph(
- graph=graph, targets=keys, scheduler=scheduler # type: ignore[arg-type]
+ graph=graph,
+ targets=targets if multi else keys, # type: ignore[arg-type]
+ scheduler=scheduler,
)
@overload
@@ -854,29 +208,6 @@ def bind_and_call(
return results[0]
return results
- def copy(self) -> Pipeline:
- """
- Make a copy of the pipeline.
- """
- out = Pipeline()
- out._providers = self._providers.copy()
- out._subproviders = {k: v.copy() for k, v in self._subproviders.items()}
- out._param_tables = self._param_tables.copy()
- out._param_name_to_table_key = self._param_name_to_table_key.copy()
- return out
-
- def __copy__(self) -> Pipeline:
- return self.copy()
-
def _repr_html_(self) -> str:
- providers_without_parameters = (
- (origin, tuple(), value) for origin, value in self._providers.items()
- ) # type: ignore[var-annotated]
- providers_with_parameters = (
- (origin, args, value)
- for origin in self._subproviders
- for args, value in self._subproviders[origin].items()
- )
- return pipeline_html_repr(
- chain(providers_without_parameters, providers_with_parameters)
- )
+ nodes = ((key, data) for key, data in self._graph.nodes.items())
+ return pipeline_html_repr(nodes)
diff --git a/src/sciline/scheduler.py b/src/sciline/scheduler.py
index 9d658671..e74f83c0 100644
--- a/src/sciline/scheduler.py
+++ b/src/sciline/scheduler.py
@@ -5,14 +5,14 @@
Any,
Callable,
Dict,
- List,
+ Hashable,
Optional,
Protocol,
Tuple,
runtime_checkable,
)
-from sciline.typing import Graph, Key
+from sciline.typing import Graph
class CycleError(Exception):
@@ -25,7 +25,7 @@ class Scheduler(Protocol):
Scheduler interface compatible with :py:class:`sciline.Pipeline`.
"""
- def get(self, graph: Graph, keys: List[Key]) -> Tuple[Any, ...]:
+ def get(self, graph: Graph, keys: list[Hashable]) -> Tuple[Any, ...]:
"""
Compute the result for given keys from the graph.
@@ -44,7 +44,7 @@ class NaiveScheduler:
:py:class:`DaskScheduler` instead.
"""
- def get(self, graph: Graph, keys: List[Key]) -> Tuple[Any, ...]:
+ def get(self, graph: Graph, keys: list[Hashable]) -> Tuple[Any, ...]:
import graphlib
dependencies = {
@@ -56,7 +56,7 @@ def get(self, graph: Graph, keys: List[Key]) -> Tuple[Any, ...]:
tasks = list(ts.static_order())
except graphlib.CycleError as e:
raise CycleError from e
- results: Dict[Key, Any] = {}
+ results: Dict[Hashable, Any] = {}
for t in tasks:
results[t] = graph[t].call(results)
return tuple(results[key] for key in keys)
@@ -87,7 +87,7 @@ def __init__(self, scheduler: Optional[Callable[..., Any]] = None) -> None:
else:
self._dask_get = scheduler
- def get(self, graph: Graph, keys: List[Key]) -> Any:
+ def get(self, graph: Graph, keys: list[Hashable]) -> Any:
from dask.utils import apply
# Use `apply` to allow passing keyword arguments.
diff --git a/src/sciline/series.py b/src/sciline/series.py
deleted file mode 100644
index 2b5b5607..00000000
--- a/src/sciline/series.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# SPDX-License-Identifier: BSD-3-Clause
-# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
-from __future__ import annotations
-
-from typing import Iterator, Mapping, Type, TypeVar
-
-Key = TypeVar('Key')
-Value = TypeVar('Value')
-
-
-class Series(Mapping[Key, Value]):
- """A series of values with labels (row index) and named row dimension."""
-
- def __init__(self, row_dim: Type[Key], items: Mapping[Key, Value]) -> None:
- """
- Create a new series.
-
- Parameters
- ----------
- row_dim:
- The row dimension. This must be a type or a type-alias (not an instance).
- items:
- The items of the series.
- """
- self._row_dim = row_dim
- self._map: Mapping[Key, Value] = items
-
- @property
- def row_dim(self) -> type:
- """The row dimension of the series."""
- return self._row_dim
-
- def __contains__(self, item: object) -> bool:
- return item in self._map
-
- def __iter__(self) -> Iterator[Key]:
- return iter(self._map)
-
- def __len__(self) -> int:
- return len(self._map)
-
- def __getitem__(self, key: Key) -> Value:
- return self._map[key]
-
- def __repr__(self) -> str:
- return f"Series(row_dim={self.row_dim.__name__}, {self._map})"
-
- def _repr_html_(self) -> str:
- return (
- f"{self.row_dim.__name__} | Value |
"
- + "".join(
- f"{k} | {v} |
" for k, v in self._map.items()
- )
- + "
"
- )
-
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, Series):
- return NotImplemented
- return self.row_dim == other.row_dim and self._map == other._map
diff --git a/src/sciline/task_graph.py b/src/sciline/task_graph.py
index e9662b93..cf80dc35 100644
--- a/src/sciline/task_graph.py
+++ b/src/sciline/task_graph.py
@@ -3,12 +3,12 @@
from __future__ import annotations
from html import escape
-from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union
+from typing import Any, Generator, Hashable, Sequence, TypeVar
from ._utils import key_name
from .scheduler import DaskScheduler, NaiveScheduler, Scheduler
from .serialize import json_serialize_task_graph
-from .typing import Graph, Item, Json, Key
+from .typing import Graph, Json, Key
T = TypeVar("T")
@@ -59,6 +59,9 @@ def wrap(s: str) -> str:
)
+Targets = Hashable | tuple[Hashable, ...]
+
+
class TaskGraph:
"""
Holds a concrete task graph and keys to compute.
@@ -68,11 +71,7 @@ class TaskGraph:
"""
def __init__(
- self,
- *,
- graph: Graph,
- targets: Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]],
- scheduler: Optional[Scheduler] = None,
+ self, *, graph: Graph, targets: Targets, scheduler: Scheduler | None = None
) -> None:
self._graph = graph
self._keys = targets
@@ -87,12 +86,7 @@ def __init__(
)
self._scheduler = scheduler
- def compute(
- self,
- targets: Optional[
- Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]]
- ] = None,
- ) -> Any:
+ def compute(self, targets: Targets | None = None) -> Any:
"""
Compute the result of the graph.
@@ -162,7 +156,7 @@ def serialize(self) -> dict[str, Json]:
def _repr_html_(self) -> str:
leafs = sorted(
[
- escape(key_name(key))
+ escape(key_name(key)) # type: ignore[arg-type]
for key in (
self._keys if isinstance(self._keys, tuple) else [self._keys]
)
diff --git a/src/sciline/typing.py b/src/sciline/typing.py
index 93b1625d..883ddd88 100644
--- a/src/sciline/typing.py
+++ b/src/sciline/typing.py
@@ -1,14 +1,11 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
-from dataclasses import dataclass
from typing import (
Any,
Dict,
- Generic,
List,
Optional,
Tuple,
- Type,
TypeVar,
Union,
get_args,
@@ -17,23 +14,10 @@
from ._provider import Provider
-
-@dataclass(frozen=True)
-class Label:
- tp: type
- index: Any
-
-
T = TypeVar('T')
-@dataclass(frozen=True)
-class Item(Generic[T]):
- label: Tuple[Label, ...]
- tp: Type[T]
-
-
-Key = Union[type, Item[Any]]
+Key = type
Graph = dict[Key, Provider]
diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py
index d6a5d855..a7e9ddae 100644
--- a/src/sciline/visualize.py
+++ b/src/sciline/visualize.py
@@ -1,23 +1,13 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from dataclasses import dataclass
-from typing import (
- Any,
- Dict,
- List,
- Optional,
- Tuple,
- Type,
- TypeVar,
- Union,
- get_args,
- get_origin,
-)
+from typing import Any, Dict, Hashable, get_args, get_origin
+import cyclebane
from graphviz import Digraph
from ._provider import Provider, ProviderKind
-from .typing import Graph, Item, Key, get_optional
+from .typing import Graph, Key, get_optional
@dataclass
@@ -41,7 +31,7 @@ def to_graphviz(
graph: Graph,
compact: bool = False,
cluster_generics: bool = True,
- cluster_color: Optional[str] = '#f0f0ff',
+ cluster_color: str | None = '#f0f0ff',
**kwargs: Any,
) -> Digraph:
"""
@@ -149,19 +139,15 @@ def _format_provider(provider: Provider, ret: Key, compact: bool) -> str:
return f'{provider.location.qualname}_{_format_type(ret, compact=compact).name}'
-T = TypeVar('T')
-
-
def _extract_type_and_labels(
- key: Union[Item[T], Type[T]], compact: bool
-) -> Tuple[Type[T], List[Union[type, Tuple[type, Any]]]]:
- if isinstance(key, Item):
- label = key.label
- return key.tp, [lb.tp if compact else (lb.tp, lb.index) for lb in label]
+ key: Hashable | cyclebane.graph.NodeName, compact: bool
+) -> tuple[Hashable, list[Hashable | tuple[Hashable, Hashable]]]:
+ if isinstance(key, cyclebane.graph.NodeName):
+ return key.name, list(key.index.axes if compact else key.index.to_tuple())
return key, []
-def _format_type(tp: Key, compact: bool = False) -> Node:
+def _format_type(tp: Hashable, compact: bool = False) -> Node:
"""
Helper for _format_graph.
@@ -169,16 +155,15 @@ def _format_type(tp: Key, compact: bool = False) -> Node:
but strip all module prefixes from the type name as well as the params.
We may make this configurable in the future.
"""
-
tp, labels = _extract_type_and_labels(tp, compact=compact)
- if (tp_ := get_optional(tp)) is not None:
+ if (tp_ := get_optional(tp)) is not None: # type: ignore[arg-type]
tp = tp_
- def get_base(tp: Key) -> str:
- return tp.__name__ if hasattr(tp, '__name__') else str(tp).split('.')[-1]
+ def get_base(tp: Hashable) -> str:
+ return str(tp.__name__) if hasattr(tp, '__name__') else str(tp).split('.')[-1]
- def format_label(label: Union[type, Tuple[type, Any]]) -> str:
+ def format_label(label: Hashable | tuple[Hashable, Any]) -> str:
if isinstance(label, tuple):
tp, index = label
return f'{get_base(tp)}={index}'
diff --git a/tests/complex_workflow_test.py b/tests/complex_workflow_test.py
index 0d68862b..ace0b0a7 100644
--- a/tests/complex_workflow_test.py
+++ b/tests/complex_workflow_test.py
@@ -22,7 +22,7 @@ class RawData:
DirectBeam = NewType('DirectBeam', npt.NDArray[np.float64])
SolidAngle = NewType('SolidAngle', npt.NDArray[np.float64])
-Run = TypeVar('Run')
+Run = TypeVar('Run', SampleRun, BackgroundRun)
# TODO Giving the base twice works with mypy, how can we avoid typing it twice?
diff --git a/tests/param_table_test.py b/tests/param_table_test.py
deleted file mode 100644
index 931c247b..00000000
--- a/tests/param_table_test.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# SPDX-License-Identifier: BSD-3-Clause
-# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
-import pytest
-
-import sciline as sl
-
-
-def test_raises_with_zero_columns() -> None:
- with pytest.raises(ValueError):
- sl.ParamTable(row_dim=int, columns={})
-
-
-def test_raises_with_inconsistent_column_sizes() -> None:
- with pytest.raises(ValueError):
- sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0]})
-
-
-def test_raises_with_inconsistent_index_length() -> None:
- with pytest.raises(ValueError):
- sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0]}, index=[1, 2, 3])
-
-
-def test_raises_with_non_unique_index() -> None:
- with pytest.raises(ValueError):
- sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[1, 1, 2])
-
-
-def test_contains_includes_all_columns() -> None:
- pt = sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0, 3.0]})
- assert int in pt
- assert float in pt
- assert str not in pt
-
-
-def test_contains_does_not_include_index() -> None:
- pt = sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0, 3.0]})
- assert int not in pt
-
-
-def test_len_is_number_of_columns() -> None:
- pt = sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0, 3.0]})
- assert len(pt) == 2
-
-
-def test_defaults_to_range_index() -> None:
- pt = sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0, 3.0]})
- assert pt.index == [0, 1, 2]
-
-
-def test_index_with_no_columns() -> None:
- pt = sl.ParamTable(row_dim=int, columns={}, index=[1, 2, 3])
- assert pt.index == [1, 2, 3]
- assert len(pt) == 0
diff --git a/tests/pipeline_map_reduce_test.py b/tests/pipeline_map_reduce_test.py
new file mode 100644
index 00000000..9bbffffa
--- /dev/null
+++ b/tests/pipeline_map_reduce_test.py
@@ -0,0 +1,49 @@
+# SPDX-License-Identifier: BSD-3-Clause
+# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
+from typing import NewType
+
+import pytest
+
+import sciline as sl
+
+A = NewType('A', int)
+B = NewType('B', int)
+C = NewType('C', int)
+D = NewType('D', int)
+X = NewType('X', int)
+
+
+def a_to_b(a: A) -> B:
+ return B(a + 1)
+
+
+def b_to_c(b: B) -> C:
+ return C(b + 2)
+
+
+def c_to_d(c: C) -> D:
+ return D(c + 4)
+
+
+def test_map_returns_pipeline_that_can_compute_for_each_value() -> None:
+ ab = sl.Pipeline((a_to_b,))
+ mapped = ab.map({A: [A(10 * i) for i in range(3)]})
+ with pytest.raises(sl.UnsatisfiedRequirement):
+ # B is not in the graph any more, since it has been duplicated
+ mapped.compute(B)
+ from cyclebane.graph import IndexValues, NodeName
+
+ for i in range(3):
+ index = IndexValues(('dim_0',), (i,))
+ assert mapped.compute(NodeName(A, index)) == A(10 * i) # type: ignore[call-overload] # noqa: E501
+ assert mapped.compute(NodeName(B, index)) == B(10 * i + 1) # type: ignore[call-overload] # noqa: E501
+
+
+def test_reduce_returns_pipeline_passing_mapped_branches_to_reducing_func() -> None:
+ ab = sl.Pipeline((a_to_b,))
+ mapped = ab.map({A: [A(10 * i) for i in range(3)]})
+ # Define key to make mypy happy. This can actually be anything but currently
+ # the type-hinting of Pipeline is too specific, disallowing, e.g., strings.
+ Result = NewType('Result', int)
+ assert mapped.reduce(func=min, name=Result).compute(Result) == Result(1)
+ assert mapped.reduce(func=max, name=Result).compute(Result) == Result(21)
diff --git a/tests/pipeline_setitem_test.py b/tests/pipeline_setitem_test.py
new file mode 100644
index 00000000..65e37bce
--- /dev/null
+++ b/tests/pipeline_setitem_test.py
@@ -0,0 +1,135 @@
+# SPDX-License-Identifier: BSD-3-Clause
+# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
+from typing import NewType
+
+import pytest
+
+import sciline as sl
+
+A = NewType('A', int)
+B = NewType('B', int)
+C = NewType('C', int)
+D = NewType('D', int)
+X = NewType('X', int)
+
+
+def a_to_b(a: A) -> B:
+ return B(a + 1)
+
+
+def ab_to_c(a: A, b: B) -> C:
+ return C(a + b)
+
+
+def b_to_c(b: B) -> C:
+ return C(b + 2)
+
+
+def c_to_d(c: C) -> D:
+ return D(c + 4)
+
+
+def test_setitem_can_compose_pipelines() -> None:
+ ab = sl.Pipeline((a_to_b,))
+ ab[A] = 0
+ bc = sl.Pipeline((b_to_c,))
+ bc[B] = ab
+ assert bc.compute(C) == C(3)
+
+
+def test_setitem_raises_if_value_pipeline_has_no_unique_output() -> None:
+ abx = sl.Pipeline((a_to_b,))
+ abx[X] = 666
+ bc = sl.Pipeline((b_to_c,))
+ with pytest.raises(ValueError, match='Graph must have exactly one sink node'):
+ bc[B] = abx
+
+
+def test_setitem_can_add_value_pipeline_at_new_node_given_by_key() -> None:
+ ab = sl.Pipeline((a_to_b,))
+ ab[A] = 0
+ empty = sl.Pipeline(())
+ empty[B] = ab
+ assert empty.compute(B) == B(1)
+
+
+def test_setitem_replaces_existing_node_value_with_value_pipeline() -> None:
+ ab = sl.Pipeline((a_to_b,))
+ ab[A] = 0
+ bc = sl.Pipeline((b_to_c,))
+ bc[B] = 111
+ bc[B] = ab
+ assert bc.compute(C) == C(3)
+
+
+def test_setitem_replaces_existing_branch_with_value() -> None:
+ ab = sl.Pipeline((a_to_b,))
+ ab[A] = 0
+ bc = sl.Pipeline((b_to_c,))
+ bc[B] = ab
+ bc[B] = 111
+ assert bc.compute(C) == C(113)
+ # __setitem__ prunes the entire branch, instead of just cutting the edges
+ with pytest.raises(sl.UnsatisfiedRequirement):
+ bc.compute(A)
+
+
+def test_setitem_with_conflicting_nodes_in_value_pipeline_raises_on_data_mismatch() -> (
+ None
+):
+ ab = sl.Pipeline((a_to_b,))
+ ab[A] = 100
+ abc = sl.Pipeline((ab_to_c,))
+ abc[A] = 666
+ with pytest.raises(ValueError, match="Node data differs"):
+ abc[B] = ab
+
+
+def test_setitem_with_conflicting_nodes_in_value_pipeline_accepts_on_data_match() -> (
+ None
+):
+ ab = sl.Pipeline((a_to_b,))
+ ab[A] = 100
+ abc = sl.Pipeline((ab_to_c,))
+ abc[A] = 100
+ abc[B] = ab
+ assert abc.compute(C) == C(201)
+
+
+def test_setitem_with_conflicting_nodes_in_value_pipeline_raises_on_unique_data() -> (
+ None
+):
+ ab = sl.Pipeline((a_to_b,))
+ ab[A] = 100
+ abc = sl.Pipeline((ab_to_c,))
+ # Missing data is just as bad as conflicting data. At first glance, it might seem
+ # like this should be allowed, but it would be a source of bugs and confusion.
+ # In particular since it would allow for adding providers or parents to an
+ # indirect dependency of the key in setitem, which would be very confusing.
+ with pytest.raises(ValueError, match="Node data differs"):
+ abc[B] = ab
+
+
+def test_setitem_with_conflicting_node_inputs_in_value_pipeline_raises() -> None:
+ def x_to_b(x: X) -> B:
+ return B(x + 1)
+
+ def bc_to_d(b: B, c: C) -> D:
+ return D(b + c)
+
+ xbc = sl.Pipeline((x_to_b, b_to_c))
+ xbc[X] = 666
+ abcd = sl.Pipeline((a_to_b, bc_to_d))
+ abcd[A] = 100
+ with pytest.raises(ValueError, match="Node inputs differ"):
+ # If this worked naively by composing the NetworkX graph, it would look as
+ # below, giving B 2 inputs instead of 1, even though provider has 1 argument,
+ # corrupting the graph:
+ # A X
+ # \ /
+ # B
+ # |\
+ # C |
+ # |/
+ # D
+ abcd[C] = xbc
diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py
index f8531260..68c5bbdc 100644
--- a/tests/pipeline_test.py
+++ b/tests/pipeline_test.py
@@ -83,7 +83,7 @@ def func2(x: int) -> str:
def test_Scope_subclass_can_be_set_as_param() -> None:
- Param = TypeVar('Param')
+ Param = TypeVar('Param', int, float)
class Str(sl.Scope[Param, str], str):
...
@@ -95,7 +95,7 @@ class Str(sl.Scope[Param, str], str):
def test_Scope_subclass_can_be_set_as_param_with_unbound_typevar() -> None:
- Param = TypeVar('Param')
+ Param = TypeVar('Param', int, float)
class Str(sl.Scope[Param, str], str):
...
@@ -107,8 +107,8 @@ class Str(sl.Scope[Param, str], str):
def test_ScopeTwoParam_subclass_can_be_set_as_param() -> None:
- Param1 = TypeVar('Param1')
- Param2 = TypeVar('Param2')
+ Param1 = TypeVar('Param1', int, float)
+ Param2 = TypeVar('Param2', int, float)
class Str(sl.ScopeTwoParams[Param1, Param2, str], str):
...
@@ -120,8 +120,8 @@ class Str(sl.ScopeTwoParams[Param1, Param2, str], str):
def test_ScopeTwoParam_subclass_can_be_set_as_param_with_unbound_typevar() -> None:
- Param1 = TypeVar('Param1')
- Param2 = TypeVar('Param2')
+ Param1 = TypeVar('Param1', int, float)
+ Param2 = TypeVar('Param2', int, float)
class Str(sl.ScopeTwoParams[Param1, Param2, str], str):
...
@@ -133,7 +133,7 @@ class Str(sl.ScopeTwoParams[Param1, Param2, str], str):
def test_generic_providers_produce_use_dependencies_based_on_bound_typevar() -> None:
- Param = TypeVar('Param')
+ Param = TypeVar('Param', int, float)
class Str(sl.Scope[Param, str], str):
...
@@ -161,7 +161,11 @@ def provide_int() -> int:
ncall += 1
return 3
- Param = TypeVar('Param')
+ Run1 = NewType('Run1', int)
+ Run2 = NewType('Run2', int)
+ Result = NewType('Result', str)
+
+ Param = TypeVar('Param', Run1, Run2, Result)
class Float(sl.Scope[Param, float], float):
...
@@ -172,10 +176,6 @@ class Str(sl.Scope[Param, str], str):
def int_float_to_str(x: int, y: Float[Param]) -> Str[Param]:
return Str(f"{x};{y}")
- Run1 = NewType('Run1', int)
- Run2 = NewType('Run2', int)
- Result = NewType('Result', str)
-
def float1() -> Float[Run1]:
return Float[Run1](1.5)
@@ -193,7 +193,7 @@ def use_strings(s1: Str[Run1], s2: Str[Run2]) -> Result:
def test_subclasses_of_generic_provider_defined_with_Scope_work() -> None:
- Param = TypeVar('Param')
+ Param = TypeVar('Param', int, float)
class StrT(sl.Scope[Param, str], str):
...
@@ -234,7 +234,8 @@ def make_str2() -> Str2[Param]:
def test_subclasses_of_generic_array_provider_defined_with_Scope_work() -> None:
- Param = TypeVar('Param')
+ # int is unused, but a single constraint is not allowed by Python
+ Param = TypeVar('Param', str, int)
class ArrayT(sl.Scope[Param, npt.NDArray[np.int64]], npt.NDArray[np.int64]):
...
@@ -268,6 +269,12 @@ def provide_none() -> None:
pipeline.insert(provide_none)
+def test_setting_None_param_raises() -> None:
+ pipeline = sl.Pipeline()
+ with pytest.raises(ValueError):
+ pipeline[None] = 3 # type: ignore[index]
+
+
def test_inserting_provider_with_no_return_type_raises() -> None:
def provide_none(): # type: ignore[no-untyped-def]
return None
@@ -280,31 +287,30 @@ def provide_none(): # type: ignore[no-untyped-def]
def test_TypeVar_requirement_of_provider_can_be_bound() -> None:
- T = TypeVar('T')
+ T = TypeVar('T', int, float)
def provider_int() -> int:
return 3
- def provider(x: T) -> List[T]:
+ def provider(x: T) -> list[T]:
return [x, x]
pipeline = sl.Pipeline([provider_int, provider])
- assert pipeline.compute(List[int]) == [3, 3]
+ assert pipeline.compute(list[int]) == [3, 3]
def test_TypeVar_that_cannot_be_bound_raises_UnboundTypeVar() -> None:
- T = TypeVar('T')
+ T = TypeVar('T', int, float)
- def provider(_: T) -> int:
- return 1
+ def provider(_: T) -> str:
+ return 'abc'
- pipeline = sl.Pipeline([provider])
with pytest.raises(sl.UnboundTypeVar):
- pipeline.compute(int)
+ sl.Pipeline([provider])
def test_unsatisfiable_TypeVar_requirement_of_provider_raises() -> None:
- T = TypeVar('T')
+ T = TypeVar('T', int, float)
def provider_int() -> int:
return 3
@@ -318,8 +324,8 @@ def provider(x: T) -> List[T]:
def test_TypeVar_params_are_not_associated_unless_they_match() -> None:
- T1 = TypeVar('T1')
- T2 = TypeVar('T2')
+ T1 = TypeVar('T1', int, float)
+ T2 = TypeVar('T2', int, float)
class A(Generic[T1]):
...
@@ -336,15 +342,15 @@ def not_matching(x: A[T1]) -> B[T2]:
def matching(x: A[T1]) -> B[T1]:
return B[T1]()
- pipeline = sl.Pipeline([source, not_matching])
with pytest.raises(sl.UnboundTypeVar):
- pipeline.compute(B[int])
+ sl.Pipeline([source, not_matching])
pipeline = sl.Pipeline([source, matching])
pipeline.compute(B[int])
def test_multi_Generic_with_fully_bound_arguments() -> None:
+ # Note that no constraints necessary here, Sciline never sees the typevars
T1 = TypeVar('T1')
T2 = TypeVar('T2')
@@ -361,8 +367,8 @@ def source() -> A[int, float]:
def test_multi_Generic_with_partially_bound_arguments() -> None:
- T1 = TypeVar('T1')
- T2 = TypeVar('T2')
+ T1 = TypeVar('T1', int, float)
+ T2 = TypeVar('T2', int, float)
@dataclass
class A(Generic[T1, T2]):
@@ -380,29 +386,30 @@ def partially_bound(x: T1) -> A[int, T1]:
def test_multi_Generic_with_multiple_unbound() -> None:
- T1 = TypeVar('T1')
- T2 = TypeVar('T2')
+ T1 = TypeVar('T1', int, float)
+ T2 = TypeVar('T2', bool, str)
@dataclass
class A(Generic[T1, T2]):
first: T1
second: T2
- def int_source() -> int:
- return 1
-
- def float_source() -> float:
- return 2.0
-
def unbound(x: T1, y: T2) -> A[T1, T2]:
return A[T1, T2](x, y)
- pipeline = sl.Pipeline([int_source, float_source, unbound])
- assert pipeline.compute(A[int, float]) == A[int, float](1, 2.0)
- assert pipeline.compute(A[float, int]) == A[float, int](2.0, 1)
+ pipeline = sl.Pipeline([unbound])
+ pipeline[int] = 1
+ pipeline[float] = 2.0
+ pipeline[bool] = True
+ pipeline[str] = 'a'
+ assert pipeline.compute(A[int, bool]) == A[int, bool](1, True)
+ assert pipeline.compute(A[int, str]) == A[int, str](1, 'a')
+ assert pipeline.compute(A[float, bool]) == A[float, bool](2.0, True)
+ assert pipeline.compute(A[float, str]) == A[float, str](2.0, 'a')
def test_distinct_fully_bound_instances_yield_distinct_results() -> None:
+ # Note that no constraints necessary here, Sciline never sees the typevar
T1 = TypeVar('T1')
@dataclass
@@ -421,8 +428,8 @@ def float_source() -> A[float]:
def test_distinct_partially_bound_instances_yield_distinct_results() -> None:
- T1 = TypeVar('T1')
- T2 = TypeVar('T2')
+ T1 = TypeVar('T1', int, float)
+ T2 = TypeVar('T2', int, str)
@dataclass
class A(Generic[T1, T2]):
@@ -432,20 +439,20 @@ class A(Generic[T1, T2]):
def str_source() -> str:
return 'a'
- def int_source(x: T1) -> A[int, T1]:
- return A[int, T1](1, x)
+ def int_source(x: T2) -> A[int, T2]:
+ return A[int, T2](1, x)
- def float_source(x: T1) -> A[float, T1]:
- return A[float, T1](2.0, x)
+ def float_source(x: T2) -> A[float, T2]:
+ return A[float, T2](2.0, x)
pipeline = sl.Pipeline([str_source, int_source, float_source])
assert pipeline.compute(A[int, str]) == A[int, str](1, 'a')
assert pipeline.compute(A[float, str]) == A[float, str](2.0, 'a')
-def test_multiple_matching_partial_providers_raises() -> None:
- T1 = TypeVar('T1')
- T2 = TypeVar('T2')
+def test_multiple_matching_partial_providers_uses_latest() -> None:
+ T1 = TypeVar('T1', int, float)
+ T2 = TypeVar('T2', int, float)
@dataclass
class A(Generic[T1, T2]):
@@ -456,7 +463,7 @@ def int_source() -> int:
return 1
def float_source() -> float:
- return 2.0
+ return 1.0
def provider1(x: T1) -> A[int, T1]:
return A[int, T1](1, x)
@@ -466,14 +473,14 @@ def provider2(x: T2) -> A[T2, float]:
pipeline = sl.Pipeline([int_source, float_source, provider1, provider2])
assert pipeline.compute(A[int, int]) == A[int, int](1, 1)
- assert pipeline.compute(A[float, float]) == A[float, float](2.0, 2.0)
- with pytest.raises(sl.AmbiguousProvider):
- pipeline.compute(A[int, float])
+ assert pipeline.compute(A[float, float]) == A[float, float](1.0, 2.0)
+ # Multiple matches, but the latest one (provider2) is used
+ assert pipeline.compute(A[int, float]) == A[int, float](1, 2.0)
def test_TypeVar_params_track_to_multiple_sources() -> None:
- T1 = TypeVar('T1')
- T2 = TypeVar('T2')
+ T1 = TypeVar('T1', int, float)
+ T2 = TypeVar('T2', int, float)
@dataclass
class A(Generic[T1]):
@@ -658,7 +665,8 @@ def func(x: T) -> A[T]:
def test_setitem_can_replace_generic_param_with_generic_param() -> None:
- T = TypeVar('T')
+ # float unused, but must have more than one constraint
+ T = TypeVar('T', int, float)
@dataclass
class A(Generic[T]):
@@ -924,7 +932,7 @@ def d(_f: float) -> None:
assert calls.index('d') in (2, 3)
-def test_prioritizes_specialized_provider_over_generic() -> None:
+def test_inserting_generic_provider_replaces_specialized_provider() -> None:
A = NewType('A', str)
B = NewType('B', str)
V = TypeVar('V', A, B)
@@ -938,68 +946,11 @@ def p1(x: V) -> H[V]:
def p2(x: B) -> H[B]:
return H[B]("Special")
- pl = sl.Pipeline([p1, p2], params={A: 'A', B: 'B'})
+ # p2 will be replaced immediately by p1.
+ pl = sl.Pipeline([p2, p1], params={A: 'A', B: 'B'})
assert str(pl.compute(H[A])) == "Generic"
- assert str(pl.compute(H[B])) == "Special"
-
-
-def test_prioritizes_specialized_provider_over_generic_several_typevars() -> None:
- A = NewType('A', str)
- B = NewType('B', str)
- T1 = TypeVar('T1')
- T2 = TypeVar('T2')
-
- @dataclass
- class C(Generic[T1, T2]):
- first: T1
- second: T2
- third: str
-
- def p1(x: T1, y: T2) -> C[T1, T2]:
- return C(x, y, 'generic')
-
- def p2(x: A, y: T2) -> C[A, T2]:
- return C(x, y, 'medium generic')
-
- def p3(x: T2, y: B) -> C[T2, B]:
- return C(x, y, 'generic medium')
-
- def p4(x: A, y: B) -> C[A, B]:
- return C(x, y, 'special')
-
- pl = sl.Pipeline([p1, p2, p3, p4], params={A: A('A'), B: B('B')})
-
- assert pl.compute(C[B, A]) == C('B', 'A', 'generic')
- assert pl.compute(C[A, A]) == C('A', 'A', 'medium generic')
- assert pl.compute(C[B, B]) == C('B', 'B', 'generic medium')
- assert pl.compute(C[A, B]) == C('A', 'B', 'special')
-
-
-def test_prioritizes_specialized_provider_raises() -> None:
- A = NewType('A', str)
- B = NewType('B', str)
- T1 = TypeVar('T1')
- T2 = TypeVar('T2')
-
- @dataclass
- class C(Generic[T1, T2]):
- first: T1
- second: T2
-
- def p1(x: A, y: T1) -> C[A, T1]:
- return C(x, y)
-
- def p2(x: T1, y: B) -> C[T1, B]:
- return C(x, y)
-
- pl = sl.Pipeline([p1, p2], params={A: A('A'), B: B('B')})
-
- with pytest.raises(sl.AmbiguousProvider):
- pl.compute(C[A, B])
-
- with pytest.raises(sl.UnsatisfiedRequirement):
- pl.compute(C[B, A])
+ assert str(pl.compute(H[B])) == "Generic"
def test_compute_time_handler_allows_for_building_but_not_computing() -> None:
@@ -1052,7 +1003,7 @@ def test_pipeline_copy_after_setitem() -> None:
def test_copy_with_generic_providers() -> None:
- Param = TypeVar('Param')
+ Param = TypeVar('Param', int, float)
class Str(sl.Scope[Param, str], str):
...
@@ -1118,7 +1069,11 @@ def test_pipeline_setitem_on_original_does_not_affect_copy() -> None:
def test_pipeline_with_generics_setitem_on_original_does_not_affect_copy() -> None:
- RunType = TypeVar('RunType')
+ Sample = NewType('Sample', int)
+ Background = NewType('Background', int)
+ Result = NewType('Result', int)
+
+ RunType = TypeVar('RunType', Sample, Background)
class RawData(sl.Scope[RunType, int], int):
...
@@ -1126,10 +1081,6 @@ class RawData(sl.Scope[RunType, int], int):
class SquaredData(sl.Scope[RunType, int], int):
...
- Sample = NewType('Sample', int)
- Background = NewType('Background', int)
- Result = NewType('Result', int)
-
def square(x: RawData[RunType]) -> SquaredData[RunType]:
return SquaredData[RunType](x * x)
@@ -1154,7 +1105,11 @@ def process(
def test_pipeline_with_generics_setitem_on_copy_does_not_affect_original() -> None:
- RunType = TypeVar('RunType')
+ Sample = NewType('Sample', int)
+ Background = NewType('Background', int)
+ Result = NewType('Result', int)
+
+ RunType = TypeVar('RunType', Sample, Background)
class RawData(sl.Scope[RunType, int], int):
...
@@ -1162,10 +1117,6 @@ class RawData(sl.Scope[RunType, int], int):
class SquaredData(sl.Scope[RunType, int], int):
...
- Sample = NewType('Sample', int)
- Background = NewType('Background', int)
- Result = NewType('Result', int)
-
def square(x: RawData[RunType]) -> SquaredData[RunType]:
return SquaredData[RunType](x * x)
@@ -1339,57 +1290,6 @@ def p2(v1: S2, v2: T2) -> N[S2, T2]:
pipeline.get(N[M[str], float])
-def test_number_of_type_vars_defines_most_specialized() -> None:
- Green = NewType('Green', str)
- Blue = NewType('Blue', str)
- Color = TypeVar('Color', Green, Blue)
-
- @dataclass
- class Likes(Generic[Color]):
- color: Color
-
- Preference = TypeVar('Preference')
-
- @dataclass
- class Person(Generic[Preference, Color]):
- preference: Preference
- hatcolor: Color
- provided_by: str
-
- def p(c: Color) -> Likes[Color]:
- return Likes(c)
-
- def p0(p: Preference, c: Color) -> Person[Preference, Color]:
- return Person(p, c, 'p0')
-
- def p1(c: Color) -> Person[Likes[Color], Color]:
- return Person(Likes(c), c, 'p1')
-
- def p2(p: Preference) -> Person[Preference, Green]:
- return Person(p, Green('g'), 'p2')
-
- pipeline = sl.Pipeline((p, p0, p1, p2))
- pipeline[Blue] = 'b'
- pipeline[Green] = 'g'
-
- # only provided by p0
- assert pipeline.compute(Person[Likes[Green], Blue]) == Person(
- Likes(Green('g')), Blue('b'), 'p0'
- )
- # provided by p1 and p0 but p1 is preferred because it has fewer typevars
- assert pipeline.compute(Person[Likes[Blue], Blue]) == Person(
- Likes(Blue('b')), Blue('b'), 'p1'
- )
- # provided by p2 and p0 but p2 is preferred because it has fewer typevars
- assert pipeline.compute(Person[Likes[Blue], Green]) == Person(
- Likes(Blue('b')), Green('g'), 'p2'
- )
-
- with pytest.raises(sl.AmbiguousProvider):
- # provided by p1 and p2 with the same number of typevars
- pipeline.get(Person[Likes[Green], Green])
-
-
def test_pipeline_with_decorated_provider() -> None:
R = TypeVar('R')
@@ -1526,3 +1426,12 @@ def __new__(cls, x: int) -> str: # type: ignore[misc]
with pytest.raises(TypeError):
sl.Pipeline([C], params={int: 3})
+
+
+def test_inserting_provider_with_duplicate_arguments_raises() -> None:
+ def bad(x: int, y: int) -> float:
+ return float(x + y)
+
+ pipeline = sl.Pipeline()
+ with pytest.raises(ValueError, match="Duplicate type hints"):
+ pipeline.insert(bad)
diff --git a/tests/pipeline_with_optional_test.py b/tests/pipeline_with_optional_test.py
index 9ca926eb..e7550d24 100644
--- a/tests/pipeline_with_optional_test.py
+++ b/tests/pipeline_with_optional_test.py
@@ -64,7 +64,7 @@ def use_optional(*, x: Optional[int]) -> str:
pipeline.get(str)
-def test_Union_argument_order_matters() -> None:
+def test_Union_argument_order_does_not_matter() -> None:
def use_union(x: int | float) -> str:
return f'{x}'
@@ -73,8 +73,10 @@ def use_union(x: int | float) -> str:
assert pipeline.compute(str) == '1'
pipeline = sl.Pipeline([use_union])
pipeline[float | int] = 1 # type: ignore[index]
- with pytest.raises(sl.UnsatisfiedRequirement):
- pipeline.get(str)
+ assert pipeline.compute(str) == '1'
+ # Note that the above works because the hashes are the same:
+ assert hash(int | float) == hash(float | int)
+ assert hash(Union[int, float]) == hash(Union[float, int])
def test_optional_dependency_cannot_be_filled_transitively() -> None:
diff --git a/tests/pipeline_with_param_table_test.py b/tests/pipeline_with_param_table_test.py
deleted file mode 100644
index b052f598..00000000
--- a/tests/pipeline_with_param_table_test.py
+++ /dev/null
@@ -1,741 +0,0 @@
-# SPDX-License-Identifier: BSD-3-Clause
-# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
-from typing import List, NewType, Optional, TypeVar
-
-import pytest
-
-import sciline as sl
-from sciline.typing import Item, Label
-
-
-def test_set_param_table_raises_if_param_names_are_duplicate() -> None:
- pl = sl.Pipeline()
- pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- with pytest.raises(ValueError):
- pl.set_param_table(sl.ParamTable(str, {float: [4.0, 5.0, 6.0]}))
- assert pl.compute(Item((Label(int, 1),), float)) == 2.0
- with pytest.raises(sl.UnsatisfiedRequirement):
- pl.compute(Item((Label(str, 1),), float))
-
-
-def test_set_param_table_removes_columns_of_replaced_table() -> None:
- pl = sl.Pipeline()
- pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- # We could imagine that this would be allowed if the index
- # (here: automatic index as range(3)) is the same. For now we do not.
- pl.set_param_table(sl.ParamTable(int, {str: ['a', 'b', 'c']}))
- assert pl.compute(Item((Label(int, 1),), str)) == 'b'
- with pytest.raises(sl.UnsatisfiedRequirement):
- pl.compute(Item((Label(int, 1),), float))
-
-
-def test_can_get_elements_of_param_table() -> None:
- pl = sl.Pipeline()
- pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- assert pl.compute(Item((Label(int, 1),), float)) == 2.0
-
-
-def test_can_get_elements_of_param_table_with_explicit_index() -> None:
- pl = sl.Pipeline()
- pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[11, 12, 13]))
- assert pl.compute(Item((Label(int, 12),), float)) == 2.0
-
-
-def test_can_replace_param_table() -> None:
- pl = sl.Pipeline()
- table1 = sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[11, 12, 13])
- pl.set_param_table(table1)
- assert pl.compute(Item((Label(int, 12),), float)) == 2.0
- table2 = sl.ParamTable(int, {float: [4.0, 5.0, 6.0]}, index=[21, 22, 23])
- pl.set_param_table(table2)
- assert pl.compute(Item((Label(int, 22),), float)) == 5.0
-
-
-def test_can_replace_param_table_with_table_of_different_length() -> None:
- pl = sl.Pipeline()
- table1 = sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[11, 12, 13])
- pl.set_param_table(table1)
- assert pl.compute(Item((Label(int, 13),), float)) == 3.0
- table2 = sl.ParamTable(int, {float: [4.0, 5.0]}, index=[21, 22])
- pl.set_param_table(table2)
- assert pl.compute(Item((Label(int, 22),), float)) == 5.0
- # Make sure rows beyond the new table are not accessible
- with pytest.raises(sl.UnsatisfiedRequirement):
- pl.compute(Item((Label(int, 13),), float))
-
-
-def test_failed_replace_due_to_column_clash_in_other_table() -> None:
- Row1 = NewType("Row1", int)
- Row2 = NewType("Row2", int)
- table1 = sl.ParamTable(Row1, {float: [1.0, 2.0, 3.0]})
- table2 = sl.ParamTable(Row2, {str: ['a', 'b', 'c']})
- table1_replacement = sl.ParamTable(
- Row1, {float: [1.1, 2.2, 3.3], str: ['a', 'b', 'c']}
- )
- pl = sl.Pipeline()
- pl.set_param_table(table1)
- pl.set_param_table(table2)
- with pytest.raises(ValueError):
- pl.set_param_table(table1_replacement)
- # Make sure the original table is still accessible
- assert pl.compute(Item((Label(Row1, 1),), float)) == 2.0
-
-
-def test_can_depend_on_elements_of_param_table() -> None:
- # This is not a valid type annotation, not sure why it works with get_type_hints
- def use_elem(x: Item((Label(int, 1),), float)) -> str: # type: ignore[valid-type]
- return str(x)
-
- pl = sl.Pipeline([use_elem])
- pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- assert pl.compute(str) == "2.0"
-
-
-def test_can_depend_on_elements_of_param_table_kwarg() -> None:
- # This is not a valid type annotation, not sure why it works with get_type_hints
- def use_elem(
- *, x: Item((Label(int, 1),), float) # type: ignore[valid-type]
- ) -> str:
- return str(x)
-
- pl = sl.Pipeline([use_elem])
- pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- assert pl.compute(str) == "2.0"
-
-
-def test_can_compute_series_of_param_values() -> None:
- pl = sl.Pipeline()
- pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- assert pl.compute(sl.Series[int, float]) == sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0})
-
-
-def test_cannot_compute_series_of_non_table_param() -> None:
- pl = sl.Pipeline()
- # Table for defining length
- pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- pl[str] = 'abc'
- # The alternative option would be to expect to return
- # sl.Series(int, {0: 'abc', 1: 'abc', 2: 'abc'})
- # For now, we are not supporting this since it is unclear if this would be
- # conceptually sound and risk free.
- with pytest.raises(sl.UnsatisfiedRequirement):
- pl.compute(sl.Series[int, str])
-
-
-def test_can_compute_series_of_derived_values() -> None:
- def process(x: float) -> str:
- return str(x)
-
- pl = sl.Pipeline([process])
- pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- assert pl.compute(sl.Series[int, str]) == sl.Series(
- int, {0: "1.0", 1: "2.0", 2: "3.0"}
- )
-
-
-def test_can_compute_series_of_derived_values_kwarg() -> None:
- def process(*, x: float) -> str:
- return str(x)
-
- pl = sl.Pipeline([process])
- pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- assert pl.compute(sl.Series[int, str]) == sl.Series(
- int, {0: "1.0", 1: "2.0", 2: "3.0"}
- )
-
-
-def test_creating_pipeline_with_provider_of_series_raises() -> None:
- series = sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0})
-
- def make_series() -> sl.Series[int, float]:
- return series
-
- with pytest.raises(ValueError):
- sl.Pipeline([make_series])
-
-
-def test_creating_pipeline_with_series_param_raises() -> None:
- series = sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0})
-
- with pytest.raises(ValueError):
- sl.Pipeline([], params={sl.Series[int, float]: series})
-
-
-def test_explicit_index_of_param_table_is_forwarded_correctly() -> None:
- def process(x: float) -> int:
- return int(x)
-
- pl = sl.Pipeline([process])
- pl.set_param_table(
- sl.ParamTable(str, {float: [1.0, 2.0, 3.0]}, index=['a', 'b', 'c'])
- )
- assert pl.compute(sl.Series[str, int]) == sl.Series(str, {'a': 1, 'b': 2, 'c': 3})
-
-
-def test_can_gather_index() -> None:
- Sum = NewType("Sum", float)
- Name = NewType("Name", str)
-
- def gather(x: sl.Series[Name, float]) -> Sum:
- return Sum(sum(x.values()))
-
- def make_float(x: str) -> float:
- return float(x)
-
- pl = sl.Pipeline([gather, make_float])
- pl.set_param_table(sl.ParamTable(Name, {str: ["1.0", "2.0", "3.0"]}))
- assert pl.compute(Sum) == 6.0
-
-
-def test_can_zip() -> None:
- Sum = NewType("Sum", str)
- Str = NewType("Str", str)
- Run = NewType("Run", int)
-
- def gather_zip(x: sl.Series[Run, Str], y: sl.Series[Run, int]) -> Sum:
- z = [f'{x_}{y_}' for x_, y_ in zip(x.values(), y.values())]
- return Sum(str(z))
-
- def use_str(x: str) -> Str:
- return Str(x)
-
- pl = sl.Pipeline([gather_zip, use_str])
- pl.set_param_table(sl.ParamTable(Run, {str: ['a', 'a', 'ccc'], int: [1, 2, 3]}))
-
- assert pl.compute(Sum) == "['a1', 'a2', 'ccc3']"
-
-
-def test_diamond_dependency_pulls_values_from_columns_in_same_param_table() -> None:
- Sum = NewType("Sum", float)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
- Row = NewType("Row", int)
-
- def gather(x: sl.Series[Row, float]) -> Sum:
- return Sum(sum(x.values()))
-
- def join(x: Param1, y: Param2) -> float:
- return x / y
-
- pl = sl.Pipeline([gather, join])
- pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]}))
-
- assert pl.compute(Sum) == Sum(6)
-
-
-def test_diamond_dependency_on_same_column() -> None:
- Sum = NewType("Sum", float)
- Param = NewType("Param", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
- Row = NewType("Row", int)
-
- def gather(x: sl.Series[Row, float]) -> Sum:
- return Sum(sum(x.values()))
-
- def to_param1(x: Param) -> Param1:
- return Param1(x)
-
- def to_param2(x: Param) -> Param2:
- return Param2(x)
-
- def join(x: Param1, y: Param2) -> float:
- return x / y
-
- pl = sl.Pipeline([gather, join, to_param1, to_param2])
- pl.set_param_table(sl.ParamTable(Row, {Param: [1, 2, 3]}))
-
- assert pl.compute(Sum) == Sum(3)
-
-
-def test_dependencies_on_different_param_tables_broadcast() -> None:
- Row1 = NewType("Row1", int)
- Row2 = NewType("Row2", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
- Product = NewType("Product", str)
-
- def gather_both(x: sl.Series[Row1, Param1], y: sl.Series[Row2, Param2]) -> Product:
- broadcast = [[x_, y_] for x_ in x.values() for y_ in y.values()]
- return Product(str(broadcast))
-
- pl = sl.Pipeline([gather_both])
- pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]}))
- pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]}))
- assert pl.compute(Product) == "[[1, 4], [1, 5], [2, 4], [2, 5], [3, 4], [3, 5]]"
-
-
-def test_dependency_on_other_param_table_in_parent_broadcasts_branch() -> None:
- Row1 = NewType("Row1", int)
- Row2 = NewType("Row2", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
- Summed2 = NewType("Summed2", int)
- Product = NewType("Product", str)
-
- def gather2_and_combine(x: Param1, y: sl.Series[Row2, Param2]) -> Summed2:
- return Summed2(x * sum(y.values()))
-
- def gather1(x: sl.Series[Row1, Summed2]) -> Product:
- return Product(str(list(x.values())))
-
- pl = sl.Pipeline([gather1, gather2_and_combine])
- pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]}))
- pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]}))
- assert pl.compute(Product) == "[9, 18, 27]"
-
-
-def test_dependency_on_other_param_table_in_grandparent_broadcasts_branch() -> None:
- Row1 = NewType("Row1", int)
- Row2 = NewType("Row2", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
- Summed2 = NewType("Summed2", int)
- Combined = NewType("Combined", int)
- Product = NewType("Product", str)
-
- def gather2(x: sl.Series[Row2, Param2]) -> Summed2:
- return Summed2(sum(x.values()))
-
- def combine(x: Param1, y: Summed2) -> Combined:
- return Combined(x * y)
-
- def gather1(x: sl.Series[Row1, Combined]) -> Product:
- return Product(str(list(x.values())))
-
- pl = sl.Pipeline([gather1, gather2, combine])
- pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]}))
- pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]}))
- assert pl.compute(Product) == "[9, 18, 27]"
-
-
-def test_nested_dependencies_on_different_param_tables() -> None:
- Row1 = NewType("Row1", int)
- Row2 = NewType("Row2", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
- Combined = NewType("Combined", int)
-
- def combine(x: Param1, y: Param2) -> Combined:
- return Combined(x * y)
-
- pl = sl.Pipeline([combine])
- pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]}))
- pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]}))
- assert pl.compute(sl.Series[Row1, sl.Series[Row2, Combined]]) == sl.Series(
- Row1,
- {
- 0: sl.Series(Row2, {0: 4, 1: 5}),
- 1: sl.Series(Row2, {0: 8, 1: 10}),
- 2: sl.Series(Row2, {0: 12, 1: 15}),
- },
- )
- assert pl.compute(sl.Series[Row2, sl.Series[Row1, Combined]]) == sl.Series(
- Row2,
- {
- 0: sl.Series(Row1, {0: 4, 1: 8, 2: 12}),
- 1: sl.Series(Row1, {0: 5, 1: 10, 2: 15}),
- },
- )
-
-
-def test_can_groupby_by_requesting_series_of_series() -> None:
- Row = NewType("Row", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
-
- pl = sl.Pipeline()
- pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]}))
- expected = sl.Series(
- Param1,
- {1: sl.Series(Row, {0: 4, 1: 5}), 3: sl.Series(Row, {2: 6})},
- )
- assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == expected
-
-
-def test_groupby_by_requesting_series_of_series_preserves_indices() -> None:
- Row = NewType("Row", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
-
- pl = sl.Pipeline()
- pl.set_param_table(
- sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]}, index=[11, 12, 13])
- )
- assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == sl.Series(
- Param1, {1: sl.Series(Row, {11: 4, 12: 5}), 3: sl.Series(Row, {13: 6})}
- )
-
-
-def test_can_groupby_by_param_used_in_ancestor() -> None:
- Row = NewType("Row", int)
- Param = NewType("Param", str)
-
- pl = sl.Pipeline()
- pl.set_param_table(sl.ParamTable(Row, {Param: ['x', 'x', 'y']}))
- expected = sl.Series(
- Param,
- {"x": sl.Series(Row, {0: "x", 1: "x"}), "y": sl.Series(Row, {2: "y"})},
- )
- assert pl.compute(sl.Series[Param, sl.Series[Row, Param]]) == expected
-
-
-def test_multi_level_groupby_raises_with_params_from_same_table() -> None:
- Row = NewType("Row", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
- Param3 = NewType("Param3", int)
-
- pl = sl.Pipeline()
- pl.set_param_table(
- sl.ParamTable(
- Row, {Param1: [1, 1, 1, 3], Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}
- )
- )
- with pytest.raises(ValueError, match='Could not find unique grouping node'):
- pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]])
-
-
-def test_multi_level_groupby_with_params_from_different_table() -> None:
- Row = NewType("Row", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
- Param3 = NewType("Param3", int)
-
- pl = sl.Pipeline()
- grouping1 = sl.ParamTable(Row, {Param2: [0, 1, 1, 2], Param3: [7, 8, 9, 10]})
- # We are not providing an explicit index here, so this only happens to work because
- # the values of Param2 match range(2).
- grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]})
- pl.set_param_table(grouping1)
- pl.set_param_table(grouping2)
- assert pl.compute(
- sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]
- ) == sl.Series(
- Param1,
- {
- 1: sl.Series(
- Param2, {0: sl.Series(Row, {0: 7}), 1: sl.Series(Row, {1: 8, 2: 9})}
- ),
- 3: sl.Series(Param2, {2: sl.Series(Row, {3: 10})}),
- },
- )
-
-
-def test_multi_level_groupby_with_params_from_different_table_can_select() -> None:
- Row = NewType("Row", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
- Param3 = NewType("Param3", int)
-
- pl = sl.Pipeline()
- grouping1 = sl.ParamTable(Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]})
- # Note the missing index "6" here.
- grouping2 = sl.ParamTable(Param2, {Param1: [1, 1]}, index=[4, 5])
- pl.set_param_table(grouping1)
- pl.set_param_table(grouping2)
- assert pl.compute(
- sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]
- ) == sl.Series(
- Param1,
- {
- 1: sl.Series(
- Param2, {4: sl.Series(Row, {0: 7}), 5: sl.Series(Row, {1: 8, 2: 9})}
- )
- },
- )
-
-
-def test_multi_level_groupby_with_params_from_different_table_preserves_index() -> None:
- Row = NewType("Row", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
- Param3 = NewType("Param3", int)
-
- pl = sl.Pipeline()
- grouping1 = sl.ParamTable(
- Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400]
- )
- grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=[4, 5, 6])
- pl.set_param_table(grouping1)
- pl.set_param_table(grouping2)
- assert pl.compute(
- sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]
- ) == sl.Series(
- Param1,
- {
- 1: sl.Series(
- Param2,
- {4: sl.Series(Row, {100: 7}), 5: sl.Series(Row, {200: 8, 300: 9})},
- ),
- 3: sl.Series(Param2, {6: sl.Series(Row, {400: 10})}),
- },
- )
-
-
-def test_multi_level_groupby_with_params_from_different_table_can_reorder() -> None:
- Row = NewType("Row", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
- Param3 = NewType("Param3", int)
-
- pl = sl.Pipeline()
- grouping1 = sl.ParamTable(
- Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400]
- )
- grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=[6, 5, 4])
- pl.set_param_table(grouping1)
- pl.set_param_table(grouping2)
- assert pl.compute(
- sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]
- ) == sl.Series(
- Param1,
- {
- 1: sl.Series(
- Param2,
- {6: sl.Series(Row, {400: 10}), 5: sl.Series(Row, {200: 8, 300: 9})},
- ),
- 3: sl.Series(Param2, {4: sl.Series(Row, {100: 7})}),
- },
- )
-
-
-@pytest.mark.parametrize("index", [None, [4, 5, 7]])
-def test_multi_level_groupby_raises_on_index_mismatch(
- index: Optional[List[int]],
-) -> None:
- Row = NewType("Row", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
- Param3 = NewType("Param3", int)
-
- pl = sl.Pipeline()
- grouping1 = sl.ParamTable(
- Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400]
- )
- grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=index)
- pl.set_param_table(grouping1)
- pl.set_param_table(grouping2)
- with pytest.raises(ValueError):
- pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]])
-
-
-@pytest.mark.parametrize("index", [None, [4, 5, 6]])
-def test_groupby_over_param_table(index: Optional[List[int]]) -> None:
- Index = NewType("Index", int)
- Name = NewType("Name", str)
- Param = NewType("Param", int)
- ProcessedParam = NewType("ProcessedParam", int)
- SummedGroup = NewType("SummedGroup", int)
- ProcessedGroup = NewType("ProcessedGroup", int)
-
- def process_param(x: Param) -> ProcessedParam:
- return ProcessedParam(x + 1)
-
- def sum_group(group: sl.Series[Index, ProcessedParam]) -> SummedGroup:
- return SummedGroup(sum(group.values()))
-
- def process(x: SummedGroup) -> ProcessedGroup:
- return ProcessedGroup(2 * x)
-
- params = sl.ParamTable(
- Index, {Param: [1, 2, 3], Name: ['a', 'a', 'b']}, index=index
- )
- pl = sl.Pipeline([process_param, sum_group, process])
- pl.set_param_table(params)
-
- graph = pl.get(sl.Series[Name, ProcessedGroup])
- assert graph.compute() == sl.Series(Name, {'a': 10, 'b': 8})
-
-
-def test_requesting_series_index_that_is_not_in_param_table_raises() -> None:
- Row = NewType("Row", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
-
- pl = sl.Pipeline()
- pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]}))
-
- with pytest.raises(KeyError):
- pl.compute(sl.Series[int, Param2])
-
-
-def test_requesting_series_index_that_is_a_param_raises_if_not_grouping() -> None:
- Row = NewType("Row", int)
- Param1 = NewType("Param1", int)
- Param2 = NewType("Param2", int)
-
- pl = sl.Pipeline()
- pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]}))
- with pytest.raises(ValueError, match='Could not find unique grouping node'):
- pl.compute(sl.Series[Param1, Param2])
-
-
-def test_generic_providers_work_with_param_tables() -> None:
- Param = TypeVar('Param')
- Row = NewType("Row", int)
-
- class Str(sl.Scope[Param, str], str):
- ...
-
- def parametrized(x: Param) -> Str[Param]:
- return Str(f'{x}')
-
- def make_float() -> float:
- return 1.5
-
- pipeline = sl.Pipeline([make_float, parametrized])
- pipeline.set_param_table(sl.ParamTable(Row, {int: [1, 2, 3]}))
-
- assert pipeline.compute(Str[float]) == Str[float]('1.5')
- with pytest.raises(sl.UnsatisfiedRequirement):
- pipeline.compute(Str[int])
- assert pipeline.compute(sl.Series[Row, Str[int]]) == sl.Series(
- Row,
- {
- 0: Str[int]('1'),
- 1: Str[int]('2'),
- 2: Str[int]('3'),
- },
- )
-
-
-def test_generic_provider_can_depend_on_param_series() -> None:
- Param = TypeVar('Param')
- Row = NewType("Row", int)
-
- class Str(sl.Scope[Param, str], str):
- ...
-
- def parametrized_gather(x: sl.Series[Row, Param]) -> Str[Param]:
- return Str(f'{list(x.values())}')
-
- pipeline = sl.Pipeline([parametrized_gather])
- pipeline.set_param_table(
- sl.ParamTable(Row, {int: [1, 2, 3], float: [1.5, 2.5, 3.5]})
- )
-
- assert pipeline.compute(Str[int]) == Str[int]('[1, 2, 3]')
- assert pipeline.compute(Str[float]) == Str[float]('[1.5, 2.5, 3.5]')
-
-
-def test_generic_provider_can_depend_on_derived_param_series() -> None:
- T = TypeVar('T')
- Row = NewType("Row", int)
-
- class Str(sl.Scope[T, str], str):
- ...
-
- def use_param(x: int) -> float:
- return x + 0.5
-
- def parametrized_gather(x: sl.Series[Row, T]) -> Str[T]:
- return Str(f'{list(x.values())}')
-
- pipeline = sl.Pipeline([parametrized_gather, use_param])
- pipeline.set_param_table(sl.ParamTable(Row, {int: [1, 2, 3]}))
-
- assert pipeline.compute(Str[float]) == Str[float]('[1.5, 2.5, 3.5]')
-
-
-def test_params_in_table_can_be_generic() -> None:
- T = TypeVar('T')
- Row = NewType("Row", int)
-
- class Str(sl.Scope[T, str], str):
- ...
-
- class Param(sl.Scope[T, str], str):
- ...
-
- def parametrized_gather(x: sl.Series[Row, Param[T]]) -> Str[T]:
- return Str(','.join(x.values()))
-
- pipeline = sl.Pipeline([parametrized_gather])
- pipeline.set_param_table(
- sl.ParamTable(Row, {Param[int]: ["1", "2"], Param[float]: ["1.5", "2.5"]})
- )
-
- assert pipeline.compute(Str[int]) == Str[int]('1,2')
- assert pipeline.compute(Str[float]) == Str[float]('1.5,2.5')
-
-
-def test_compute_time_handler_works_alongside_param_table() -> None:
- Missing = NewType("Missing", str)
-
- def process(x: float, missing: Missing) -> str:
- return str(x) + missing
-
- pl = sl.Pipeline([process])
- pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- pl.get(sl.Series[int, str], handler=sl.HandleAsComputeTimeException())
-
-
-def test_param_table_column_and_param_of_same_type_can_coexist() -> None:
- pl = sl.Pipeline()
- pl[float] = 1.0
- pl.set_param_table(sl.ParamTable(int, {float: [2.0, 3.0]}))
- assert pl.compute(float) == 1.0
- assert pl.compute(sl.Series[int, float]) == sl.Series(int, {0: 2.0, 1: 3.0})
-
-
-def test_pipeline_copy_with_param_table() -> None:
- a = sl.Pipeline()
- a.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- b = a.copy()
- assert b.compute(sl.Series[int, float]) == sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0})
-
-
-def test_pipeline_set_param_table_on_original_does_not_affect_copy() -> None:
- a = sl.Pipeline()
- b = a.copy()
- a.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- assert a.compute(sl.Series[int, float]) == sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0})
- with pytest.raises(sl.UnsatisfiedRequirement):
- b.compute(sl.Series[int, float])
-
-
-def test_pipeline_set_param_table_on_copy_does_not_affect_original() -> None:
- a = sl.Pipeline()
- b = a.copy()
- b.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- assert b.compute(sl.Series[int, float]) == sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0})
- with pytest.raises(sl.UnsatisfiedRequirement):
- a.compute(sl.Series[int, float])
-
-
-def test_can_make_html_repr_with_param_table() -> None:
- pl = sl.Pipeline()
- pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
- assert pl._repr_html_()
-
-
-def test_set_param_series_sets_up_pipeline_so_derived_series_can_be_computed() -> None:
- ints = [1, 2, 3]
-
- def to_str(x: int) -> str:
- return str(x)
-
- pl = sl.Pipeline((to_str,))
- pl.set_param_series(int, ints)
- assert pl.compute(sl.Series[int, str]) == sl.Series(int, {1: '1', 2: '2', 3: '3'})
-
-
-def test_multiple_param_series_can_be_broadcast() -> None:
- ints = [1, 2, 3]
- floats = [1.0, 2.0]
-
- def to_str(x: int, y: float) -> str:
- return str(x) + str(y)
-
- pl = sl.Pipeline((to_str,))
- pl.set_param_series(int, ints)
- pl.set_param_series(float, floats)
- assert pl.compute(sl.Series[int, sl.Series[float, str]]) == sl.Series(
- int,
- {
- 1: sl.Series(float, {1.0: '11.0', 2.0: '12.0'}),
- 2: sl.Series(float, {1.0: '21.0', 2.0: '22.0'}),
- 3: sl.Series(float, {1.0: '31.0', 2.0: '32.0'}),
- },
- )
diff --git a/tests/pipeline_with_postponed_annotations_test.py b/tests/pipeline_with_postponed_annotations_test.py
index 12db530b..8af481d5 100644
--- a/tests/pipeline_with_postponed_annotations_test.py
+++ b/tests/pipeline_with_postponed_annotations_test.py
@@ -18,7 +18,7 @@
Int = NewType('Int', int)
Str = NewType('Str', str)
-T = TypeVar('T')
+T = TypeVar('T', Int, Str)
class Number(sl.Scope[T, int], int):
diff --git a/tests/serialize/json_test.py b/tests/serialize/json_test.py
index 465dfe03..834ab948 100644
--- a/tests/serialize/json_test.py
+++ b/tests/serialize/json_test.py
@@ -223,124 +223,6 @@ def test_serialize_kwonlyargs() -> None:
assert res == expected_serialized_kwonlyargs_graph
-def repeated_arg(a: str, b: str) -> list[str]:
- return [a, b]
-
-
-# Ids correspond to the result of assign_predictable_ids
-expected_serialized_repeated_arg_nodes = [
- {
- 'id': '0',
- 'label': 'list[str]',
- 'kind': 'data',
- 'type': 'builtins.list[builtins.str]',
- },
- {
- 'id': '1',
- 'label': 'repeated_arg',
- 'kind': 'function',
- 'function': 'tests.serialize.json_test.repeated_arg',
- 'args': ['101', '102'],
- 'kwargs': {},
- },
- {
- 'id': '2',
- 'label': 'str',
- 'kind': 'data',
- 'type': 'builtins.str',
- },
-]
-expected_serialized_repeated_arg_edges = [
- {'id': '100', 'source': '1', 'target': '0'},
- # The edge is repeated and disambiguated by the `args` of the function node.
- {'id': '101', 'source': '2', 'target': '1'},
- {'id': '102', 'source': '2', 'target': '1'},
-]
-expected_serialized_repeated_arg_graph = {
- 'directed': True,
- 'multigraph': False,
- 'nodes': expected_serialized_repeated_arg_nodes,
- 'edges': expected_serialized_repeated_arg_edges,
-}
-
-
-def test_serialize_repeated_arg() -> None:
- pl = sl.Pipeline([repeated_arg], params={str: 'abc'})
- graph = pl.get(list[str])
- res = graph.serialize()
- res = make_graph_predictable(res)
- assert res == expected_serialized_repeated_arg_graph
-
-
-def repeated_arg_kwonlyarg(a: str, *, b: str) -> list[str]:
- return [a, b]
-
-
-def repeated_kwonlyargs(*, x: int, b: int) -> str:
- return str(x + b)
-
-
-# Ids correspond to the result of assign_predictable_ids
-expected_serialized_repeated_kwonlyarg_nodes = [
- {
- 'id': '0',
- 'label': 'int',
- 'kind': 'data',
- 'type': 'builtins.int',
- },
- {
- 'id': '1',
- 'label': 'list[str]',
- 'kind': 'data',
- 'type': 'builtins.list[builtins.str]',
- },
- {
- 'id': '2',
- 'label': 'repeated_arg_kwonlyarg',
- 'kind': 'function',
- 'function': 'tests.serialize.json_test.repeated_arg_kwonlyarg',
- 'args': ['103'],
- 'kwargs': {'b': '104'},
- },
- {
- 'id': '3',
- 'label': 'str',
- 'kind': 'data',
- 'type': 'builtins.str',
- },
- {
- 'id': '4',
- 'label': 'repeated_kwonlyargs',
- 'kind': 'function',
- 'function': 'tests.serialize.json_test.repeated_kwonlyargs',
- 'args': [],
- 'kwargs': {'x': '100', 'b': '101'},
- },
-]
-expected_serialized_repeated_kwonlyarg_edges = [
- {'id': '100', 'source': '0', 'target': '4'},
- {'id': '101', 'source': '0', 'target': '4'},
- {'id': '102', 'source': '2', 'target': '1'},
- {'id': '103', 'source': '3', 'target': '2'},
- {'id': '104', 'source': '3', 'target': '2'},
- {'id': '105', 'source': '4', 'target': '3'},
-]
-expected_serialized_repeated_kwonlyarg_graph = {
- 'directed': True,
- 'multigraph': False,
- 'nodes': expected_serialized_repeated_kwonlyarg_nodes,
- 'edges': expected_serialized_repeated_kwonlyarg_edges,
-}
-
-
-def test_serialize_repeated_konlywarg() -> None:
- pl = sl.Pipeline([repeated_arg_kwonlyarg, repeated_kwonlyargs], params={int: 4})
- graph = pl.get(list[str])
- res = graph.serialize()
- res = make_graph_predictable(res)
- assert res == expected_serialized_repeated_kwonlyarg_graph
-
-
# Ids correspond to the result of assign_predictable_ids
expected_serialized_lambda_nodes = [
{
@@ -398,14 +280,6 @@ def __call__(self, x: int) -> float:
graph.serialize()
-def test_serialize_param_table() -> None:
- pl = sl.Pipeline([as_float])
- pl.set_param_table(sl.ParamTable(str, {int: [3, -5]}))
- graph = pl.get(sl.Series[str, float])
- with pytest.raises(ValueError):
- graph.serialize()
-
-
def test_serialize_validate_schema() -> None:
pl = sl.Pipeline([make_int_b, zeros, to_string], params={Int[A]: 3})
graph = pl.get(str)
diff --git a/tests/task_graph_test.py b/tests/task_graph_test.py
index 799c231c..4a31b5cf 100644
--- a/tests/task_graph_test.py
+++ b/tests/task_graph_test.py
@@ -31,7 +31,9 @@ def as_float(x: int) -> float:
def make_task_graph() -> Graph:
pl = sl.Pipeline([as_float], params={int: 1})
- return pl.build(float, handler=sl.HandleAsBuildTimeException())
+ return sl.data_graph.to_task_graph(
+ pl, (float,), handler=sl.HandleAsBuildTimeException()
+ )
def test_default_scheduler_is_dask_when_dask_available() -> None:
diff --git a/tests/utils_test.py b/tests/utils_test.py
index 96054fc2..bd3d77c1 100644
--- a/tests/utils_test.py
+++ b/tests/utils_test.py
@@ -8,7 +8,6 @@
import sciline
from sciline import _utils
from sciline._provider import Provider
-from sciline.typing import Item, Label
def module_foo(x: list[str]) -> str:
@@ -50,11 +49,6 @@ def test_key_name_type_var() -> None:
assert _utils.key_name(MyType) == '~MyType' # type: ignore[arg-type]
-def test_key_name_item() -> None:
- item = Item(tp=int, label=(Label(tp=float, index=0), Label(tp=str, index=1)))
- assert _utils.key_name(item) == 'int(float:0, str:1)'
-
-
def test_key_name_builtin_generic() -> None:
MyType = NewType('MyType', str)
assert _utils.key_name(list) == 'list'
@@ -128,14 +122,7 @@ def test_key_full_qualname_type_var() -> None:
assert res == 'tests.utils_test.~MyType'
-def test_key_full_qualname_item() -> None:
- item = Item(tp=int, label=(Label(tp=float, index=0), Label(tp=str, index=1)))
- assert (
- _utils.key_full_qualname(item)
- == 'builtins.int(builtins.float:0, builtins.str:1)'
- )
-
-
+@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher")
def test_key_full_qualname_builtin_generic() -> None:
MyType = NewType('MyType', str)
assert _utils.key_full_qualname(list) == 'builtins.list'
diff --git a/tests/visualize_test.py b/tests/visualize_test.py
index ea404af2..266c0cf9 100644
--- a/tests/visualize_test.py
+++ b/tests/visualize_test.py
@@ -43,5 +43,5 @@ class SubA(A[T]):
def test_optional_types_formatted_as_their_content() -> None:
- formatted = sl.visualize._format_type(Optional[float]) # type: ignore[arg-type]
+ formatted = sl.visualize._format_type(Optional[float])
assert formatted.name == 'float'