Skip to content

Commit

Permalink
StateManager: Pass intermediates using normalized values, rename coer…
Browse files Browse the repository at this point in the history
…ce to normalize (#298)
  • Loading branch information
paulmelnikow authored Oct 21, 2022
1 parent 7437e9a commit a2d25c7
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 147 deletions.
181 changes: 56 additions & 125 deletions poetry.lock

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ rds-graphile-worker-client = {version = "^0.1.1", optional = true}
semver = "3.0.0.dev3"
simplejson = "^3.17.5"
jsonschema = "^4.1.2"
artifax = {version = "0.4", optional = true}
# artifax = {git = "https://github.com/curvewise-forks/artifax.git", rev = "2e496525a04185525cfedbb3a9375101ec7faec1", optional = true}
# artifax = {version = "0.4", optional = true}
artifax = {git = "https://github.com/curvewise-forks/artifax.git", rev = "3cd288a1698a798c0e488eba31e8e9ce3f075e1c", optional = true}
# Temporarily declare artifax dependencies until we publish an artifax fork.
# pathos = {version = "*", optional = true}
# exos = {version = "*", optional = true}
pathos = {version = "*", optional = true}
exos = {version = "*", optional = true}

[tool.poetry.extras]
aws_lambda_build = ["executor"]
client = ["boto3"]
compute_graph = ["artifax"]
# compute_graph = ["artifax", "pathos", "exos"]
# compute_graph = ["artifax"]
compute_graph = ["artifax", "pathos", "exos"]
lambda_common = ["harrison"]
cli = ["click"]
rds_graphile_worker = ["rds-graphile-worker-client"]
Expand Down
4 changes: 2 additions & 2 deletions werkit/compute/graph/_custom_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def deserialize(cls, json_data: JSONType) -> CanonicalType:

@classmethod
@abstractmethod
def coerce(cls, value: t.Any) -> CanonicalType:
def normalize(cls, value: t.Any) -> CanonicalType:
"""
Coerce the given value to the canonical native type. Raise an exception
if it can't be coerced.
if it can't be normalized.
"""

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions werkit/compute/graph/_dependency_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def deserialize(self, value: JSONType) -> t.Any:
value_type.validate(value)
return value_type.deserialize(value)

def coerce(self, name: str, value: t.Any) -> t.Any:
def normalize(self, name: str, value: t.Any) -> t.Any:
if self.value_type_is_built_in:
return coerce_value_to_builtin_type(
name=name,
Expand All @@ -60,7 +60,7 @@ def coerce(self, name: str, value: t.Any) -> t.Any:
)
else:
# TODO: Perhaps catch and re-throw to improve the error message.
return t.cast(t.Type[CustomType], self.value_type).coerce(value)
return t.cast(t.Type[CustomType], self.value_type).normalize(value)

def serialize_value(self, value: t.Any) -> JSONType:
if self.value_type_is_built_in:
Expand Down
28 changes: 20 additions & 8 deletions werkit/compute/graph/_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,23 @@ def deserialize(self, **kwargs: t.Dict) -> None:
}
self.store.update(deserialized)

def coerce(self, **kwargs: t.Dict) -> t.Dict:
def normalize(self, **kwargs: t.Dict) -> t.Dict:
return {
name: self.dependency_graph.all_nodes[name].coerce(name=name, value=value)
name: self.dependency_graph.all_nodes[name].normalize(
name=name, value=value
)
for name, value in kwargs.items()
}

def set(self, **kwargs: t.Dict) -> None:
self._assert_known_keys(kwargs.keys())
coerced = self.coerce(**kwargs)
self.store.update(coerced)
normalized = self.normalize(**kwargs)
self.store.update(normalized)

def evaluate(
self, targets: t.List[str] = None, handle_exceptions: bool = False
) -> None:
import functools
from artifax import Artifax

if targets is not None:
Expand All @@ -47,18 +50,27 @@ def evaluate(
else:
self._assert_known_keys(targets)

def wrap_node(name, node):
wrapped = node.bind(self.instance)

def wrapper(*args):
value = wrapped(*args)
return node.normalize(name, value)

functools.update_wrapper(wrapper, wrapped)
return wrapper

afx = Artifax(
{
name: self.dependency_graph.compute_nodes[name].bind(self.instance)
for name in self.dependency_graph.compute_nodes.keys()
name: wrap_node(name, node)
for name, node in self.dependency_graph.compute_nodes.items()
}
)
if self.store:
afx.set(**self.store)
afx.build(targets=targets)
# TODO: `afx.build()` should always return an object.
coerced = self.coerce(**afx._result)
self.store.update(coerced)
self.store.update(**afx._result)

def serialize(self, targets: t.List[str] = None) -> t.Dict:
if targets is not None:
Expand Down
9 changes: 9 additions & 0 deletions werkit/compute/graph/test_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,18 @@ def test_state_manager_with_custom_type() -> None:
assert thing.description == "Example description"
assert thing.count == 25

# Due to rounding in normalize(), other_thing should be rounded.
assert state_manager.store["other_thing"] == (1.52, 2.52, 3.52)


def test_state_manager_propagates_normalized_value() -> None:
state_manager = MyComputeProcessWithCustomType().state_manager
state_manager.set(a=1, b=2)
state_manager.evaluate()

assert state_manager.store["further_derived_thing"] == "(1.52, 2.52, 3.52)"


def test_state_manager_deserializes_custom_type() -> None:
state_manager = MyComputeProcessWithCustomType().state_manager

Expand Down
16 changes: 12 additions & 4 deletions werkit/compute/graph/testing_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@ def name(cls) -> str:
return "MyModel"

@classmethod
def coerce(cls, value: t.Any) -> MyModel:
def normalize(cls, value: t.Any) -> MyModel:
if not isinstance(value, MyModel):
raise ValueError(f"Can't coerce {type(value).__name__} to {cls.__name__}")
raise ValueError(
f"Can't normalize {type(value).__name__} to {cls.__name__}"
)
return value

@classmethod
Expand All @@ -93,9 +95,11 @@ class Vector3(CustomType[tuple]):
DECIMALS = 2

@classmethod
def coerce(cls, value: t.Any) -> tuple:
def normalize(cls, value: t.Any) -> tuple:
if not isinstance(value, tuple):
raise ValueError(f"Can't coerce {type(value).__name__} to {cls.__name__}")
raise ValueError(
f"Can't normalize {type(value).__name__} to {cls.__name__}"
)
elif not len(value) == 3:
raise ValueError("Excepted tuple to have length 3")
return tuple(round(coord, cls.DECIMALS) for coord in value)
Expand All @@ -120,3 +124,7 @@ def thing(self) -> MyModel:
@output(value_type=Vector3)
def other_thing(self) -> tuple:
return (1.5151, 2.5151, 3.5151)

@output(value_type=str)
def further_derived_thing(self, other_thing) -> str:
return str(other_thing)

0 comments on commit a2d25c7

Please sign in to comment.