diff --git a/src/sciline/display.py b/src/sciline/display.py
index 2f9e8498..a73d3219 100644
--- a/src/sciline/display.py
+++ b/src/sciline/display.py
@@ -1,21 +1,9 @@
import inspect
-from dataclasses import dataclass
from html import escape
-from itertools import chain
-from typing import Any, Literal, Mapping, Optional, Sequence, Tuple, TypeVar, Union
+from typing import Iterable, List, Tuple, TypeVar, Union
-from .typing import Item, Key
-from .utils import groupby, qualname
-
-ProviderKind = Literal['function', 'parameter', 'table']
-
-
-@dataclass
-class ProviderDisplayData:
- origin: Key
- args: Tuple[Union[Key, TypeVar], ...]
- kind: ProviderKind
- value: Any
+from .typing import Item, Key, Provider
+from .utils import groupby, keyname, kind_of_provider
def _details(summary: str, body: str) -> str:
@@ -27,50 +15,65 @@ def _details(summary: str, body: str) -> str:
'''
-def _provider_name(p: Any) -> str:
- if isinstance(p, tuple):
- (name, cname), values = p
- return escape(f'{qualname(cname)}({qualname(name)})')
- name = f'{qualname(p.origin)}'
- if p.args:
- args = ','.join(
- ('*' if isinstance(arg, TypeVar) else f'{qualname(arg)}' for arg in p.args)
- )
- name += f'[{args}]'
- return escape(f'{name}')
+def _provider_name(
+ p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]]
+) -> str:
+ key, args, _ = p
+ if args:
+ # This is always the case, but mypy complains
+ if hasattr(key, '__getitem__'):
+ return escape(keyname(key[args]))
+ return escape(keyname(key))
-def _provider_source(p: Any) -> str:
- if isinstance(p, tuple):
- (name, cname), values = p
- return escape(f'ParamTable({qualname(name)}, length={len(values)})')
- if p.kind == 'function':
- module = getattr(inspect.getmodule(p.value), '__name__', '')
+def _provider_source(
+ p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]]
+) -> str:
+ key, _, (v, *rest) = p
+ kind = kind_of_provider(v)
+ if kind == 'table':
+ # This is always the case, but mypy complains
+ if isinstance(key, Item):
+ return escape(
+ f'ParamTable({keyname(key.label[0].tp)}, length={len((v, *rest))})'
+ )
+ if kind == 'function':
+ module = getattr(inspect.getmodule(v), '__name__', '')
return _details(
- escape(p.value.__name__),
- escape(f'{module}.{p.value.__name__}'),
+ escape(v.__name__),
+ escape(f'{module}.{v.__name__}'),
)
return ''
-def _provider_value(p: Any) -> str:
- if not isinstance(p, tuple) and p.kind == 'parameter':
- html = escape(str(p.value)).strip()
+def _provider_value(
+ p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]]
+) -> str:
+ _, _, (v, *_) = p
+ kind = kind_of_provider(v)
+ if kind == 'parameter':
+ html = escape(str(v())).strip()
return _details(f'{html[:30]}...', html) if len(html) > 30 else html
return ''
def pipeline_html_repr(
- providers: Mapping[ProviderKind, Sequence[ProviderDisplayData]]
+ providers: Iterable[Tuple[Key, Tuple[Union[Key, TypeVar], ...], Provider]]
) -> str:
- def table_name_and_column_name(p: ProviderDisplayData) -> Optional[Tuple[Any, Any]]:
- if isinstance(p.origin, Item):
- return (p.origin.label[0].tp, p.origin.tp)
- return None
+ def associate_table_values(
+ p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], Provider]
+ ) -> Tuple[Key, Union[type, Tuple[Union[Key, TypeVar], ...]]]:
+ key, args, v = p
+ if isinstance(key, Item):
+ return (key.label[0].tp, key.tp)
+ return (key, args)
- param_table_columns_by_name_colname = groupby(
- table_name_and_column_name,
- providers['table'],
+ providers_collected = (
+ (key, args, [value, *(v for _, _, v in rest)])
+ for ((key, args, value), *rest) in groupby(
+ associate_table_values,
+ providers,
+ ).values()
)
provider_rows = '\n'.join(
(
@@ -81,11 +84,7 @@ def table_name_and_column_name(p: ProviderDisplayData) -> Optional[Tuple[Any, An
{_provider_source(p)}
'''
for p in sorted(
- chain(
- providers['function'],
- providers['parameter'],
- param_table_columns_by_name_colname.items(),
- ),
+ providers_collected,
key=_provider_name,
)
)
diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py
index 738ea2ec..0e82ed0b 100644
--- a/src/sciline/pipeline.py
+++ b/src/sciline/pipeline.py
@@ -26,7 +26,7 @@
from sciline.task_graph import TaskGraph
-from .display import ProviderDisplayData, ProviderKind, pipeline_html_repr
+from .display import pipeline_html_repr
from .domain import Scope, ScopeTwoParams
from .handler import (
ErrorHandler,
@@ -38,7 +38,6 @@
from .scheduler import Scheduler
from .series import Series
from .typing import Graph, Item, Key, Label, Provider, get_optional, get_union
-from .utils import groupby, qualname
T = TypeVar('T')
KeyType = TypeVar('KeyType')
@@ -140,14 +139,6 @@ def provide_none() -> None:
return None
-def _kind_of_provider(p: Callable[..., Any]) -> ProviderKind:
- if qualname(p) == f'{qualname(Pipeline.__setitem__)}..':
- return 'parameter'
- if qualname(p) == f'{qualname(Pipeline.set_param_table)}..':
- return 'table'
- return 'function'
-
-
class ReplicatorBase(Generic[IndexType]):
def __init__(self, index_name: type, index: Iterable[IndexType], path: List[Key]):
if len(path) == 0:
@@ -828,18 +819,6 @@ def _repr_html_(self) -> str:
for origin in self._subproviders
for args, value in self._subproviders[origin].items()
)
- providers = groupby(
- lambda p: p.kind,
- (
- ProviderDisplayData(
- origin,
- args,
- kind := _kind_of_provider(value),
- value() if kind != 'function' else value,
- )
- for origin, args, value in chain(
- providers_with_parameters, providers_without_parameters
- )
- ),
+ return pipeline_html_repr(
+ chain(providers_without_parameters, providers_with_parameters)
)
- return pipeline_html_repr(providers)
diff --git a/src/sciline/scheduler.py b/src/sciline/scheduler.py
index 343d8962..9337c255 100644
--- a/src/sciline/scheduler.py
+++ b/src/sciline/scheduler.py
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
+import inspect
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple
from sciline.typing import Graph, Key
@@ -49,6 +50,9 @@ def get(self, graph: Graph, keys: List[Key]) -> Tuple[Any, ...]:
results[t] = provider(*[results[arg] for arg in args])
return tuple(results[key] for key in keys)
+ def __repr__(self) -> str:
+ return f'{self.__class__.__name__}()'
+
class DaskScheduler:
"""Wrapper for a Dask scheduler.
@@ -80,3 +84,8 @@ def get(self, graph: Graph, keys: List[Key]) -> Any:
if str(e).startswith("Cycle detected"):
raise CycleError from e
raise
+
+ def __repr__(self) -> str:
+ module = getattr(inspect.getmodule(self._dask_get), '__name__', '')
+ name = self._dask_get.__name__
+ return f'{self.__class__.__name__}({module}.{name})'
diff --git a/src/sciline/task_graph.py b/src/sciline/task_graph.py
index 3e66a11e..830193b7 100644
--- a/src/sciline/task_graph.py
+++ b/src/sciline/task_graph.py
@@ -2,14 +2,62 @@
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from __future__ import annotations
-from typing import Any, Optional, Tuple, TypeVar, Union
+from html import escape
+from typing import Any, Optional, Sequence, Tuple, TypeVar, Union
from .scheduler import DaskScheduler, NaiveScheduler, Scheduler
from .typing import Graph, Item
+from .utils import keyname, kind_of_provider
T = TypeVar("T")
+def _list_items(items: Sequence[str]) -> str:
+ return '\n'.join(
+ (
+ '',
+ ('\n'.join((f'- {escape(it)}
' for it in items))),
+ ' ',
+ )
+ )
+
+
+def _list_max_n_then_hide(items: Sequence[str], n: int = 5, header: str = '') -> str:
+ def wrap(s: str) -> str:
+ return '\n'.join(
+ (
+ ''
+ '',
+ s,
+ ' ',
+ )
+ )
+
+ return wrap(
+ '\n'.join(
+ (
+ header,
+ _list_items(items),
+ )
+ )
+ if len(items) <= n
+ else '\n'.join(
+ (
+ '',
+ '',
+ '',
+ header,
+ _list_items((*items[:n], '...')),
+ '',
+ _list_items(items),
+ ' ',
+ )
+ )
+ )
+
+
class TaskGraph:
"""
Holds a concrete task graph and keys to compute.
@@ -79,3 +127,39 @@ def visualize(
from .visualize import to_graphviz
return to_graphviz(self._graph, **kwargs)
+
+ def _repr_html_(self) -> str:
+ leafs = sorted(
+ [
+ escape(keyname(key))
+ for key in (
+ self._keys if isinstance(self._keys, tuple) else [self._keys]
+ )
+ ]
+ )
+ roots = sorted(
+ {
+ escape(keyname(key))
+ for key, (value, _) in self._graph.items()
+ if kind_of_provider(value) != 'function'
+ }
+ )
+ scheduler = escape(str(self._scheduler))
+
+ def head(word: str) -> str:
+ return f'{word}'
+
+ return '\n'.join(
+ (
+ '',
+ '',
+ head('Output keys: '),
+ ','.join(leafs),
+ ' ',
+ head('Scheduler: '),
+ scheduler,
+ ' ',
+ _list_max_n_then_hide(roots, header=head('Input keys:')),
+ ' ',
+ )
+ )
diff --git a/src/sciline/typing.py b/src/sciline/typing.py
index ac7af46e..1dcbb52d 100644
--- a/src/sciline/typing.py
+++ b/src/sciline/typing.py
@@ -6,6 +6,7 @@
Callable,
Dict,
Generic,
+ Literal,
Optional,
Tuple,
Type,
@@ -32,6 +33,7 @@ class Item(Generic[T]):
Provider = Callable[..., Any]
+ProviderKind = Literal['function', 'parameter', 'table']
Key = Union[type, Item[Any]]
diff --git a/src/sciline/utils.py b/src/sciline/utils.py
index bcd77c5b..a560595b 100644
--- a/src/sciline/utils.py
+++ b/src/sciline/utils.py
@@ -1,5 +1,7 @@
from collections import defaultdict
-from typing import Any, Callable, DefaultDict, Iterable, TypeVar
+from typing import Any, Callable, DefaultDict, Iterable, TypeVar, Union, get_args
+
+from .typing import Item, Key, ProviderKind
T = TypeVar('T')
G = TypeVar('G')
@@ -16,3 +18,25 @@ def qualname(obj: Any) -> str:
return str(
obj.__qualname__ if hasattr(obj, '__qualname__') else obj.__class__.__qualname__
)
+
+
+def keyname(key: Union[Key, TypeVar]) -> str:
+ if isinstance(key, TypeVar):
+ return str(key)
+ if isinstance(key, Item):
+ return f'{keyname(key.tp)}({keyname(key.label[0].tp)})'
+ args = get_args(key)
+ if len(args):
+ parameters = ', '.join(map(keyname, args))
+ return f'{qualname(key)}[{parameters}]'
+ return qualname(key)
+
+
+def kind_of_provider(p: Callable[..., Any]) -> ProviderKind:
+ from .pipeline import Pipeline
+
+ if qualname(p) == f'{qualname(Pipeline.__setitem__)}..':
+ return 'parameter'
+ if qualname(p) == f'{qualname(Pipeline.set_param_table)}..':
+ return 'table'
+ return 'function'
|