diff --git a/src/sciline/display.py b/src/sciline/display.py index 2f9e8498..a73d3219 100644 --- a/src/sciline/display.py +++ b/src/sciline/display.py @@ -1,21 +1,9 @@ import inspect -from dataclasses import dataclass from html import escape -from itertools import chain -from typing import Any, Literal, Mapping, Optional, Sequence, Tuple, TypeVar, Union +from typing import Iterable, List, Tuple, TypeVar, Union -from .typing import Item, Key -from .utils import groupby, qualname - -ProviderKind = Literal['function', 'parameter', 'table'] - - -@dataclass -class ProviderDisplayData: - origin: Key - args: Tuple[Union[Key, TypeVar], ...] - kind: ProviderKind - value: Any +from .typing import Item, Key, Provider +from .utils import groupby, keyname, kind_of_provider def _details(summary: str, body: str) -> str: @@ -27,50 +15,65 @@ def _details(summary: str, body: str) -> str: ''' -def _provider_name(p: Any) -> str: - if isinstance(p, tuple): - (name, cname), values = p - return escape(f'{qualname(cname)}({qualname(name)})') - name = f'{qualname(p.origin)}' - if p.args: - args = ','.join( - ('*' if isinstance(arg, TypeVar) else f'{qualname(arg)}' for arg in p.args) - ) - name += f'[{args}]' - return escape(f'{name}') +def _provider_name( + p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]] +) -> str: + key, args, _ = p + if args: + # This is always the case, but mypy complains + if hasattr(key, '__getitem__'): + return escape(keyname(key[args])) + return escape(keyname(key)) -def _provider_source(p: Any) -> str: - if isinstance(p, tuple): - (name, cname), values = p - return escape(f'ParamTable({qualname(name)}, length={len(values)})') - if p.kind == 'function': - module = getattr(inspect.getmodule(p.value), '__name__', '') +def _provider_source( + p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]] +) -> str: + key, _, (v, *rest) = p + kind = kind_of_provider(v) + if kind == 'table': + # This is always the case, but mypy complains + if isinstance(key, Item): + return escape( + f'ParamTable({keyname(key.label[0].tp)}, length={len((v, *rest))})' + ) + if kind == 'function': + module = getattr(inspect.getmodule(v), '__name__', '') return _details( - escape(p.value.__name__), - escape(f'{module}.{p.value.__name__}'), + escape(v.__name__), + escape(f'{module}.{v.__name__}'), ) return '' -def _provider_value(p: Any) -> str: - if not isinstance(p, tuple) and p.kind == 'parameter': - html = escape(str(p.value)).strip() +def _provider_value( + p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]] +) -> str: + _, _, (v, *_) = p + kind = kind_of_provider(v) + if kind == 'parameter': + html = escape(str(v())).strip() return _details(f'{html[:30]}...', html) if len(html) > 30 else html return '' def pipeline_html_repr( - providers: Mapping[ProviderKind, Sequence[ProviderDisplayData]] + providers: Iterable[Tuple[Key, Tuple[Union[Key, TypeVar], ...], Provider]] ) -> str: - def table_name_and_column_name(p: ProviderDisplayData) -> Optional[Tuple[Any, Any]]: - if isinstance(p.origin, Item): - return (p.origin.label[0].tp, p.origin.tp) - return None + def associate_table_values( + p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], Provider] + ) -> Tuple[Key, Union[type, Tuple[Union[Key, TypeVar], ...]]]: + key, args, v = p + if isinstance(key, Item): + return (key.label[0].tp, key.tp) + return (key, args) - param_table_columns_by_name_colname = groupby( - table_name_and_column_name, - providers['table'], + providers_collected = ( + (key, args, [value, *(v for _, _, v in rest)]) + for ((key, args, value), *rest) in groupby( + associate_table_values, + providers, + ).values() ) provider_rows = '\n'.join( ( @@ -81,11 +84,7 @@ def table_name_and_column_name(p: ProviderDisplayData) -> Optional[Tuple[Any, An {_provider_source(p)} ''' for p in sorted( - chain( - providers['function'], - providers['parameter'], - param_table_columns_by_name_colname.items(), - ), + providers_collected, key=_provider_name, ) ) diff --git a/src/sciline/pipeline.py b/src/sciline/pipeline.py index 738ea2ec..0e82ed0b 100644 --- a/src/sciline/pipeline.py +++ b/src/sciline/pipeline.py @@ -26,7 +26,7 @@ from sciline.task_graph import TaskGraph -from .display import ProviderDisplayData, ProviderKind, pipeline_html_repr +from .display import pipeline_html_repr from .domain import Scope, ScopeTwoParams from .handler import ( ErrorHandler, @@ -38,7 +38,6 @@ from .scheduler import Scheduler from .series import Series from .typing import Graph, Item, Key, Label, Provider, get_optional, get_union -from .utils import groupby, qualname T = TypeVar('T') KeyType = TypeVar('KeyType') @@ -140,14 +139,6 @@ def provide_none() -> None: return None -def _kind_of_provider(p: Callable[..., Any]) -> ProviderKind: - if qualname(p) == f'{qualname(Pipeline.__setitem__)}..': - return 'parameter' - if qualname(p) == f'{qualname(Pipeline.set_param_table)}..': - return 'table' - return 'function' - - class ReplicatorBase(Generic[IndexType]): def __init__(self, index_name: type, index: Iterable[IndexType], path: List[Key]): if len(path) == 0: @@ -828,18 +819,6 @@ def _repr_html_(self) -> str: for origin in self._subproviders for args, value in self._subproviders[origin].items() ) - providers = groupby( - lambda p: p.kind, - ( - ProviderDisplayData( - origin, - args, - kind := _kind_of_provider(value), - value() if kind != 'function' else value, - ) - for origin, args, value in chain( - providers_with_parameters, providers_without_parameters - ) - ), + return pipeline_html_repr( + chain(providers_without_parameters, providers_with_parameters) ) - return pipeline_html_repr(providers) diff --git a/src/sciline/scheduler.py b/src/sciline/scheduler.py index 343d8962..9337c255 100644 --- a/src/sciline/scheduler.py +++ b/src/sciline/scheduler.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +import inspect from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple from sciline.typing import Graph, Key @@ -49,6 +50,9 @@ def get(self, graph: Graph, keys: List[Key]) -> Tuple[Any, ...]: results[t] = provider(*[results[arg] for arg in args]) return tuple(results[key] for key in keys) + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' + class DaskScheduler: """Wrapper for a Dask scheduler. @@ -80,3 +84,8 @@ def get(self, graph: Graph, keys: List[Key]) -> Any: if str(e).startswith("Cycle detected"): raise CycleError from e raise + + def __repr__(self) -> str: + module = getattr(inspect.getmodule(self._dask_get), '__name__', '') + name = self._dask_get.__name__ + return f'{self.__class__.__name__}({module}.{name})' diff --git a/src/sciline/task_graph.py b/src/sciline/task_graph.py index 3e66a11e..830193b7 100644 --- a/src/sciline/task_graph.py +++ b/src/sciline/task_graph.py @@ -2,14 +2,62 @@ # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) from __future__ import annotations -from typing import Any, Optional, Tuple, TypeVar, Union +from html import escape +from typing import Any, Optional, Sequence, Tuple, TypeVar, Union from .scheduler import DaskScheduler, NaiveScheduler, Scheduler from .typing import Graph, Item +from .utils import keyname, kind_of_provider T = TypeVar("T") +def _list_items(items: Sequence[str]) -> str: + return '\n'.join( + ( + '
    ', + ('\n'.join((f'
  • {escape(it)}
  • ' for it in items))), + '
', + ) + ) + + +def _list_max_n_then_hide(items: Sequence[str], n: int = 5, header: str = '') -> str: + def wrap(s: str) -> str: + return '\n'.join( + ( + '
' + '', + s, + '
', + ) + ) + + return wrap( + '\n'.join( + ( + header, + _list_items(items), + ) + ) + if len(items) <= n + else '\n'.join( + ( + '
', + '', + '', + header, + _list_items((*items[:n], '...')), + '', + _list_items(items), + '
', + ) + ) + ) + + class TaskGraph: """ Holds a concrete task graph and keys to compute. @@ -79,3 +127,39 @@ def visualize( from .visualize import to_graphviz return to_graphviz(self._graph, **kwargs) + + def _repr_html_(self) -> str: + leafs = sorted( + [ + escape(keyname(key)) + for key in ( + self._keys if isinstance(self._keys, tuple) else [self._keys] + ) + ] + ) + roots = sorted( + { + escape(keyname(key)) + for key, (value, _) in self._graph.items() + if kind_of_provider(value) != 'function' + } + ) + scheduler = escape(str(self._scheduler)) + + def head(word: str) -> str: + return f'
{word}
' + + return '\n'.join( + ( + '', + '
', + head('Output keys: '), + ','.join(leafs), + '
', + head('Scheduler: '), + scheduler, + '
', + _list_max_n_then_hide(roots, header=head('Input keys:')), + '
', + ) + ) diff --git a/src/sciline/typing.py b/src/sciline/typing.py index ac7af46e..1dcbb52d 100644 --- a/src/sciline/typing.py +++ b/src/sciline/typing.py @@ -6,6 +6,7 @@ Callable, Dict, Generic, + Literal, Optional, Tuple, Type, @@ -32,6 +33,7 @@ class Item(Generic[T]): Provider = Callable[..., Any] +ProviderKind = Literal['function', 'parameter', 'table'] Key = Union[type, Item[Any]] diff --git a/src/sciline/utils.py b/src/sciline/utils.py index bcd77c5b..a560595b 100644 --- a/src/sciline/utils.py +++ b/src/sciline/utils.py @@ -1,5 +1,7 @@ from collections import defaultdict -from typing import Any, Callable, DefaultDict, Iterable, TypeVar +from typing import Any, Callable, DefaultDict, Iterable, TypeVar, Union, get_args + +from .typing import Item, Key, ProviderKind T = TypeVar('T') G = TypeVar('G') @@ -16,3 +18,25 @@ def qualname(obj: Any) -> str: return str( obj.__qualname__ if hasattr(obj, '__qualname__') else obj.__class__.__qualname__ ) + + +def keyname(key: Union[Key, TypeVar]) -> str: + if isinstance(key, TypeVar): + return str(key) + if isinstance(key, Item): + return f'{keyname(key.tp)}({keyname(key.label[0].tp)})' + args = get_args(key) + if len(args): + parameters = ', '.join(map(keyname, args)) + return f'{qualname(key)}[{parameters}]' + return qualname(key) + + +def kind_of_provider(p: Callable[..., Any]) -> ProviderKind: + from .pipeline import Pipeline + + if qualname(p) == f'{qualname(Pipeline.__setitem__)}..': + return 'parameter' + if qualname(p) == f'{qualname(Pipeline.set_param_table)}..': + return 'table' + return 'function'