Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support replacing params and providers #90

Merged
merged 1 commit into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,12 +437,8 @@ def _set_provider(
if (origin := get_origin(key)) is not None:
subproviders = self._subproviders.setdefault(origin, {})
args = get_args(key)
if args in subproviders:
raise ValueError(f'Provider for {key} already exists')
subproviders[args] = provider
else:
if key in self._providers:
raise ValueError(f'Provider for {key} already exists')
self._providers[key] = provider

def _get_provider(
Expand Down
109 changes: 106 additions & 3 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,11 +577,114 @@ class B(Generic[T]):
pl[B[int]] = 1.0


def test_setitem_raises_if_key_exists() -> None:
def test_setitem_can_replace_param_with_param() -> None:
pl = sl.Pipeline()
pl[int] = 1
with pytest.raises(ValueError):
pl[int] = 2
pl[int] = 2
assert pl.compute(int) == 2


def test_insert_can_replace_param_with_provider() -> None:
def func() -> int:
return 2

pl = sl.Pipeline()
pl[int] = 1
pl.insert(func)
assert pl.compute(int) == 2


def test_setitem_can_replace_provider_with_param() -> None:
def func() -> int:
return 2

pl = sl.Pipeline()
pl.insert(func)
pl[int] = 1
assert pl.compute(int) == 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_setitem_can_replace_instance() -> None:
pl = sl.Pipeline()
original_number = 1
pl[int] = original_number
assert pl.compute(int) is original_number
new_number = 1
pl[int] = new_number
assert pl.compute(int) == original_number
assert pl.compute(int) is not original_number
assert pl.compute(int) is new_number

Can we also have a test case that shows instance check like this...?
Just to show that it's passed by lambda function, so the instance is kept not just the value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int is immutable, not sure I see the value of such a test?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but it shows that it doesn't copy the value and return it but keep the instance itself.
In case we want to pass around dict, it might be useful information I thought...?

I just suggested what I wanted to try with the changes.
Feel free to drop them...!


def test_insert_can_replace_provider_with_provider() -> None:
def func1() -> int:
return 1

def func2() -> int:
return 2

pl = sl.Pipeline()
pl.insert(func1)
pl.insert(func2)
assert pl.compute(int) == 2


def test_insert_can_replace_generic_provider_with_generic_provider() -> None:
T = TypeVar('T', int, float)

@dataclass
class A(Generic[T]):
value: T

def func1(x: T) -> A[T]:
return A[T](x)

def func2(x: T) -> A[T]:
return A[T](x + x)

pl = sl.Pipeline()
pl[int] = 1
pl.insert(func1)
pl.insert(func2)
assert pl.compute(A[int]) == A[int](2)


def test_insert_can_replace_generic_param_with_generic_provider() -> None:
T = TypeVar('T', int, float)

@dataclass
class A(Generic[T]):
value: T

def func(x: T) -> A[T]:
return A[T](x + x)

pl = sl.Pipeline()
pl[int] = 1
pl[A[T]] = A[T](1) # type: ignore[valid-type]
assert pl.compute(A[int]) == A[int](1)
pl.insert(func)
assert pl.compute(A[int]) == A[int](2)


def test_setitem_can_replace_generic_provider_with_generic_param() -> None:
T = TypeVar('T', int, float)

@dataclass
class A(Generic[T]):
value: T

def func(x: T) -> A[T]:
return A[T](x + x)

pl = sl.Pipeline()
pl[int] = 1
pl.insert(func)
assert pl.compute(A[int]) == A[int](2)
pl[A[T]] = A[T](1) # type: ignore[valid-type]
assert pl.compute(A[int]) == A[int](1)


def test_setitem_can_replace_generic_param_with_generic_param() -> None:
T = TypeVar('T')

@dataclass
class A(Generic[T]):
value: T

pl = sl.Pipeline()
pl[A[T]] = A[T](1) # type: ignore[valid-type]
assert pl.compute(A[int]) == A[int](1)
pl[A[T]] = A[T](2) # type: ignore[valid-type]
assert pl.compute(A[int]) == A[int](2)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_setitem_replace_generic_param_not_prioritized() -> None:
T = TypeVar('T')
@dataclass
class A(Generic[T]):
value: T
pl = sl.Pipeline(params={A[int]: 1})
pl[A[T]] = A[T](0.1) # type: ignore[valid-type]
assert pl.compute(A[int]) == A[int](1)
assert pl.compute(A[float]) == A[float](0.1)

Maybe it's worth show/check that the inserted generic provider doesn't overwrite or be prioritized to the explicit one...?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated to this change, I think? There are already tests for specializations.


def test_init_with_params() -> None:
Expand Down
8 changes: 8 additions & 0 deletions tests/pipeline_with_param_table_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,3 +605,11 @@ def process(x: float, missing: Missing) -> str:
pl = sl.Pipeline([process])
pl.set_param_table(sl.ParamTable(int, {float: [1.0, 2.0, 3.0]}))
pl.get(sl.Series[int, str], handler=sl.HandleAsComputeTimeException())


def test_param_table_column_and_param_of_same_type_can_coexist() -> None:
pl = sl.Pipeline()
pl[float] = 1.0
pl.set_param_table(sl.ParamTable(int, {float: [2.0, 3.0]}))
assert pl.compute(float) == 1.0
assert pl.compute(sl.Series[int, float]) == sl.Series(int, {0: 2.0, 1: 3.0})