Skip to content

Commit

Permalink
Merge pull request #99 from scipp/html-repr-for-taskgraph
Browse files Browse the repository at this point in the history
feat: task graph html repr
  • Loading branch information
jokasimr authored Jan 15, 2024
2 parents 5b462b6 + 805eeb3 commit 7951e09
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 76 deletions.
99 changes: 49 additions & 50 deletions src/sciline/display.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,9 @@
import inspect
from dataclasses import dataclass
from html import escape
from itertools import chain
from typing import Any, Literal, Mapping, Optional, Sequence, Tuple, TypeVar, Union
from typing import Iterable, List, Tuple, TypeVar, Union

from .typing import Item, Key
from .utils import groupby, qualname

ProviderKind = Literal['function', 'parameter', 'table']


@dataclass
class ProviderDisplayData:
origin: Key
args: Tuple[Union[Key, TypeVar], ...]
kind: ProviderKind
value: Any
from .typing import Item, Key, Provider
from .utils import groupby, keyname, kind_of_provider


def _details(summary: str, body: str) -> str:
Expand All @@ -27,50 +15,65 @@ def _details(summary: str, body: str) -> str:
'''


def _provider_name(p: Any) -> str:
if isinstance(p, tuple):
(name, cname), values = p
return escape(f'{qualname(cname)}({qualname(name)})')
name = f'{qualname(p.origin)}'
if p.args:
args = ','.join(
('*' if isinstance(arg, TypeVar) else f'{qualname(arg)}' for arg in p.args)
)
name += f'[{args}]'
return escape(f'{name}')
def _provider_name(
p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]]
) -> str:
key, args, _ = p
if args:
# This is always the case, but mypy complains
if hasattr(key, '__getitem__'):
return escape(keyname(key[args]))
return escape(keyname(key))


def _provider_source(p: Any) -> str:
if isinstance(p, tuple):
(name, cname), values = p
return escape(f'ParamTable({qualname(name)}, length={len(values)})')
if p.kind == 'function':
module = getattr(inspect.getmodule(p.value), '__name__', '')
def _provider_source(
p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]]
) -> str:
key, _, (v, *rest) = p
kind = kind_of_provider(v)
if kind == 'table':
# This is always the case, but mypy complains
if isinstance(key, Item):
return escape(
f'ParamTable({keyname(key.label[0].tp)}, length={len((v, *rest))})'
)
if kind == 'function':
module = getattr(inspect.getmodule(v), '__name__', '')
return _details(
escape(p.value.__name__),
escape(f'{module}.{p.value.__name__}'),
escape(v.__name__),
escape(f'{module}.{v.__name__}'),
)
return ''


def _provider_value(p: Any) -> str:
if not isinstance(p, tuple) and p.kind == 'parameter':
html = escape(str(p.value)).strip()
def _provider_value(
p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], List[Provider]]
) -> str:
_, _, (v, *_) = p
kind = kind_of_provider(v)
if kind == 'parameter':
html = escape(str(v())).strip()
return _details(f'{html[:30]}...', html) if len(html) > 30 else html
return ''


def pipeline_html_repr(
providers: Mapping[ProviderKind, Sequence[ProviderDisplayData]]
providers: Iterable[Tuple[Key, Tuple[Union[Key, TypeVar], ...], Provider]]
) -> str:
def table_name_and_column_name(p: ProviderDisplayData) -> Optional[Tuple[Any, Any]]:
if isinstance(p.origin, Item):
return (p.origin.label[0].tp, p.origin.tp)
return None
def associate_table_values(
p: Tuple[Key, Tuple[Union[Key, TypeVar], ...], Provider]
) -> Tuple[Key, Union[type, Tuple[Union[Key, TypeVar], ...]]]:
key, args, v = p
if isinstance(key, Item):
return (key.label[0].tp, key.tp)
return (key, args)

param_table_columns_by_name_colname = groupby(
table_name_and_column_name,
providers['table'],
providers_collected = (
(key, args, [value, *(v for _, _, v in rest)])
for ((key, args, value), *rest) in groupby(
associate_table_values,
providers,
).values()
)
provider_rows = '\n'.join(
(
Expand All @@ -81,11 +84,7 @@ def table_name_and_column_name(p: ProviderDisplayData) -> Optional[Tuple[Any, An
<td scope="row">{_provider_source(p)}</th>
</tr>'''
for p in sorted(
chain(
providers['function'],
providers['parameter'],
param_table_columns_by_name_colname.items(),
),
providers_collected,
key=_provider_name,
)
)
Expand Down
27 changes: 3 additions & 24 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

from sciline.task_graph import TaskGraph

from .display import ProviderDisplayData, ProviderKind, pipeline_html_repr
from .display import pipeline_html_repr
from .domain import Scope, ScopeTwoParams
from .handler import (
ErrorHandler,
Expand All @@ -38,7 +38,6 @@
from .scheduler import Scheduler
from .series import Series
from .typing import Graph, Item, Key, Label, Provider, get_optional, get_union
from .utils import groupby, qualname

T = TypeVar('T')
KeyType = TypeVar('KeyType')
Expand Down Expand Up @@ -140,14 +139,6 @@ def provide_none() -> None:
return None


def _kind_of_provider(p: Callable[..., Any]) -> ProviderKind:
if qualname(p) == f'{qualname(Pipeline.__setitem__)}.<locals>.<lambda>':
return 'parameter'
if qualname(p) == f'{qualname(Pipeline.set_param_table)}.<locals>.<lambda>':
return 'table'
return 'function'


class ReplicatorBase(Generic[IndexType]):
def __init__(self, index_name: type, index: Iterable[IndexType], path: List[Key]):
if len(path) == 0:
Expand Down Expand Up @@ -828,18 +819,6 @@ def _repr_html_(self) -> str:
for origin in self._subproviders
for args, value in self._subproviders[origin].items()
)
providers = groupby(
lambda p: p.kind,
(
ProviderDisplayData(
origin,
args,
kind := _kind_of_provider(value),
value() if kind != 'function' else value,
)
for origin, args, value in chain(
providers_with_parameters, providers_without_parameters
)
),
return pipeline_html_repr(
chain(providers_without_parameters, providers_with_parameters)
)
return pipeline_html_repr(providers)
9 changes: 9 additions & 0 deletions src/sciline/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
import inspect
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple

from sciline.typing import Graph, Key
Expand Down Expand Up @@ -49,6 +50,9 @@ def get(self, graph: Graph, keys: List[Key]) -> Tuple[Any, ...]:
results[t] = provider(*[results[arg] for arg in args])
return tuple(results[key] for key in keys)

def __repr__(self) -> str:
return f'{self.__class__.__name__}()'


class DaskScheduler:
"""Wrapper for a Dask scheduler.
Expand Down Expand Up @@ -80,3 +84,8 @@ def get(self, graph: Graph, keys: List[Key]) -> Any:
if str(e).startswith("Cycle detected"):
raise CycleError from e
raise

def __repr__(self) -> str:
module = getattr(inspect.getmodule(self._dask_get), '__name__', '')
name = self._dask_get.__name__
return f'{self.__class__.__name__}({module}.{name})'
86 changes: 85 additions & 1 deletion src/sciline/task_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,62 @@
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
from __future__ import annotations

from typing import Any, Optional, Tuple, TypeVar, Union
from html import escape
from typing import Any, Optional, Sequence, Tuple, TypeVar, Union

from .scheduler import DaskScheduler, NaiveScheduler, Scheduler
from .typing import Graph, Item
from .utils import keyname, kind_of_provider

T = TypeVar("T")


def _list_items(items: Sequence[str]) -> str:
return '\n'.join(
(
'<ul>',
('\n'.join((f'<li>{escape(it)}</li>' for it in items))),
'</ul>',
)
)


def _list_max_n_then_hide(items: Sequence[str], n: int = 5, header: str = '') -> str:
def wrap(s: str) -> str:
return '\n'.join(
(
'<div class="task-graph-detail-list">'
'<style> .task-graph-detail-list ul { margin-top: 0; } </style>',
s,
'</div>',
)
)

return wrap(
'\n'.join(
(
header,
_list_items(items),
)
)
if len(items) <= n
else '\n'.join(
(
'<details>',
'<style>',
'details[open] .task-graph-summary ul { display: none; }',
'</style>',
'<summary class="task-graph-summary">',
header,
_list_items((*items[:n], '...')),
'</summary>',
_list_items(items),
'</details>',
)
)
)


class TaskGraph:
"""
Holds a concrete task graph and keys to compute.
Expand Down Expand Up @@ -79,3 +127,39 @@ def visualize(
from .visualize import to_graphviz

return to_graphviz(self._graph, **kwargs)

def _repr_html_(self) -> str:
leafs = sorted(
[
escape(keyname(key))
for key in (
self._keys if isinstance(self._keys, tuple) else [self._keys]
)
]
)
roots = sorted(
{
escape(keyname(key))
for key, (value, _) in self._graph.items()
if kind_of_provider(value) != 'function'
}
)
scheduler = escape(str(self._scheduler))

def head(word: str) -> str:
return f'<h5>{word}</h5>'

return '\n'.join(
(
'<style>.task-graph-repr h5 { display: inline; }</style>',
'<div class="task-graph-repr">',
head('Output keys: '),
','.join(leafs),
'<br>',
head('Scheduler: '),
scheduler,
'<br>',
_list_max_n_then_hide(roots, header=head('Input keys:')),
'</div>',
)
)
2 changes: 2 additions & 0 deletions src/sciline/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Callable,
Dict,
Generic,
Literal,
Optional,
Tuple,
Type,
Expand All @@ -32,6 +33,7 @@ class Item(Generic[T]):


Provider = Callable[..., Any]
ProviderKind = Literal['function', 'parameter', 'table']


Key = Union[type, Item[Any]]
Expand Down
26 changes: 25 additions & 1 deletion src/sciline/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from collections import defaultdict
from typing import Any, Callable, DefaultDict, Iterable, TypeVar
from typing import Any, Callable, DefaultDict, Iterable, TypeVar, Union, get_args

from .typing import Item, Key, ProviderKind

T = TypeVar('T')
G = TypeVar('G')
Expand All @@ -16,3 +18,25 @@ def qualname(obj: Any) -> str:
return str(
obj.__qualname__ if hasattr(obj, '__qualname__') else obj.__class__.__qualname__
)


def keyname(key: Union[Key, TypeVar]) -> str:
if isinstance(key, TypeVar):
return str(key)
if isinstance(key, Item):
return f'{keyname(key.tp)}({keyname(key.label[0].tp)})'
args = get_args(key)
if len(args):
parameters = ', '.join(map(keyname, args))
return f'{qualname(key)}[{parameters}]'
return qualname(key)


def kind_of_provider(p: Callable[..., Any]) -> ProviderKind:
from .pipeline import Pipeline

if qualname(p) == f'{qualname(Pipeline.__setitem__)}.<locals>.<lambda>':
return 'parameter'
if qualname(p) == f'{qualname(Pipeline.set_param_table)}.<locals>.<lambda>':
return 'table'
return 'function'

0 comments on commit 7951e09

Please sign in to comment.