Skip to content

Commit

Permalink
Merge pull request #280 from OpenFreeEnergy/feature/iss-277-restart-p…
Browse files Browse the repository at this point in the history
…olicy

Implement task restart policies
  • Loading branch information
dotsdl authored Jan 23, 2025
2 parents c32f00d + 4056752 commit e01e445
Show file tree
Hide file tree
Showing 11 changed files with 1,921 additions and 49 deletions.
7 changes: 6 additions & 1 deletion alchemiscale/compute/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from fastapi import FastAPI, APIRouter, Body, Depends
from fastapi.middleware.gzip import GZipMiddleware
from gufe.tokenization import GufeTokenizable, JSON_HANDLER
from gufe.protocols import ProtocolDAGResult

from ..base.api import (
QueryGUFEHandler,
Expand Down Expand Up @@ -329,7 +330,7 @@ def set_task_result(
validate_scopes(task_sk.scope, token)

pdr = json.loads(protocoldagresult, cls=JSON_HANDLER.decoder)
pdr = GufeTokenizable.from_dict(pdr)
pdr: ProtocolDAGResult = GufeTokenizable.from_dict(pdr)

tf_sk, _ = n4js.get_task_transformation(
task=task_scoped_key,
Expand All @@ -351,7 +352,11 @@ def set_task_result(
if protocoldagresultref.ok:
n4js.set_task_complete(tasks=[task_sk])
else:
n4js.add_protocol_dag_result_ref_tracebacks(
pdr.protocol_unit_failures, result_sk
)
n4js.set_task_error(tasks=[task_sk])
n4js.resolve_task_restarts(tasks=[task_sk])

return result_sk

Expand Down
94 changes: 94 additions & 0 deletions alchemiscale/interface/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,100 @@ def get_task_transformation(
return str(transformation)


@router.post("/networks/{network_scoped_key}/restartpatterns/add")
def add_task_restart_patterns(
network_scoped_key: str,
*,
patterns: list[str] = Body(embed=True),
num_allowed_restarts: int = Body(embed=True),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
sk = ScopedKey.from_str(network_scoped_key)
validate_scopes(sk.scope, token)

taskhub_scoped_key = n4js.get_taskhub(sk)
n4js.add_task_restart_patterns(taskhub_scoped_key, patterns, num_allowed_restarts)


@router.post("/networks/{network_scoped_key}/restartpatterns/remove")
def remove_task_restart_patterns(
network_scoped_key: str,
*,
patterns: list[str] = Body(embed=True),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
sk = ScopedKey.from_str(network_scoped_key)
validate_scopes(sk.scope, token)

taskhub_scoped_key = n4js.get_taskhub(sk)
n4js.remove_task_restart_patterns(taskhub_scoped_key, patterns)


@router.get("/networks/{network_scoped_key}/restartpatterns/clear")
def clear_task_restart_patterns(
network_scoped_key: str,
*,
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
sk = ScopedKey.from_str(network_scoped_key)
validate_scopes(sk.scope, token)

taskhub_scoped_key = n4js.get_taskhub(sk)
n4js.clear_task_restart_patterns(taskhub_scoped_key)
return [network_scoped_key]


@router.post("/bulk/networks/restartpatterns/get")
def get_task_restart_patterns(
*,
networks: list[str] = Body(embed=True),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
) -> dict[str, set[tuple[str, int]]]:

network_scoped_keys = [ScopedKey.from_str(network) for network in networks]
for sk in network_scoped_keys:
validate_scopes(sk.scope, token)

taskhub_scoped_keys = n4js.get_taskhubs(network_scoped_keys)

taskhub_network_map = {
taskhub_scoped_key: network_scoped_key
for taskhub_scoped_key, network_scoped_key in zip(
taskhub_scoped_keys, network_scoped_keys
)
}

restart_patterns = n4js.get_task_restart_patterns(taskhub_scoped_keys)

network_patterns = {
str(taskhub_network_map[key]): value for key, value in restart_patterns.items()
}

return network_patterns


@router.post("/networks/{network_scoped_key}/restartpatterns/maxretries")
def set_task_restart_patterns_max_retries(
network_scoped_key: str,
*,
patterns: list[str] = Body(embed=True),
max_retries: int = Body(embed=True),
n4js: Neo4jStore = Depends(get_n4js_depends),
token: TokenData = Depends(get_token_data_depends),
):
sk = ScopedKey.from_str(network_scoped_key)
validate_scopes(sk.scope, token)

taskhub_scoped_key = n4js.get_taskhub(sk)
n4js.set_task_restart_patterns_max_retries(
taskhub_scoped_key, patterns, max_retries
)


### results


Expand Down
110 changes: 109 additions & 1 deletion alchemiscale/interface/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,6 @@ def get_transformation_results(
visualize
If ``True``, show retrieval progress indicators.
"""

if not return_protocoldagresults:
Expand Down Expand Up @@ -1739,3 +1738,112 @@ def get_task_failures(
)

return pdrs

def add_task_restart_patterns(
self,
network_scoped_key: ScopedKey,
patterns: list[str],
num_allowed_restarts: int,
) -> ScopedKey:
"""Add a list of `Task` restart patterns to an `AlchemicalNetwork`.
Parameters
----------
network_scoped_key
The `ScopedKey` for the `AlchemicalNetwork` to add the patterns to.
patterns
The regular expression strings to compare to `ProtocolUnitFailure`
tracebacks. Matching patterns will set the `Task` status back to
'waiting'.
num_allowed_restarts
The number of times each pattern will be able to restart each
`Task`. When this number is exceeded, the `Task` is canceled from
the `AlchemicalNetwork` and left with the `error` status.
Returns
-------
network_scoped_key
The `ScopedKey` of the `AlchemicalNetwork` the patterns were added to.
"""
data = {"patterns": patterns, "num_allowed_restarts": num_allowed_restarts}
self._post_resource(f"/networks/{network_scoped_key}/restartpatterns/add", data)
return network_scoped_key

def get_task_restart_patterns(
self, network_scoped_key: ScopedKey
) -> dict[str, int]:
"""Get the `Task` restart patterns applied to an `AlchemicalNetwork`
along with the number of retries allowed for each pattern.
Parameters
----------
network_scoped_key
The `ScopedKey` of the `AlchemicalNetwork` to query.
Returns
-------
patterns
A dictionary whose keys are all of the patterns applied to the
`AlchemicalNetwork` and whose values are the number of retries each
pattern will allow.
"""
data = {"networks": [str(network_scoped_key)]}
mapped_patterns = self._post_resource(
"/bulk/networks/restartpatterns/get", data=data
)
network_patterns = mapped_patterns[str(network_scoped_key)]
patterns_with_retries = {pattern: retry for pattern, retry in network_patterns}
return patterns_with_retries

def set_task_restart_patterns_allowed_restarts(
self,
network_scoped_key: ScopedKey,
patterns: list[str],
num_allowed_restarts: int,
) -> None:
"""Set the number of `Task` restarts that patterns are allowed to
perform for the given `AlchemicalNetwork`.
Parameters
----------
network_scoped_key
The `ScopedKey` of the `AlchemicalNetwork` the `patterns` are
applied to.
patterns
The patterns to set the number of allowed restarts for.
num_allowed_restarts
The new number of allowed restarts.
"""
data = {"patterns": patterns, "max_retries": num_allowed_restarts}
self._post_resource(
f"/networks/{network_scoped_key}/restartpatterns/maxretries", data
)

def remove_task_restart_patterns(
self, network_scoped_key: ScopedKey, patterns: list[str]
) -> None:
"""Remove specific `Task` restart patterns from an `AlchemicalNetwork`.
Parameters
----------
network_scoped_key
The `ScopedKey` of the `AlchemicalNetwork` the `patterns` are
applied to.
patterns
The patterns to remove from the `AlchemicalNetwork`.
"""
data = {"patterns": patterns}
self._post_resource(
f"/networks/{network_scoped_key}/restartpatterns/remove", data
)

def clear_task_restart_patterns(self, network_scoped_key: ScopedKey) -> None:
"""Clear all restart patterns from an `AlchemicalNetwork`.
Parameters
----------
network_scoped_key
The `ScopedKey` of the `AlchemicalNetwork` to be cleared of restart
patterns.
"""
self._query_resource(f"/networks/{network_scoped_key}/restartpatterns/clear")
105 changes: 103 additions & 2 deletions alchemiscale/storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from copy import copy
from datetime import datetime
from enum import Enum
from typing import Union, Dict, Optional
from typing import Union, Optional, List
from uuid import uuid4
import hashlib


from pydantic import BaseModel, Field
from pydantic import BaseModel
from gufe.tokenization import GufeTokenizable, GufeKey

from ..models import ScopedKey, Scope
Expand Down Expand Up @@ -143,6 +143,107 @@ def _defaults(cls):
return super()._defaults()


class TaskRestartPattern(GufeTokenizable):
"""A pattern to compare returned Task tracebacks to.
Attributes
----------
pattern: str
A regular expression pattern that can match to returned tracebacks of errored Tasks.
max_retries: int
The number of times the pattern can trigger a restart for a Task.
taskhub_sk: str
The TaskHub the pattern is bound to. This is needed to properly set a unique Gufe key.
"""

pattern: str
max_retries: int
taskhub_sk: str

def __init__(
self, pattern: str, max_retries: int, taskhub_scoped_key: Union[str, ScopedKey]
):

if not isinstance(pattern, str) or pattern == "":
raise ValueError("`pattern` must be a non-empty string")

self.pattern = pattern

if not isinstance(max_retries, int) or max_retries <= 0:
raise ValueError("`max_retries` must have a positive integer value.")
self.max_retries = max_retries

self.taskhub_scoped_key = str(taskhub_scoped_key)

def _gufe_tokenize(self):
key_string = self.pattern + self.taskhub_scoped_key
return hashlib.md5(key_string.encode()).hexdigest()

@classmethod
def _defaults(cls):
return super()._defaults()

@classmethod
def _from_dict(cls, dct):
return cls(**dct)

def _to_dict(self):
return {
"pattern": self.pattern,
"max_retries": self.max_retries,
"taskhub_scoped_key": self.taskhub_scoped_key,
}


class Tracebacks(GufeTokenizable):
"""
Attributes
----------
tracebacks: list[str]
The tracebacks returned with the ProtocolUnitFailures.
source_keys: list[GufeKey]
The GufeKeys of the ProtocolUnits that failed.
failure_keys: list[GufeKey]
The GufeKeys of the ProtocolUnitFailures.
"""

def __init__(
self,
tracebacks: List[str],
source_keys: List[GufeKey],
failure_keys: List[GufeKey],
):
value_error = ValueError(
"`tracebacks` must be a non-empty list of non-empty string values"
)
if not isinstance(tracebacks, list) or tracebacks == []:
raise value_error

all_string_values = all([isinstance(value, str) for value in tracebacks])
if not all_string_values or "" in tracebacks:
raise value_error

# TODO: validate
self.tracebacks = tracebacks
self.source_keys = source_keys
self.failure_keys = failure_keys

@classmethod
def _defaults(cls):
return super()._defaults()

@classmethod
def _from_dict(cls, dct):
return cls(**dct)

def _to_dict(self):
return {
"tracebacks": self.tracebacks,
"source_keys": self.source_keys,
"failure_keys": self.failure_keys,
}


class TaskHub(GufeTokenizable):
"""
Expand Down
Loading

0 comments on commit e01e445

Please sign in to comment.