Skip to content

Commit

Permalink
Merge pull request #78 from scipp/html-repr-for-pipeline
Browse files Browse the repository at this point in the history
Html repr for pipeline
  • Loading branch information
jokasimr authored Jan 11, 2024
2 parents 7c2d6a4 + e25cbe8 commit 5b462b6
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 0 deletions.
108 changes: 108 additions & 0 deletions src/sciline/display.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import inspect
from dataclasses import dataclass
from html import escape
from itertools import chain
from typing import Any, Literal, Mapping, Optional, Sequence, Tuple, TypeVar, Union

from .typing import Item, Key
from .utils import groupby, qualname

ProviderKind = Literal['function', 'parameter', 'table']


@dataclass
class ProviderDisplayData:
origin: Key
args: Tuple[Union[Key, TypeVar], ...]
kind: ProviderKind
value: Any


def _details(summary: str, body: str) -> str:
return f'''
<details>
<summary>{summary}</summary>
{body}
</details>
'''


def _provider_name(p: Any) -> str:
if isinstance(p, tuple):
(name, cname), values = p
return escape(f'{qualname(cname)}({qualname(name)})')
name = f'{qualname(p.origin)}'
if p.args:
args = ','.join(
('*' if isinstance(arg, TypeVar) else f'{qualname(arg)}' for arg in p.args)
)
name += f'[{args}]'
return escape(f'{name}')


def _provider_source(p: Any) -> str:
if isinstance(p, tuple):
(name, cname), values = p
return escape(f'ParamTable({qualname(name)}, length={len(values)})')
if p.kind == 'function':
module = getattr(inspect.getmodule(p.value), '__name__', '')
return _details(
escape(p.value.__name__),
escape(f'{module}.{p.value.__name__}'),
)
return ''


def _provider_value(p: Any) -> str:
if not isinstance(p, tuple) and p.kind == 'parameter':
html = escape(str(p.value)).strip()
return _details(f'{html[:30]}...', html) if len(html) > 30 else html
return ''


def pipeline_html_repr(
providers: Mapping[ProviderKind, Sequence[ProviderDisplayData]]
) -> str:
def table_name_and_column_name(p: ProviderDisplayData) -> Optional[Tuple[Any, Any]]:
if isinstance(p.origin, Item):
return (p.origin.label[0].tp, p.origin.tp)
return None

param_table_columns_by_name_colname = groupby(
table_name_and_column_name,
providers['table'],
)
provider_rows = '\n'.join(
(
f'''
<tr>
<td scope="row">{_provider_name(p)}</td>
<td scope="row">{_provider_value(p)}</td>
<td scope="row">{_provider_source(p)}</th>
</tr>'''
for p in sorted(
chain(
providers['function'],
providers['parameter'],
param_table_columns_by_name_colname.items(),
),
key=_provider_name,
)
)
)
return f'''
<div class="pipeline-html-repr">
<table>
<thead>
<tr>
<th scope="col">Name</th>
<th scope="col">Value</th>
<th scope="col">Source</th>
</tr>
</thead>
<tbody>
{provider_rows}
</tbody>
</table>
</div>
'''.strip()
35 changes: 35 additions & 0 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from sciline.task_graph import TaskGraph

from .display import ProviderDisplayData, ProviderKind, pipeline_html_repr
from .domain import Scope, ScopeTwoParams
from .handler import (
ErrorHandler,
Expand All @@ -37,6 +38,7 @@
from .scheduler import Scheduler
from .series import Series
from .typing import Graph, Item, Key, Label, Provider, get_optional, get_union
from .utils import groupby, qualname

T = TypeVar('T')
KeyType = TypeVar('KeyType')
Expand Down Expand Up @@ -138,6 +140,14 @@ def provide_none() -> None:
return None


def _kind_of_provider(p: Callable[..., Any]) -> ProviderKind:
if qualname(p) == f'{qualname(Pipeline.__setitem__)}.<locals>.<lambda>':
return 'parameter'
if qualname(p) == f'{qualname(Pipeline.set_param_table)}.<locals>.<lambda>':
return 'table'
return 'function'


class ReplicatorBase(Generic[IndexType]):
def __init__(self, index_name: type, index: Iterable[IndexType], path: List[Key]):
if len(path) == 0:
Expand Down Expand Up @@ -808,3 +818,28 @@ def copy(self) -> Pipeline:

def __copy__(self) -> Pipeline:
return self.copy()

def _repr_html_(self) -> str:
providers_without_parameters = (
(origin, tuple(), value) for origin, value in self._providers.items()
) # type: ignore[var-annotated]
providers_with_parameters = (
(origin, args, value)
for origin in self._subproviders
for args, value in self._subproviders[origin].items()
)
providers = groupby(
lambda p: p.kind,
(
ProviderDisplayData(
origin,
args,
kind := _kind_of_provider(value),
value() if kind != 'function' else value,
)
for origin, args, value in chain(
providers_with_parameters, providers_without_parameters
)
),
)
return pipeline_html_repr(providers)
18 changes: 18 additions & 0 deletions src/sciline/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from collections import defaultdict
from typing import Any, Callable, DefaultDict, Iterable, TypeVar

T = TypeVar('T')
G = TypeVar('G')


def groupby(f: Callable[[T], G], a: Iterable[T]) -> DefaultDict[G, list[T]]:
g = defaultdict(lambda: [])
for e in a:
g[f(e)].append(e)
return g


def qualname(obj: Any) -> str:
return str(
obj.__qualname__ if hasattr(obj, '__qualname__') else obj.__class__.__qualname__
)
5 changes: 5 additions & 0 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,3 +1197,8 @@ def process(
b[RawData[Sample]] = 7
assert a.compute(Result) == 29
assert b.compute(Result) == 53


def test_html_repr() -> None:
pipeline = sl.Pipeline([make_int], params={float: 5.0})
assert isinstance(pipeline._repr_html_(), str)
6 changes: 6 additions & 0 deletions tests/pipeline_with_param_table_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,3 +638,9 @@ def test_pipeline_set_param_table_on_copy_does_not_affect_original() -> None:
assert b.compute(sl.Series[int, float]) == sl.Series(int, {0: 1.0, 1: 2.0, 2: 3.0})
with pytest.raises(sl.UnsatisfiedRequirement):
a.compute(sl.Series[int, float])


def test_can_make_html_repr_with_param_table() -> None:
pl = sl.Pipeline()
pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
assert pl._repr_html_()

0 comments on commit 5b462b6

Please sign in to comment.