diff --git a/docs/api-reference/index.md b/docs/api-reference/index.md index 9b072629..0982117f 100644 --- a/docs/api-reference/index.md +++ b/docs/api-reference/index.md @@ -10,14 +10,15 @@ :template: class-template.rst :recursive: - ParamTable Pipeline Scope - Series + ScopeTwoParams scheduler.Scheduler scheduler.DaskScheduler scheduler.NaiveScheduler TaskGraph + HandleAsBuildTimeException + HandleAsComputeTimeException ``` ## Exceptions @@ -28,7 +29,6 @@ :template: class-template.rst :recursive: - AmbiguousProvider UnboundTypeVar UnsatisfiedRequirement ``` diff --git a/docs/developer/architecture-and-design/rewrite.ipynb b/docs/developer/architecture-and-design/rewrite.ipynb new file mode 100644 index 00000000..b8d9f43d --- /dev/null +++ b/docs/developer/architecture-and-design/rewrite.ipynb @@ -0,0 +1,281 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Rewrite of Sciline's Pipeline as a Data Graph\n", + "\n", + "## Introduction\n", + "\n", + "There has been a series of issues and discussions about Sciline's `Pipeline` class and its implementation.\n", + "\n", + "- Detect unused parameters [#43](https://github.com/scipp/sciline/issues/43).\n", + "- More helpful error messages when pipeline fails to build or compute? [#74](https://github.com/scipp/sciline/issues/74).\n", + "- Get missing params from a pipeline [#83](https://github.com/scipp/sciline/issues/83).\n", + "- Support for graph operations [#107](https://github.com/scipp/sciline/issues/107).\n", + "- Supporting different file handle types is too difficult [#140](https://github.com/scipp/sciline/issues/140).\n", + "- A new approach for \"parameter tables\" [#141](https://github.com/scipp/sciline/issues/141).\n", + "- Pruning for repeated workflow calls [#148](https://github.com/scipp/sciline/issues/148)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Current implementation\n", + "\n", + "- `sciline.Pipeline` is a box that can be filled with providers (a provider is callable that can compute a value) as well as values.\n", + "- Providers can provide generic types.\n", + " The concrete types and values that such providers compute is determined *later*, when the pipeline is built, based on which instances of the generic outputs are requested (by other providers or by the user when building the pipeline).\n", + "- Parameter tables and a special `sciline.Series` type are supported to create task graphs with duplicate branches and \"reduction\" or grouping operations.\n", + "- The pipeline is built by calling `build` on it, which returns a `sciline.TaskGraph`.\n", + " Most of the complexity is handled in this step.\n", + "\n", + "The presence of generic providers as well as parameter tables makes the implementation of the pipeline quite complex.\n", + "It implies that internally a pipeline is *not* representable as a graph, as (1) generics lead to a task-graph structure that is in principle undefined until the pipeline is built, and (2) parameter tables lead to implicit duplication of task graph branches, which means that if `Pipeline` would internally use a graph representation, adding or replacing providers would conflict with the duplicate structure." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Proposal\n", + "\n", + "The key idea of this proposal is to introduce `sciline.DataGraph`, a directed acyclic graph (DAG), which can roughly be thought of a graph representation of the pipeline.\n", + "The data graph describes dependencies between data, defined via the type-hints of providers.\n", + "Providers (or values) are stored as node data.\n", + "\n", + "As the support for generic providers was a hindrance in the current implementation, we propose to restrict this to generic return types *with constraints*.\n", + "This means that such a provider defines a *known* set of outputs, and the data graph can thus be updated with multiple nodes, each with the same provider.\n", + "\n", + "The support for parameter tables would be replaced by using `map` and `reduce` operations on the data graph.[2](#f2)\n", + "\n", + "1. [^](#a1)\n", + " Whether `Pipeline` will be kept as a wrapper around `DataGraph` or whether `DataGraph` will be the main interface is not yet clear.\n", + "2. [^](#a2)\n", + " This has been prototyped in the `cyclebane` library.\n", + " Whether this would be *integrated into* or *used by* Sciline is not yet clear.\n", + "\n", + "### Note on chosen implementation\n", + "\n", + "Keeping the existing `Pipeline` interface, the new functionality has been added in the `DataGraph` class, which has been made a base class of `Pipeline`.\n", + "`DataGraph` is implemented as a wrapper for `cyclebane.Graph`, a new and generic support library based on NetworkX.\n", + "\n", + "### Example 1: Basic DataGraph" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sciline\n", + "\n", + "\n", + "def f1() -> float:\n", + " return 1.0\n", + "\n", + "\n", + "def f2(a: float, b: str) -> int:\n", + " return int(a) + len(b)\n", + "\n", + "\n", + "def f3(a: int) -> list[int]:\n", + " return list(range(a))\n", + "\n", + "\n", + "data_graph = sciline.Pipeline([f1, f3, f2])\n", + "data_graph.visualize_data_graph(graph_attr={'rankdir': 'LR'})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can add a value for `str` using `__setitem__`, build a `sciline.TaskGraph`, and compute the result:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_graph[str] = 'abcde'\n", + "task_graph = data_graph.get(list[int])\n", + "task_graph.compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "task_graph.visualize(graph_attr={'rankdir': 'LR'})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example 2: DataGraph with generic provider" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import TypeVar\n", + "import sciline\n", + "\n", + "T = TypeVar('T', int, float) # The constraints are mandatory now!\n", + "\n", + "\n", + "def make_list(length: T) -> list[T]:\n", + " return [length, length + length]\n", + "\n", + "\n", + "def make_dict(key: list[int], value: list[float]) -> dict[int, float]:\n", + " return dict(zip(key, value))\n", + "\n", + "\n", + "data_graph = sciline.Pipeline([make_list, make_dict])\n", + "data_graph.visualize_data_graph(graph_attr={'rankdir': 'LR'})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_graph[int] = 3\n", + "data_graph[float] = 1.2\n", + "data_graph.get(dict[int, float]).compute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Example 3: DataGraph with map and reduce" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sciline\n", + "\n", + "\n", + "def f1(x: float) -> str:\n", + " return str(x)\n", + "\n", + "\n", + "def f2(x: str) -> int:\n", + " return len(x)\n", + "\n", + "\n", + "def f3(a: int) -> list[int]:\n", + " return list(range(a))\n", + "\n", + "\n", + "data_graph = sciline.Pipeline([f1, f2, f3])\n", + "data_graph.visualize_data_graph(graph_attr={'rankdir': 'LR'})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "params = pd.DataFrame({float: [0.1, 1.0, 10.0]})\n", + "params" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def concat_strings(*strings: str) -> str:\n", + " return '+'.join(strings)\n", + "\n", + "\n", + "data_graph[str] = data_graph[str].map(params).reduce(func=concat_strings)\n", + "data_graph.visualize_data_graph(graph_attr={'rankdir': 'LR'})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tg = data_graph.get(list[int])\n", + "tg.visualize(graph_attr={'rankdir': 'LR'})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tg.compute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Criticism\n", + "\n", + "The `map` and `reduce` operations kind of break out of the core idea of Sciline.\n", + "It is some sort of intermediate state between declarative and imperative programming (as in Sciline and Dask, respectively).\n", + "The example above may be re-imagined as something along the lines of\n", + "\n", + "```python\n", + "# Assuming with_value returns a copy of the graph with the value set\n", + "branches = map(data_graph[str].with_value, params[float])\n", + "# Not actually `dask.delayed`, but you get the idea\n", + "data_graph[str] = dask.delayed(concat_strings)(branches)\n", + "```\n", + "\n", + "The graph could then be optimized to remove duplicate nodes (part of `data_graph[str]`, but not an descendant of `float`)." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dev310", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/user-guide/generic-providers.ipynb b/docs/user-guide/generic-providers.ipynb index 10b8ddc0..1dd89be1 100644 --- a/docs/user-guide/generic-providers.ipynb +++ b/docs/user-guide/generic-providers.ipynb @@ -33,22 +33,22 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import TypeVar, List\n", + "from typing import TypeVar\n", "import sciline\n", "\n", - "T = TypeVar(\"T\")\n", + "T = TypeVar(\"T\", int, float, str)\n", "\n", "\n", - "def duplicate(x: T) -> List[T]:\n", + "def duplicate(x: T) -> list[T]:\n", " \"\"\"A generic provider that can make any list.\"\"\"\n", " return [x, x]\n", "\n", "\n", "pipeline = sciline.Pipeline([duplicate], params={int: 1, float: 2.0, str: \"3\"})\n", "\n", - "print(pipeline.compute(List[int]))\n", - "print(pipeline.compute(List[float]))\n", - "print(pipeline.compute(List[str]))" + "print(pipeline.compute(list[int]))\n", + "print(pipeline.compute(list[float]))\n", + "print(pipeline.compute(list[str]))" ] }, { @@ -198,25 +198,22 @@ "\n", "# 1. Define domain types\n", "\n", - "# 1.a Define generic domain types\n", - "RunType = TypeVar('RunType')\n", + "# 1.a Define concrete RunType values we will use.\n", + "Sample = NewType('Sample', int)\n", + "Background = NewType('Background', int)\n", "\n", + "# 1.b Define generic domain types\n", + "RunType = TypeVar('RunType', Sample, Background)\n", "\n", - "class Filename(sciline.Scope[RunType, str], str):\n", - " ...\n", "\n", + "class Filename(sciline.Scope[RunType, str], str): ...\n", "\n", - "class RawData(sciline.Scope[RunType, dict], dict):\n", - " ...\n", "\n", + "class RawData(sciline.Scope[RunType, dict], dict): ...\n", "\n", - "class CleanedData(sciline.Scope[RunType, list], list):\n", - " ...\n", "\n", + "class CleanedData(sciline.Scope[RunType, list], list): ...\n", "\n", - "# 1.b Define concrete RunType values we will use.\n", - "Sample = NewType('Sample', int)\n", - "Background = NewType('Background', int)\n", "\n", "# 1.c Define normal domain types\n", "ScaleFactor = NewType('ScaleFactor', float)\n", @@ -274,7 +271,20 @@ "\n", "
\n", "\n", - "Note\n", + "**Note**\n", + "\n", + "Sciline requires type variables that are used as part of keys to have [constraints](https://docs.python.org/3/library/typing.html#typing.TypeVar).\n", + "In the above example, we define a type variable `RunType` that is constrained to be either a `Sample` or a `Background`:\n", + "\n", + "```python\n", + "RunType = TypeVar('RunType', Sample, Background)\n", + "```\n", + "\n", + "
\n", + "\n", + "
\n", + "\n", + "**Note**\n", "\n", "We use a peculiar-looking syntax for defining \"generic type aliases\".\n", "We would love to use [typing.NewType](https://docs.python.org/3/library/typing.html#typing.NewType) for this, but it does not allow for definition of generic aliases.\n", @@ -360,9 +370,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Generic providers and parameter tables\n", + "### Generic providers and map operations\n", "\n", - "As a more complex example of where generic providers are useful, we may add a parameter table, so we can process multiple samples:" + "As a more complex example of where generic providers are useful, we may add `map` operation, so we can process multiple samples:" ] }, { @@ -371,33 +381,17 @@ "metadata": {}, "outputs": [], "source": [ - "RunID = NewType('RunID', int)\n", "run_ids = [102, 103, 104, 105]\n", "filenames = [f'file{i}.txt' for i in run_ids]\n", - "param_table = sciline.ParamTable(RunID, {Filename[Sample]: filenames}, index=run_ids)\n", - "param_table" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can now create a parametrized pipeline:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ "params = {\n", " ScaleFactor: 2.0,\n", " Filename[Background]: 'background.txt',\n", "}\n", "pipeline = sciline.Pipeline(providers, params=params)\n", - "pipeline.set_param_table(param_table)\n", - "graph = pipeline.get(sciline.Series[RunID, Result])\n", + "pipeline = pipeline.map({Filename[Sample]: filenames})\n", + "\n", + "# We can collect the results into a list for simplicity in this example\n", + "graph = pipeline.reduce(func=lambda *x: list[x], name='collected').get('collected')\n", "graph.visualize()" ] }, @@ -409,15 +403,6 @@ "source": [ "graph.compute()" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "With the generics mechanism, we could have added a line for the background to the parameter table.\n", - "However, the `subtrack_background` function would then have to be modified to accept a `Series[RunID, CleanedData]`.\n", - "More importantly, this would have resulted in a synchronization point in the computation graph, preventing efficient scheduling of the subsequent computation, with potentially disastrous effects on memory consumption." - ] } ], "metadata": { diff --git a/docs/user-guide/parameter-tables.ipynb b/docs/user-guide/parameter-tables.ipynb index aeae0ba8..bb7c70b3 100644 --- a/docs/user-guide/parameter-tables.ipynb +++ b/docs/user-guide/parameter-tables.ipynb @@ -9,9 +9,10 @@ "\n", "## Overview\n", "\n", - "Parameter tables provide a mechanism for repeating parts of or all of a computation with different values for one or more parameters.\n", + "Sciline supports a mechanism for repeating parts of or all of a computation with different values for one or more parameters.\n", "This allows for a variety of use cases, similar to *map*, *reduce*, and *groupby* operations in other systems.\n", - "We illustrate each of these in the follow three chapters." + "We illustrate each of these in the follow three chapters.\n", + "Sciline's implementation is based on [Cyclebane](scipp.github.io/cyclebane)." ] }, { @@ -20,9 +21,10 @@ "source": [ "## Computing results for series of parameters\n", "\n", - "This chapter illustrates how to implement *map* operations with Sciline.\n", + "This chapter illustrates how to perform *map* operations with Sciline.\n", "\n", - "Starting with the model workflow introduced in [Getting Started](getting-started.ipynb), we can replace the fixed `Filename` parameter with a series of filenames listed in a [ParamTable](../generated/classes/sciline.ParamTable.rst):" + "Starting with the model workflow introduced in [Getting Started](getting-started.ipynb), we would like to replace the fixed `Filename` parameter with a series of filenames listed in a \"parameter table\".\n", + "We begin by defining the base pipeline:" ] }, { @@ -74,23 +76,17 @@ "\n", "# 3. Create pipeline\n", "\n", - "# 3.a Providers and normal parameters\n", "providers = [load, clean, process]\n", "params = {ScaleFactor: 2.0}\n", - "\n", - "# 3.b Parameter table\n", - "RunID = NewType('RunID', int)\n", - "run_ids = [102, 103, 104, 105]\n", - "filenames = [f'file{i}.txt' for i in run_ids]\n", - "param_table = sciline.ParamTable(RunID, {Filename: filenames}, index=run_ids)" + "base = sciline.Pipeline(providers, params=params)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Note how steps 1.) and 2.) are identical to those from the example without parameter table.\n", - "Above we have created the following parameter table:" + "Aside from not having defined a value for the `Filename` parameter, this is identical to the example in [Getting Started](getting-started.ipynb).\n", + "The task-graph visualization indicates this missing parameter:" ] }, { @@ -99,14 +95,14 @@ "metadata": {}, "outputs": [], "source": [ - "param_table" + "base.visualize(Result, graph_attr={'rankdir': 'LR'})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can now create the pipeline and set the parameter table:" + "We now define a \"parameter table\" listing the filenames we would like to process:" ] }, { @@ -115,16 +111,26 @@ "metadata": {}, "outputs": [], "source": [ - "# 3.c Setup pipeline\n", - "pipeline = sciline.Pipeline(providers, params=params)\n", - "pipeline.set_param_table(param_table)" + "import pandas as pd\n", + "\n", + "run_ids = [102, 103, 104, 105]\n", + "filenames = [f'file{i}.txt' for i in run_ids]\n", + "param_table = pd.DataFrame({Filename: filenames}, index=run_ids).rename_axis(\n", + " index='run_id'\n", + ")\n", + "param_table" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Then we can compute `Result` for each index in the parameter table:" + "Note how we used a node name of the pipeline as the column name in the parameter table.\n", + "For convenience we used a `pandas.DataFrame` to represent the table above, but the use of Pandas is entirely optional.\n", + "Equivalently the table could be represented as a `dict`, where each key corresponds to a column header and each value is a list of values for that column, i.e., `{Filename: filenames}`.\n", + "Specifying an index is currently not possible in this case, and it will default to a range index.\n", + "\n", + "We can now use [Pipeline.map](https://scipp.github.io/sciline/generated/classes/sciline.Pipeline.html#sciline.Pipeline.map) to create a modified pipeline that processes each row in the parameter table:" ] }, { @@ -133,18 +139,35 @@ "metadata": {}, "outputs": [], "source": [ - "pipeline.compute(sciline.Series[RunID, Result])" + "pipeline = base.map(param_table)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ + "Then we can compute `Result` for each index in the parameter table.\n", + "Currently there is no convenient way of accessing these, instead we manually define the target nodes to compute:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from cyclebane.graph import NodeName, IndexValues\n", "\n", - "`sciline.Series` is a special `dict`-like type that signals to Sciline that the values of the series are based on values from one or more columns of a parameter table.\n", - "The parameter table is identified using the first argument to `Series`, in this case `RunID`.\n", - "The second argument specifies the result to be computed.\n", - "\n", + "targets = [NodeName(Result, IndexValues(('run_id',), (i,))) for i in run_ids]\n", + "pipeline.compute(targets)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note the use of the `run_id` index.\n", + "If the index axis of the DataFrame has no name then a default of `dim_0`, `dim_1`, etc. is used.\n", "We can also visualize the task graph for computing the series of `Result` values:" ] }, @@ -154,15 +177,14 @@ "metadata": {}, "outputs": [], "source": [ - "pipeline.visualize(sciline.Series[RunID, Result])" + "pipeline.visualize(targets)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Nodes that depend on values from a parameter table are drawn with the parameter index name (the row dimension of the parameter table) and value given in parenthesis.\n", - "The dashed arrow indicates and internal transformation that gathers result from each branch and combines them into a single output, here `Series[RunID, Result]`.\n", + "Nodes that depend on values from a parameter table are drawn with the parameter index name (the row dimension of the parameter table) and index value (defaulting to a range index starting at 0 if no index if given) shown in parenthesis.\n", "\n", "
\n", "\n", @@ -185,8 +207,7 @@ "\n", "This chapter illustrates how to implement *reduce* operations with Sciline.\n", "\n", - "Instead of requesting a series of results as above, we can also build pipelines with providers that depend on such series.\n", - "We can create a new pipeline, or extend the existing one by inserting a new provider:\n" + "Instead of requesting a series of results as above, we use the [Pipeline.reduce](https://scipp.github.io/sciline/generated/classes/sciline.Pipeline.html#sciline.Pipeline.reduce) method and pass a function that combines the results from each parameter into a single result:" ] }, { @@ -195,23 +216,31 @@ "metadata": {}, "outputs": [], "source": [ - "MergedResult = NewType('MergedResult', float)\n", - "\n", + "graph = pipeline.reduce(func=lambda *result: sum(result), name='merged').get('merged')\n", + "graph.visualize()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", "\n", - "def merge_runs(runs: sciline.Series[RunID, Result]) -> MergedResult:\n", - " return MergedResult(sum(runs.values()))\n", + "**Note**\n", "\n", + "The `func` passed to `reduce` is *not* making use of Sciline's mechanism of assembling a graph based on type hints.\n", + "In particular, the input type may be identical to the output type.\n", + "The [Pipeline.reduce](https://scipp.github.io/sciline/generated/classes/sciline.Pipeline.html#sciline.Pipeline.reduce) method adds a *new* node, attached at a unique (but mapped) sink node of the graph.\n", + "[Pipeline.__getitem__](https://scipp.github.io/sciline/generated/classes/sciline.Pipeline.html#sciline.Pipeline.__getitem__) and [Pipeline.__setitem__](https://scipp.github.io/sciline/generated/classes/sciline.Pipeline.html#sciline.Pipeline.__setitem__) can be used to compose more complex graphs where the reduction is not the final step.\n", "\n", - "pipeline.insert(merge_runs)\n", - "graph = pipeline.get(MergedResult)\n", - "graph.visualize()" + "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Note that this is identical to the example in the previous section, except for the last two nodes in the graph.\n", + "Note that the graph shown above is identical to the example in the previous section, except for the last two nodes in the graph.\n", "The computation now returns a single result:" ] }, @@ -231,12 +260,29 @@ "This is useful if we need to continue computation after gathering results without setting up a second pipeline." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "
\n", + "\n", + "**Note**\n", + "\n", + "For the `reduce` operation, all inputs to the reduction function have to be kept in memory simultaneously.\n", + "This can be very memory intensive.\n", + "We intend to support, e.g., hierarchical reduction operations in the future, where intermediate results are combined in a tree-like fashion to avoid excessive memory consumption..\n", + "\n", + "
" + ] + }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Grouping intermediate results based on secondary parameters\n", "\n", + "**Cyclebane and Sciline do not support `groupby` yet, this is work in progress so this example is not functional yet.**\n", + "\n", "This chapter illustrates how to implement *groupby* operations with Sciline.\n", "\n", "Continuing from the examples for *map* and *reduce*, we can introduce a secondary parameter in the table, such as the material of the sample:" @@ -250,123 +296,35 @@ "source": [ "Material = NewType('Material', str)\n", "\n", - "# 3.a Providers and normal parameters\n", - "providers = [load, clean, process, merge_runs]\n", - "params = {ScaleFactor: 2.0}\n", - "\n", - "# 3.b Parameter table\n", "run_ids = [102, 103, 104, 105]\n", "sample = ['diamond', 'graphite', 'graphite', 'graphite']\n", "filenames = [f'file{i}.txt' for i in run_ids]\n", - "param_table = sciline.ParamTable(\n", - " RunID, {Filename: filenames, Material: sample}, index=run_ids\n", - ")\n", + "param_table = pd.DataFrame(\n", + " {Filename: filenames, Material: sample}, index=run_ids\n", + ").rename_axis(index='run_id')\n", "param_table" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# 3.c Setup pipeline\n", - "pipeline = sciline.Pipeline(providers, params=params)\n", - "pipeline.set_param_table(param_table)" - ] - }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We can now compute `MergedResult` for a series of \"materials\":" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pipeline.compute(sciline.Series[Material, MergedResult])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The computation looks as show below.\n", - "Note how the initial steps of the computation depend on the `RunID` parameter, while later steps depend on `Material`:\n", - "The files for each run ID have been grouped by their material and then merged:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pipeline.visualize(sciline.Series[Material, MergedResult])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## More examples" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Using tables for series of parameters\n", - "\n", - "Sometimes the parameter of interest is the index of a parameter table itself.\n", - "If there are no further parameters, the param table may have no columns (aside from the index).\n", - "In this case we can bypass the manual creation of a parameter table and use the [Pipeline.set_param_series](../generated/classes/sciline.Pipeline.rst#sciline.Pipeline.set_param_series) function instead:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from typing import NewType\n", - "import sciline as sl\n", - "\n", - "Param = NewType(\"Param\", int)\n", - "Sum = NewType(\"Sum\", float)\n", - "\n", - "\n", - "def compute(x: Param) -> float:\n", - " return 0.5 * x\n", - "\n", + "Future releases of Sciline will support a `groupby` operation, roughly as follows:\n", "\n", - "def gather(x: sl.Series[Param, float]) -> Sum:\n", - " return Sum(sum(x.values()))\n", + "```python\n", + "pipeline = base.map(param_table).groupby(Material).reduce(func=merge)\n", + "```\n", "\n", - "\n", - "pl = sl.Pipeline([gather, compute])\n", - "pl.set_param_series(Param, [1, 4, 9])\n", - "pl.visualize(Sum)" + "We can then compute the merged result, grouped by the value of `Material`.\n", + "Note how the initial steps of the computation depend on the `run_id` index name, while later steps depend on `Material`, a new index name defined by the `groupby` operation.\n", + "The files for each run ID have been grouped by their material and then merged." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Note that `pl.set_param_series(Param, [1, 4, 9])` above is equivalent to `pl.set_param_table(sl.ParamTable(Param, columns={}, index=[1, 4, 9]))`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pl.compute(Sum)" + "## More examples" ] }, { @@ -387,21 +345,19 @@ "Sum = NewType(\"Sum\", float)\n", "Param1 = NewType(\"Param1\", int)\n", "Param2 = NewType(\"Param2\", int)\n", - "Row = NewType(\"Run\", int)\n", "\n", "\n", - "def gather(\n", - " x: sl.Series[Row, float],\n", - ") -> Sum:\n", - " return Sum(sum(x.values()))\n", + "def gather(*x: float) -> Sum:\n", + " return Sum(sum(x))\n", "\n", "\n", "def product(x: Param1, y: Param2) -> float:\n", " return x / y\n", "\n", "\n", - "pl = sl.Pipeline([gather, product])\n", - "pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]}))\n", + "params = pd.DataFrame({Param1: [1, 4, 9], Param2: [1, 2, 3]})\n", + "pl = sl.Pipeline([product])\n", + "pl = pl.map(params).reduce(func=gather, name=Sum)\n", "\n", "pl.visualize(Sum)" ] @@ -432,11 +388,10 @@ "Param = NewType(\"Param\", int)\n", "Param1 = NewType(\"Param1\", int)\n", "Param2 = NewType(\"Param2\", int)\n", - "Row = NewType(\"Run\", int)\n", "\n", "\n", - "def gather(x: sl.Series[Row, float]) -> Sum:\n", - " return Sum(sum(x.values()))\n", + "def gather(*x: float) -> float:\n", + " return sum(x)\n", "\n", "\n", "def to_param1(x: Param) -> Param1:\n", @@ -451,8 +406,9 @@ " return x * y\n", "\n", "\n", - "pl = sl.Pipeline([gather, product, to_param1, to_param2])\n", - "pl.set_param_table(sl.ParamTable(Row, {Param: [1, 2, 3]}))\n", + "pl = sl.Pipeline([product, to_param1, to_param2])\n", + "params = pd.DataFrame({Param: [1, 2, 3]})\n", + "pl = pl.map(params).reduce(func=gather, name=Sum)\n", "pl.visualize(Sum)" ] }, @@ -469,40 +425,37 @@ "metadata": {}, "outputs": [], "source": [ + "from typing import Any\n", "import sciline as sl\n", "\n", - "List1 = NewType(\"List1\", float)\n", - "List2 = NewType(\"List2\", float)\n", "Param1 = NewType(\"Param1\", int)\n", "Param2 = NewType(\"Param2\", int)\n", - "Row1 = NewType(\"Row1\", int)\n", - "Row2 = NewType(\"Row2\", int)\n", - "\n", "\n", - "def gather1(x: sl.Series[Row1, float]) -> List1:\n", - " return List1(list(x.values()))\n", "\n", - "\n", - "def gather2(x: sl.Series[Row2, List1]) -> List2:\n", - " return List2(list(x.values()))\n", + "def gather(*x: Any) -> list[Any]:\n", + " return list(x)\n", "\n", "\n", "def product(x: Param1, y: Param2) -> float:\n", " return x * y\n", "\n", "\n", - "pl = sl.Pipeline([gather1, gather2, product])\n", - "pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 4, 9]}))\n", - "pl.set_param_table(sl.ParamTable(Row2, {Param2: [1, 2]}))\n", + "base = sl.Pipeline([product])\n", + "pl = (\n", + " base.map({Param1: [1, 4, 9]})\n", + " .map({Param2: [1, 2]})\n", + " .reduce(func=gather, name='reduce_1', index='dim_1')\n", + " .reduce(func=gather, name='reduce_0')\n", + ")\n", "\n", - "pl.visualize(List2)" + "pl.visualize('reduce_0')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Note how intermediates such as `float(Row1, Row2)` depend on two parameters, i.e., we are dealing with a 2-D array of branches in the graph." + "Note how intermediates such as `float(dim_1, dim_0)` depend on two parameters, i.e., we are dealing with a 2-D array of branches in the graph." ] }, { @@ -511,7 +464,29 @@ "metadata": {}, "outputs": [], "source": [ - "pl.compute(List2)" + "pl.compute('reduce_0')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "It is also possible to reduce multiple axes at once.\n", + "For example, `reduce` will reduce all axes if no `index` or `axis` is specified:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pl = (\n", + " base.map({Param1: [1, 4, 9]})\n", + " .map({Param2: [1, 2]})\n", + " .reduce(func=gather, name='reduce_both')\n", + ")\n", + "pl.visualize('reduce_both')" ] } ], diff --git a/pyproject.toml b/pyproject.toml index 13ab13fd..d4be62be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ requires-python = ">=3.10" # Run 'tox -e deps' after making changes here. This will update requirement files. # Make sure to list one dependency per line. dependencies = [ + "cyclebane", ] dynamic = ["version"] diff --git a/requirements/base.in b/requirements/base.in index b801db0e..3baebf5e 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -2,4 +2,4 @@ # will not be touched by ``make_base.py`` # --- END OF CUSTOM SECTION --- # The following was generated by 'tox -e deps', DO NOT EDIT MANUALLY! - +cyclebane diff --git a/requirements/base.txt b/requirements/base.txt index c24fd27f..eaed248d 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,8 +1,11 @@ -# SHA1:da39a3ee5e6b4b0d3255bfef95601890afd80709 +# SHA1:dd62fc84fce7fde1639783710b14ec7a54d8cdc5 # # This file is autogenerated by pip-compile-multi # To update, run: # # pip-compile-multi # - +cyclebane==24.5.0 + # via -r base.in +networkx==3.3 + # via cyclebane diff --git a/requirements/basetest.txt b/requirements/basetest.txt index b36e420b..b5c71f98 100644 --- a/requirements/basetest.txt +++ b/requirements/basetest.txt @@ -9,13 +9,13 @@ attrs==23.2.0 # via # jsonschema # referencing -exceptiongroup==1.2.0 +exceptiongroup==1.2.1 # via pytest graphviz==0.20.3 # via -r basetest.in iniconfig==2.0.0 # via pytest -jsonschema==4.21.1 +jsonschema==4.22.0 # via -r basetest.in jsonschema-specifications==2023.12.1 # via jsonschema @@ -23,15 +23,15 @@ numpy==1.26.4 # via -r basetest.in packaging==24.0 # via pytest -pluggy==1.4.0 +pluggy==1.5.0 # via pytest -pytest==8.1.1 +pytest==8.2.1 # via -r basetest.in -referencing==0.34.0 +referencing==0.35.1 # via # jsonschema # jsonschema-specifications -rpds-py==0.18.0 +rpds-py==0.18.1 # via # jsonschema # referencing diff --git a/requirements/ci.txt b/requirements/ci.txt index 88d49ee6..ab682a51 100644 --- a/requirements/ci.txt +++ b/requirements/ci.txt @@ -17,7 +17,7 @@ colorama==0.4.6 # via tox distlib==0.3.8 # via virtualenv -filelock==3.13.4 +filelock==3.14.0 # via # tox # virtualenv @@ -32,15 +32,15 @@ packaging==24.0 # -r ci.in # pyproject-api # tox -platformdirs==4.2.0 +platformdirs==4.2.2 # via # tox # virtualenv -pluggy==1.4.0 +pluggy==1.5.0 # via tox pyproject-api==1.6.1 # via tox -requests==2.31.0 +requests==2.32.3 # via -r ci.in smmap==5.0.1 # via gitdb @@ -48,9 +48,9 @@ tomli==2.0.1 # via # pyproject-api # tox -tox==4.14.2 +tox==4.15.0 # via -r ci.in urllib3==2.2.1 # via requests -virtualenv==20.25.1 +virtualenv==20.26.2 # via tox diff --git a/requirements/dev.txt b/requirements/dev.txt index adba5090..ab785ebb 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -12,9 +12,9 @@ -r static.txt -r test.txt -r wheels.txt -annotated-types==0.6.0 +annotated-types==0.7.0 # via pydantic -anyio==4.3.0 +anyio==4.4.0 # via # httpx # jupyter-server @@ -34,7 +34,7 @@ click==8.1.7 # pip-tools copier==9.2.0 # via -r dev.in -dunamai==1.20.0 +dunamai==1.21.1 # via copier fqdn==1.5.1 # via jsonschema @@ -54,7 +54,7 @@ json5==0.9.25 # via jupyterlab-server jsonpointer==2.4 # via jsonschema -jsonschema[format-nongpl]==4.21.1 +jsonschema[format-nongpl]==4.22.0 # via # -r basetest.in # jupyter-events @@ -72,9 +72,9 @@ jupyter-server==2.14.0 # notebook-shim jupyter-server-terminals==0.5.3 # via jupyter-server -jupyterlab==4.1.6 +jupyterlab==4.2.1 # via -r dev.in -jupyterlab-server==2.26.0 +jupyterlab-server==2.27.2 # via jupyterlab notebook-shim==0.2.4 # via jupyterlab @@ -86,15 +86,15 @@ pip-compile-multi==2.6.3 # via -r dev.in pip-tools==7.4.1 # via pip-compile-multi -plumbum==1.8.2 +plumbum==1.8.3 # via copier prometheus-client==0.20.0 # via jupyter-server pycparser==2.22 # via cffi -pydantic==2.7.0 +pydantic==2.7.2 # via copier -pydantic-core==2.18.1 +pydantic-core==2.18.3 # via pydantic python-json-logger==2.0.7 # via jupyter-events @@ -128,7 +128,7 @@ uri-template==1.3.0 # via jsonschema webcolors==1.13 # via jsonschema -websocket-client==1.7.0 +websocket-client==1.8.0 # via jupyter-server wheel==0.43.0 # via pip-tools diff --git a/requirements/docs.in b/requirements/docs.in index 8a2366c9..f512d9c6 100644 --- a/requirements/docs.in +++ b/requirements/docs.in @@ -4,6 +4,7 @@ ipykernel ipython!=8.7.0 # Breaks syntax highlighting in Jupyter code cells. myst-parser nbsphinx +pandas pydata-sphinx-theme>=0.14 sphinx sphinx-autodoc-typehints diff --git a/requirements/docs.txt b/requirements/docs.txt index 6ee43684..1d0313b1 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,4 +1,4 @@ -# SHA1:69ee98dd11fd5e6fe2f8e3b546660b033b8839c4 +# SHA1:d646fa4c965681f36669a3ad404405afcd35ce97 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -6,7 +6,7 @@ # pip-compile-multi # -r base.txt -accessible-pygments==0.0.4 +accessible-pygments==0.0.5 # via pydata-sphinx-theme alabaster==0.7.16 # via sphinx @@ -16,7 +16,7 @@ attrs==23.2.0 # via # jsonschema # referencing -babel==2.14.0 +babel==2.15.0 # via # pydata-sphinx-theme # sphinx @@ -38,13 +38,13 @@ decorator==5.1.1 # via ipython defusedxml==0.7.1 # via nbconvert -docutils==0.20.1 +docutils==0.21.2 # via # myst-parser # nbsphinx # pydata-sphinx-theme # sphinx -exceptiongroup==1.2.0 +exceptiongroup==1.2.1 # via ipython executing==2.0.1 # via stack-data @@ -58,23 +58,23 @@ imagesize==1.4.1 # via sphinx ipykernel==6.29.4 # via -r docs.in -ipython==8.23.0 +ipython==8.24.0 # via # -r docs.in # ipykernel jedi==0.19.1 # via ipython -jinja2==3.1.3 +jinja2==3.1.4 # via # myst-parser # nbconvert # nbsphinx # sphinx -jsonschema==4.21.1 +jsonschema==4.22.0 # via nbformat jsonschema-specifications==2023.12.1 # via jsonschema -jupyter-client==8.6.1 +jupyter-client==8.6.2 # via # ipykernel # nbclient @@ -95,46 +95,50 @@ markupsafe==2.1.5 # via # jinja2 # nbconvert -matplotlib-inline==0.1.6 +matplotlib-inline==0.1.7 # via # ipykernel # ipython -mdit-py-plugins==0.4.0 +mdit-py-plugins==0.4.1 # via myst-parser mdurl==0.1.2 # via markdown-it-py mistune==3.0.2 # via nbconvert -myst-parser==2.0.0 +myst-parser==3.0.1 # via -r docs.in nbclient==0.10.0 # via nbconvert -nbconvert==7.16.3 +nbconvert==7.16.4 # via nbsphinx nbformat==5.10.4 # via # nbclient # nbconvert # nbsphinx -nbsphinx==0.9.3 +nbsphinx==0.9.4 # via -r docs.in nest-asyncio==1.6.0 # via ipykernel +numpy==1.26.4 + # via pandas packaging==24.0 # via # ipykernel # nbconvert # pydata-sphinx-theme # sphinx +pandas==2.2.2 + # via -r docs.in pandocfilters==1.5.1 # via nbconvert parso==0.8.4 # via jedi pexpect==4.9.0 # via ipython -platformdirs==4.2.0 +platformdirs==4.2.2 # via jupyter-core -prompt-toolkit==3.0.43 +prompt-toolkit==3.0.45 # via ipython psutil==5.9.8 # via ipykernel @@ -142,9 +146,9 @@ ptyprocess==0.7.0 # via pexpect pure-eval==0.2.2 # via stack-data -pydata-sphinx-theme==0.15.2 +pydata-sphinx-theme==0.15.3 # via -r docs.in -pygments==2.17.2 +pygments==2.18.0 # via # accessible-pygments # ipython @@ -152,20 +156,24 @@ pygments==2.17.2 # pydata-sphinx-theme # sphinx python-dateutil==2.9.0.post0 - # via jupyter-client + # via + # jupyter-client + # pandas +pytz==2024.1 + # via pandas pyyaml==6.0.1 # via myst-parser -pyzmq==26.0.0 +pyzmq==26.0.3 # via # ipykernel # jupyter-client -referencing==0.34.0 +referencing==0.35.1 # via # jsonschema # jsonschema-specifications -requests==2.31.0 +requests==2.32.3 # via sphinx -rpds-py==0.18.0 +rpds-py==0.18.1 # via # jsonschema # referencing @@ -178,7 +186,7 @@ snowballstemmer==2.2.0 # via sphinx soupsieve==2.5 # via beautifulsoup4 -sphinx==7.2.6 +sphinx==7.3.7 # via # -r docs.in # myst-parser @@ -187,11 +195,11 @@ sphinx==7.2.6 # sphinx-autodoc-typehints # sphinx-copybutton # sphinx-design -sphinx-autodoc-typehints==2.0.1 +sphinx-autodoc-typehints==2.1.0 # via -r docs.in sphinx-copybutton==0.5.2 # via -r docs.in -sphinx-design==0.5.0 +sphinx-design==0.6.0 # via -r docs.in sphinxcontrib-applehelp==1.0.8 # via sphinx @@ -207,13 +215,15 @@ sphinxcontrib-serializinghtml==1.1.10 # via sphinx stack-data==0.6.3 # via ipython -tinycss2==1.2.1 +tinycss2==1.3.0 # via nbconvert +tomli==2.0.1 + # via sphinx tornado==6.4 # via # ipykernel # jupyter-client -traitlets==5.14.2 +traitlets==5.14.3 # via # comm # ipykernel @@ -225,10 +235,12 @@ traitlets==5.14.2 # nbconvert # nbformat # nbsphinx -typing-extensions==4.11.0 +typing-extensions==4.12.0 # via # ipython # pydata-sphinx-theme +tzdata==2024.1 + # via pandas urllib3==2.2.1 # via requests wcwidth==0.2.13 diff --git a/requirements/mypy.txt b/requirements/mypy.txt index 8d293074..f0e7b3bf 100644 --- a/requirements/mypy.txt +++ b/requirements/mypy.txt @@ -6,9 +6,9 @@ # pip-compile-multi # -r test.txt -mypy==1.9.0 +mypy==1.10.0 # via -r mypy.in mypy-extensions==1.0.0 # via mypy -typing-extensions==4.11.0 +typing-extensions==4.12.0 # via mypy diff --git a/requirements/nightly.in b/requirements/nightly.in index 6b1ebcc2..55a65bd8 100644 --- a/requirements/nightly.in +++ b/requirements/nightly.in @@ -1,4 +1,4 @@ -r basetest.in # --- END OF CUSTOM SECTION --- # The following was generated by 'tox -e deps', DO NOT EDIT MANUALLY! - +cyclebane diff --git a/requirements/nightly.txt b/requirements/nightly.txt index a98564b2..e697c13b 100644 --- a/requirements/nightly.txt +++ b/requirements/nightly.txt @@ -1,4 +1,4 @@ -# SHA1:e8b11c1210855f07eaedfbcfb3ecd1aec3595dee +# SHA1:6e913d4c1e64cd8d5433e7de17be332a12fefe11 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -6,3 +6,7 @@ # pip-compile-multi # -r basetest.txt +cyclebane==24.5.0 + # via -r nightly.in +networkx==3.3 + # via cyclebane diff --git a/requirements/static.txt b/requirements/static.txt index 99028a0d..ea619e41 100644 --- a/requirements/static.txt +++ b/requirements/static.txt @@ -9,20 +9,17 @@ cfgv==3.4.0 # via pre-commit distlib==0.3.8 # via virtualenv -filelock==3.13.4 +filelock==3.14.0 # via virtualenv -identify==2.5.35 +identify==2.5.36 # via pre-commit -nodeenv==1.8.0 +nodeenv==1.9.0 # via pre-commit -platformdirs==4.2.0 +platformdirs==4.2.2 # via virtualenv -pre-commit==3.7.0 +pre-commit==3.7.1 # via -r static.in pyyaml==6.0.1 # via pre-commit -virtualenv==20.25.1 +virtualenv==20.26.2 # via pre-commit - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/requirements/test-dask.txt b/requirements/test-dask.txt index 6a4f2eeb..1362167d 100644 --- a/requirements/test-dask.txt +++ b/requirements/test-dask.txt @@ -10,15 +10,15 @@ click==8.1.7 # via dask cloudpickle==3.0.0 # via dask -dask==2024.4.1 +dask==2024.5.1 # via -r test-dask.in -fsspec==2024.3.1 +fsspec==2024.5.0 # via dask importlib-metadata==7.1.0 # via dask locket==1.0.0 # via partd -partd==1.4.1 +partd==1.4.2 # via dask pyyaml==6.0.1 # via dask @@ -26,5 +26,5 @@ toolz==0.12.1 # via # dask # partd -zipp==3.18.1 +zipp==3.19.0 # via importlib-metadata diff --git a/requirements/wheels.txt b/requirements/wheels.txt index ff60a184..d1a95de4 100644 --- a/requirements/wheels.txt +++ b/requirements/wheels.txt @@ -9,9 +9,7 @@ build==1.2.1 # via -r wheels.in packaging==24.0 # via build -pyproject-hooks==1.0.0 +pyproject-hooks==1.1.0 # via build tomli==2.0.1 - # via - # build - # pyproject-hooks + # via build diff --git a/src/sciline/__init__.py b/src/sciline/__init__.py index 5a87ae80..8e2ec6c6 100644 --- a/src/sciline/__init__.py +++ b/src/sciline/__init__.py @@ -17,17 +17,12 @@ HandleAsComputeTimeException, UnsatisfiedRequirement, ) -from .param_table import ParamTable -from .pipeline import AmbiguousProvider, Pipeline -from .series import Series +from .pipeline import Pipeline from .task_graph import TaskGraph __all__ = [ - "AmbiguousProvider", - "ParamTable", "Pipeline", "scheduler", - "Series", "Scope", "ScopeTwoParams", 'TaskGraph', diff --git a/src/sciline/_provider.py b/src/sciline/_provider.py index d21e8907..e999d474 100644 --- a/src/sciline/_provider.py +++ b/src/sciline/_provider.py @@ -11,6 +11,7 @@ Any, Callable, Generator, + Hashable, Literal, Optional, TypeVar, @@ -138,14 +139,6 @@ def bind_type_vars(self, bound: dict[TypeVar, Key]) -> Provider: kind=self._kind, ) - def map_arg_keys(self, transform: Callable[[Key], Key]) -> Provider: - """Return a new provider with transformed argument keys.""" - return Provider( - func=self._func, - arg_spec=self._arg_spec.map_keys(transform), - kind=self._kind, - ) - def __str__(self) -> str: return f"Provider('{self.location.name}')" @@ -155,7 +148,7 @@ def __repr__(self) -> str: f"func={self._func})" ) - def call(self, values: dict[Key, Any]) -> Any: + def call(self, values: dict[Hashable, Any]) -> Any: """Call the provider with arguments extracted from ``values``.""" return self._func( *(values[arg] for arg in self._arg_spec.args), @@ -170,10 +163,19 @@ def __init__( self, *, args: dict[str, Key], kwargs: dict[str, Key], return_: Optional[Key] ) -> None: """Build from components, use dedicated creation functions instead.""" + # Duplicate type hints could be allowed in principle, but it makes structure + # analysis and checks more difficult and error prone. As there is likely + # little utility in supporting this, we disallow it. + if len(set(args.values()) | set(kwargs.values())) != len(args) + len(kwargs): + raise ValueError("Duplicate type hints found in args and/or kwargs") self._args = args self._kwargs = kwargs self._return = return_ + def __len__(self) -> int: + """Number of args and kwargs, not counting return value.""" + return len(self._args) + len(self._kwargs) + @classmethod def from_function(cls, provider: ToProvider) -> ArgSpec: """Parse the argument spec of a provider.""" @@ -227,7 +229,7 @@ def map_keys(self, transform: Callable[[Key], Key]) -> ArgSpec: return ArgSpec( args={name: transform(arg) for name, arg in self._args.items()}, kwargs={name: transform(arg) for name, arg in self._kwargs.items()}, - return_=self._return, + return_=self._return if self._return is None else transform(self._return), ) diff --git a/src/sciline/_utils.py b/src/sciline/_utils.py index 024edb57..c8c1ca0f 100644 --- a/src/sciline/_utils.py +++ b/src/sciline/_utils.py @@ -5,7 +5,7 @@ from typing import Any, Callable, DefaultDict, Iterable, TypeVar, Union, get_args from ._provider import Provider -from .typing import Item, Key +from .typing import Key T = TypeVar('T') G = TypeVar('G') @@ -31,11 +31,6 @@ def full_qualname(obj: Any) -> str: def key_name(key: Union[Key, TypeVar]) -> str: - if isinstance(key, Item): - parameters = ", ".join( - f'{key_name(label.tp)}:{label.index}' for label in key.label - ) - return f'{key_name(key.tp)}({parameters})' args = get_args(key) if len(args): parameters = ', '.join(map(key_name, args)) @@ -43,15 +38,12 @@ def key_name(key: Union[Key, TypeVar]) -> str: return f'{getattr(key, "__name__", "")}[{parameters}]' if isinstance(key, TypeVar): return str(key) - return key.__name__ + if hasattr(key, '__name__'): + return key.__name__ + return str(key) def key_full_qualname(key: Union[Key, TypeVar]) -> str: - if isinstance(key, Item): - parameters = ", ".join( - f'{key_full_qualname(label.tp)}:{label.index}' for label in key.label - ) - return f'{key_full_qualname(key.tp)}({parameters})' args = get_args(key) if len(args): origin = key.__origin__ # type: ignore[union-attr] # key is a TypeVar diff --git a/src/sciline/data_graph.py b/src/sciline/data_graph.py new file mode 100644 index 00000000..b2c778e5 --- /dev/null +++ b/src/sciline/data_graph.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2024 Scipp contributors (https://github.com/scipp) +from __future__ import annotations + +import itertools +from collections.abc import Iterable +from types import NoneType +from typing import Any, Callable, Generator, TypeVar, Union, get_args + +import cyclebane as cb +import networkx as nx + +from ._provider import ArgSpec, Provider, ToProvider, _bind_free_typevars +from .handler import ErrorHandler, HandleAsBuildTimeException +from .typing import Graph, Key + + +def _as_graph(key: Key, value: Any) -> cb.Graph: + """Create a cyclebane.Graph with a single value.""" + graph = nx.DiGraph() + graph.add_node(key, value=value) + return cb.Graph(graph) + + +def _find_all_typevars(t: type | TypeVar) -> set[TypeVar]: + """Returns the set of all TypeVars in a type expression.""" + if isinstance(t, TypeVar): + return {t} + return set(itertools.chain(*map(_find_all_typevars, get_args(t)))) + + +def _get_typevar_constraints(t: TypeVar) -> set[type]: + """Returns the set of constraints of a TypeVar.""" + return set(t.__constraints__) + + +def _mapping_to_constrained( + type_vars: set[TypeVar], +) -> Generator[dict[TypeVar, type], None, None]: + constraints = [_get_typevar_constraints(t) for t in type_vars] + if any(len(c) == 0 for c in constraints): + raise ValueError('Typevars must have constraints') + for combination in itertools.product(*constraints): + yield dict(zip(type_vars, combination, strict=True)) + + +T = TypeVar('T', bound='DataGraph') + + +class DataGraph: + def __init__(self, providers: None | Iterable[ToProvider | Provider]) -> None: + self._cbgraph = cb.Graph(nx.DiGraph()) + for provider in providers or []: + self.insert(provider) + + @classmethod + def _from_cyclebane(cls: type[T], graph: cb.Graph) -> T: + out = cls([]) + out._cbgraph = graph + return out + + def copy(self: T) -> T: + return self._from_cyclebane(self._cbgraph.copy()) + + def __copy__(self: T) -> T: + return self.copy() + + @property + def _graph(self) -> nx.DiGraph: + return self._cbgraph.graph + + def _get_clean_node(self, key: Key) -> Any: + """Return node ready for setting value or provider.""" + if key is NoneType: + raise ValueError('Key must not be None') + if key in self._graph: + self._graph.remove_edges_from(list(self._graph.in_edges(key))) + self._graph.nodes[key].pop('value', None) + self._graph.nodes[key].pop('provider', None) + self._graph.nodes[key].pop('reduce', None) + else: + self._graph.add_node(key) + return self._graph.nodes[key] + + def insert(self, provider: Union[ToProvider, Provider], /) -> None: + """ + Insert a callable into the graph that provides its return value. + + Parameters + ---------- + provider: + Either a callable that provides its return value. Its arguments + and return value must be annotated with type hints. + Or a ``Provider`` object that has been constructed from such a callable. + """ + if not isinstance(provider, Provider): + provider = Provider.from_function(provider) + return_type = provider.deduce_key() + if typevars := _find_all_typevars(return_type): + for bound in _mapping_to_constrained(typevars): + self.insert(provider.bind_type_vars(bound)) + return + # Trigger UnboundTypeVar error if any input typevars are not bound + provider = provider.bind_type_vars({}) + self._get_clean_node(return_type)['provider'] = provider + for dep in provider.arg_spec.keys(): + self._graph.add_edge(dep, return_type, key=dep) + + def __setitem__(self, key: Key, value: DataGraph | Any) -> None: + """ + Provide a concrete value for a type. + + Parameters + ---------- + key: + Type to provide a value for. + param: + Concrete value to provide. + """ + # This is a questionable approach: Using MyGeneric[T] as a key will actually + # not pass mypy [valid-type] checks. What we do on our side is ok, but the + # calling code is not. + if typevars := _find_all_typevars(key): + for bound in _mapping_to_constrained(typevars): + self[_bind_free_typevars(key, bound)] = value + return + + # TODO If key is generic, should we support multi-sink case and update all? + # Would imply that we need the same for __getitem__. + self._cbgraph[key] = ( + value._cbgraph if isinstance(value, DataGraph) else _as_graph(key, value) + ) + + def __getitem__(self: T, key: Key) -> T: + return self._from_cyclebane(self._cbgraph[key]) + + def map(self: T, node_values: dict[Key, Any]) -> T: + return self._from_cyclebane(self._cbgraph.map(node_values)) + + def reduce(self: T, *, func: Callable[..., Any], **kwargs: Any) -> T: + # Note that the type hints of `func` are not checked here. As we are explicit + # about the modification, this is in line with __setitem__ which does not + # perform such checks and allows for using generic reduction functions. + return self._from_cyclebane( + self._cbgraph.reduce(attrs={'reduce': func}, **kwargs) + ) + + def to_networkx(self) -> nx.DiGraph: + return self._cbgraph.to_networkx() + + def visualize_data_graph( + self, **kwargs: Any + ) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821 + import graphviz + + dot = graphviz.Digraph(strict=True, **kwargs) + for node in self._graph.nodes: + dot.node(str(node), label=str(node), shape='box') + attrs = self._graph.nodes[node] + attrs = '\n'.join(f'{k}={v}' for k, v in attrs.items()) + dot.node(str(node), label=f'{node}\n{attrs}', shape='box') + for edge in self._graph.edges: + key = self._graph.edges[edge].get('key') + label = str(key) if key is not None else '' + dot.edge(str(edge[0]), str(edge[1]), label=label) + return dot + + +_no_value = object() + + +def to_task_graph( + data_graph: DataGraph, targets: tuple[Key, ...], handler: ErrorHandler | None = None +) -> Graph: + graph = data_graph.to_networkx() + handler = handler or HandleAsBuildTimeException() + ancestors = list(targets) + for target in targets: + if target not in graph: + handler.handle_unsatisfied_requirement(target) + ancestors.extend(nx.ancestors(graph, target)) + graph = graph.subgraph(set(ancestors)) + out = {} + + for key in graph.nodes: + node = graph.nodes[key] + input_nodes = list(graph.predecessors(key)) + input_edges = list(graph.in_edges(key, data=True)) + orig_keys = [edge[2].get('key', None) for edge in input_edges] + if (value := node.get('value', _no_value)) is not _no_value: + out[key] = Provider.parameter(value) + elif (provider := node.get('provider')) is not None: + new_key = dict(zip(orig_keys, input_nodes)) + spec = provider.arg_spec.map_keys(new_key.get) + if len(spec) != len(input_nodes): + # This should be caught by __setitem__, but we check here to be safe. + raise ValueError("Corrupted graph") + # TODO also kwargs + out[key] = Provider(func=provider.func, arg_spec=spec, kind='function') + elif (func := node.get('reduce')) is not None: + spec = ArgSpec.from_args(*input_nodes) + out[key] = Provider(func=func, arg_spec=spec, kind='function') + else: + out[key] = handler.handle_unsatisfied_requirement(key) + return out diff --git a/src/sciline/display.py b/src/sciline/display.py index 4a066e6b..770a982a 100644 --- a/src/sciline/display.py +++ b/src/sciline/display.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from html import escape -from typing import Iterable, List, Tuple, TypeVar, Union +from typing import Any, Iterable, Tuple -from ._provider import Provider -from ._utils import groupby, key_name -from .typing import Item, Key +from ._utils import key_name +from .typing import Key def _details(summary: str, body: str) -> str: @@ -17,27 +16,13 @@ def _details(summary: str, body: str) -> str: ''' -def _provider_name( - p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]] -) -> str: - key, args, _ = p - if args: - # This is always the case, but mypy complains - if hasattr(key, '__getitem__'): - return escape(key_name(key[args])) - return escape(key_name(key)) +def _provider_name(p: Tuple[Key, dict[str, Any]]) -> str: + return escape(key_name(p[0])) -def _provider_source( - p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]] -) -> str: - key, _, (v, *rest) = p - if v.kind == 'table_cell': - # This is always the case, but mypy complains - if isinstance(key, Item): - return escape( - f'ParamTable({key_name(key.label[0].tp)}, length={len((v, *rest))})' - ) +def _provider_source(data: dict[str, Any]) -> str: + if (v := data.get('provider', None)) is None: + return '' if v.kind == 'function': return _details( escape(v.location.name), @@ -46,46 +31,23 @@ def _provider_source( return '' -def _provider_value( - p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]] -) -> str: - _, _, (v, *_) = p - if v.kind == 'parameter': - html = escape(str(v.call({}))).strip() - return _details(f'{html[:30]}...', html) if len(html) > 30 else html - return '' - +def _provider_value(data: dict[str, Any]) -> str: + if (value := data.get('value', None)) is None: + return '' + html = escape(str(value)).strip() + return _details(f'{html[:30]}...', html) if len(html) > 30 else html -def pipeline_html_repr( - providers: Iterable[Tuple[Key, Tuple[Union[Key, TypeVar], ...], Provider]] -) -> str: - def associate_table_values( - p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], Provider] - ) -> Tuple[Key, Union[type, Tuple[Union[Key, TypeVar], ...]]]: - key, args, v = p - if isinstance(key, Item): - return (key.label[0].tp, key.tp) - return (key, args) - providers_collected = ( - (key, args, [value, *(v for _, _, v in rest)]) - for ((key, args, value), *rest) in groupby( - associate_table_values, - providers, - ).values() - ) +def pipeline_html_repr(nodes: Iterable[Tuple[Key, dict[str, Any]]]) -> str: provider_rows = '\n'.join( ( f''' - {_provider_name(p)} - {_provider_value(p)} - {_provider_source(p)} + {_provider_name(item)} + {_provider_value(item[1])} + {_provider_source(item[1])} ''' - for p in sorted( - providers_collected, - key=_provider_name, - ) + for item in sorted(nodes, key=_provider_name) ) ) return f''' diff --git a/src/sciline/param_table.py b/src/sciline/param_table.py deleted file mode 100644 index 96e76f5f..00000000 --- a/src/sciline/param_table.py +++ /dev/null @@ -1,94 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from __future__ import annotations - -from typing import Any, Collection, Dict, Mapping, Optional - - -class ParamTable(Mapping[type, Collection[Any]]): - """A table of parameters with a row index and named row dimension.""" - - def __init__( - self, - row_dim: type, - columns: Dict[type, Collection[Any]], - *, - index: Optional[Collection[Any]] = None, - ): - """ - Create a new param table. - - Parameters - ---------- - row_dim: - The row dimension. This must be a type or a type-alias (not an instance), - and is used by :py:class:`sciline.Pipeline` to identify each parameter - table. - columns: - The columns of the table. The keys (column names) must be types or type- - aliases matching the values in the respective columns. - index: - The row index of the table. If not given, a default index will be - generated, as the integer range of the column length. - """ - sizes = set(len(v) for v in columns.values()) - if len(sizes) > 1: - raise ValueError( - f"Columns in param table must all have same size, got {sizes}" - ) - if sizes: - size = sizes.pop() - elif index is None: - raise ValueError("Cannot create param table with zero columns and no index") - else: - size = len(index) - if index is not None: - if len(index) != size: - raise ValueError( - f"Index length not matching columns, got {len(index)} and {size}" - ) - if len(set(index)) != len(index): - raise ValueError(f"Index must be unique, got {index}") - self._row_dim = row_dim - self._columns = columns - self._index = index or list(range(size)) - - @property - def row_dim(self) -> type: - """The row dimension of the table.""" - return self._row_dim - - @property - def index(self) -> Collection[Any]: - """The row index of the table.""" - return self._index - - def __contains__(self, key: Any) -> bool: - return self._columns.__contains__(key) - - def __getitem__(self, key: Any) -> Any: - return self._columns.__getitem__(key) - - def __iter__(self) -> Any: - return self._columns.__iter__() - - def __len__(self) -> int: - return self._columns.__len__() - - def __repr__(self) -> str: - return f"ParamTable(row_dim={self.row_dim.__name__}, columns={self._columns})" - - def _repr_html_(self) -> str: - return ( - f"" - + "".join( - f"" - for k in self._columns.keys() - ) - + "" - + "".join( - f"" + "".join(f"" for v in row) + "" - for idx, row in zip(self.index, zip(*self._columns.values())) - ) - + "
{self.row_dim.__name__}{getattr(k, '__name__', str(k).split('.')[-1])}
{idx}{v}
" - ) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 1ae2b270..3e20cd59 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -2,155 +2,36 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations -from collections import defaultdict from collections.abc import Iterable from itertools import chain from types import UnionType from typing import ( Any, Callable, - Collection, Dict, - Generic, - List, - Mapping, Optional, - Set, Tuple, Type, TypeVar, Union, get_args, - get_origin, get_type_hints, overload, ) -from sciline.task_graph import TaskGraph - -from ._provider import ArgSpec, Provider, ProviderLocation, ToProvider -from ._utils import key_name +from ._provider import Provider, ToProvider +from .data_graph import DataGraph, to_task_graph from .display import pipeline_html_repr -from .handler import ( - ErrorHandler, - HandleAsBuildTimeException, - HandleAsComputeTimeException, - UnsatisfiedRequirement, -) -from .param_table import ParamTable +from .handler import ErrorHandler, HandleAsComputeTimeException from .scheduler import Scheduler -from .series import Series -from .typing import Graph, Item, Key, Label +from .task_graph import TaskGraph +from .typing import Key T = TypeVar('T') KeyType = TypeVar('KeyType', bound=Key) -ValueType = TypeVar('ValueType', bound=Key) -IndexType = TypeVar('IndexType', bound=Key) -LabelType = TypeVar('LabelType', bound=Key) - - -class AmbiguousProvider(Exception): - """Raised when multiple providers are found for a type.""" - - -def _extract_typevars_from_generic_type(t: type) -> Tuple[TypeVar, ...]: - """Returns the typevars that were used in the definition of a Generic type.""" - if not hasattr(t, '__orig_bases__'): - return () - return tuple( - chain(*(get_args(b) for b in t.__orig_bases__ if get_origin(b) == Generic)) - ) - - -def _find_all_typevars(t: Union[type, TypeVar]) -> Set[TypeVar]: - """Returns the set of all TypeVars in a type expression.""" - if isinstance(t, TypeVar): - return {t} - return set(chain(*map(_find_all_typevars, get_args(t)))) - - -def _find_bounds_to_make_compatible_type( - requested: Key, - provided: Key | TypeVar, -) -> Optional[Dict[TypeVar, Key]]: - """ - Check if a type is compatible to a provided type. - If the types are compatible, return a mapping from typevars to concrete types - that makes the provided type equal to the requested type. - """ - if provided == requested: - ret: Dict[TypeVar, Key] = {} - return ret - if isinstance(provided, TypeVar): - # If the type var has no constraints, accept anything - if not provided.__constraints__: - return {provided: requested} - for c in provided.__constraints__: - if _find_bounds_to_make_compatible_type(requested, c) is not None: - return {provided: requested} - if get_origin(provided) is not None: - if get_origin(provided) == get_origin(requested): - return _find_bounds_to_make_compatible_type_tuple( - get_args(requested), get_args(provided) - ) - return None - - -def _find_bounds_to_make_compatible_type_tuple( - requested: tuple[Key, ...], - provided: tuple[Key | TypeVar, ...], -) -> Optional[Dict[TypeVar, Key]]: - """ - Check if a tuple of requested types is compatible with a tuple of provided types - and return a mapping from type vars to concrete types that makes all provided - types equal to their corresponding requested type. - If any of the types is not compatible, return None. - """ - union: Dict[TypeVar, Key] = {} - for bound in map(_find_bounds_to_make_compatible_type, requested, provided): - # If no mapping from the type-var to a concrete type was found, - # or if the mapping is inconsistent, - # interrupt the search and report that no compatible types were found. - if bound is None or any(k in union and union[k] != bound[k] for k in bound): - return None - union.update(bound) - return union - - -def _find_all_paths( - dependencies: Mapping[Key, Collection[Key]], start: Key, end: Key -) -> List[List[Key]]: - """Find all paths from start to end in a DAG.""" - if start == end: - return [[start]] - if start not in dependencies: - return [] - paths = [] - for node in dependencies[start]: - if start == node: - continue - for path in _find_all_paths(dependencies, node, end): - paths.append([start] + path) - return paths - - -def _find_nodes_in_paths(graph: Graph, end: Key) -> List[Key]: - """ - Helper for Pipeline. Finds all nodes that need to be duplicated since they depend - on a value from a param table. - """ - start = next(iter(graph)) - dependencies = {k: tuple(p.arg_spec.keys()) for k, p in graph.items()} - paths = _find_all_paths(dependencies, start, end) - nodes = set() - for path in paths: - nodes.update(path) - return list(nodes) -def _is_multiple_keys( - keys: type | Iterable[type] | Item[T] | object, -) -> bool: +def _is_multiple_keys(keys: type | Iterable[type] | UnionType) -> bool: # Cannot simply use isinstance(keys, Iterable) because that is True for # generic aliases of iterable types, e.g., # @@ -159,196 +40,14 @@ def _is_multiple_keys( # # And isinstance(keys, type) does not work on its own because # it is False for the above type. + if isinstance(keys, str): + return False return ( not isinstance(keys, type) and not get_args(keys) and isinstance(keys, Iterable) ) -class ReplicatorBase(Generic[IndexType]): - def __init__(self, index_name: type, index: Iterable[IndexType], path: List[Key]): - if len(path) == 0: - raise UnsatisfiedRequirement( - 'Could not find path to param in param table. This is likely caused ' - 'by requesting a Series that does not depend directly or transitively ' - 'on any param from a table.' - ) - self._index_name = index_name - self.index = index - self._path = path - - def __contains__(self, key: Key) -> bool: - return key in self._path - - def replicate( - self, - key: Key, - provider: Provider, - get_provider: Callable[..., Tuple[Provider, Dict[TypeVar, Key]]], - ) -> Graph: - graph: Graph = {} - for idx in self.index: - subkey = self.key(idx, key) - if isinstance(provider, _ParamSentinel): - graph[subkey] = get_provider(subkey)[0] - else: - graph[subkey] = self._copy_node(key, provider, idx) - return graph - - def _copy_node( - self, - key: Key, - provider: Union[Provider, SeriesProvider[IndexType]], - idx: IndexType, - ) -> Provider: - return provider.map_arg_keys( - lambda arg: self.key(idx, arg) if arg in self else arg - ) - - def key(self, i: IndexType, value_name: Union[Type[T], Item[T]]) -> Item[T]: - label = Label(self._index_name, i) - if isinstance(value_name, Item): - return Item(value_name.label + (label,), value_name.tp) - else: - return Item((label,), value_name) - - -class Replicator(ReplicatorBase[IndexType]): - r""" - Helper for rewriting the graph to map over a given index. - - See Pipeline._build_series for context. Given a graph template, this makes a - transformation as follows: - - S P1[0] P1[1] P1[2] - | | | | - A -> A[0] A[1] A[2] - | | | | - B B[0] B[1] B[2] - - Where S is a sentinel value, P1 are parameters from a parameter table, 0,1,2 - are indices of the param table rows, and A and B are arbitrary nodes in the graph. - """ - - def __init__(self, param_table: ParamTable, graph_template: Graph) -> None: - index_name = param_table.row_dim - super().__init__( - index_name=index_name, - index=param_table.index, - path=_find_nodes_in_paths(graph_template, index_name), - ) - - -class GroupingReplicator(ReplicatorBase[LabelType], Generic[IndexType, LabelType]): - r""" - Helper for rewriting the graph to group by a given index. - - See Pipeline._build_series for context. Given a graph template, this makes a - transformation as follows: - - P1[0] P1[1] P1[2] P1[0] P1[1] P1[2] - | | | | | | - A[0] A[1] A[2] A[0] A[1] A[2] - | | | | | | - B[0] B[1] B[2] -> B[0] B[1] B[2] - \______|______/ \______/ | - | | | - SB SB[x] SB[y] - | | | - C C[x] C[y] - - Where SB is Series[Idx,B]. Here, the upper half of the graph originates from a - prior transformation of a graph template using `Replicator`. The output of this - combined with further nodes is the graph template passed to this class. x and y - are the labels used in a grouping operation, based on the values of a ParamTable - column P2. - """ - - def __init__( - self, param_table: ParamTable, graph_template: Graph, label_name: type - ) -> None: - self._label_name = label_name - self._group_node = self._find_grouping_node(param_table.row_dim, graph_template) - self._groups: Dict[LabelType, List[IndexType]] = defaultdict(list) - for idx, label in zip(param_table.index, param_table[label_name]): - self._groups[label].append(idx) - super().__init__( - index_name=label_name, - index=self._groups, - path=_find_nodes_in_paths(graph_template, self._group_node), - ) - - def _copy_node( - self, - key: Key, - provider: Union[Provider, SeriesProvider[IndexType]], - idx: LabelType, - ) -> Provider: - if (not isinstance(provider, SeriesProvider)) or key != self._group_node: - return super()._copy_node(key, provider, idx) - labels = self._groups[idx] - if set(labels) - set(provider.labels): - raise ValueError(f'{labels} is not a subset of {provider.labels}') - if tuple(provider.arg_spec.kwargs): - raise RuntimeError( - 'A Series was provided with keyword arguments. This should not happen ' - 'and is an internal error of Sciline.' - ) - selected = { - label: arg - for label, arg in zip(provider.labels, provider.arg_spec.args) - if label in labels - } - return SeriesProvider(selected.keys(), provider.row_dim, args=selected.values()) - - def _find_grouping_node(self, index_name: Key, subgraph: Graph) -> type: - ends: List[type] = [] - for key in subgraph: - if get_origin(key) == Series and get_args(key)[0] == index_name: - # Because of the succeeded get_origin we know it is a type - ends.append(key) # type: ignore[arg-type] - if len(ends) == 1: - return ends[0] - raise ValueError(f"Could not find unique grouping node, found {ends}") - - -class SeriesProvider(Generic[KeyType], Provider): - """ - Internal provider for combining results obtained based on different rows in a - param table into a single object. - """ - - def __init__( - self, - labels: Iterable[KeyType], - row_dim: Type[KeyType], - *, - args: Optional[Iterable[Key]] = None, - ) -> None: - super().__init__( - func=self._call, - arg_spec=ArgSpec.from_args(*(args if args is not None else labels)), - kind='series', - ) - self.labels = labels - self.row_dim = row_dim - - def _call(self, *vals: ValueType) -> Series[KeyType, ValueType]: - return Series(self.row_dim, dict(zip(self.labels, vals))) - - -class _ParamSentinel(Provider): - def __init__(self, key: Key) -> None: - super().__init__( - func=lambda: None, - arg_spec=ArgSpec.from_args(key), - kind='sentinel', - location=ProviderLocation( - name=f'param_sentinel({type(key).__name__})', module='sciline' - ), - ) - - -class Pipeline: +class Pipeline(DataGraph): """A container for providers that can be assembled into a task graph.""" def __init__( @@ -368,348 +67,10 @@ def __init__( params: Dictionary of concrete values to provide for types. """ - self._providers: Dict[Key, Provider] = {} - self._subproviders: Dict[type, Dict[Tuple[Key | TypeVar, ...], Provider]] = {} - self._param_tables: Dict[Key, ParamTable] = {} - self._param_name_to_table_key: Dict[Key, Key] = {} - for provider in providers or []: - self.insert(provider) + super().__init__(providers) for tp, param in (params or {}).items(): self[tp] = param - def insert(self, provider: Union[ToProvider, Provider], /) -> None: - """ - Add a callable that provides its return value to the pipeline. - - Parameters - ---------- - provider: - Either a callable that provides its return value. Its arguments - and return value must be annotated with type hints. - Or a ``Provider`` object that has been constructed from such a callable. - """ - if not isinstance(provider, Provider): - provider = Provider.from_function(provider) - self._set_provider(provider.deduce_key(), provider) - - def __setitem__(self, key: Type[T], param: T) -> None: - """ - Provide a concrete value for a type. - - Parameters - ---------- - key: - Type to provide a value for. - param: - Concrete value to provide. - """ - self._set_provider(key, Provider.parameter(param)) - - def set_param_table(self, params: ParamTable) -> None: - """ - Set a parameter table for a row dimension. - - Values in the parameter table provide concrete values for a type given by the - respective column header. - - A pipeline can have multiple parameter tables, but only one per row dimension. - Column names must be unique across all parameter tables. - - Parameters - ---------- - params: - Parameter table to set. - """ - for param_name in params: - if (existing := self._param_name_to_table_key.get(param_name)) is not None: - if ( - existing == params.row_dim - and param_name in self._param_tables[existing] - ): - # Column will be removed by del_param_table below, clash is ok - continue - raise ValueError(f'Parameter {param_name} already set') - if params.row_dim in self._param_tables: - self.del_param_table(params.row_dim) - self._param_tables[params.row_dim] = params - for param_name in params: - self._param_name_to_table_key[param_name] = params.row_dim - for param_name, values in params.items(): - for index, label in zip(params.index, values): - self._set_provider( - Item((Label(tp=params.row_dim, index=index),), param_name), - Provider.table_cell(label), - ) - for index, label in zip(params.index, params.index): - self._set_provider( - Item((Label(tp=params.row_dim, index=index),), params.row_dim), - Provider.table_cell(label), - ) - - def del_param_table(self, row_dim: type) -> None: - """ - Remove a parameter table. - - Parameters - ---------- - row_dim: - Row dimension of the parameter table to remove. - """ - # 1. Remove providers pointing to table cells - params = self._param_tables[row_dim] - for index in params.index: - label = (Label(tp=row_dim, index=index),) - for param_name in params: - del self._providers[Item(label, param_name)] - del self._providers[Item(label, row_dim)] - # 2. Remove column to table mapping - for param_name in list(self._param_name_to_table_key): - if self._param_name_to_table_key[param_name] == row_dim: - del self._param_name_to_table_key[param_name] - # 3. Remove table - del self._param_tables[row_dim] - - def set_param_series(self, row_dim: type, index: Collection[Any]) -> None: - """ - Set a series of parameters. - - This is a convenience method for creating and setting a parameter table with - no columns and an index given by `index`. - - Parameters - ---------- - row_dim: - Row dimension of the parameter table to set. - index: - Index of the parameter table to set. - """ - self.set_param_table(ParamTable(row_dim, columns={}, index=index)) - - def _set_provider( - self, - key: Key, - provider: Provider, - ) -> None: - # isinstance does not work here and types.NoneType available only in 3.10+ - if key == type(None): # noqa: E721 - raise ValueError(f'Provider {provider} returning `None` is not allowed') - if get_origin(key) == Series: - raise ValueError( - f'Provider {provider} returning a sciline.Series is not allowed. ' - 'Series is a special container reserved for use in conjunction with ' - 'sciline.ParamTable and must not be provided directly.' - ) - if (origin := get_origin(key)) is not None: - subproviders = self._subproviders.setdefault(origin, {}) - args = get_args(key) - subproviders[args] = provider - else: - self._providers[key] = provider - - def _get_provider( - self, tp: Union[Type[T], Item[T]], handler: Optional[ErrorHandler] = None - ) -> Tuple[Provider, Dict[TypeVar, Key]]: - handler = handler or HandleAsBuildTimeException() - explanation: List[str] = [] - if (provider := self._providers.get(tp)) is not None: - return provider, {} - elif (origin := get_origin(tp)) is not None and ( - subproviders := self._subproviders.get(origin) - ) is not None: - requested = get_args(tp) - matches = [ - (subprovider, bound) - for args, subprovider in subproviders.items() - if ( - bound := _find_bounds_to_make_compatible_type_tuple(requested, args) - ) - is not None - ] - typevar_counts = [len(bound) for _, bound in matches] - min_typevar_count = min(typevar_counts, default=0) - matches = [ - m - for count, m in zip(typevar_counts, matches) - if count == min_typevar_count - ] - - if len(matches) == 1: - provider, bound = matches[0] - return provider, bound - elif len(matches) > 1: - matching_providers = [provider.location.name for provider, _ in matches] - raise AmbiguousProvider( - f"Multiple providers found for type {tp}." - f" Matching providers are: {matching_providers}." - ) - else: - typevars_in_expression = _extract_typevars_from_generic_type(origin) - if typevars_in_expression: - explanation = [ - ''.join( - map( - str, - ( - 'Note that ', - key_name(origin[typevars_in_expression]), - ' has constraints ', - ( - { - key_name(tv): tuple( - map(key_name, tv.__constraints__) - ) - for tv in typevars_in_expression - } - ), - ), - ) - ) - ] - return handler.handle_unsatisfied_requirement(tp, *explanation), {} - - def build( - self, - tp: Union[Type[T], Item[T]], - /, - *, - handler: ErrorHandler, - search_param_tables: bool = False, - ) -> Graph: - """ - Return a dict of providers required for building the requested type `tp`. - - This is mainly for internal and low-level use. Prefer using :py:meth:`get`. - - The values are tuples containing the provider and the dict of arguments for - the provider. The values in the latter dict reference other keys in the returned - graph. - - Parameters - ---------- - tp: - Type to build the graph for. - search_param_tables: - Whether to search parameter tables for concrete keys. - """ - graph: Graph = {} - stack: List[Union[Type[T], Item[T]]] = [tp] - while stack: - tp = stack.pop() - # First look in column labels of param tables - if search_param_tables and tp in self._param_name_to_table_key: - graph[tp] = _ParamSentinel(self._param_name_to_table_key[tp]) - continue - # Then also indices of param tables. This comes second because we need to - # prefer column labels over indices for multi-level grouping. - if search_param_tables and tp in self._param_tables: - graph[tp] = _ParamSentinel(tp) - continue - if get_origin(tp) == Series: - sub = self._build_series(tp, handler=handler) # type: ignore[arg-type] - graph.update(sub) - continue - provider, bound = self._get_provider(tp, handler=handler) - provider = provider.bind_type_vars(bound) - graph[tp] = provider - stack.extend(provider.arg_spec.keys() - graph.keys()) - return graph - - def _build_series( - self, tp: Type[Series[KeyType, ValueType]], handler: ErrorHandler - ) -> Graph: - """ - Build (sub)graph for a Series type implementing ParamTable-based functionality. - - We illustrate this with an example. Given a ParamTable with row_dim 'Idx': - - Idx | P1 | P2 - 0 | a | x - 1 | b | x - 2 | c | y - - and providers for A depending on P1 and B depending on A. Calling - build(Series[Idx,B]) will call _build_series(Series[Idx,B]). This results in - the following procedure here: - - 1. Call build(P1), resulting, e.g., in a graph S->A->B, where S is a sentinel. - The sentinel is used because build() cannot find a unique P1, since it is - not a single value but a column in a table. - 2. Instantiation of `Replicator`, which will be used to replicate the - relevant parts of the graph (see illustration there). - 3. Insert a special `SeriesProvider` node, which will gather the duplicates of - the 'B' node and provides the requested Series[Idx,B]. - 4. Replicate the graph. Nodes that do not directly or indirectly depend on P1 - are not replicated. - - Conceptually, the final result will be { - 0: B(A(a)), - 1: B(A(b)), - 2: B(A(c)) - }. - - In more complex cases, we may be dealing with multiple levels of Series, - which is used for grouping operations. Consider the above example, but with - and additional provider for C depending on Series[Idx,B]. Calling - build(Series[P2,C]) will call _build_series(Series[P2,C]). This results in - the following procedure here: - - a. Call build(C), which results in the procedure above, i.e., a nested call - to _build_series(Series[Idx,B]) and the resulting graph as explained above. - b. Instantiation of `GroupingReplicator`, which will be used to replicate the - relevant parts of the graph (see illustration there). - c. Insert a special `SeriesProvider` node, which will gather the duplicates of - the 'C' node and providers the requested Series[P2,C]. - c. Replicate the graph. Nodes that do not directly or indirectly depend on - the special `SeriesProvider` node (from step 3.) are not replicated. - - Conceptually, the final result will be { - x: C({ - 0: B(A(a)), - 1: B(A(b)) - }), - y: C({ - 2: B(A(c)) - }) - }. - """ - index_name: Type[KeyType] - value_type: Type[ValueType] - index_name, value_type = get_args(tp) - - subgraph = self.build(value_type, search_param_tables=True, handler=handler) - - replicator: ReplicatorBase[KeyType] - if ( - # For multi-level grouping a type is an index as well as a column label. - # In this case we do not want to replicate the graph, but group it (elif). - index_name not in self._param_name_to_table_key - and (params := self._param_tables.get(index_name)) is not None - ): - replicator = Replicator(param_table=params, graph_template=subgraph) - elif (table_key := self._param_name_to_table_key.get(index_name)) is not None: - replicator = GroupingReplicator( - param_table=self._param_tables[table_key], - graph_template=subgraph, - label_name=index_name, - ) - else: - raise KeyError(f'No parameter table found for label {index_name}') - - graph: Graph = { - tp: SeriesProvider( - list(replicator.index), - index_name, - args=(replicator.key(idx, value_type) for idx in replicator.index), - ) - } - - for key, provider in subgraph.items(): - if key in replicator: - graph.update(replicator.replicate(key, provider, self._get_provider)) - else: - graph[key] = provider - return graph - @overload def compute(self, tp: Type[T], **kwargs: Any) -> T: ... @@ -718,17 +79,11 @@ def compute(self, tp: Type[T], **kwargs: Any) -> T: def compute(self, tp: Iterable[Type[T]], **kwargs: Any) -> Dict[Type[T], T]: ... - @overload - def compute(self, tp: Item[T], **kwargs: Any) -> T: - ... - @overload def compute(self, tp: UnionType, **kwargs: Any) -> Any: ... - def compute( - self, tp: type | Iterable[type] | Item[T] | UnionType, **kwargs: Any - ) -> Any: + def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any: """ Compute result for the given keys. @@ -764,7 +119,7 @@ def visualize( def get( self, - keys: type | Iterable[type] | Item[T] | object, + keys: type | Iterable[type] | UnionType, *, scheduler: Optional[Scheduler] = None, handler: Optional[ErrorHandler] = None, @@ -788,16 +143,15 @@ def get( raises an exception only when the graph is computed. This can be achieved by passing :py:class:`HandleAsComputeTimeException` as the handler. """ - handler = handler or HandleAsBuildTimeException() - if _is_multiple_keys(keys): - keys = tuple(keys) # type: ignore[arg-type] - graph: Graph = {} - for t in keys: # type: ignore[var-annotated] - graph.update(self.build(t, handler=handler)) + if multi := _is_multiple_keys(keys): + targets = tuple(keys) # type: ignore[arg-type] else: - graph = self.build(keys, handler=handler) # type: ignore[arg-type] + targets = (keys,) # type: ignore[assignment] + graph = to_task_graph(self, targets=targets, handler=handler) return TaskGraph( - graph=graph, targets=keys, scheduler=scheduler # type: ignore[arg-type] + graph=graph, + targets=targets if multi else keys, # type: ignore[arg-type] + scheduler=scheduler, ) @overload @@ -854,29 +208,6 @@ def bind_and_call( return results[0] return results - def copy(self) -> Pipeline: - """ - Make a copy of the pipeline. - """ - out = Pipeline() - out._providers = self._providers.copy() - out._subproviders = {k: v.copy() for k, v in self._subproviders.items()} - out._param_tables = self._param_tables.copy() - out._param_name_to_table_key = self._param_name_to_table_key.copy() - return out - - def __copy__(self) -> Pipeline: - return self.copy() - def _repr_html_(self) -> str: - providers_without_parameters = ( - (origin, tuple(), value) for origin, value in self._providers.items() - ) # type: ignore[var-annotated] - providers_with_parameters = ( - (origin, args, value) - for origin in self._subproviders - for args, value in self._subproviders[origin].items() - ) - return pipeline_html_repr( - chain(providers_without_parameters, providers_with_parameters) - ) + nodes = ((key, data) for key, data in self._graph.nodes.items()) + return pipeline_html_repr(nodes) diff --git a/src/sciline/scheduler.py b/src/sciline/scheduler.py index 9d658671..e74f83c0 100644 --- a/src/sciline/scheduler.py +++ b/src/sciline/scheduler.py @@ -5,14 +5,14 @@ Any, Callable, Dict, - List, + Hashable, Optional, Protocol, Tuple, runtime_checkable, ) -from sciline.typing import Graph, Key +from sciline.typing import Graph class CycleError(Exception): @@ -25,7 +25,7 @@ class Scheduler(Protocol): Scheduler interface compatible with :py:class:`sciline.Pipeline`. """ - def get(self, graph: Graph, keys: List[Key]) -> Tuple[Any, ...]: + def get(self, graph: Graph, keys: list[Hashable]) -> Tuple[Any, ...]: """ Compute the result for given keys from the graph. @@ -44,7 +44,7 @@ class NaiveScheduler: :py:class:`DaskScheduler` instead. """ - def get(self, graph: Graph, keys: List[Key]) -> Tuple[Any, ...]: + def get(self, graph: Graph, keys: list[Hashable]) -> Tuple[Any, ...]: import graphlib dependencies = { @@ -56,7 +56,7 @@ def get(self, graph: Graph, keys: List[Key]) -> Tuple[Any, ...]: tasks = list(ts.static_order()) except graphlib.CycleError as e: raise CycleError from e - results: Dict[Key, Any] = {} + results: Dict[Hashable, Any] = {} for t in tasks: results[t] = graph[t].call(results) return tuple(results[key] for key in keys) @@ -87,7 +87,7 @@ def __init__(self, scheduler: Optional[Callable[..., Any]] = None) -> None: else: self._dask_get = scheduler - def get(self, graph: Graph, keys: List[Key]) -> Any: + def get(self, graph: Graph, keys: list[Hashable]) -> Any: from dask.utils import apply # Use `apply` to allow passing keyword arguments. diff --git a/src/sciline/series.py b/src/sciline/series.py deleted file mode 100644 index 2b5b5607..00000000 --- a/src/sciline/series.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from __future__ import annotations - -from typing import Iterator, Mapping, Type, TypeVar - -Key = TypeVar('Key') -Value = TypeVar('Value') - - -class Series(Mapping[Key, Value]): - """A series of values with labels (row index) and named row dimension.""" - - def __init__(self, row_dim: Type[Key], items: Mapping[Key, Value]) -> None: - """ - Create a new series. - - Parameters - ---------- - row_dim: - The row dimension. This must be a type or a type-alias (not an instance). - items: - The items of the series. - """ - self._row_dim = row_dim - self._map: Mapping[Key, Value] = items - - @property - def row_dim(self) -> type: - """The row dimension of the series.""" - return self._row_dim - - def __contains__(self, item: object) -> bool: - return item in self._map - - def __iter__(self) -> Iterator[Key]: - return iter(self._map) - - def __len__(self) -> int: - return len(self._map) - - def __getitem__(self, key: Key) -> Value: - return self._map[key] - - def __repr__(self) -> str: - return f"Series(row_dim={self.row_dim.__name__}, {self._map})" - - def _repr_html_(self) -> str: - return ( - f"" - + "".join( - f"" for k, v in self._map.items() - ) - + "
{self.row_dim.__name__}Value
{k}{v}
" - ) - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Series): - return NotImplemented - return self.row_dim == other.row_dim and self._map == other._map diff --git a/src/sciline/task_graph.py b/src/sciline/task_graph.py index e9662b93..cf80dc35 100644 --- a/src/sciline/task_graph.py +++ b/src/sciline/task_graph.py @@ -3,12 +3,12 @@ from __future__ import annotations from html import escape -from typing import Any, Generator, Optional, Sequence, Tuple, TypeVar, Union +from typing import Any, Generator, Hashable, Sequence, TypeVar from ._utils import key_name from .scheduler import DaskScheduler, NaiveScheduler, Scheduler from .serialize import json_serialize_task_graph -from .typing import Graph, Item, Json, Key +from .typing import Graph, Json, Key T = TypeVar("T") @@ -59,6 +59,9 @@ def wrap(s: str) -> str: ) +Targets = Hashable | tuple[Hashable, ...] + + class TaskGraph: """ Holds a concrete task graph and keys to compute. @@ -68,11 +71,7 @@ class TaskGraph: """ def __init__( - self, - *, - graph: Graph, - targets: Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]], - scheduler: Optional[Scheduler] = None, + self, *, graph: Graph, targets: Targets, scheduler: Scheduler | None = None ) -> None: self._graph = graph self._keys = targets @@ -87,12 +86,7 @@ def __init__( ) self._scheduler = scheduler - def compute( - self, - targets: Optional[ - Union[type, Tuple[type, ...], Item[T], Tuple[Item[T], ...]] - ] = None, - ) -> Any: + def compute(self, targets: Targets | None = None) -> Any: """ Compute the result of the graph. @@ -162,7 +156,7 @@ def serialize(self) -> dict[str, Json]: def _repr_html_(self) -> str: leafs = sorted( [ - escape(key_name(key)) + escape(key_name(key)) # type: ignore[arg-type] for key in ( self._keys if isinstance(self._keys, tuple) else [self._keys] ) diff --git a/src/sciline/typing.py b/src/sciline/typing.py index 93b1625d..883ddd88 100644 --- a/src/sciline/typing.py +++ b/src/sciline/typing.py @@ -1,14 +1,11 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from dataclasses import dataclass from typing import ( Any, Dict, - Generic, List, Optional, Tuple, - Type, TypeVar, Union, get_args, @@ -17,23 +14,10 @@ from ._provider import Provider - -@dataclass(frozen=True) -class Label: - tp: type - index: Any - - T = TypeVar('T') -@dataclass(frozen=True) -class Item(Generic[T]): - label: Tuple[Label, ...] - tp: Type[T] - - -Key = Union[type, Item[Any]] +Key = type Graph = dict[Key, Provider] diff --git a/src/sciline/visualize.py b/src/sciline/visualize.py index d6a5d855..a7e9ddae 100644 --- a/src/sciline/visualize.py +++ b/src/sciline/visualize.py @@ -1,23 +1,13 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from dataclasses import dataclass -from typing import ( - Any, - Dict, - List, - Optional, - Tuple, - Type, - TypeVar, - Union, - get_args, - get_origin, -) +from typing import Any, Dict, Hashable, get_args, get_origin +import cyclebane from graphviz import Digraph from ._provider import Provider, ProviderKind -from .typing import Graph, Item, Key, get_optional +from .typing import Graph, Key, get_optional @dataclass @@ -41,7 +31,7 @@ def to_graphviz( graph: Graph, compact: bool = False, cluster_generics: bool = True, - cluster_color: Optional[str] = '#f0f0ff', + cluster_color: str | None = '#f0f0ff', **kwargs: Any, ) -> Digraph: """ @@ -149,19 +139,15 @@ def _format_provider(provider: Provider, ret: Key, compact: bool) -> str: return f'{provider.location.qualname}_{_format_type(ret, compact=compact).name}' -T = TypeVar('T') - - def _extract_type_and_labels( - key: Union[Item[T], Type[T]], compact: bool -) -> Tuple[Type[T], List[Union[type, Tuple[type, Any]]]]: - if isinstance(key, Item): - label = key.label - return key.tp, [lb.tp if compact else (lb.tp, lb.index) for lb in label] + key: Hashable | cyclebane.graph.NodeName, compact: bool +) -> tuple[Hashable, list[Hashable | tuple[Hashable, Hashable]]]: + if isinstance(key, cyclebane.graph.NodeName): + return key.name, list(key.index.axes if compact else key.index.to_tuple()) return key, [] -def _format_type(tp: Key, compact: bool = False) -> Node: +def _format_type(tp: Hashable, compact: bool = False) -> Node: """ Helper for _format_graph. @@ -169,16 +155,15 @@ def _format_type(tp: Key, compact: bool = False) -> Node: but strip all module prefixes from the type name as well as the params. We may make this configurable in the future. """ - tp, labels = _extract_type_and_labels(tp, compact=compact) - if (tp_ := get_optional(tp)) is not None: + if (tp_ := get_optional(tp)) is not None: # type: ignore[arg-type] tp = tp_ - def get_base(tp: Key) -> str: - return tp.__name__ if hasattr(tp, '__name__') else str(tp).split('.')[-1] + def get_base(tp: Hashable) -> str: + return str(tp.__name__) if hasattr(tp, '__name__') else str(tp).split('.')[-1] - def format_label(label: Union[type, Tuple[type, Any]]) -> str: + def format_label(label: Hashable | tuple[Hashable, Any]) -> str: if isinstance(label, tuple): tp, index = label return f'{get_base(tp)}={index}' diff --git a/tests/complex_workflow_test.py b/tests/complex_workflow_test.py index 0d68862b..ace0b0a7 100644 --- a/tests/complex_workflow_test.py +++ b/tests/complex_workflow_test.py @@ -22,7 +22,7 @@ class RawData: DirectBeam = NewType('DirectBeam', npt.NDArray[np.float64]) SolidAngle = NewType('SolidAngle', npt.NDArray[np.float64]) -Run = TypeVar('Run') +Run = TypeVar('Run', SampleRun, BackgroundRun) # TODO Giving the base twice works with mypy, how can we avoid typing it twice? diff --git a/tests/param_table_test.py b/tests/param_table_test.py deleted file mode 100644 index 931c247b..00000000 --- a/tests/param_table_test.py +++ /dev/null @@ -1,53 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -import pytest - -import sciline as sl - - -def test_raises_with_zero_columns() -> None: - with pytest.raises(ValueError): - sl.ParamTable(row_dim=int, columns={}) - - -def test_raises_with_inconsistent_column_sizes() -> None: - with pytest.raises(ValueError): - sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0]}) - - -def test_raises_with_inconsistent_index_length() -> None: - with pytest.raises(ValueError): - sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0]}, index=[1, 2, 3]) - - -def test_raises_with_non_unique_index() -> None: - with pytest.raises(ValueError): - sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[1, 1, 2]) - - -def test_contains_includes_all_columns() -> None: - pt = sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0, 3.0]}) - assert int in pt - assert float in pt - assert str not in pt - - -def test_contains_does_not_include_index() -> None: - pt = sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0, 3.0]}) - assert int not in pt - - -def test_len_is_number_of_columns() -> None: - pt = sl.ParamTable(row_dim=int, columns={int: [1, 2, 3], float: [1.0, 2.0, 3.0]}) - assert len(pt) == 2 - - -def test_defaults_to_range_index() -> None: - pt = sl.ParamTable(row_dim=int, columns={float: [1.0, 2.0, 3.0]}) - assert pt.index == [0, 1, 2] - - -def test_index_with_no_columns() -> None: - pt = sl.ParamTable(row_dim=int, columns={}, index=[1, 2, 3]) - assert pt.index == [1, 2, 3] - assert len(pt) == 0 diff --git a/tests/pipeline_map_reduce_test.py b/tests/pipeline_map_reduce_test.py new file mode 100644 index 00000000..9bbffffa --- /dev/null +++ b/tests/pipeline_map_reduce_test.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from typing import NewType + +import pytest + +import sciline as sl + +A = NewType('A', int) +B = NewType('B', int) +C = NewType('C', int) +D = NewType('D', int) +X = NewType('X', int) + + +def a_to_b(a: A) -> B: + return B(a + 1) + + +def b_to_c(b: B) -> C: + return C(b + 2) + + +def c_to_d(c: C) -> D: + return D(c + 4) + + +def test_map_returns_pipeline_that_can_compute_for_each_value() -> None: + ab = sl.Pipeline((a_to_b,)) + mapped = ab.map({A: [A(10 * i) for i in range(3)]}) + with pytest.raises(sl.UnsatisfiedRequirement): + # B is not in the graph any more, since it has been duplicated + mapped.compute(B) + from cyclebane.graph import IndexValues, NodeName + + for i in range(3): + index = IndexValues(('dim_0',), (i,)) + assert mapped.compute(NodeName(A, index)) == A(10 * i) # type: ignore[call-overload] # noqa: E501 + assert mapped.compute(NodeName(B, index)) == B(10 * i + 1) # type: ignore[call-overload] # noqa: E501 + + +def test_reduce_returns_pipeline_passing_mapped_branches_to_reducing_func() -> None: + ab = sl.Pipeline((a_to_b,)) + mapped = ab.map({A: [A(10 * i) for i in range(3)]}) + # Define key to make mypy happy. This can actually be anything but currently + # the type-hinting of Pipeline is too specific, disallowing, e.g., strings. + Result = NewType('Result', int) + assert mapped.reduce(func=min, name=Result).compute(Result) == Result(1) + assert mapped.reduce(func=max, name=Result).compute(Result) == Result(21) diff --git a/tests/pipeline_setitem_test.py b/tests/pipeline_setitem_test.py new file mode 100644 index 00000000..65e37bce --- /dev/null +++ b/tests/pipeline_setitem_test.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: BSD-3-Clause +# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from typing import NewType + +import pytest + +import sciline as sl + +A = NewType('A', int) +B = NewType('B', int) +C = NewType('C', int) +D = NewType('D', int) +X = NewType('X', int) + + +def a_to_b(a: A) -> B: + return B(a + 1) + + +def ab_to_c(a: A, b: B) -> C: + return C(a + b) + + +def b_to_c(b: B) -> C: + return C(b + 2) + + +def c_to_d(c: C) -> D: + return D(c + 4) + + +def test_setitem_can_compose_pipelines() -> None: + ab = sl.Pipeline((a_to_b,)) + ab[A] = 0 + bc = sl.Pipeline((b_to_c,)) + bc[B] = ab + assert bc.compute(C) == C(3) + + +def test_setitem_raises_if_value_pipeline_has_no_unique_output() -> None: + abx = sl.Pipeline((a_to_b,)) + abx[X] = 666 + bc = sl.Pipeline((b_to_c,)) + with pytest.raises(ValueError, match='Graph must have exactly one sink node'): + bc[B] = abx + + +def test_setitem_can_add_value_pipeline_at_new_node_given_by_key() -> None: + ab = sl.Pipeline((a_to_b,)) + ab[A] = 0 + empty = sl.Pipeline(()) + empty[B] = ab + assert empty.compute(B) == B(1) + + +def test_setitem_replaces_existing_node_value_with_value_pipeline() -> None: + ab = sl.Pipeline((a_to_b,)) + ab[A] = 0 + bc = sl.Pipeline((b_to_c,)) + bc[B] = 111 + bc[B] = ab + assert bc.compute(C) == C(3) + + +def test_setitem_replaces_existing_branch_with_value() -> None: + ab = sl.Pipeline((a_to_b,)) + ab[A] = 0 + bc = sl.Pipeline((b_to_c,)) + bc[B] = ab + bc[B] = 111 + assert bc.compute(C) == C(113) + # __setitem__ prunes the entire branch, instead of just cutting the edges + with pytest.raises(sl.UnsatisfiedRequirement): + bc.compute(A) + + +def test_setitem_with_conflicting_nodes_in_value_pipeline_raises_on_data_mismatch() -> ( + None +): + ab = sl.Pipeline((a_to_b,)) + ab[A] = 100 + abc = sl.Pipeline((ab_to_c,)) + abc[A] = 666 + with pytest.raises(ValueError, match="Node data differs"): + abc[B] = ab + + +def test_setitem_with_conflicting_nodes_in_value_pipeline_accepts_on_data_match() -> ( + None +): + ab = sl.Pipeline((a_to_b,)) + ab[A] = 100 + abc = sl.Pipeline((ab_to_c,)) + abc[A] = 100 + abc[B] = ab + assert abc.compute(C) == C(201) + + +def test_setitem_with_conflicting_nodes_in_value_pipeline_raises_on_unique_data() -> ( + None +): + ab = sl.Pipeline((a_to_b,)) + ab[A] = 100 + abc = sl.Pipeline((ab_to_c,)) + # Missing data is just as bad as conflicting data. At first glance, it might seem + # like this should be allowed, but it would be a source of bugs and confusion. + # In particular since it would allow for adding providers or parents to an + # indirect dependency of the key in setitem, which would be very confusing. + with pytest.raises(ValueError, match="Node data differs"): + abc[B] = ab + + +def test_setitem_with_conflicting_node_inputs_in_value_pipeline_raises() -> None: + def x_to_b(x: X) -> B: + return B(x + 1) + + def bc_to_d(b: B, c: C) -> D: + return D(b + c) + + xbc = sl.Pipeline((x_to_b, b_to_c)) + xbc[X] = 666 + abcd = sl.Pipeline((a_to_b, bc_to_d)) + abcd[A] = 100 + with pytest.raises(ValueError, match="Node inputs differ"): + # If this worked naively by composing the NetworkX graph, it would look as + # below, giving B 2 inputs instead of 1, even though provider has 1 argument, + # corrupting the graph: + # A X + # \ / + # B + # |\ + # C | + # |/ + # D + abcd[C] = xbc diff --git a/tests/pipeline_test.py b/tests/pipeline_test.py index f8531260..68c5bbdc 100644 --- a/tests/pipeline_test.py +++ b/tests/pipeline_test.py @@ -83,7 +83,7 @@ def func2(x: int) -> str: def test_Scope_subclass_can_be_set_as_param() -> None: - Param = TypeVar('Param') + Param = TypeVar('Param', int, float) class Str(sl.Scope[Param, str], str): ... @@ -95,7 +95,7 @@ class Str(sl.Scope[Param, str], str): def test_Scope_subclass_can_be_set_as_param_with_unbound_typevar() -> None: - Param = TypeVar('Param') + Param = TypeVar('Param', int, float) class Str(sl.Scope[Param, str], str): ... @@ -107,8 +107,8 @@ class Str(sl.Scope[Param, str], str): def test_ScopeTwoParam_subclass_can_be_set_as_param() -> None: - Param1 = TypeVar('Param1') - Param2 = TypeVar('Param2') + Param1 = TypeVar('Param1', int, float) + Param2 = TypeVar('Param2', int, float) class Str(sl.ScopeTwoParams[Param1, Param2, str], str): ... @@ -120,8 +120,8 @@ class Str(sl.ScopeTwoParams[Param1, Param2, str], str): def test_ScopeTwoParam_subclass_can_be_set_as_param_with_unbound_typevar() -> None: - Param1 = TypeVar('Param1') - Param2 = TypeVar('Param2') + Param1 = TypeVar('Param1', int, float) + Param2 = TypeVar('Param2', int, float) class Str(sl.ScopeTwoParams[Param1, Param2, str], str): ... @@ -133,7 +133,7 @@ class Str(sl.ScopeTwoParams[Param1, Param2, str], str): def test_generic_providers_produce_use_dependencies_based_on_bound_typevar() -> None: - Param = TypeVar('Param') + Param = TypeVar('Param', int, float) class Str(sl.Scope[Param, str], str): ... @@ -161,7 +161,11 @@ def provide_int() -> int: ncall += 1 return 3 - Param = TypeVar('Param') + Run1 = NewType('Run1', int) + Run2 = NewType('Run2', int) + Result = NewType('Result', str) + + Param = TypeVar('Param', Run1, Run2, Result) class Float(sl.Scope[Param, float], float): ... @@ -172,10 +176,6 @@ class Str(sl.Scope[Param, str], str): def int_float_to_str(x: int, y: Float[Param]) -> Str[Param]: return Str(f"{x};{y}") - Run1 = NewType('Run1', int) - Run2 = NewType('Run2', int) - Result = NewType('Result', str) - def float1() -> Float[Run1]: return Float[Run1](1.5) @@ -193,7 +193,7 @@ def use_strings(s1: Str[Run1], s2: Str[Run2]) -> Result: def test_subclasses_of_generic_provider_defined_with_Scope_work() -> None: - Param = TypeVar('Param') + Param = TypeVar('Param', int, float) class StrT(sl.Scope[Param, str], str): ... @@ -234,7 +234,8 @@ def make_str2() -> Str2[Param]: def test_subclasses_of_generic_array_provider_defined_with_Scope_work() -> None: - Param = TypeVar('Param') + # int is unused, but a single constraint is not allowed by Python + Param = TypeVar('Param', str, int) class ArrayT(sl.Scope[Param, npt.NDArray[np.int64]], npt.NDArray[np.int64]): ... @@ -268,6 +269,12 @@ def provide_none() -> None: pipeline.insert(provide_none) +def test_setting_None_param_raises() -> None: + pipeline = sl.Pipeline() + with pytest.raises(ValueError): + pipeline[None] = 3 # type: ignore[index] + + def test_inserting_provider_with_no_return_type_raises() -> None: def provide_none(): # type: ignore[no-untyped-def] return None @@ -280,31 +287,30 @@ def provide_none(): # type: ignore[no-untyped-def] def test_TypeVar_requirement_of_provider_can_be_bound() -> None: - T = TypeVar('T') + T = TypeVar('T', int, float) def provider_int() -> int: return 3 - def provider(x: T) -> List[T]: + def provider(x: T) -> list[T]: return [x, x] pipeline = sl.Pipeline([provider_int, provider]) - assert pipeline.compute(List[int]) == [3, 3] + assert pipeline.compute(list[int]) == [3, 3] def test_TypeVar_that_cannot_be_bound_raises_UnboundTypeVar() -> None: - T = TypeVar('T') + T = TypeVar('T', int, float) - def provider(_: T) -> int: - return 1 + def provider(_: T) -> str: + return 'abc' - pipeline = sl.Pipeline([provider]) with pytest.raises(sl.UnboundTypeVar): - pipeline.compute(int) + sl.Pipeline([provider]) def test_unsatisfiable_TypeVar_requirement_of_provider_raises() -> None: - T = TypeVar('T') + T = TypeVar('T', int, float) def provider_int() -> int: return 3 @@ -318,8 +324,8 @@ def provider(x: T) -> List[T]: def test_TypeVar_params_are_not_associated_unless_they_match() -> None: - T1 = TypeVar('T1') - T2 = TypeVar('T2') + T1 = TypeVar('T1', int, float) + T2 = TypeVar('T2', int, float) class A(Generic[T1]): ... @@ -336,15 +342,15 @@ def not_matching(x: A[T1]) -> B[T2]: def matching(x: A[T1]) -> B[T1]: return B[T1]() - pipeline = sl.Pipeline([source, not_matching]) with pytest.raises(sl.UnboundTypeVar): - pipeline.compute(B[int]) + sl.Pipeline([source, not_matching]) pipeline = sl.Pipeline([source, matching]) pipeline.compute(B[int]) def test_multi_Generic_with_fully_bound_arguments() -> None: + # Note that no constraints necessary here, Sciline never sees the typevars T1 = TypeVar('T1') T2 = TypeVar('T2') @@ -361,8 +367,8 @@ def source() -> A[int, float]: def test_multi_Generic_with_partially_bound_arguments() -> None: - T1 = TypeVar('T1') - T2 = TypeVar('T2') + T1 = TypeVar('T1', int, float) + T2 = TypeVar('T2', int, float) @dataclass class A(Generic[T1, T2]): @@ -380,29 +386,30 @@ def partially_bound(x: T1) -> A[int, T1]: def test_multi_Generic_with_multiple_unbound() -> None: - T1 = TypeVar('T1') - T2 = TypeVar('T2') + T1 = TypeVar('T1', int, float) + T2 = TypeVar('T2', bool, str) @dataclass class A(Generic[T1, T2]): first: T1 second: T2 - def int_source() -> int: - return 1 - - def float_source() -> float: - return 2.0 - def unbound(x: T1, y: T2) -> A[T1, T2]: return A[T1, T2](x, y) - pipeline = sl.Pipeline([int_source, float_source, unbound]) - assert pipeline.compute(A[int, float]) == A[int, float](1, 2.0) - assert pipeline.compute(A[float, int]) == A[float, int](2.0, 1) + pipeline = sl.Pipeline([unbound]) + pipeline[int] = 1 + pipeline[float] = 2.0 + pipeline[bool] = True + pipeline[str] = 'a' + assert pipeline.compute(A[int, bool]) == A[int, bool](1, True) + assert pipeline.compute(A[int, str]) == A[int, str](1, 'a') + assert pipeline.compute(A[float, bool]) == A[float, bool](2.0, True) + assert pipeline.compute(A[float, str]) == A[float, str](2.0, 'a') def test_distinct_fully_bound_instances_yield_distinct_results() -> None: + # Note that no constraints necessary here, Sciline never sees the typevar T1 = TypeVar('T1') @dataclass @@ -421,8 +428,8 @@ def float_source() -> A[float]: def test_distinct_partially_bound_instances_yield_distinct_results() -> None: - T1 = TypeVar('T1') - T2 = TypeVar('T2') + T1 = TypeVar('T1', int, float) + T2 = TypeVar('T2', int, str) @dataclass class A(Generic[T1, T2]): @@ -432,20 +439,20 @@ class A(Generic[T1, T2]): def str_source() -> str: return 'a' - def int_source(x: T1) -> A[int, T1]: - return A[int, T1](1, x) + def int_source(x: T2) -> A[int, T2]: + return A[int, T2](1, x) - def float_source(x: T1) -> A[float, T1]: - return A[float, T1](2.0, x) + def float_source(x: T2) -> A[float, T2]: + return A[float, T2](2.0, x) pipeline = sl.Pipeline([str_source, int_source, float_source]) assert pipeline.compute(A[int, str]) == A[int, str](1, 'a') assert pipeline.compute(A[float, str]) == A[float, str](2.0, 'a') -def test_multiple_matching_partial_providers_raises() -> None: - T1 = TypeVar('T1') - T2 = TypeVar('T2') +def test_multiple_matching_partial_providers_uses_latest() -> None: + T1 = TypeVar('T1', int, float) + T2 = TypeVar('T2', int, float) @dataclass class A(Generic[T1, T2]): @@ -456,7 +463,7 @@ def int_source() -> int: return 1 def float_source() -> float: - return 2.0 + return 1.0 def provider1(x: T1) -> A[int, T1]: return A[int, T1](1, x) @@ -466,14 +473,14 @@ def provider2(x: T2) -> A[T2, float]: pipeline = sl.Pipeline([int_source, float_source, provider1, provider2]) assert pipeline.compute(A[int, int]) == A[int, int](1, 1) - assert pipeline.compute(A[float, float]) == A[float, float](2.0, 2.0) - with pytest.raises(sl.AmbiguousProvider): - pipeline.compute(A[int, float]) + assert pipeline.compute(A[float, float]) == A[float, float](1.0, 2.0) + # Multiple matches, but the latest one (provider2) is used + assert pipeline.compute(A[int, float]) == A[int, float](1, 2.0) def test_TypeVar_params_track_to_multiple_sources() -> None: - T1 = TypeVar('T1') - T2 = TypeVar('T2') + T1 = TypeVar('T1', int, float) + T2 = TypeVar('T2', int, float) @dataclass class A(Generic[T1]): @@ -658,7 +665,8 @@ def func(x: T) -> A[T]: def test_setitem_can_replace_generic_param_with_generic_param() -> None: - T = TypeVar('T') + # float unused, but must have more than one constraint + T = TypeVar('T', int, float) @dataclass class A(Generic[T]): @@ -924,7 +932,7 @@ def d(_f: float) -> None: assert calls.index('d') in (2, 3) -def test_prioritizes_specialized_provider_over_generic() -> None: +def test_inserting_generic_provider_replaces_specialized_provider() -> None: A = NewType('A', str) B = NewType('B', str) V = TypeVar('V', A, B) @@ -938,68 +946,11 @@ def p1(x: V) -> H[V]: def p2(x: B) -> H[B]: return H[B]("Special") - pl = sl.Pipeline([p1, p2], params={A: 'A', B: 'B'}) + # p2 will be replaced immediately by p1. + pl = sl.Pipeline([p2, p1], params={A: 'A', B: 'B'}) assert str(pl.compute(H[A])) == "Generic" - assert str(pl.compute(H[B])) == "Special" - - -def test_prioritizes_specialized_provider_over_generic_several_typevars() -> None: - A = NewType('A', str) - B = NewType('B', str) - T1 = TypeVar('T1') - T2 = TypeVar('T2') - - @dataclass - class C(Generic[T1, T2]): - first: T1 - second: T2 - third: str - - def p1(x: T1, y: T2) -> C[T1, T2]: - return C(x, y, 'generic') - - def p2(x: A, y: T2) -> C[A, T2]: - return C(x, y, 'medium generic') - - def p3(x: T2, y: B) -> C[T2, B]: - return C(x, y, 'generic medium') - - def p4(x: A, y: B) -> C[A, B]: - return C(x, y, 'special') - - pl = sl.Pipeline([p1, p2, p3, p4], params={A: A('A'), B: B('B')}) - - assert pl.compute(C[B, A]) == C('B', 'A', 'generic') - assert pl.compute(C[A, A]) == C('A', 'A', 'medium generic') - assert pl.compute(C[B, B]) == C('B', 'B', 'generic medium') - assert pl.compute(C[A, B]) == C('A', 'B', 'special') - - -def test_prioritizes_specialized_provider_raises() -> None: - A = NewType('A', str) - B = NewType('B', str) - T1 = TypeVar('T1') - T2 = TypeVar('T2') - - @dataclass - class C(Generic[T1, T2]): - first: T1 - second: T2 - - def p1(x: A, y: T1) -> C[A, T1]: - return C(x, y) - - def p2(x: T1, y: B) -> C[T1, B]: - return C(x, y) - - pl = sl.Pipeline([p1, p2], params={A: A('A'), B: B('B')}) - - with pytest.raises(sl.AmbiguousProvider): - pl.compute(C[A, B]) - - with pytest.raises(sl.UnsatisfiedRequirement): - pl.compute(C[B, A]) + assert str(pl.compute(H[B])) == "Generic" def test_compute_time_handler_allows_for_building_but_not_computing() -> None: @@ -1052,7 +1003,7 @@ def test_pipeline_copy_after_setitem() -> None: def test_copy_with_generic_providers() -> None: - Param = TypeVar('Param') + Param = TypeVar('Param', int, float) class Str(sl.Scope[Param, str], str): ... @@ -1118,7 +1069,11 @@ def test_pipeline_setitem_on_original_does_not_affect_copy() -> None: def test_pipeline_with_generics_setitem_on_original_does_not_affect_copy() -> None: - RunType = TypeVar('RunType') + Sample = NewType('Sample', int) + Background = NewType('Background', int) + Result = NewType('Result', int) + + RunType = TypeVar('RunType', Sample, Background) class RawData(sl.Scope[RunType, int], int): ... @@ -1126,10 +1081,6 @@ class RawData(sl.Scope[RunType, int], int): class SquaredData(sl.Scope[RunType, int], int): ... - Sample = NewType('Sample', int) - Background = NewType('Background', int) - Result = NewType('Result', int) - def square(x: RawData[RunType]) -> SquaredData[RunType]: return SquaredData[RunType](x * x) @@ -1154,7 +1105,11 @@ def process( def test_pipeline_with_generics_setitem_on_copy_does_not_affect_original() -> None: - RunType = TypeVar('RunType') + Sample = NewType('Sample', int) + Background = NewType('Background', int) + Result = NewType('Result', int) + + RunType = TypeVar('RunType', Sample, Background) class RawData(sl.Scope[RunType, int], int): ... @@ -1162,10 +1117,6 @@ class RawData(sl.Scope[RunType, int], int): class SquaredData(sl.Scope[RunType, int], int): ... - Sample = NewType('Sample', int) - Background = NewType('Background', int) - Result = NewType('Result', int) - def square(x: RawData[RunType]) -> SquaredData[RunType]: return SquaredData[RunType](x * x) @@ -1339,57 +1290,6 @@ def p2(v1: S2, v2: T2) -> N[S2, T2]: pipeline.get(N[M[str], float]) -def test_number_of_type_vars_defines_most_specialized() -> None: - Green = NewType('Green', str) - Blue = NewType('Blue', str) - Color = TypeVar('Color', Green, Blue) - - @dataclass - class Likes(Generic[Color]): - color: Color - - Preference = TypeVar('Preference') - - @dataclass - class Person(Generic[Preference, Color]): - preference: Preference - hatcolor: Color - provided_by: str - - def p(c: Color) -> Likes[Color]: - return Likes(c) - - def p0(p: Preference, c: Color) -> Person[Preference, Color]: - return Person(p, c, 'p0') - - def p1(c: Color) -> Person[Likes[Color], Color]: - return Person(Likes(c), c, 'p1') - - def p2(p: Preference) -> Person[Preference, Green]: - return Person(p, Green('g'), 'p2') - - pipeline = sl.Pipeline((p, p0, p1, p2)) - pipeline[Blue] = 'b' - pipeline[Green] = 'g' - - # only provided by p0 - assert pipeline.compute(Person[Likes[Green], Blue]) == Person( - Likes(Green('g')), Blue('b'), 'p0' - ) - # provided by p1 and p0 but p1 is preferred because it has fewer typevars - assert pipeline.compute(Person[Likes[Blue], Blue]) == Person( - Likes(Blue('b')), Blue('b'), 'p1' - ) - # provided by p2 and p0 but p2 is preferred because it has fewer typevars - assert pipeline.compute(Person[Likes[Blue], Green]) == Person( - Likes(Blue('b')), Green('g'), 'p2' - ) - - with pytest.raises(sl.AmbiguousProvider): - # provided by p1 and p2 with the same number of typevars - pipeline.get(Person[Likes[Green], Green]) - - def test_pipeline_with_decorated_provider() -> None: R = TypeVar('R') @@ -1526,3 +1426,12 @@ def __new__(cls, x: int) -> str: # type: ignore[misc] with pytest.raises(TypeError): sl.Pipeline([C], params={int: 3}) + + +def test_inserting_provider_with_duplicate_arguments_raises() -> None: + def bad(x: int, y: int) -> float: + return float(x + y) + + pipeline = sl.Pipeline() + with pytest.raises(ValueError, match="Duplicate type hints"): + pipeline.insert(bad) diff --git a/tests/pipeline_with_optional_test.py b/tests/pipeline_with_optional_test.py index 9ca926eb..e7550d24 100644 --- a/tests/pipeline_with_optional_test.py +++ b/tests/pipeline_with_optional_test.py @@ -64,7 +64,7 @@ def use_optional(*, x: Optional[int]) -> str: pipeline.get(str) -def test_Union_argument_order_matters() -> None: +def test_Union_argument_order_does_not_matter() -> None: def use_union(x: int | float) -> str: return f'{x}' @@ -73,8 +73,10 @@ def use_union(x: int | float) -> str: assert pipeline.compute(str) == '1' pipeline = sl.Pipeline([use_union]) pipeline[float | int] = 1 # type: ignore[index] - with pytest.raises(sl.UnsatisfiedRequirement): - pipeline.get(str) + assert pipeline.compute(str) == '1' + # Note that the above works because the hashes are the same: + assert hash(int | float) == hash(float | int) + assert hash(Union[int, float]) == hash(Union[float, int]) def test_optional_dependency_cannot_be_filled_transitively() -> None: diff --git a/tests/pipeline_with_param_table_test.py b/tests/pipeline_with_param_table_test.py deleted file mode 100644 index b052f598..00000000 --- a/tests/pipeline_with_param_table_test.py +++ /dev/null @@ -1,741 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright (c) 2023 Scipp contributors (https://github.com/scipp) -from typing import List, NewType, Optional, TypeVar - -import pytest - -import sciline as sl -from sciline.typing import Item, Label - - -def test_set_param_table_raises_if_param_names_are_duplicate() -> None: - pl = sl.Pipeline() - pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - with pytest.raises(ValueError): - pl.set_param_table(sl.ParamTable(str, {float: [4.0, 5.0, 6.0]})) - assert pl.compute(Item((Label(int, 1),), float)) == 2.0 - with pytest.raises(sl.UnsatisfiedRequirement): - pl.compute(Item((Label(str, 1),), float)) - - -def test_set_param_table_removes_columns_of_replaced_table() -> None: - pl = sl.Pipeline() - pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - # We could imagine that this would be allowed if the index - # (here: automatic index as range(3)) is the same. For now we do not. - pl.set_param_table(sl.ParamTable(int, {str: ['a', 'b', 'c']})) - assert pl.compute(Item((Label(int, 1),), str)) == 'b' - with pytest.raises(sl.UnsatisfiedRequirement): - pl.compute(Item((Label(int, 1),), float)) - - -def test_can_get_elements_of_param_table() -> None: - pl = sl.Pipeline() - pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - assert pl.compute(Item((Label(int, 1),), float)) == 2.0 - - -def test_can_get_elements_of_param_table_with_explicit_index() -> None: - pl = sl.Pipeline() - pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[11, 12, 13])) - assert pl.compute(Item((Label(int, 12),), float)) == 2.0 - - -def test_can_replace_param_table() -> None: - pl = sl.Pipeline() - table1 = sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[11, 12, 13]) - pl.set_param_table(table1) - assert pl.compute(Item((Label(int, 12),), float)) == 2.0 - table2 = sl.ParamTable(int, {float: [4.0, 5.0, 6.0]}, index=[21, 22, 23]) - pl.set_param_table(table2) - assert pl.compute(Item((Label(int, 22),), float)) == 5.0 - - -def test_can_replace_param_table_with_table_of_different_length() -> None: - pl = sl.Pipeline() - table1 = sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}, index=[11, 12, 13]) - pl.set_param_table(table1) - assert pl.compute(Item((Label(int, 13),), float)) == 3.0 - table2 = sl.ParamTable(int, {float: [4.0, 5.0]}, index=[21, 22]) - pl.set_param_table(table2) - assert pl.compute(Item((Label(int, 22),), float)) == 5.0 - # Make sure rows beyond the new table are not accessible - with pytest.raises(sl.UnsatisfiedRequirement): - pl.compute(Item((Label(int, 13),), float)) - - -def test_failed_replace_due_to_column_clash_in_other_table() -> None: - Row1 = NewType("Row1", int) - Row2 = NewType("Row2", int) - table1 = sl.ParamTable(Row1, {float: [1.0, 2.0, 3.0]}) - table2 = sl.ParamTable(Row2, {str: ['a', 'b', 'c']}) - table1_replacement = sl.ParamTable( - Row1, {float: [1.1, 2.2, 3.3], str: ['a', 'b', 'c']} - ) - pl = sl.Pipeline() - pl.set_param_table(table1) - pl.set_param_table(table2) - with pytest.raises(ValueError): - pl.set_param_table(table1_replacement) - # Make sure the original table is still accessible - assert pl.compute(Item((Label(Row1, 1),), float)) == 2.0 - - -def test_can_depend_on_elements_of_param_table() -> None: - # This is not a valid type annotation, not sure why it works with get_type_hints - def use_elem(x: Item((Label(int, 1),), float)) -> str: # type: ignore[valid-type] - return str(x) - - pl = sl.Pipeline([use_elem]) - pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - assert pl.compute(str) == "2.0" - - -def test_can_depend_on_elements_of_param_table_kwarg() -> None: - # This is not a valid type annotation, not sure why it works with get_type_hints - def use_elem( - *, x: Item((Label(int, 1),), float) # type: ignore[valid-type] - ) -> str: - return str(x) - - pl = sl.Pipeline([use_elem]) - pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - assert pl.compute(str) == "2.0" - - -def test_can_compute_series_of_param_values() -> None: - pl = sl.Pipeline() - pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - assert pl.compute(sl.Series[int, float]) == sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0}) - - -def test_cannot_compute_series_of_non_table_param() -> None: - pl = sl.Pipeline() - # Table for defining length - pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - pl[str] = 'abc' - # The alternative option would be to expect to return - # sl.Series(int, {0: 'abc', 1: 'abc', 2: 'abc'}) - # For now, we are not supporting this since it is unclear if this would be - # conceptually sound and risk free. - with pytest.raises(sl.UnsatisfiedRequirement): - pl.compute(sl.Series[int, str]) - - -def test_can_compute_series_of_derived_values() -> None: - def process(x: float) -> str: - return str(x) - - pl = sl.Pipeline([process]) - pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - assert pl.compute(sl.Series[int, str]) == sl.Series( - int, {0: "1.0", 1: "2.0", 2: "3.0"} - ) - - -def test_can_compute_series_of_derived_values_kwarg() -> None: - def process(*, x: float) -> str: - return str(x) - - pl = sl.Pipeline([process]) - pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - assert pl.compute(sl.Series[int, str]) == sl.Series( - int, {0: "1.0", 1: "2.0", 2: "3.0"} - ) - - -def test_creating_pipeline_with_provider_of_series_raises() -> None: - series = sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0}) - - def make_series() -> sl.Series[int, float]: - return series - - with pytest.raises(ValueError): - sl.Pipeline([make_series]) - - -def test_creating_pipeline_with_series_param_raises() -> None: - series = sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0}) - - with pytest.raises(ValueError): - sl.Pipeline([], params={sl.Series[int, float]: series}) - - -def test_explicit_index_of_param_table_is_forwarded_correctly() -> None: - def process(x: float) -> int: - return int(x) - - pl = sl.Pipeline([process]) - pl.set_param_table( - sl.ParamTable(str, {float: [1.0, 2.0, 3.0]}, index=['a', 'b', 'c']) - ) - assert pl.compute(sl.Series[str, int]) == sl.Series(str, {'a': 1, 'b': 2, 'c': 3}) - - -def test_can_gather_index() -> None: - Sum = NewType("Sum", float) - Name = NewType("Name", str) - - def gather(x: sl.Series[Name, float]) -> Sum: - return Sum(sum(x.values())) - - def make_float(x: str) -> float: - return float(x) - - pl = sl.Pipeline([gather, make_float]) - pl.set_param_table(sl.ParamTable(Name, {str: ["1.0", "2.0", "3.0"]})) - assert pl.compute(Sum) == 6.0 - - -def test_can_zip() -> None: - Sum = NewType("Sum", str) - Str = NewType("Str", str) - Run = NewType("Run", int) - - def gather_zip(x: sl.Series[Run, Str], y: sl.Series[Run, int]) -> Sum: - z = [f'{x_}{y_}' for x_, y_ in zip(x.values(), y.values())] - return Sum(str(z)) - - def use_str(x: str) -> Str: - return Str(x) - - pl = sl.Pipeline([gather_zip, use_str]) - pl.set_param_table(sl.ParamTable(Run, {str: ['a', 'a', 'ccc'], int: [1, 2, 3]})) - - assert pl.compute(Sum) == "['a1', 'a2', 'ccc3']" - - -def test_diamond_dependency_pulls_values_from_columns_in_same_param_table() -> None: - Sum = NewType("Sum", float) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - Row = NewType("Row", int) - - def gather(x: sl.Series[Row, float]) -> Sum: - return Sum(sum(x.values())) - - def join(x: Param1, y: Param2) -> float: - return x / y - - pl = sl.Pipeline([gather, join]) - pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 4, 9], Param2: [1, 2, 3]})) - - assert pl.compute(Sum) == Sum(6) - - -def test_diamond_dependency_on_same_column() -> None: - Sum = NewType("Sum", float) - Param = NewType("Param", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - Row = NewType("Row", int) - - def gather(x: sl.Series[Row, float]) -> Sum: - return Sum(sum(x.values())) - - def to_param1(x: Param) -> Param1: - return Param1(x) - - def to_param2(x: Param) -> Param2: - return Param2(x) - - def join(x: Param1, y: Param2) -> float: - return x / y - - pl = sl.Pipeline([gather, join, to_param1, to_param2]) - pl.set_param_table(sl.ParamTable(Row, {Param: [1, 2, 3]})) - - assert pl.compute(Sum) == Sum(3) - - -def test_dependencies_on_different_param_tables_broadcast() -> None: - Row1 = NewType("Row1", int) - Row2 = NewType("Row2", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - Product = NewType("Product", str) - - def gather_both(x: sl.Series[Row1, Param1], y: sl.Series[Row2, Param2]) -> Product: - broadcast = [[x_, y_] for x_ in x.values() for y_ in y.values()] - return Product(str(broadcast)) - - pl = sl.Pipeline([gather_both]) - pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) - pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) - assert pl.compute(Product) == "[[1, 4], [1, 5], [2, 4], [2, 5], [3, 4], [3, 5]]" - - -def test_dependency_on_other_param_table_in_parent_broadcasts_branch() -> None: - Row1 = NewType("Row1", int) - Row2 = NewType("Row2", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - Summed2 = NewType("Summed2", int) - Product = NewType("Product", str) - - def gather2_and_combine(x: Param1, y: sl.Series[Row2, Param2]) -> Summed2: - return Summed2(x * sum(y.values())) - - def gather1(x: sl.Series[Row1, Summed2]) -> Product: - return Product(str(list(x.values()))) - - pl = sl.Pipeline([gather1, gather2_and_combine]) - pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) - pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) - assert pl.compute(Product) == "[9, 18, 27]" - - -def test_dependency_on_other_param_table_in_grandparent_broadcasts_branch() -> None: - Row1 = NewType("Row1", int) - Row2 = NewType("Row2", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - Summed2 = NewType("Summed2", int) - Combined = NewType("Combined", int) - Product = NewType("Product", str) - - def gather2(x: sl.Series[Row2, Param2]) -> Summed2: - return Summed2(sum(x.values())) - - def combine(x: Param1, y: Summed2) -> Combined: - return Combined(x * y) - - def gather1(x: sl.Series[Row1, Combined]) -> Product: - return Product(str(list(x.values()))) - - pl = sl.Pipeline([gather1, gather2, combine]) - pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) - pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) - assert pl.compute(Product) == "[9, 18, 27]" - - -def test_nested_dependencies_on_different_param_tables() -> None: - Row1 = NewType("Row1", int) - Row2 = NewType("Row2", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - Combined = NewType("Combined", int) - - def combine(x: Param1, y: Param2) -> Combined: - return Combined(x * y) - - pl = sl.Pipeline([combine]) - pl.set_param_table(sl.ParamTable(Row1, {Param1: [1, 2, 3]})) - pl.set_param_table(sl.ParamTable(Row2, {Param2: [4, 5]})) - assert pl.compute(sl.Series[Row1, sl.Series[Row2, Combined]]) == sl.Series( - Row1, - { - 0: sl.Series(Row2, {0: 4, 1: 5}), - 1: sl.Series(Row2, {0: 8, 1: 10}), - 2: sl.Series(Row2, {0: 12, 1: 15}), - }, - ) - assert pl.compute(sl.Series[Row2, sl.Series[Row1, Combined]]) == sl.Series( - Row2, - { - 0: sl.Series(Row1, {0: 4, 1: 8, 2: 12}), - 1: sl.Series(Row1, {0: 5, 1: 10, 2: 15}), - }, - ) - - -def test_can_groupby_by_requesting_series_of_series() -> None: - Row = NewType("Row", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - - pl = sl.Pipeline() - pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]})) - expected = sl.Series( - Param1, - {1: sl.Series(Row, {0: 4, 1: 5}), 3: sl.Series(Row, {2: 6})}, - ) - assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == expected - - -def test_groupby_by_requesting_series_of_series_preserves_indices() -> None: - Row = NewType("Row", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - - pl = sl.Pipeline() - pl.set_param_table( - sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]}, index=[11, 12, 13]) - ) - assert pl.compute(sl.Series[Param1, sl.Series[Row, Param2]]) == sl.Series( - Param1, {1: sl.Series(Row, {11: 4, 12: 5}), 3: sl.Series(Row, {13: 6})} - ) - - -def test_can_groupby_by_param_used_in_ancestor() -> None: - Row = NewType("Row", int) - Param = NewType("Param", str) - - pl = sl.Pipeline() - pl.set_param_table(sl.ParamTable(Row, {Param: ['x', 'x', 'y']})) - expected = sl.Series( - Param, - {"x": sl.Series(Row, {0: "x", 1: "x"}), "y": sl.Series(Row, {2: "y"})}, - ) - assert pl.compute(sl.Series[Param, sl.Series[Row, Param]]) == expected - - -def test_multi_level_groupby_raises_with_params_from_same_table() -> None: - Row = NewType("Row", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - Param3 = NewType("Param3", int) - - pl = sl.Pipeline() - pl.set_param_table( - sl.ParamTable( - Row, {Param1: [1, 1, 1, 3], Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]} - ) - ) - with pytest.raises(ValueError, match='Could not find unique grouping node'): - pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) - - -def test_multi_level_groupby_with_params_from_different_table() -> None: - Row = NewType("Row", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - Param3 = NewType("Param3", int) - - pl = sl.Pipeline() - grouping1 = sl.ParamTable(Row, {Param2: [0, 1, 1, 2], Param3: [7, 8, 9, 10]}) - # We are not providing an explicit index here, so this only happens to work because - # the values of Param2 match range(2). - grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}) - pl.set_param_table(grouping1) - pl.set_param_table(grouping2) - assert pl.compute( - sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]] - ) == sl.Series( - Param1, - { - 1: sl.Series( - Param2, {0: sl.Series(Row, {0: 7}), 1: sl.Series(Row, {1: 8, 2: 9})} - ), - 3: sl.Series(Param2, {2: sl.Series(Row, {3: 10})}), - }, - ) - - -def test_multi_level_groupby_with_params_from_different_table_can_select() -> None: - Row = NewType("Row", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - Param3 = NewType("Param3", int) - - pl = sl.Pipeline() - grouping1 = sl.ParamTable(Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}) - # Note the missing index "6" here. - grouping2 = sl.ParamTable(Param2, {Param1: [1, 1]}, index=[4, 5]) - pl.set_param_table(grouping1) - pl.set_param_table(grouping2) - assert pl.compute( - sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]] - ) == sl.Series( - Param1, - { - 1: sl.Series( - Param2, {4: sl.Series(Row, {0: 7}), 5: sl.Series(Row, {1: 8, 2: 9})} - ) - }, - ) - - -def test_multi_level_groupby_with_params_from_different_table_preserves_index() -> None: - Row = NewType("Row", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - Param3 = NewType("Param3", int) - - pl = sl.Pipeline() - grouping1 = sl.ParamTable( - Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400] - ) - grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=[4, 5, 6]) - pl.set_param_table(grouping1) - pl.set_param_table(grouping2) - assert pl.compute( - sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]] - ) == sl.Series( - Param1, - { - 1: sl.Series( - Param2, - {4: sl.Series(Row, {100: 7}), 5: sl.Series(Row, {200: 8, 300: 9})}, - ), - 3: sl.Series(Param2, {6: sl.Series(Row, {400: 10})}), - }, - ) - - -def test_multi_level_groupby_with_params_from_different_table_can_reorder() -> None: - Row = NewType("Row", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - Param3 = NewType("Param3", int) - - pl = sl.Pipeline() - grouping1 = sl.ParamTable( - Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400] - ) - grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=[6, 5, 4]) - pl.set_param_table(grouping1) - pl.set_param_table(grouping2) - assert pl.compute( - sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]] - ) == sl.Series( - Param1, - { - 1: sl.Series( - Param2, - {6: sl.Series(Row, {400: 10}), 5: sl.Series(Row, {200: 8, 300: 9})}, - ), - 3: sl.Series(Param2, {4: sl.Series(Row, {100: 7})}), - }, - ) - - -@pytest.mark.parametrize("index", [None, [4, 5, 7]]) -def test_multi_level_groupby_raises_on_index_mismatch( - index: Optional[List[int]], -) -> None: - Row = NewType("Row", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - Param3 = NewType("Param3", int) - - pl = sl.Pipeline() - grouping1 = sl.ParamTable( - Row, {Param2: [4, 5, 5, 6], Param3: [7, 8, 9, 10]}, index=[100, 200, 300, 400] - ) - grouping2 = sl.ParamTable(Param2, {Param1: [1, 1, 3]}, index=index) - pl.set_param_table(grouping1) - pl.set_param_table(grouping2) - with pytest.raises(ValueError): - pl.compute(sl.Series[Param1, sl.Series[Param2, sl.Series[Row, Param3]]]) - - -@pytest.mark.parametrize("index", [None, [4, 5, 6]]) -def test_groupby_over_param_table(index: Optional[List[int]]) -> None: - Index = NewType("Index", int) - Name = NewType("Name", str) - Param = NewType("Param", int) - ProcessedParam = NewType("ProcessedParam", int) - SummedGroup = NewType("SummedGroup", int) - ProcessedGroup = NewType("ProcessedGroup", int) - - def process_param(x: Param) -> ProcessedParam: - return ProcessedParam(x + 1) - - def sum_group(group: sl.Series[Index, ProcessedParam]) -> SummedGroup: - return SummedGroup(sum(group.values())) - - def process(x: SummedGroup) -> ProcessedGroup: - return ProcessedGroup(2 * x) - - params = sl.ParamTable( - Index, {Param: [1, 2, 3], Name: ['a', 'a', 'b']}, index=index - ) - pl = sl.Pipeline([process_param, sum_group, process]) - pl.set_param_table(params) - - graph = pl.get(sl.Series[Name, ProcessedGroup]) - assert graph.compute() == sl.Series(Name, {'a': 10, 'b': 8}) - - -def test_requesting_series_index_that_is_not_in_param_table_raises() -> None: - Row = NewType("Row", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - - pl = sl.Pipeline() - pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]})) - - with pytest.raises(KeyError): - pl.compute(sl.Series[int, Param2]) - - -def test_requesting_series_index_that_is_a_param_raises_if_not_grouping() -> None: - Row = NewType("Row", int) - Param1 = NewType("Param1", int) - Param2 = NewType("Param2", int) - - pl = sl.Pipeline() - pl.set_param_table(sl.ParamTable(Row, {Param1: [1, 1, 3], Param2: [4, 5, 6]})) - with pytest.raises(ValueError, match='Could not find unique grouping node'): - pl.compute(sl.Series[Param1, Param2]) - - -def test_generic_providers_work_with_param_tables() -> None: - Param = TypeVar('Param') - Row = NewType("Row", int) - - class Str(sl.Scope[Param, str], str): - ... - - def parametrized(x: Param) -> Str[Param]: - return Str(f'{x}') - - def make_float() -> float: - return 1.5 - - pipeline = sl.Pipeline([make_float, parametrized]) - pipeline.set_param_table(sl.ParamTable(Row, {int: [1, 2, 3]})) - - assert pipeline.compute(Str[float]) == Str[float]('1.5') - with pytest.raises(sl.UnsatisfiedRequirement): - pipeline.compute(Str[int]) - assert pipeline.compute(sl.Series[Row, Str[int]]) == sl.Series( - Row, - { - 0: Str[int]('1'), - 1: Str[int]('2'), - 2: Str[int]('3'), - }, - ) - - -def test_generic_provider_can_depend_on_param_series() -> None: - Param = TypeVar('Param') - Row = NewType("Row", int) - - class Str(sl.Scope[Param, str], str): - ... - - def parametrized_gather(x: sl.Series[Row, Param]) -> Str[Param]: - return Str(f'{list(x.values())}') - - pipeline = sl.Pipeline([parametrized_gather]) - pipeline.set_param_table( - sl.ParamTable(Row, {int: [1, 2, 3], float: [1.5, 2.5, 3.5]}) - ) - - assert pipeline.compute(Str[int]) == Str[int]('[1, 2, 3]') - assert pipeline.compute(Str[float]) == Str[float]('[1.5, 2.5, 3.5]') - - -def test_generic_provider_can_depend_on_derived_param_series() -> None: - T = TypeVar('T') - Row = NewType("Row", int) - - class Str(sl.Scope[T, str], str): - ... - - def use_param(x: int) -> float: - return x + 0.5 - - def parametrized_gather(x: sl.Series[Row, T]) -> Str[T]: - return Str(f'{list(x.values())}') - - pipeline = sl.Pipeline([parametrized_gather, use_param]) - pipeline.set_param_table(sl.ParamTable(Row, {int: [1, 2, 3]})) - - assert pipeline.compute(Str[float]) == Str[float]('[1.5, 2.5, 3.5]') - - -def test_params_in_table_can_be_generic() -> None: - T = TypeVar('T') - Row = NewType("Row", int) - - class Str(sl.Scope[T, str], str): - ... - - class Param(sl.Scope[T, str], str): - ... - - def parametrized_gather(x: sl.Series[Row, Param[T]]) -> Str[T]: - return Str(','.join(x.values())) - - pipeline = sl.Pipeline([parametrized_gather]) - pipeline.set_param_table( - sl.ParamTable(Row, {Param[int]: ["1", "2"], Param[float]: ["1.5", "2.5"]}) - ) - - assert pipeline.compute(Str[int]) == Str[int]('1,2') - assert pipeline.compute(Str[float]) == Str[float]('1.5,2.5') - - -def test_compute_time_handler_works_alongside_param_table() -> None: - Missing = NewType("Missing", str) - - def process(x: float, missing: Missing) -> str: - return str(x) + missing - - pl = sl.Pipeline([process]) - pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - pl.get(sl.Series[int, str], handler=sl.HandleAsComputeTimeException()) - - -def test_param_table_column_and_param_of_same_type_can_coexist() -> None: - pl = sl.Pipeline() - pl[float] = 1.0 - pl.set_param_table(sl.ParamTable(int, {float: [2.0, 3.0]})) - assert pl.compute(float) == 1.0 - assert pl.compute(sl.Series[int, float]) == sl.Series(int, {0: 2.0, 1: 3.0}) - - -def test_pipeline_copy_with_param_table() -> None: - a = sl.Pipeline() - a.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - b = a.copy() - assert b.compute(sl.Series[int, float]) == sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0}) - - -def test_pipeline_set_param_table_on_original_does_not_affect_copy() -> None: - a = sl.Pipeline() - b = a.copy() - a.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - assert a.compute(sl.Series[int, float]) == sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0}) - with pytest.raises(sl.UnsatisfiedRequirement): - b.compute(sl.Series[int, float]) - - -def test_pipeline_set_param_table_on_copy_does_not_affect_original() -> None: - a = sl.Pipeline() - b = a.copy() - b.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - assert b.compute(sl.Series[int, float]) == sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0}) - with pytest.raises(sl.UnsatisfiedRequirement): - a.compute(sl.Series[int, float]) - - -def test_can_make_html_repr_with_param_table() -> None: - pl = sl.Pipeline() - pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]})) - assert pl._repr_html_() - - -def test_set_param_series_sets_up_pipeline_so_derived_series_can_be_computed() -> None: - ints = [1, 2, 3] - - def to_str(x: int) -> str: - return str(x) - - pl = sl.Pipeline((to_str,)) - pl.set_param_series(int, ints) - assert pl.compute(sl.Series[int, str]) == sl.Series(int, {1: '1', 2: '2', 3: '3'}) - - -def test_multiple_param_series_can_be_broadcast() -> None: - ints = [1, 2, 3] - floats = [1.0, 2.0] - - def to_str(x: int, y: float) -> str: - return str(x) + str(y) - - pl = sl.Pipeline((to_str,)) - pl.set_param_series(int, ints) - pl.set_param_series(float, floats) - assert pl.compute(sl.Series[int, sl.Series[float, str]]) == sl.Series( - int, - { - 1: sl.Series(float, {1.0: '11.0', 2.0: '12.0'}), - 2: sl.Series(float, {1.0: '21.0', 2.0: '22.0'}), - 3: sl.Series(float, {1.0: '31.0', 2.0: '32.0'}), - }, - ) diff --git a/tests/pipeline_with_postponed_annotations_test.py b/tests/pipeline_with_postponed_annotations_test.py index 12db530b..8af481d5 100644 --- a/tests/pipeline_with_postponed_annotations_test.py +++ b/tests/pipeline_with_postponed_annotations_test.py @@ -18,7 +18,7 @@ Int = NewType('Int', int) Str = NewType('Str', str) -T = TypeVar('T') +T = TypeVar('T', Int, Str) class Number(sl.Scope[T, int], int): diff --git a/tests/serialize/json_test.py b/tests/serialize/json_test.py index 465dfe03..834ab948 100644 --- a/tests/serialize/json_test.py +++ b/tests/serialize/json_test.py @@ -223,124 +223,6 @@ def test_serialize_kwonlyargs() -> None: assert res == expected_serialized_kwonlyargs_graph -def repeated_arg(a: str, b: str) -> list[str]: - return [a, b] - - -# Ids correspond to the result of assign_predictable_ids -expected_serialized_repeated_arg_nodes = [ - { - 'id': '0', - 'label': 'list[str]', - 'kind': 'data', - 'type': 'builtins.list[builtins.str]', - }, - { - 'id': '1', - 'label': 'repeated_arg', - 'kind': 'function', - 'function': 'tests.serialize.json_test.repeated_arg', - 'args': ['101', '102'], - 'kwargs': {}, - }, - { - 'id': '2', - 'label': 'str', - 'kind': 'data', - 'type': 'builtins.str', - }, -] -expected_serialized_repeated_arg_edges = [ - {'id': '100', 'source': '1', 'target': '0'}, - # The edge is repeated and disambiguated by the `args` of the function node. - {'id': '101', 'source': '2', 'target': '1'}, - {'id': '102', 'source': '2', 'target': '1'}, -] -expected_serialized_repeated_arg_graph = { - 'directed': True, - 'multigraph': False, - 'nodes': expected_serialized_repeated_arg_nodes, - 'edges': expected_serialized_repeated_arg_edges, -} - - -def test_serialize_repeated_arg() -> None: - pl = sl.Pipeline([repeated_arg], params={str: 'abc'}) - graph = pl.get(list[str]) - res = graph.serialize() - res = make_graph_predictable(res) - assert res == expected_serialized_repeated_arg_graph - - -def repeated_arg_kwonlyarg(a: str, *, b: str) -> list[str]: - return [a, b] - - -def repeated_kwonlyargs(*, x: int, b: int) -> str: - return str(x + b) - - -# Ids correspond to the result of assign_predictable_ids -expected_serialized_repeated_kwonlyarg_nodes = [ - { - 'id': '0', - 'label': 'int', - 'kind': 'data', - 'type': 'builtins.int', - }, - { - 'id': '1', - 'label': 'list[str]', - 'kind': 'data', - 'type': 'builtins.list[builtins.str]', - }, - { - 'id': '2', - 'label': 'repeated_arg_kwonlyarg', - 'kind': 'function', - 'function': 'tests.serialize.json_test.repeated_arg_kwonlyarg', - 'args': ['103'], - 'kwargs': {'b': '104'}, - }, - { - 'id': '3', - 'label': 'str', - 'kind': 'data', - 'type': 'builtins.str', - }, - { - 'id': '4', - 'label': 'repeated_kwonlyargs', - 'kind': 'function', - 'function': 'tests.serialize.json_test.repeated_kwonlyargs', - 'args': [], - 'kwargs': {'x': '100', 'b': '101'}, - }, -] -expected_serialized_repeated_kwonlyarg_edges = [ - {'id': '100', 'source': '0', 'target': '4'}, - {'id': '101', 'source': '0', 'target': '4'}, - {'id': '102', 'source': '2', 'target': '1'}, - {'id': '103', 'source': '3', 'target': '2'}, - {'id': '104', 'source': '3', 'target': '2'}, - {'id': '105', 'source': '4', 'target': '3'}, -] -expected_serialized_repeated_kwonlyarg_graph = { - 'directed': True, - 'multigraph': False, - 'nodes': expected_serialized_repeated_kwonlyarg_nodes, - 'edges': expected_serialized_repeated_kwonlyarg_edges, -} - - -def test_serialize_repeated_konlywarg() -> None: - pl = sl.Pipeline([repeated_arg_kwonlyarg, repeated_kwonlyargs], params={int: 4}) - graph = pl.get(list[str]) - res = graph.serialize() - res = make_graph_predictable(res) - assert res == expected_serialized_repeated_kwonlyarg_graph - - # Ids correspond to the result of assign_predictable_ids expected_serialized_lambda_nodes = [ { @@ -398,14 +280,6 @@ def __call__(self, x: int) -> float: graph.serialize() -def test_serialize_param_table() -> None: - pl = sl.Pipeline([as_float]) - pl.set_param_table(sl.ParamTable(str, {int: [3, -5]})) - graph = pl.get(sl.Series[str, float]) - with pytest.raises(ValueError): - graph.serialize() - - def test_serialize_validate_schema() -> None: pl = sl.Pipeline([make_int_b, zeros, to_string], params={Int[A]: 3}) graph = pl.get(str) diff --git a/tests/task_graph_test.py b/tests/task_graph_test.py index 799c231c..4a31b5cf 100644 --- a/tests/task_graph_test.py +++ b/tests/task_graph_test.py @@ -31,7 +31,9 @@ def as_float(x: int) -> float: def make_task_graph() -> Graph: pl = sl.Pipeline([as_float], params={int: 1}) - return pl.build(float, handler=sl.HandleAsBuildTimeException()) + return sl.data_graph.to_task_graph( + pl, (float,), handler=sl.HandleAsBuildTimeException() + ) def test_default_scheduler_is_dask_when_dask_available() -> None: diff --git a/tests/utils_test.py b/tests/utils_test.py index 96054fc2..bd3d77c1 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -8,7 +8,6 @@ import sciline from sciline import _utils from sciline._provider import Provider -from sciline.typing import Item, Label def module_foo(x: list[str]) -> str: @@ -50,11 +49,6 @@ def test_key_name_type_var() -> None: assert _utils.key_name(MyType) == '~MyType' # type: ignore[arg-type] -def test_key_name_item() -> None: - item = Item(tp=int, label=(Label(tp=float, index=0), Label(tp=str, index=1))) - assert _utils.key_name(item) == 'int(float:0, str:1)' - - def test_key_name_builtin_generic() -> None: MyType = NewType('MyType', str) assert _utils.key_name(list) == 'list' @@ -128,14 +122,7 @@ def test_key_full_qualname_type_var() -> None: assert res == 'tests.utils_test.~MyType' -def test_key_full_qualname_item() -> None: - item = Item(tp=int, label=(Label(tp=float, index=0), Label(tp=str, index=1))) - assert ( - _utils.key_full_qualname(item) - == 'builtins.int(builtins.float:0, builtins.str:1)' - ) - - +@pytest.mark.skipif(sys.version_info < (3, 10), reason="requires python3.10 or higher") def test_key_full_qualname_builtin_generic() -> None: MyType = NewType('MyType', str) assert _utils.key_full_qualname(list) == 'builtins.list' diff --git a/tests/visualize_test.py b/tests/visualize_test.py index ea404af2..266c0cf9 100644 --- a/tests/visualize_test.py +++ b/tests/visualize_test.py @@ -43,5 +43,5 @@ class SubA(A[T]): def test_optional_types_formatted_as_their_content() -> None: - formatted = sl.visualize._format_type(Optional[float]) # type: ignore[arg-type] + formatted = sl.visualize._format_type(Optional[float]) assert formatted.name == 'float'