diff --git a/packages/examples/cvat/exchange-oracle/poetry.lock b/packages/examples/cvat/exchange-oracle/poetry.lock index 4c5324f873..d32f499341 100644 --- a/packages/examples/cvat/exchange-oracle/poetry.lock +++ b/packages/examples/cvat/exchange-oracle/poetry.lock @@ -2872,6 +2872,8 @@ files = [ {file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"}, {file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"}, {file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"}, + {file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"}, + {file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"}, {file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"}, {file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"}, {file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"}, @@ -3155,6 +3157,23 @@ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-mock" +version = "3.14.0" +description = "Thin-wrapper around the mock package for easier use with pytest" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-mock-3.14.0.tar.gz", hash = "sha256:2719255a1efeceadbc056d6bf3df3d1c5015530fb40cf347c0f9afac88410bd0"}, + {file = "pytest_mock-3.14.0-py3-none-any.whl", hash = "sha256:0b72c38033392a5f4621342fe11e9219ac11ec9d375f8e2a0c164539e0d70f6f"}, +] + +[package.dependencies] +pytest = ">=6.2.5" + +[package.extras] +dev = ["pre-commit", "pytest-asyncio", "tox"] + [[package]] name = "python-dateutil" version = "2.8.2" @@ -3594,30 +3613,50 @@ files = [ {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:b42169467c42b692c19cf539c38d4602069d8c1505e97b86387fcf7afb766e1d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-macosx_13_0_arm64.whl", hash = "sha256:07238db9cbdf8fc1e9de2489a4f68474e70dffcb32232db7c08fa61ca0c7c462"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:fff3573c2db359f091e1589c3d7c5fc2f86f5bdb6f24252c2d8e539d4e45f412"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-manylinux_2_24_aarch64.whl", hash = "sha256:aa2267c6a303eb483de8d02db2871afb5c5fc15618d894300b88958f729ad74f"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:840f0c7f194986a63d2c2465ca63af8ccbbc90ab1c6001b1978f05119b5e7334"}, + {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:024cfe1fc7c7f4e1aff4a81e718109e13409767e4f871443cbff3dba3578203d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win32.whl", hash = "sha256:c69212f63169ec1cfc9bb44723bf2917cbbd8f6191a00ef3410f5a7fe300722d"}, {file = "ruamel.yaml.clib-0.2.8-cp310-cp310-win_amd64.whl", hash = "sha256:cabddb8d8ead485e255fe80429f833172b4cadf99274db39abc080e068cbcc31"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:bef08cd86169d9eafb3ccb0a39edb11d8e25f3dae2b28f5c52fd997521133069"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:b16420e621d26fdfa949a8b4b47ade8810c56002f5389970db4ddda51dbff248"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_24_x86_64.whl", hash = "sha256:25c515e350e5b739842fc3228d662413ef28f295791af5e5110b543cf0b57d9b"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-manylinux_2_24_aarch64.whl", hash = "sha256:1707814f0d9791df063f8c19bb51b0d1278b8e9a2353abbb676c2f685dee6afe"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:46d378daaac94f454b3a0e3d8d78cafd78a026b1d71443f4966c696b48a6d899"}, + {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:09b055c05697b38ecacb7ac50bdab2240bfca1a0c4872b0fd309bb07dc9aa3a9"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win32.whl", hash = "sha256:53a300ed9cea38cf5a2a9b069058137c2ca1ce658a874b79baceb8f892f915a7"}, {file = "ruamel.yaml.clib-0.2.8-cp311-cp311-win_amd64.whl", hash = "sha256:c2a72e9109ea74e511e29032f3b670835f8a59bbdc9ce692c5b4ed91ccf1eedb"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ebc06178e8821efc9692ea7544aa5644217358490145629914d8020042c24aa1"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:edaef1c1200c4b4cb914583150dcaa3bc30e592e907c01117c08b13a07255ec2"}, {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d176b57452ab5b7028ac47e7b3cf644bcfdc8cacfecf7e71759f7f51a59e5c92"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-manylinux_2_24_aarch64.whl", hash = "sha256:1dc67314e7e1086c9fdf2680b7b6c2be1c0d8e3a8279f2e993ca2a7545fecf62"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3213ece08ea033eb159ac52ae052a4899b56ecc124bb80020d9bbceeb50258e9"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aab7fd643f71d7946f2ee58cc88c9b7bfc97debd71dcc93e03e2d174628e7e2d"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win32.whl", hash = "sha256:5c365d91c88390c8d0a8545df0b5857172824b1c604e867161e6b3d59a827eaa"}, + {file = "ruamel.yaml.clib-0.2.8-cp312-cp312-win_amd64.whl", hash = "sha256:1758ce7d8e1a29d23de54a16ae867abd370f01b5a69e1a3ba75223eaa3ca1a1b"}, {file = "ruamel.yaml.clib-0.2.8-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:a5aa27bad2bb83670b71683aae140a1f52b0857a2deff56ad3f6c13a017a26ed"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c58ecd827313af6864893e7af0a3bb85fd529f862b6adbefe14643947cfe2942"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-macosx_12_0_arm64.whl", hash = "sha256:f481f16baec5290e45aebdc2a5168ebc6d35189ae6fea7a58787613a25f6e875"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_24_aarch64.whl", hash = "sha256:77159f5d5b5c14f7c34073862a6b7d34944075d9f93e681638f6d753606c6ce6"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:7f67a1ee819dc4562d444bbafb135832b0b909f81cc90f7aa00260968c9ca1b3"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4ecbf9c3e19f9562c7fdd462e8d18dd902a47ca046a2e64dba80699f0b6c09b7"}, + {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:87ea5ff66d8064301a154b3933ae406b0863402a799b16e4a1d24d9fbbcbe0d3"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win32.whl", hash = "sha256:75e1ed13e1f9de23c5607fe6bd1aeaae21e523b32d83bb33918245361e9cc51b"}, {file = "ruamel.yaml.clib-0.2.8-cp37-cp37m-win_amd64.whl", hash = "sha256:3f215c5daf6a9d7bbed4a0a4f760f3113b10e82ff4c5c44bec20a68c8014f675"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1b617618914cb00bf5c34d4357c37aa15183fa229b24767259657746c9077615"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:a6a9ffd280b71ad062eae53ac1659ad86a17f59a0fdc7699fd9be40525153337"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_24_aarch64.whl", hash = "sha256:305889baa4043a09e5b76f8e2a51d4ffba44259f6b4c72dec8ca56207d9c6fe1"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:700e4ebb569e59e16a976857c8798aee258dceac7c7d6b50cab63e080058df91"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:e2b4c44b60eadec492926a7270abb100ef9f72798e18743939bdbf037aab8c28"}, + {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:e79e5db08739731b0ce4850bed599235d601701d5694c36570a99a0c5ca41a9d"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win32.whl", hash = "sha256:955eae71ac26c1ab35924203fda6220f84dce57d6d7884f189743e2abe3a9fbe"}, {file = "ruamel.yaml.clib-0.2.8-cp38-cp38-win_amd64.whl", hash = "sha256:56f4252222c067b4ce51ae12cbac231bce32aee1d33fbfc9d17e5b8d6966c312"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:03d1162b6d1df1caa3a4bd27aa51ce17c9afc2046c31b0ad60a0a96ec22f8001"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:bba64af9fa9cebe325a62fa398760f5c7206b215201b0ec825005f1b18b9bccf"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_24_aarch64.whl", hash = "sha256:a1a45e0bb052edf6a1d3a93baef85319733a888363938e1fc9924cb00c8df24c"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:da09ad1c359a728e112d60116f626cc9f29730ff3e0e7db72b9a2dbc2e4beed5"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:184565012b60405d93838167f425713180b949e9d8dd0bbc7b49f074407c5a8b"}, + {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a75879bacf2c987c003368cf14bed0ffe99e8e85acfa6c0bfffc21a090f16880"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win32.whl", hash = "sha256:84b554931e932c46f94ab306913ad7e11bba988104c5cff26d90d03f68258cd5"}, {file = "ruamel.yaml.clib-0.2.8-cp39-cp39-win_amd64.whl", hash = "sha256:25ac8c08322002b06fa1d49d1646181f0b2c72f5cbc15a85e80b4c30a544bb15"}, {file = "ruamel.yaml.clib-0.2.8.tar.gz", hash = "sha256:beb2e0404003de9a4cab9753a8805a8fe9320ee6673136ed7f04255fe60bb512"}, @@ -4314,4 +4353,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.10,<3.13" -content-hash = "b712896e193904e263c976c5caea68fe5e47bb16ec5c63f867a40d965acced1e" +content-hash = "b4cacb79b88787c2788df663fba5d516cd8d2c5f24e10e5e9fce4ddd818d377e" diff --git a/packages/examples/cvat/exchange-oracle/pyproject.toml b/packages/examples/cvat/exchange-oracle/pyproject.toml index 9f5cd3631b..ea06c6e5f3 100644 --- a/packages/examples/cvat/exchange-oracle/pyproject.toml +++ b/packages/examples/cvat/exchange-oracle/pyproject.toml @@ -31,6 +31,7 @@ hexbytes = ">=1.2.0" # required for to_0x_hex() function [tool.poetry.group.dev.dependencies] pre-commit = "^3.0.4" ruff = "^0.6.0" +pytest-mock = "^3.14.0" [tool.ruff] line-length = 100 @@ -99,7 +100,6 @@ ignore = [ "ERA001", # Found commented-out code "N801", # Class name should use CapWords convention "PLR0915", # Too many statements - "F401", # Imported but unused "PLR2004", # Magic value used in comparison, consider replacing with a constant variable "ANN002", # Missing type annotation for `*args` "TRY300", # Consider moving this statement to an `else` block @@ -130,6 +130,7 @@ ignore = [ ] # alembic is not a package in a traditional sense, so putting __init__.py there doesn't make sense "alembic/*" = ["INP001"] +"__init__.py" = ["F401"] [tool.ruff.lint.pep8-naming] classmethod-decorators = [ diff --git a/packages/examples/cvat/exchange-oracle/src/.env.template b/packages/examples/cvat/exchange-oracle/src/.env.template index f8fc892092..2432a56cd0 100644 --- a/packages/examples/cvat/exchange-oracle/src/.env.template +++ b/packages/examples/cvat/exchange-oracle/src/.env.template @@ -32,6 +32,8 @@ PROCESS_JOB_LAUNCHER_WEBHOOKS_INT= PROCESS_JOB_LAUNCHER_WEBHOOKS_CHUNK_SIZE= PROCESS_RECORDING_ORACLE_WEBHOOKS_INT= PROCESS_RECORDING_ORACLE_WEBHOOKS_CHUNK_SIZE= +PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE= +PROCESS_REPUTATION_ORACLE_WEBHOOKS_INT= TRACK_COMPLETED_PROJECTS_INT= TRACK_COMPLETED_PROJECTS_CHUNK_SIZE= TRACK_COMPLETED_TASKS_INT= @@ -90,6 +92,7 @@ HUMAN_APP_SIGNATURE= LOCALHOST_RECORDING_ORACLE_ADDRESS= LOCALHOST_RECORDING_ORACLE_URL= LOCALHOST_JOB_LAUNCHER_URL= +LOCALHOST_REPUTATION_ORACLE_URL= # Encryption PGP_PRIVATE_KEY= diff --git a/packages/examples/cvat/exchange-oracle/src/chain/escrow.py b/packages/examples/cvat/exchange-oracle/src/chain/escrow.py index f8ccd86338..ee408f595d 100644 --- a/packages/examples/cvat/exchange-oracle/src/chain/escrow.py +++ b/packages/examples/cvat/exchange-oracle/src/chain/escrow.py @@ -6,6 +6,7 @@ from human_protocol_sdk.storage import StorageUtils from src.core.config import Config +from src.core.types import OracleWebhookTypes def get_escrow(chain_id: int, escrow_address: str) -> EscrowData: @@ -56,12 +57,16 @@ def get_escrow_manifest(chain_id: int, escrow_address: str) -> dict: return json.loads(manifest_content) -def get_job_launcher_address(chain_id: int, escrow_address: str) -> str: - return get_escrow(chain_id, escrow_address).launcher - - -def get_recording_oracle_address(chain_id: int, escrow_address: str) -> str: - if address := Config.localhost.recording_oracle_address: - return address - - return get_escrow(chain_id, escrow_address).recording_oracle +def get_available_webhook_types( + chain_id: int, escrow_address: str +) -> dict[str, OracleWebhookTypes]: + escrow = get_escrow(chain_id, escrow_address) + return { + escrow.launcher.lower(): OracleWebhookTypes.job_launcher, + ( + Config.localhost.recording_oracle_address or escrow.recording_oracle + ).lower(): OracleWebhookTypes.recording_oracle, + ( + Config.localhost.reputation_oracle_url or escrow.reputation_oracle + ).lower(): OracleWebhookTypes.reputation_oracle, + } diff --git a/packages/examples/cvat/exchange-oracle/src/chain/kvstore.py b/packages/examples/cvat/exchange-oracle/src/chain/kvstore.py index 1a71d0bc2b..f1ed6f935c 100644 --- a/packages/examples/cvat/exchange-oracle/src/chain/kvstore.py +++ b/packages/examples/cvat/exchange-oracle/src/chain/kvstore.py @@ -16,6 +16,15 @@ def get_recording_oracle_url(chain_id: int, escrow_address: str) -> str: return OperatorUtils.get_leader(ChainId(chain_id), escrow.recording_oracle).webhook_url +def get_reputation_oracle_url(chain_id: int, escrow_address: str) -> str: + if url := Config.localhost.recording_oracle_url: + return url + + escrow = get_escrow(chain_id, escrow_address) + + return OperatorUtils.get_leader(ChainId(chain_id), escrow.recording_oracle).webhook_url + + def get_job_launcher_url(chain_id: int, escrow_address: str) -> str: if url := Config.localhost.job_launcher_url: return url diff --git a/packages/examples/cvat/exchange-oracle/src/core/config.py b/packages/examples/cvat/exchange-oracle/src/core/config.py index 7a6481574b..cd271f5208 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/config.py +++ b/packages/examples/cvat/exchange-oracle/src/core/config.py @@ -83,18 +83,25 @@ class LocalhostConfig(_NetworkConfig): recording_oracle_address = os.environ.get("LOCALHOST_RECORDING_ORACLE_ADDRESS") recording_oracle_url = os.environ.get("LOCALHOST_RECORDING_ORACLE_URL") + reputation_oracle_url = os.environ.get("LOCALHOST_REPUTATION_ORACLE_URL") class CronConfig: process_job_launcher_webhooks_int = int(os.environ.get("PROCESS_JOB_LAUNCHER_WEBHOOKS_INT", 30)) - process_job_launcher_webhooks_chunk_size = os.environ.get( - "PROCESS_JOB_LAUNCHER_WEBHOOKS_CHUNK_SIZE", 5 + process_job_launcher_webhooks_chunk_size = int( + os.environ.get("PROCESS_JOB_LAUNCHER_WEBHOOKS_CHUNK_SIZE", 5) ) process_recording_oracle_webhooks_int = int( os.environ.get("PROCESS_RECORDING_ORACLE_WEBHOOKS_INT", 30) ) - process_recording_oracle_webhooks_chunk_size = os.environ.get( - "PROCESS_RECORDING_ORACLE_WEBHOOKS_CHUNK_SIZE", 5 + process_recording_oracle_webhooks_chunk_size = int( + os.environ.get("PROCESS_RECORDING_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) + ) + process_reputation_oracle_webhooks_chunk_size = int( + os.environ.get("PROCESS_REPUTATION_ORACLE_WEBHOOKS_CHUNK_SIZE", 5) + ) + process_reputation_oracle_webhooks_int = int( + os.environ.get("PROCESS_REPUTATION_ORACLE_WEBHOOKS_INT", 5) ) track_completed_projects_int = int(os.environ.get("TRACK_COMPLETED_PROJECTS_INT", 30)) track_completed_projects_chunk_size = os.environ.get("TRACK_COMPLETED_PROJECTS_CHUNK_SIZE", 5) diff --git a/packages/examples/cvat/exchange-oracle/src/core/oracle_events.py b/packages/examples/cvat/exchange-oracle/src/core/oracle_events.py index 6188d598f0..1f1a790feb 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/oracle_events.py +++ b/packages/examples/cvat/exchange-oracle/src/core/oracle_events.py @@ -1,5 +1,3 @@ -from typing import Union - from pydantic import BaseModel from src.core.types import ( @@ -7,6 +5,7 @@ JobLauncherEventTypes, OracleWebhookTypes, RecordingOracleEventTypes, + ReputationOracleEventTypes, ) EventTypeTag = ExchangeOracleEventTypes | JobLauncherEventTypes | RecordingOracleEventTypes @@ -44,6 +43,14 @@ class ExchangeOracleEvent_TaskFinished(OracleEvent): pass # escrow is enough for now +class ExchangeOracleEvent_EscrowCleaned(OracleEvent): + pass + + +class ReputationOracleEvent_EscrowCompleted(OracleEvent): + pass + + _event_type_map = { JobLauncherEventTypes.escrow_created: JobLauncherEvent_EscrowCreated, JobLauncherEventTypes.escrow_canceled: JobLauncherEvent_EscrowCanceled, @@ -51,6 +58,8 @@ class ExchangeOracleEvent_TaskFinished(OracleEvent): RecordingOracleEventTypes.task_rejected: RecordingOracleEvent_TaskRejected, ExchangeOracleEventTypes.task_creation_failed: ExchangeOracleEvent_TaskCreationFailed, ExchangeOracleEventTypes.task_finished: ExchangeOracleEvent_TaskFinished, + ExchangeOracleEventTypes.escrow_cleaned: ExchangeOracleEvent_EscrowCleaned, + ReputationOracleEventTypes.escrow_completed: ReputationOracleEvent_EscrowCompleted, } @@ -83,6 +92,7 @@ def parse_event( OracleWebhookTypes.job_launcher: JobLauncherEventTypes, OracleWebhookTypes.recording_oracle: RecordingOracleEventTypes, OracleWebhookTypes.exchange_oracle: ExchangeOracleEventTypes, + OracleWebhookTypes.reputation_oracle: ReputationOracleEventTypes, } sender_events = sender_events_mapping.get(sender) diff --git a/packages/examples/cvat/exchange-oracle/src/core/storage.py b/packages/examples/cvat/exchange-oracle/src/core/storage.py index b934b865c0..223e213b76 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/storage.py +++ b/packages/examples/cvat/exchange-oracle/src/core/storage.py @@ -2,9 +2,17 @@ from src.core.types import Networks +def compose_data_bucket_prefix(escrow_address: str, chain_id: Networks): + return f"{escrow_address}@{chain_id}" + + +def compose_results_bucket_prefix(escrow_address: str, chain_id: Networks): + return f"{escrow_address}@{chain_id}{Config.storage_config.results_dir_suffix}" + + def compose_data_bucket_filename(escrow_address: str, chain_id: Networks, filename: str) -> str: - return f"{escrow_address}@{chain_id}/{filename}" + return f"{compose_data_bucket_prefix(escrow_address, chain_id)}/{filename}" def compose_results_bucket_filename(escrow_address: str, chain_id: Networks, filename: str) -> str: - return f"{escrow_address}@{chain_id}{Config.storage_config.results_dir_suffix}/{filename}" + return f"{compose_results_bucket_prefix(escrow_address, chain_id)}/{filename}" diff --git a/packages/examples/cvat/exchange-oracle/src/core/types.py b/packages/examples/cvat/exchange-oracle/src/core/types.py index 3e4869c7a8..6f29a98cd5 100644 --- a/packages/examples/cvat/exchange-oracle/src/core/types.py +++ b/packages/examples/cvat/exchange-oracle/src/core/types.py @@ -23,6 +23,7 @@ class ProjectStatuses(str, Enum, metaclass=BetterEnumMeta): validation = "validation" canceled = "canceled" recorded = "recorded" + deleted = "deleted" class TaskStatuses(str, Enum, metaclass=BetterEnumMeta): @@ -55,11 +56,13 @@ class OracleWebhookTypes(str, Enum, metaclass=BetterEnumMeta): exchange_oracle = "exchange_oracle" job_launcher = "job_launcher" recording_oracle = "recording_oracle" + reputation_oracle = "reputation_oracle" class ExchangeOracleEventTypes(str, Enum, metaclass=BetterEnumMeta): task_creation_failed = "task_creation_failed" task_finished = "task_finished" + escrow_cleaned = "escrow_cleaned" class JobLauncherEventTypes(str, Enum, metaclass=BetterEnumMeta): @@ -72,6 +75,11 @@ class RecordingOracleEventTypes(str, Enum, metaclass=BetterEnumMeta): task_rejected = "task_rejected" +class ReputationOracleEventTypes(str, Enum, metaclass=BetterEnumMeta): + # TODO: rename to ReputationOracleEventType + escrow_completed = "escrow_completed" + + class OracleWebhookStatuses(str, Enum, metaclass=BetterEnumMeta): pending = "pending" completed = "completed" diff --git a/packages/examples/cvat/exchange-oracle/src/crons/__init__.py b/packages/examples/cvat/exchange-oracle/src/crons/__init__.py index a3ec8def0b..f2581784ea 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/__init__.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/__init__.py @@ -2,15 +2,7 @@ from fastapi import FastAPI from src.core.config import Config -from src.crons.process_job_launcher_webhooks import ( - process_incoming_job_launcher_webhooks, - process_outgoing_job_launcher_webhooks, -) -from src.crons.process_recording_oracle_webhooks import ( - process_incoming_recording_oracle_webhooks, - process_outgoing_recording_oracle_webhooks, -) -from src.crons.state_trackers import ( +from src.crons.cvat.state_trackers import ( track_assignments, track_completed_escrows, track_completed_projects, @@ -18,9 +10,18 @@ track_escrow_creation, track_task_creation, ) +from src.crons.webhooks.job_launcher import ( + process_incoming_job_launcher_webhooks, + process_outgoing_job_launcher_webhooks, +) +from src.crons.webhooks.recording_oracle import ( + process_incoming_recording_oracle_webhooks, + process_outgoing_recording_oracle_webhooks, +) +from src.crons.webhooks.reputation_oracle import process_incoming_reputation_oracle_webhooks -def setup_cron_jobs(app: FastAPI): +def setup_cron_jobs(app: FastAPI) -> None: @app.on_event("startup") def cron_record(): scheduler = BackgroundScheduler() @@ -44,6 +45,11 @@ def cron_record(): "interval", seconds=Config.cron_config.process_recording_oracle_webhooks_int, ) + scheduler.add_job( + process_incoming_reputation_oracle_webhooks, + "interval", + seconds=Config.cron_config.process_reputation_oracle_webhooks_int, + ) scheduler.add_job( track_completed_projects, "interval", diff --git a/packages/examples/cvat/exchange-oracle/src/crons/_cron_job.py b/packages/examples/cvat/exchange-oracle/src/crons/_cron_job.py new file mode 100644 index 0000000000..951580fcdc --- /dev/null +++ b/packages/examples/cvat/exchange-oracle/src/crons/_cron_job.py @@ -0,0 +1,71 @@ +import inspect +import logging +from collections.abc import Callable +from functools import wraps +from typing import NamedTuple + +from sqlalchemy.orm import Session + +from src.db import SessionLocal +from src.log import get_logger_name + + +class CronSpec(NamedTuple): + manage_session: bool + repr: str + + +def _validate_cron_function_signature(fn: Callable[..., None]) -> CronSpec: + cron_repr = repr(fn.__name__) + parameters = dict(inspect.signature(fn).parameters) + + session_param = parameters.pop("session", None) + if session_param is not None and session_param.annotation is not Session: + raise TypeError(f"{cron_repr} session argument type must be of type {Session.__qualname__}") + + logger_param = parameters.pop("logger", None) + if logger_param is None or logger_param.annotation is not logging.Logger: + raise TypeError(f"{cron_repr} must have logger argument with type of {logging.Logger}") + + if parameters: + raise TypeError( + f"{cron_repr} expected to have only have logger and session arguments," + f" not {set(parameters.keys())}" + ) + + return CronSpec(manage_session=session_param is not None, repr=cron_repr) + + +def cron_job(fn: Callable[..., None]) -> Callable[[], None]: + """ + Wrapper that supplies logger and optionally session to the cron job. + + Example usage: + >>> @cron_job + >>> def handle_webhook(logger: logging.Logger) -> None: + >>> ... + Example usage with session: + >>> @cron_job + >>> def handle_webhook(logger: logging.Logger, session: Session) -> None: + >>> ... + + Returns: + Cron job ready to be registered in scheduler. + """ + logger = logging.getLogger(get_logger_name(f"{fn.__module__}.{fn.__name__}")) + cron_spec = _validate_cron_function_signature(fn) + + @wraps(fn) + def wrapper(): + logger.debug(f"Cron {cron_spec.repr} is starting") + try: + if not cron_spec.manage_session: + return fn(logger) + with SessionLocal.begin() as session: + return fn(logger, session) + except Exception: + logger.exception(f"Exception while running {cron_spec.repr} cron") + finally: + logger.debug(f"Cron {cron_spec.repr} finished") + + return wrapper diff --git a/packages/examples/cvat/exchange-oracle/src/crons/cvat/__init__.py b/packages/examples/cvat/exchange-oracle/src/crons/cvat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py b/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py new file mode 100644 index 0000000000..14064a483e --- /dev/null +++ b/packages/examples/cvat/exchange-oracle/src/crons/cvat/state_trackers.py @@ -0,0 +1,294 @@ +import logging + +from sqlalchemy import exc as sa_errors +from sqlalchemy.orm import Session + +import src.cvat.api_calls as cvat_api +import src.models.cvat as cvat_models +import src.services.cvat as cvat_service +import src.services.webhook as oracle_db_service +from src.core.config import CronConfig +from src.core.oracle_events import ExchangeOracleEvent_TaskCreationFailed +from src.core.types import JobStatuses, OracleWebhookTypes, ProjectStatuses, TaskStatuses +from src.crons._cron_job import cron_job +from src.db import SessionLocal +from src.db import errors as db_errors +from src.db.utils import ForUpdateParams +from src.handlers.completed_escrows import handle_completed_escrows + + +@cron_job +def track_completed_projects(logger: logging.Logger, session: Session) -> None: + """ + Tracks completed projects: + 1. Retrieves projects with "annotation" status + 2. Retrieves tasks related to this project + 3. If all tasks are completed -> updates project status to "completed" + """ + projects = cvat_service.get_projects_by_status( + session, + ProjectStatuses.annotation, + task_status=TaskStatuses.completed, + limit=CronConfig.track_completed_projects_chunk_size, + for_update=ForUpdateParams(skip_locked=True), + ) + + completed_project_ids = [] + + for project in projects: + tasks = cvat_service.get_tasks_by_cvat_project_id(session, project.cvat_id) + if tasks and all(task.status == TaskStatuses.completed for task in tasks): + cvat_service.update_project_status(session, project.id, ProjectStatuses.completed) + + completed_project_ids.append(project.cvat_id) + + if completed_project_ids: + logger.info( + "Found new completed projects: {}".format( + ", ".join(str(t) for t in completed_project_ids) + ) + ) + + +@cron_job +def track_completed_tasks(logger: logging.Logger, session: Session) -> None: + """ + Tracks completed tasks: + 1. Retrieves tasks with "annotation" status + 2. Retrieves jobs related to this task + 3. If all jobs are completed -> updates task status to "completed" + """ + tasks = cvat_service.get_tasks_by_status( + session, + TaskStatuses.annotation, + job_status=JobStatuses.completed, + project_status=ProjectStatuses.annotation, + limit=CronConfig.track_completed_tasks_chunk_size, + for_update=ForUpdateParams(skip_locked=True), + ) + + completed_task_ids = [] + + for task in tasks: + jobs = cvat_service.get_jobs_by_cvat_task_id(session, task.cvat_id) + if jobs and all(job.status == JobStatuses.completed for job in jobs): + cvat_service.update_task_status(session, task.id, TaskStatuses.completed) + + completed_task_ids.append(task.cvat_id) + + if completed_task_ids: + logger.info( + "Found new completed tasks: {}".format(", ".join(str(t) for t in completed_task_ids)) + ) + + +@cron_job +def track_assignments(logger: logging.Logger) -> None: + """ + Tracks assignments: + 1. Checks time for each active assignment + 2. If an assignment is timed out, expires it + 3. If a project or task state is not "annotation", cancels assignments + """ + with SessionLocal.begin() as session: + assignments = cvat_service.get_unprocessed_expired_assignments( + session, + limit=CronConfig.track_assignments_chunk_size, + for_update=ForUpdateParams(skip_locked=True), + ) + + for assignment in assignments: + logger.info( + "Expiring the unfinished assignment {} (user {}, job id {})".format( + assignment.id, + assignment.user_wallet_address, + assignment.cvat_job_id, + ) + ) + + latest_assignment = cvat_service.get_latest_assignment_by_cvat_job_id( + session, assignment.cvat_job_id + ) + if latest_assignment.id == assignment.id: + # Avoid un-assigning if it's not the latest assignment + + cvat_api.update_job_assignee( + assignment.cvat_job_id, assignee_id=None + ) # note that calling it in a loop can take too much time + + cvat_service.expire_assignment(session, assignment.id) + + with SessionLocal.begin() as session: + assignments = cvat_service.get_active_assignments( + session, + limit=CronConfig.track_assignments_chunk_size, + for_update=ForUpdateParams(skip_locked=True), + ) + + for assignment in assignments: + if assignment.job.project.status != ProjectStatuses.annotation: + logger.warning( + "Canceling the unfinished assignment {} (user {}, job id {}) - " + "the project state is not annotation".format( + assignment.id, + assignment.user_wallet_address, + assignment.cvat_job_id, + ) + ) + + latest_assignment = cvat_service.get_latest_assignment_by_cvat_job_id( + session, assignment.cvat_job_id + ) + if latest_assignment.id == assignment.id: + # Avoid un-assigning if it's not the latest assignment + + cvat_api.update_job_assignee( + assignment.cvat_job_id, assignee_id=None + ) # note that calling it in a loop can take too much time + + cvat_service.cancel_assignment(session, assignment.id) + + +@cron_job +def track_completed_escrows(logger: logging.Logger) -> None: + handle_completed_escrows(logger) + + +@cron_job +def track_task_creation(logger: logging.Logger, session: Session) -> None: + """ + Checks task creation status to report failed tasks and continue task creation process. + """ + + # TODO: maybe add load balancing (e.g. round-robin queue, shuffling) + # to avoid blocking new tasks + uploads = cvat_service.get_active_task_uploads( + session, + limit=CronConfig.track_creating_tasks_chunk_size, + for_update=ForUpdateParams(skip_locked=True), + ) + + if not uploads: + return + + logger.debug( + "Checking the data uploading status of CVAT tasks: {}".format( + ", ".join(str(u.task_id) for u in uploads) + ) + ) + + completed: list[cvat_models.DataUpload] = [] + failed: list[cvat_models.DataUpload] = [] + for upload in uploads: + status, reason = cvat_api.get_task_upload_status(upload.task_id) + project = upload.task.project + if not status or status == cvat_api.UploadStatus.FAILED: + # TODO: add retries if 5xx + failed.append(upload) + + oracle_db_service.outbox.create_webhook( + session, + escrow_address=project.escrow_address, + chain_id=project.chain_id, + type=OracleWebhookTypes.job_launcher, + event=ExchangeOracleEvent_TaskCreationFailed(reason=reason), + ) + elif status == cvat_api.UploadStatus.FINISHED: + try: + cvat_jobs = cvat_api.fetch_task_jobs(upload.task_id) + + existing_jobs = cvat_service.get_jobs_by_cvat_task_id(session, upload.task_id) + existing_job_ids = set(j.cvat_id for j in existing_jobs) + + for cvat_job in cvat_jobs: + if cvat_job.id in existing_job_ids: + continue + + cvat_service.create_job( + session, + cvat_job.id, + upload.task_id, + upload.task.cvat_project_id, + status=JobStatuses(cvat_job.state), + ) + + completed.append(upload) + except cvat_api.exceptions.ApiException as e: + failed.append(upload) + + oracle_db_service.outbox.create_webhook( + session, + escrow_address=project.escrow_address, + chain_id=project.chain_id, + type=OracleWebhookTypes.job_launcher, + event=ExchangeOracleEvent_TaskCreationFailed(reason=str(e)), + ) + + if completed or failed: + cvat_service.finish_data_uploads(session, failed + completed) + + logger.info( + "Updated creation status of CVAT tasks: {}".format( + "; ".join( + f"{k}: {v}" + for k, v in { + "success": ", ".join(str(u.task_id) for u in completed), + "failed": ", ".join(str(u.task_id) for u in failed), + }.items() + if v + ) + ) + ) + + +@cron_job +def track_escrow_creation(logger: logging.Logger, session: Session) -> None: + creations = cvat_service.get_active_escrow_creations( + session, + limit=CronConfig.track_escrow_creation_chunk_size, + for_update=ForUpdateParams(skip_locked=True), + ) + + if not creations: + return + + logger.debug( + "Checking escrow creation statuses for escrows: {}".format( + ", ".join(str(c.escrow_address) for c in creations) + ) + ) + + finished: list[cvat_models.EscrowCreation] = [] + for creation in creations: + created_jobs_count = cvat_service.count_jobs_by_escrow_address( + session, + escrow_address=creation.escrow_address, + chain_id=creation.chain_id, + status=JobStatuses.new, + ) + + if created_jobs_count != creation.total_jobs: + continue + + with session.begin_nested(): + try: + cvat_service.update_project_statuses_by_escrow_address( + session=session, + escrow_address=creation.escrow_address, + chain_id=creation.chain_id, + status=ProjectStatuses.annotation, + ) + finished.append(creation) + except sa_errors.OperationalError as e: + if isinstance(e.orig, db_errors.LockNotAvailable): + continue + raise + + if finished: + cvat_service.finish_escrow_creations(session, finished) + + logger.info( + "Updated creation status of escrows: {}".format( + ", ".join(c.escrow_address for c in finished) + ) + ) diff --git a/packages/examples/cvat/exchange-oracle/src/crons/process_job_launcher_webhooks.py b/packages/examples/cvat/exchange-oracle/src/crons/process_job_launcher_webhooks.py deleted file mode 100644 index 2682f6ff95..0000000000 --- a/packages/examples/cvat/exchange-oracle/src/crons/process_job_launcher_webhooks.py +++ /dev/null @@ -1,211 +0,0 @@ -import logging - -import httpx -from human_protocol_sdk.constants import Status as EscrowStatus -from sqlalchemy.orm import Session - -import src.handlers.job_creation as cvat -import src.services.cvat as cvat_db_service -import src.services.webhook as oracle_db_service -from src.chain.escrow import validate_escrow -from src.chain.kvstore import get_job_launcher_url -from src.core.config import Config, CronConfig -from src.core.oracle_events import ExchangeOracleEvent_TaskCreationFailed -from src.core.types import JobLauncherEventTypes, OracleWebhookTypes, ProjectStatuses -from src.db import SessionLocal -from src.db.utils import ForUpdateParams -from src.log import ROOT_LOGGER_NAME -from src.models.webhook import Webhook -from src.utils.logging import get_function_logger -from src.utils.webhooks import prepare_outgoing_webhook_body, prepare_signed_message - -module_logger_name = f"{ROOT_LOGGER_NAME}.cron.webhook" - - -def process_incoming_job_launcher_webhooks(): - """ - Process incoming job launcher webhooks - """ - logger = get_function_logger(module_logger_name) - - try: - logger.debug("Starting cron job") - - with SessionLocal.begin() as session: - webhooks = oracle_db_service.inbox.get_pending_webhooks( - session, - OracleWebhookTypes.job_launcher, - limit=CronConfig.process_job_launcher_webhooks_chunk_size, - for_update=ForUpdateParams(skip_locked=True), - ) - - for webhook in webhooks: - try: - logger.debug( - "Processing webhook " - f"{webhook.type}.{webhook.event_type}~{webhook.signature} " - f"in escrow_address={webhook.escrow_address} " - f"(attempt {webhook.attempts + 1})" - ) - - handle_job_launcher_event(webhook, db_session=session, logger=logger) - - oracle_db_service.inbox.handle_webhook_success(session, webhook.id) - - logger.debug("Webhook handled successfully") - except Exception as e: - logger.exception(f"Webhook {webhook.id} handling failed: {e}") - oracle_db_service.inbox.handle_webhook_fail(session, webhook.id) - except Exception as e: - logger.exception(e) - finally: - logger.debug("Finishing cron job") - - -def handle_job_launcher_event(webhook: Webhook, *, db_session: Session, logger: logging.Logger): - assert webhook.type == OracleWebhookTypes.job_launcher - - match webhook.event_type: - case JobLauncherEventTypes.escrow_created: - try: - validate_escrow( - webhook.chain_id, - webhook.escrow_address, - allow_no_funds=True, - ) - - if cvat_db_service.get_project_by_escrow_address( - db_session, webhook.escrow_address, for_update=True - ): - logger.error( - f"Received an escrow creation event for " - f"escrow_address {webhook.escrow_address}. " - "A CVAT project for this escrow already exists, ignoring the event." - ) - return - - logger.info( - f"Creating a new CVAT project (escrow_address={webhook.escrow_address})" - ) - - cvat.create_task(webhook.escrow_address, webhook.chain_id) - - except Exception as ex: - try: - cvat.remove_task(webhook.escrow_address) - except Exception as ex_remove: - logger.exception(ex_remove) - - if webhook.attempts + 1 >= Config.webhook_max_retries: - # We should not notify before the webhook handling attempts have expired - oracle_db_service.outbox.create_webhook( - session=db_session, - escrow_address=webhook.escrow_address, - chain_id=webhook.chain_id, - type=OracleWebhookTypes.job_launcher, - event=ExchangeOracleEvent_TaskCreationFailed(reason=str(ex)), - ) - - raise - - case JobLauncherEventTypes.escrow_canceled: - validate_escrow( - webhook.chain_id, - webhook.escrow_address, - accepted_states=[EscrowStatus.Pending, EscrowStatus.Cancelled], - ) - - projects = cvat_db_service.get_projects_by_escrow_address( - db_session, webhook.escrow_address, for_update=True, limit=None - ) - if not projects: - logger.error( - "Received escrow cancel event " - f"(escrow_address={webhook.escrow_address}). " - "The project doesn't exist, ignoring" - ) - return - - for project in projects: - if project.status in [ - ProjectStatuses.canceled, - ProjectStatuses.recorded, - ]: - logger.error( - "Received escrow cancel event " - f"(escrow_address={webhook.escrow_address}). " - "The project is already finished, ignoring" - ) - continue - - logger.info( - f"Received escrow cancel event (escrow_address={webhook.escrow_address}). " - "Canceling the project" - ) - cvat_db_service.update_project_status( - db_session, project.id, ProjectStatuses.canceled - ) - - cvat_db_service.finish_escrow_creations_by_escrow_address( - db_session, escrow_address=webhook.escrow_address, chain_id=webhook.chain_id - ) - case _: - raise AssertionError(f"Unknown job launcher event {webhook.event_type}") - - -def process_outgoing_job_launcher_webhooks(): - """ - Process webhooks that needs to be sent to recording oracle: - * Retrieves `webhook_url` from KVStore - * Sends webhook to recording oracle - """ - logger = get_function_logger(module_logger_name) - - try: - logger.debug("Starting cron job") - - with SessionLocal.begin() as session: - webhooks = oracle_db_service.outbox.get_pending_webhooks( - session, - OracleWebhookTypes.job_launcher, - limit=CronConfig.process_job_launcher_webhooks_chunk_size, - for_update=ForUpdateParams(skip_locked=True), - ) - for webhook in webhooks: - try: - logger.debug( - "Processing webhook " - f"{webhook.type}.{webhook.event_type} " - f"in escrow_address={webhook.escrow_address} " - f"(attempt {webhook.attempts + 1})" - ) - - body = prepare_outgoing_webhook_body( - webhook.escrow_address, - webhook.chain_id, - webhook.event_type, - webhook.event_data, - timestamp=None, # TODO: launcher doesn't support it yet - ) - - _, signature = prepare_signed_message( - webhook.escrow_address, - webhook.chain_id, - body=body, - ) - - headers = {"human-signature": signature} - webhook_url = get_job_launcher_url(webhook.chain_id, webhook.escrow_address) - with httpx.Client() as client: - response = client.post(webhook_url, headers=headers, json=body) - response.raise_for_status() - - oracle_db_service.outbox.handle_webhook_success(session, webhook.id) - logger.debug("Webhook handled successfully") - except Exception as e: - logger.exception(f"Webhook {webhook.id} sending failed: {e}") - oracle_db_service.outbox.handle_webhook_fail(session, webhook.id) - except Exception as e: - logger.exception(e) - finally: - logger.debug("Finishing cron job") diff --git a/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py b/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py deleted file mode 100644 index 868569b522..0000000000 --- a/packages/examples/cvat/exchange-oracle/src/crons/state_trackers.py +++ /dev/null @@ -1,351 +0,0 @@ -from sqlalchemy import exc as sa_errors - -import src.cvat.api_calls as cvat_api -import src.models.cvat as cvat_models -import src.services.cvat as cvat_service -import src.services.webhook as oracle_db_service -from src.core.config import CronConfig -from src.core.oracle_events import ExchangeOracleEvent_TaskCreationFailed -from src.core.types import JobStatuses, OracleWebhookTypes, ProjectStatuses, TaskStatuses -from src.db import SessionLocal -from src.db import errors as db_errors -from src.db.utils import ForUpdateParams -from src.handlers.completed_escrows import handle_completed_escrows -from src.log import ROOT_LOGGER_NAME -from src.utils.logging import get_function_logger - -module_logger = f"{ROOT_LOGGER_NAME}.cron.cvat" - - -def track_completed_projects() -> None: - """ - Tracks completed projects: - 1. Retrieves projects with "annotation" status - 2. Retrieves tasks related to this project - 3. If all tasks are completed -> updates project status to "completed" - """ - logger = get_function_logger(module_logger) - - try: - logger.debug("Starting cron job") - with SessionLocal.begin() as session: - # Get active projects from db - projects = cvat_service.get_projects_by_status( - session, - ProjectStatuses.annotation, - task_status=TaskStatuses.completed, - limit=CronConfig.track_completed_projects_chunk_size, - for_update=ForUpdateParams(skip_locked=True), - ) - - completed_project_ids = [] - - for project in projects: - tasks = cvat_service.get_tasks_by_cvat_project_id(session, project.cvat_id) - if tasks and all(task.status == TaskStatuses.completed for task in tasks): - cvat_service.update_project_status( - session, project.id, ProjectStatuses.completed - ) - - completed_project_ids.append(project.cvat_id) - - if completed_project_ids: - logger.info( - "Found new completed projects: {}".format( - ", ".join(str(t) for t in completed_project_ids) - ) - ) - except Exception as error: - logger.exception(error) - finally: - logger.debug("Finishing cron job") - - -def track_completed_tasks() -> None: - """ - Tracks completed tasks: - 1. Retrieves tasks with "annotation" status - 2. Retrieves jobs related to this task - 3. If all jobs are completed -> updates task status to "completed" - """ - logger = get_function_logger(module_logger) - - try: - logger.debug("Starting cron job") - with SessionLocal.begin() as session: - tasks = cvat_service.get_tasks_by_status( - session, - TaskStatuses.annotation, - job_status=JobStatuses.completed, - project_status=ProjectStatuses.annotation, - limit=CronConfig.track_completed_tasks_chunk_size, - for_update=ForUpdateParams(skip_locked=True), - ) - - completed_task_ids = [] - - for task in tasks: - jobs = cvat_service.get_jobs_by_cvat_task_id(session, task.cvat_id) - if jobs and all(job.status == JobStatuses.completed for job in jobs): - cvat_service.update_task_status(session, task.id, TaskStatuses.completed) - - completed_task_ids.append(task.cvat_id) - - if completed_task_ids: - logger.info( - "Found new completed tasks: {}".format( - ", ".join(str(t) for t in completed_task_ids) - ) - ) - except Exception as error: - logger.exception(error) - finally: - logger.debug("Finishing cron job") - - -def track_assignments() -> None: - """ - Tracks assignments: - 1. Checks time for each active assignment - 2. If an assignment is timed out, expires it - 3. If a project or task state is not "annotation", cancels assignments - """ - logger = get_function_logger(module_logger) - - try: - logger.debug("Starting cron job") - - with SessionLocal.begin() as session: - assignments = cvat_service.get_unprocessed_expired_assignments( - session, - limit=CronConfig.track_assignments_chunk_size, - for_update=ForUpdateParams(skip_locked=True), - ) - - for assignment in assignments: - logger.info( - "Expiring the unfinished assignment {} (user {}, job id {})".format( - assignment.id, - assignment.user_wallet_address, - assignment.cvat_job_id, - ) - ) - - latest_assignment = cvat_service.get_latest_assignment_by_cvat_job_id( - session, assignment.cvat_job_id - ) - if latest_assignment.id == assignment.id: - # Avoid un-assigning if it's not the latest assignment - - cvat_api.update_job_assignee( - assignment.cvat_job_id, assignee_id=None - ) # note that calling it in a loop can take too much time - - cvat_service.expire_assignment(session, assignment.id) - - with SessionLocal.begin() as session: - assignments = cvat_service.get_active_assignments( - session, - limit=CronConfig.track_assignments_chunk_size, - for_update=ForUpdateParams(skip_locked=True), - ) - - for assignment in assignments: - if assignment.job.project.status != ProjectStatuses.annotation: - logger.warning( - "Canceling the unfinished assignment {} (user {}, job id {}) - " - "the project state is not annotation".format( - assignment.id, - assignment.user_wallet_address, - assignment.cvat_job_id, - ) - ) - - latest_assignment = cvat_service.get_latest_assignment_by_cvat_job_id( - session, assignment.cvat_job_id - ) - if latest_assignment.id == assignment.id: - # Avoid un-assigning if it's not the latest assignment - - cvat_api.update_job_assignee( - assignment.cvat_job_id, assignee_id=None - ) # note that calling it in a loop can take too much time - - cvat_service.cancel_assignment(session, assignment.id) - except Exception as error: - logger.exception(error) - finally: - logger.debug("Finishing cron job") - - -def track_completed_escrows() -> None: - logger = get_function_logger(module_logger) - - try: - logger.debug("Starting cron job") - - handle_completed_escrows(logger) - except Exception as error: - logger.exception(error) - finally: - logger.debug("Finishing cron job") - - -def track_task_creation() -> None: - """ - Checks task creation status to report failed tasks and continue task creation process. - """ - - logger = get_function_logger(module_logger) - - try: - logger.debug("Starting cron job") - - with SessionLocal.begin() as session: - # TODO: maybe add load balancing (e.g. round-robin queue, shuffling) - # to avoid blocking new tasks - uploads = cvat_service.get_active_task_uploads( - session, - limit=CronConfig.track_creating_tasks_chunk_size, - for_update=ForUpdateParams(skip_locked=True), - ) - - if not uploads: - return - - logger.debug( - "Checking the data uploading status of CVAT tasks: {}".format( - ", ".join(str(u.task_id) for u in uploads) - ) - ) - - completed: list[cvat_models.DataUpload] = [] - failed: list[cvat_models.DataUpload] = [] - for upload in uploads: - status, reason = cvat_api.get_task_upload_status(upload.task_id) - project = upload.task.project - if not status or status == cvat_api.UploadStatus.FAILED: - # TODO: add retries if 5xx - failed.append(upload) - - oracle_db_service.outbox.create_webhook( - session, - escrow_address=project.escrow_address, - chain_id=project.chain_id, - type=OracleWebhookTypes.job_launcher, - event=ExchangeOracleEvent_TaskCreationFailed(reason=reason), - ) - elif status == cvat_api.UploadStatus.FINISHED: - try: - cvat_jobs = cvat_api.fetch_task_jobs(upload.task_id) - - existing_jobs = cvat_service.get_jobs_by_cvat_task_id( - session, upload.task_id - ) - existing_job_ids = set(j.cvat_id for j in existing_jobs) - - for cvat_job in cvat_jobs: - if cvat_job.id in existing_job_ids: - continue - - cvat_service.create_job( - session, - cvat_job.id, - upload.task_id, - upload.task.cvat_project_id, - status=JobStatuses(cvat_job.state), - ) - - completed.append(upload) - except cvat_api.exceptions.ApiException as e: - failed.append(upload) - - oracle_db_service.outbox.create_webhook( - session, - escrow_address=project.escrow_address, - chain_id=project.chain_id, - type=OracleWebhookTypes.job_launcher, - event=ExchangeOracleEvent_TaskCreationFailed(reason=str(e)), - ) - - if completed or failed: - cvat_service.finish_data_uploads(session, failed + completed) - - logger.info( - "Updated creation status of CVAT tasks: {}".format( - "; ".join( - f"{k}: {v}" - for k, v in { - "success": ", ".join(str(u.task_id) for u in completed), - "failed": ", ".join(str(u.task_id) for u in failed), - }.items() - if v - ) - ) - ) - except Exception as error: - logger.exception(error) - finally: - logger.debug("Finishing cron job") - - -def track_escrow_creation() -> None: - logger = get_function_logger(module_logger) - - try: - logger.debug("Starting cron job") - - with SessionLocal.begin() as session: - creations = cvat_service.get_active_escrow_creations( - session, - limit=CronConfig.track_escrow_creation_chunk_size, - for_update=ForUpdateParams(skip_locked=True), - ) - - if not creations: - return - - logger.debug( - "Checking escrow creation statuses for escrows: {}".format( - ", ".join(str(c.escrow_address) for c in creations) - ) - ) - - finished: list[cvat_models.EscrowCreation] = [] - for creation in creations: - created_jobs_count = cvat_service.count_jobs_by_escrow_address( - session, - escrow_address=creation.escrow_address, - chain_id=creation.chain_id, - status=JobStatuses.new, - ) - - if created_jobs_count != creation.total_jobs: - continue - - with session.begin_nested(): - try: - cvat_service.update_project_statuses_by_escrow_address( - session=session, - escrow_address=creation.escrow_address, - chain_id=creation.chain_id, - status=ProjectStatuses.annotation, - ) - finished.append(creation) - except sa_errors.OperationalError as e: - if isinstance(e.orig, db_errors.LockNotAvailable): - continue - raise - - if finished: - cvat_service.finish_escrow_creations(session, finished) - - logger.info( - "Updated creation status of escrows: {}".format( - ", ".join(c.escrow_address for c in finished) - ) - ) - except Exception as error: - logger.exception(error) - finally: - logger.debug("Finishing cron job") diff --git a/packages/examples/cvat/exchange-oracle/src/crons/webhooks/__init__.py b/packages/examples/cvat/exchange-oracle/src/crons/webhooks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/examples/cvat/exchange-oracle/src/crons/webhooks/_common.py b/packages/examples/cvat/exchange-oracle/src/crons/webhooks/_common.py new file mode 100644 index 0000000000..a75231388d --- /dev/null +++ b/packages/examples/cvat/exchange-oracle/src/crons/webhooks/_common.py @@ -0,0 +1,85 @@ +import logging +from collections.abc import Callable +from contextlib import contextmanager + +import httpx +from sqlalchemy.orm import Session + +from src.core.types import OracleWebhookTypes +from src.db.utils import ForUpdateParams +from src.models.webhook import Webhook +from src.services import webhook as webhook_service +from src.utils.webhooks import prepare_outgoing_webhook_body, prepare_signed_message + + +@contextmanager +def handle_webhook( + logger: logging.Logger, + session: Session, + webhook: Webhook, + *, + on_fail: Callable[[Session, Webhook, Exception], None] = lambda _s, _w, _e: None, +): + logger.debug( + "Processing webhook " + f"{webhook.type}.{webhook.event_type}~{webhook.signature} " + f"in escrow_address={webhook.escrow_address} " + f"(attempt {webhook.attempts + 1})" + ) + savepoint = session.begin_nested() + try: + yield + except Exception as e: + savepoint.rollback() + logger.exception(f"Webhook {webhook.id} sending failed: {e}") + savepoint = session.begin_nested() + try: + on_fail(session, webhook, e) + except Exception: + savepoint.rollback() + raise + finally: + webhook_service.outbox.handle_webhook_fail(session, webhook.id) + else: + webhook_service.outbox.handle_webhook_success(session, webhook.id) + logger.debug("Webhook handled successfully") + + +def _send_webhook(url: str, webhook: Webhook, *, with_timestamp: bool = True) -> None: + body = prepare_outgoing_webhook_body( + webhook.escrow_address, + webhook.chain_id, + webhook.event_type, + webhook.event_data, + timestamp=webhook.created_at if with_timestamp else None, + ) + _, signature = prepare_signed_message( + webhook.escrow_address, + webhook.chain_id, + body=body, + ) + headers = {"human-signature": signature} + with httpx.Client() as client: + response = client.post(url, headers=headers, json=body) + response.raise_for_status() + + +def process_outgoing_webhooks( + logger: logging.Logger, + session: Session, + webhook_type: OracleWebhookTypes, + url_getter: Callable[[int, str], str], + chunk_size: int, + *, + with_timestamp: bool = True, +): + webhooks = webhook_service.outbox.get_pending_webhooks( + session, + webhook_type, + limit=chunk_size, + for_update=ForUpdateParams(skip_locked=True), + ) + for webhook in webhooks: + with handle_webhook(logger, session, webhook): + webhook_url = url_getter(webhook.chain_id, webhook.escrow_address) + _send_webhook(webhook_url, webhook, with_timestamp=with_timestamp) diff --git a/packages/examples/cvat/exchange-oracle/src/crons/webhooks/job_launcher.py b/packages/examples/cvat/exchange-oracle/src/crons/webhooks/job_launcher.py new file mode 100644 index 0000000000..6ce5e6901e --- /dev/null +++ b/packages/examples/cvat/exchange-oracle/src/crons/webhooks/job_launcher.py @@ -0,0 +1,161 @@ +import logging + +from human_protocol_sdk.constants import Status as EscrowStatus +from sqlalchemy.orm import Session + +import src.handlers.job_creation as cvat +import src.services.cvat as cvat_db_service +import src.services.webhook as oracle_db_service +from src.chain.escrow import validate_escrow +from src.chain.kvstore import get_job_launcher_url +from src.core.config import Config, CronConfig +from src.core.oracle_events import ( + ExchangeOracleEvent_EscrowCleaned, + ExchangeOracleEvent_TaskCreationFailed, +) +from src.core.types import JobLauncherEventTypes, Networks, OracleWebhookTypes, ProjectStatuses +from src.crons._cron_job import cron_job +from src.crons.webhooks._common import handle_webhook, process_outgoing_webhooks +from src.db.utils import ForUpdateParams +from src.handlers.escrow_cleanup import cleanup_escrow +from src.models.webhook import Webhook + + +def handle_failure(session: Session, webhook: Webhook, exc: Exception) -> None: + if ( + webhook.event_type == JobLauncherEventTypes.escrow_created + and webhook.attempts + 1 >= Config.webhook_max_retries + ): + logging.error( + f"Exceeded maximum retries for {webhook.escrow_address=} creation. " + f"Notifying job launcher." + ) + # TODO: think about unifying this further + oracle_db_service.outbox.create_webhook( + session=session, + escrow_address=webhook.escrow_address, + chain_id=webhook.chain_id, + type=OracleWebhookTypes.job_launcher, + event=ExchangeOracleEvent_TaskCreationFailed(reason=str(exc)), + ) + + +@cron_job +def process_incoming_job_launcher_webhooks(logger: logging.Logger, session: Session): + """ + Process incoming job launcher webhooks + """ + webhooks = oracle_db_service.inbox.get_pending_webhooks( + session, + OracleWebhookTypes.job_launcher, + limit=CronConfig.process_job_launcher_webhooks_chunk_size, + for_update=ForUpdateParams(skip_locked=True), + ) + + for webhook in webhooks: + with handle_webhook(logger, session, webhook, on_fail=handle_failure): + handle_job_launcher_event(webhook, db_session=session, logger=logger) + + +def handle_job_launcher_event(webhook: Webhook, *, db_session: Session, logger: logging.Logger): + assert webhook.type == OracleWebhookTypes.job_launcher + + match webhook.event_type: + case JobLauncherEventTypes.escrow_created: + try: + validate_escrow( + webhook.chain_id, + webhook.escrow_address, + allow_no_funds=True, + ) + + if cvat_db_service.get_project_by_escrow_address( + db_session, webhook.escrow_address, for_update=True + ): + logger.error( + f"Received an escrow creation event for " + f"escrow_address {webhook.escrow_address}. " + "A CVAT project for this escrow already exists, ignoring the event." + ) + return + + logger.info( + f"Creating a new CVAT project (escrow_address={webhook.escrow_address})" + ) + + cvat.create_task(webhook.escrow_address, webhook.chain_id) + + except Exception: + projects = cvat_db_service.get_projects_by_escrow_address( + db_session, webhook.escrow_address + ) + + cleanup_escrow(webhook.escrow_address, Networks(webhook.chain_id), projects) + cvat_db_service.delete_projects( + db_session, webhook.escrow_address, webhook.chain_id + ) + raise + + case JobLauncherEventTypes.escrow_canceled: + validate_escrow( + webhook.chain_id, + webhook.escrow_address, + accepted_states=[EscrowStatus.Pending, EscrowStatus.Cancelled], + ) + + projects = cvat_db_service.get_projects_by_escrow_address( + db_session, webhook.escrow_address, for_update=True, limit=None + ) + if not projects: + logger.error( + "Received escrow cancel event " + f"(escrow_address={webhook.escrow_address}). " + "The project doesn't exist, ignoring" + ) + return + + for project in projects: + if project.status in [ + ProjectStatuses.canceled, + ProjectStatuses.recorded, + ]: + logger.error( + "Received escrow cancel event " + f"(escrow_address={webhook.escrow_address}). " + "The project is already finished, ignoring" + ) + continue + + logger.info( + f"Received escrow cancel event (escrow_address={webhook.escrow_address}). " + "Canceling the project" + ) + + cvat_db_service.finish_escrow_creations_by_escrow_address( + db_session, escrow_address=webhook.escrow_address, chain_id=webhook.chain_id + ) + cvat_db_service.update_project_statuses_by_escrow_address( + db_session, webhook.escrow_address, webhook.chain_id, ProjectStatuses.canceled + ) + cleanup_escrow(webhook.escrow_address, Networks(webhook.chain_id), projects) + + oracle_db_service.outbox.create_webhook( + session=db_session, + escrow_address=webhook.escrow_address, + chain_id=webhook.chain_id, + type=OracleWebhookTypes.recording_oracle, + event=ExchangeOracleEvent_EscrowCleaned(), + ) + case _: + raise AssertionError(f"Unknown job launcher event {webhook.event_type}") + + +@cron_job +def process_outgoing_job_launcher_webhooks(logger: logging.Logger, session: Session): + process_outgoing_webhooks( + logger, + session, + OracleWebhookTypes.job_launcher, + get_job_launcher_url, + CronConfig.process_job_launcher_webhooks_chunk_size, + ) diff --git a/packages/examples/cvat/exchange-oracle/src/crons/process_recording_oracle_webhooks.py b/packages/examples/cvat/exchange-oracle/src/crons/webhooks/recording_oracle.py similarity index 58% rename from packages/examples/cvat/exchange-oracle/src/crons/process_recording_oracle_webhooks.py rename to packages/examples/cvat/exchange-oracle/src/crons/webhooks/recording_oracle.py index c49fac60b9..4ec83840c2 100644 --- a/packages/examples/cvat/exchange-oracle/src/crons/process_recording_oracle_webhooks.py +++ b/packages/examples/cvat/exchange-oracle/src/crons/webhooks/recording_oracle.py @@ -1,6 +1,5 @@ import logging -import httpx from datumaro.util import take_by from sqlalchemy.orm import Session @@ -16,45 +15,27 @@ RecordingOracleEventTypes, TaskStatuses, ) -from src.db import SessionLocal +from src.crons._cron_job import cron_job +from src.crons.webhooks._common import handle_webhook, process_outgoing_webhooks from src.db.utils import ForUpdateParams -from src.log import ROOT_LOGGER_NAME from src.models.webhook import Webhook -from src.utils.logging import get_function_logger -from src.utils.webhooks import prepare_outgoing_webhook_body, prepare_signed_message -module_logger_name = f"{ROOT_LOGGER_NAME}.cron.webhook" - -def process_incoming_recording_oracle_webhooks(): +@cron_job +def process_incoming_recording_oracle_webhooks(logger: logging.Logger, session: Session): """ Process incoming oracle webhooks """ - logger = get_function_logger(module_logger_name) - - try: - logger.debug("Starting cron job") - - with SessionLocal.begin() as session: - webhooks = oracle_db_service.inbox.get_pending_webhooks( - session, - OracleWebhookTypes.recording_oracle, - limit=CronConfig.process_recording_oracle_webhooks_chunk_size, - for_update=ForUpdateParams(skip_locked=True), - ) + webhooks = oracle_db_service.inbox.get_pending_webhooks( + session, + OracleWebhookTypes.recording_oracle, + limit=CronConfig.process_recording_oracle_webhooks_chunk_size, + for_update=ForUpdateParams(skip_locked=True), + ) - for webhook in webhooks: - try: - handle_recording_oracle_event(webhook, db_session=session, logger=logger) - - oracle_db_service.inbox.handle_webhook_success(session, webhook.id) - except Exception as e: - logger.exception(f"Webhook {webhook.id} handling failed: {e}") - oracle_db_service.inbox.handle_webhook_fail(session, webhook.id) - except Exception as e: - logger.exception(e) - finally: - logger.debug("Finishing cron job") + for webhook in webhooks: + with handle_webhook(logger, session, webhook): + handle_recording_oracle_event(webhook, db_session=session, logger=logger) def handle_recording_oracle_event(webhook: Webhook, *, db_session: Session, logger: logging.Logger): @@ -149,58 +130,12 @@ def handle_recording_oracle_event(webhook: Webhook, *, db_session: Session, logg raise AssertionError(f"Unknown recording oracle event {webhook.event_type}") -def process_outgoing_recording_oracle_webhooks(): - """ - Process webhooks that needs to be sent to recording oracle: - * Retrieves `webhook_url` from KVStore - * Sends webhook to recording oracle - """ - logger = get_function_logger(module_logger_name) - - try: - logger.debug("Starting cron job") - - with SessionLocal.begin() as session: - webhooks = oracle_db_service.outbox.get_pending_webhooks( - session, - OracleWebhookTypes.recording_oracle, - limit=CronConfig.process_recording_oracle_webhooks_chunk_size, - for_update=ForUpdateParams(skip_locked=True), - ) - for webhook in webhooks: - try: - logger.debug( - "Processing webhook " - f"{webhook.type}.{webhook.event_type} " - f"(attempt {webhook.attempts + 1})" - ) - - body = prepare_outgoing_webhook_body( - webhook.escrow_address, - webhook.chain_id, - webhook.event_type, - webhook.event_data, - timestamp=webhook.created_at, - ) - - _, signature = prepare_signed_message( - webhook.escrow_address, - webhook.chain_id, - body=body, - ) - - headers = {"human-signature": signature} - webhook_url = get_recording_oracle_url(webhook.chain_id, webhook.escrow_address) - with httpx.Client() as client: - response = client.post(webhook_url, headers=headers, json=body) - response.raise_for_status() - - oracle_db_service.outbox.handle_webhook_success(session, webhook.id) - logger.debug("Webhook handled successfully") - except Exception as e: - logger.exception(f"Webhook {webhook.id} sending failed: {e}") - oracle_db_service.outbox.handle_webhook_fail(session, webhook.id) - except Exception as e: - logger.exception(e) - finally: - logger.debug("Finishing cron job") +@cron_job +def process_outgoing_recording_oracle_webhooks(logger: logging.Logger, session: Session): + process_outgoing_webhooks( + logger, + session, + OracleWebhookTypes.recording_oracle, + get_recording_oracle_url, + CronConfig.process_recording_oracle_webhooks_chunk_size, + ) diff --git a/packages/examples/cvat/exchange-oracle/src/crons/webhooks/reputation_oracle.py b/packages/examples/cvat/exchange-oracle/src/crons/webhooks/reputation_oracle.py new file mode 100644 index 0000000000..263a47e825 --- /dev/null +++ b/packages/examples/cvat/exchange-oracle/src/crons/webhooks/reputation_oracle.py @@ -0,0 +1,67 @@ +import logging + +from sqlalchemy.orm import Session + +import src.services.cvat as db_service +import src.services.webhook as oracle_db_service +from src.chain.kvstore import get_reputation_oracle_url +from src.core.config import CronConfig +from src.core.oracle_events import ( + ExchangeOracleEvent_EscrowCleaned, +) +from src.core.types import ( + Networks, + OracleWebhookTypes, + ProjectStatuses, + ReputationOracleEventTypes, +) +from src.crons._cron_job import cron_job +from src.crons.webhooks._common import handle_webhook, process_outgoing_webhooks +from src.db.utils import ForUpdateParams +from src.handlers.escrow_cleanup import cleanup_escrow + + +@cron_job +def process_incoming_reputation_oracle_webhooks(logger: logging.Logger, session: Session): + webhooks = oracle_db_service.inbox.get_pending_webhooks( + session, + OracleWebhookTypes.reputation_oracle, + limit=CronConfig.process_reputation_oracle_webhooks_chunk_size, + for_update=ForUpdateParams(skip_locked=True), + ) + for webhook in webhooks: + with handle_webhook(logger, session, webhook): + match webhook.event_type: + case ReputationOracleEventTypes.escrow_completed: + projects = db_service.get_projects_by_escrow_address( + session, webhook.escrow_address + ) + cleanup_escrow(webhook.escrow_address, Networks(webhook.chain_id), projects) + + db_service.update_project_statuses_by_escrow_address( + session, + webhook.escrow_address, + webhook.chain_id, + status=ProjectStatuses.deleted, + ) + + oracle_db_service.outbox.create_webhook( + session=session, + escrow_address=webhook.escrow_address, + chain_id=webhook.chain_id, + type=OracleWebhookTypes.recording_oracle, + event=ExchangeOracleEvent_EscrowCleaned(), + ) + case _: + raise TypeError(f"Unknown reputation oracle event {webhook.event_type}") + + +@cron_job +def process_outgoing_reputation_oracle_webhooks(logger: logging.Logger, session: Session): + process_outgoing_webhooks( + logger, + session, + OracleWebhookTypes.recording_oracle, + get_reputation_oracle_url, + CronConfig.process_recording_oracle_webhooks_chunk_size, + ) diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/escrow_cleanup.py b/packages/examples/cvat/exchange-oracle/src/handlers/escrow_cleanup.py new file mode 100644 index 0000000000..4536e44a93 --- /dev/null +++ b/packages/examples/cvat/exchange-oracle/src/handlers/escrow_cleanup.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import contextlib +import logging +from typing import TYPE_CHECKING + +from cvat_sdk.api_client.exceptions import NotFoundException + +import src.cvat.api_calls as cvat_api +import src.services.cloud as cloud_service +from src.core.config import Config +from src.core.storage import ( + compose_data_bucket_prefix, + compose_results_bucket_prefix, +) +from src.log import get_logger_name +from src.services.cloud.utils import BucketAccessInfo + +if TYPE_CHECKING: + from collections.abc import Generator + + from src.models.cvat import Project + +logger = logging.getLogger(get_logger_name(__name__)) + + +@contextlib.contextmanager +def _log_error(errors_container: list[Exception], message: str) -> Generator[None, None, None]: + try: + yield + except Exception as e: + errors_container.append(e) + logger.exception(message) + + +def _cleanup_cvat(projects: list[Project]) -> None: + """ + CVAT can throw a timeout error or 500 status code unexpectedly. + + We don't want these errors affecting deletion of other projects, but want to reraise them, + so we'll be able to retry later. + + We also want to ignore NotFoundException since project might have been deleted manually + or on the previous attempt. + """ + cloud_storage_ids_to_delete = set() # probably will allways have one element + errors = [] + for project in projects: + cloud_storage_ids_to_delete.add(project.cvat_cloudstorage_id) + if project.cvat_id is not None: + with ( + _log_error( + errors, f"Encountered error while deliting CVAT project {project.cvat_id}" + ), + contextlib.suppress(NotFoundException), + ): + cvat_api.delete_project(project.cvat_id) + + for cloud_storage_id in cloud_storage_ids_to_delete: + with ( + _log_error( + errors, f"Encountered error while deleting CVAT cloudstorage {cloud_storage_id}" + ), + contextlib.suppress(NotFoundException), + ): + cvat_api.delete_cloudstorage(cloud_storage_id) + + if errors: + raise RuntimeError( + f"Encountered {len(errors)} error(s) while deleting CVAT projects. " + "All errors have been logged.", + errors, + ) + + +def _cleanup_storage(escrow_address: str, chain_id: int) -> None: + storage_client = cloud_service.make_client(BucketAccessInfo.parse_obj(Config.storage_config)) + storage_client.remove_files( + prefix=compose_data_bucket_prefix(escrow_address, chain_id), + ) + storage_client.remove_files( + prefix=compose_results_bucket_prefix(escrow_address, chain_id), + ) + + +def cleanup_escrow(escrow_address: str, chain_id: int, projects: list[Project]) -> None: + """ + Cleans up CVAT resources and storage related to the given escrow. + """ + try: + _cleanup_cvat(projects) + finally: + # in case both _cleanup_cvat and _cleanup_storage raise an exception, + # both will be in the traceback + _cleanup_storage(escrow_address, chain_id) diff --git a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py index 93532aba4b..e66ca9cfd3 100644 --- a/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py +++ b/packages/examples/cvat/exchange-oracle/src/handlers/job_creation.py @@ -8,7 +8,7 @@ from itertools import chain, groupby from math import ceil from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, TypeVar, Union, cast +from typing import TYPE_CHECKING, TypeVar, cast import cv2 import datumaro as dm @@ -2554,14 +2554,3 @@ def create_task(escrow_address: str, chain_id: int) -> None: with builder_type(manifest, escrow_address, chain_id) as task_builder: task_builder.set_logger(logger) task_builder.build() - - -def remove_task(escrow_address: str) -> None: - with SessionLocal.begin() as session: - project = db_service.get_project_by_escrow_address(session, escrow_address) - if project is not None: - if project.cvat_cloudstorage_id: - cvat_api.delete_cloudstorage(project.cvat_cloudstorage_id) - if project.cvat_id: - cvat_api.delete_project(project.cvat_id) - db_service.delete_project(session, project.id) diff --git a/packages/examples/cvat/exchange-oracle/src/log.py b/packages/examples/cvat/exchange-oracle/src/log.py index 09349ef6f4..903942fc92 100644 --- a/packages/examples/cvat/exchange-oracle/src/log.py +++ b/packages/examples/cvat/exchange-oracle/src/log.py @@ -8,6 +8,10 @@ ROOT_LOGGER_NAME = "app" +def get_logger_name(module_name: str) -> str: + return f"{ROOT_LOGGER_NAME}.{module_name.removeprefix('src.')}" + + def setup_logging(): log_level_name = logging.getLevelName( Config.loglevel or (logging.DEBUG if Config.environment == "development" else logging.INFO) diff --git a/packages/examples/cvat/exchange-oracle/src/services/cloud/client.py b/packages/examples/cvat/exchange-oracle/src/services/cloud/client.py index 53cf13dcc0..7db5532b64 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cloud/client.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cloud/client.py @@ -15,6 +15,9 @@ def create_file(self, key: str, data: bytes = b"", *, bucket: str | None = None) @abstractmethod def remove_file(self, key: str, *, bucket: str | None = None): ... + @abstractmethod + def remove_files(self, prefix: str, *, bucket: str | None = None): ... + @abstractmethod def file_exists(self, key: str, *, bucket: str | None = None) -> bool: ... diff --git a/packages/examples/cvat/exchange-oracle/src/services/cloud/gcs.py b/packages/examples/cvat/exchange-oracle/src/services/cloud/gcs.py index 014be4ce1a..d147f04d08 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cloud/gcs.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cloud/gcs.py @@ -32,6 +32,20 @@ def remove_file(self, key: str, *, bucket: str | None = None) -> None: bucket_client = self.client.get_bucket(bucket) bucket_client.delete_blob(unquote(key)) + def remove_files(self, prefix: str, *, bucket: str | None = None): + import warnings + + warnings.warn( + "Avoid usage of `GcsClient.remove_files`. See: " + "https://cloud.google.com/storage/docs/deleting-objects#delete-objects-in-bulk", + UserWarning, + stacklevel=2, + ) + bucket = unquote(bucket) if bucket else self._bucket + bucket_client = self.client.get_bucket(bucket) + keys = self.list_files(prefix=prefix) + bucket_client.delete_blobs([unquote(key) for key in keys]) + def file_exists(self, key: str, *, bucket: str | None = None) -> bool: bucket = unquote(bucket) if bucket else self._bucket bucket_client = self.client.get_bucket(bucket) diff --git a/packages/examples/cvat/exchange-oracle/src/services/cloud/s3.py b/packages/examples/cvat/exchange-oracle/src/services/cloud/s3.py index ccb5557e7d..7c9cc61fbc 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cloud/s3.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cloud/s3.py @@ -1,4 +1,5 @@ from io import BytesIO +from typing import TYPE_CHECKING from urllib.parse import unquote import boto3 @@ -8,6 +9,9 @@ from src.services.cloud.client import StorageClient DEFAULT_S3_HOST = "s3.amazonaws.com" +if TYPE_CHECKING: + from mypy_boto3_s3 import S3Client as S3ClientStub + from mypy_boto3_s3 import S3ServiceResource as S3ServiceResourceStub class S3Client(StorageClient): @@ -27,8 +31,9 @@ def __init__( s3 = session.resource( "s3", **({"endpoint_url": unquote(endpoint_url)} if endpoint_url else {}) ) - self.resource = s3 - self.client = s3.meta.client + self.resource: S3ServiceResourceStub = s3 + + self.client: S3ClientStub = s3.meta.client if not access_key and not secret_key: self.client.meta.events.register("choose-signer.s3.*", disable_signing) @@ -41,6 +46,10 @@ def remove_file(self, key: str, *, bucket: str | None = None): bucket = unquote(bucket) if bucket else self._bucket self.client.delete_object(Bucket=bucket, Key=unquote(key)) + def remove_files(self, prefix: str, *, bucket: str | None = None): + bucket = unquote(bucket) if bucket else self._bucket + self.resource.Bucket(bucket).objects.filter(Prefix=unquote(prefix)).delete() + def file_exists(self, key: str, *, bucket: str | None = None) -> bool: bucket = unquote(bucket) if bucket else self._bucket try: diff --git a/packages/examples/cvat/exchange-oracle/src/services/cvat.py b/packages/examples/cvat/exchange-oracle/src/services/cvat.py index ccd6fd9ddd..bdca974032 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/cvat.py +++ b/packages/examples/cvat/exchange-oracle/src/services/cvat.py @@ -215,6 +215,7 @@ def update_project_statuses_by_escrow_address( Project.chain_id == chain_id, ) .values(status=status.value) + .returning(Project.cvat_id) ) session.execute(statement) @@ -224,6 +225,15 @@ def delete_project(session: Session, project_id: str) -> None: session.delete(project) +def delete_projects(session: Session, escrow_address: str, chain_id: int) -> None: + session.execute( + delete(Project).where( + Project.escrow_address == escrow_address, + Project.chain_id == chain_id, + ) + ) + + def is_project_completed(session: Session, project_id: str) -> bool: project = get_project_by_id(session, project_id) jobs = get_jobs_by_cvat_project_id(session, project.cvat_id) diff --git a/packages/examples/cvat/exchange-oracle/src/services/webhook.py b/packages/examples/cvat/exchange-oracle/src/services/webhook.py index e2bc9c7fbe..c5766329bd 100644 --- a/packages/examples/cvat/exchange-oracle/src/services/webhook.py +++ b/packages/examples/cvat/exchange-oracle/src/services/webhook.py @@ -10,6 +10,7 @@ from src.core.config import Config from src.core.oracle_events import OracleEvent, validate_event from src.core.types import OracleWebhookStatuses, OracleWebhookTypes +from src.db.utils import ForUpdateParams from src.db.utils import maybe_for_update as _maybe_for_update from src.models.webhook import Webhook from src.utils.enums import BetterEnumMeta @@ -91,7 +92,7 @@ def get_pending_webhooks( type: OracleWebhookTypes, *, limit: int = 10, - for_update: bool = False, + for_update: bool | ForUpdateParams = False, ) -> list[Webhook]: return ( _maybe_for_update(session.query(Webhook), enable=for_update) diff --git a/packages/examples/cvat/exchange-oracle/src/validators/signature.py b/packages/examples/cvat/exchange-oracle/src/validators/signature.py index 1d4bd6d0df..b47b39b418 100644 --- a/packages/examples/cvat/exchange-oracle/src/validators/signature.py +++ b/packages/examples/cvat/exchange-oracle/src/validators/signature.py @@ -5,7 +5,9 @@ from fastapi import HTTPException, Request -from src.chain.escrow import get_job_launcher_address, get_recording_oracle_address +from src.chain.escrow import ( + get_available_webhook_types, +) from src.chain.web3 import recover_signer from src.core.config import Config from src.core.types import OracleWebhookTypes @@ -18,29 +20,13 @@ async def validate_oracle_webhook_signature( data: bytes = await request.body() message: dict = literal_eval(data.decode("utf-8")) - signer = recover_signer(webhook.chain_id, message, signature) + signer = recover_signer(webhook.chain_id, message, signature).lower() + webhook_types = get_available_webhook_types(webhook.chain_id, webhook.escrow_address) - job_launcher_address = get_job_launcher_address(webhook.chain_id, webhook.escrow_address) - recording_oracle_address = get_recording_oracle_address( - webhook.chain_id, webhook.escrow_address - ) - possible_signers = { - OracleWebhookTypes.job_launcher: job_launcher_address, - OracleWebhookTypes.recording_oracle: recording_oracle_address, - } - - matched_signer = next( - ( - s_type - for s_type in possible_signers - if signer.lower() == possible_signers[s_type].lower() - ), - None, - ) - if not matched_signer: + if not (webhook_sender := webhook_types.get(signer)): raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED) - return matched_signer + return webhook_sender async def validate_cvat_signature(request: Request, x_signature_256: str): diff --git a/packages/examples/cvat/exchange-oracle/tests/api/test_cvat_webhook_api.py b/packages/examples/cvat/exchange-oracle/tests/api/test_cvat_webhook_api.py index 2feb6b8300..2cc02ba21f 100644 --- a/packages/examples/cvat/exchange-oracle/tests/api/test_cvat_webhook_api.py +++ b/packages/examples/cvat/exchange-oracle/tests/api/test_cvat_webhook_api.py @@ -76,7 +76,7 @@ def test_incoming_webhook_200_update_expired_assignmets(client: TestClient) -> N (job, _) = get_cvat_job_from_db(1) # Check if "update:job" event works with expired assignments wallet_address = "0x86e83d346041E8806e352681f3F14549C0d2BC68" - add_asignment_to_db(wallet_address, 1, job.cvat_id, datetime.now()) + add_asignment_to_db(wallet_address, 1, job.cvat_id, datetime.now(tz=timezone.utc)) data = { "event": "update:job", diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_escrow.py b/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_escrow.py index d65097e64c..2d69ea7813 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_escrow.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/chain/test_escrow.py @@ -8,15 +8,16 @@ from human_protocol_sdk.escrow import EscrowClientError, EscrowData from src.chain.escrow import ( + get_available_webhook_types, get_escrow_manifest, - get_job_launcher_address, - get_recording_oracle_address, validate_escrow, ) +from src.core.types import OracleWebhookTypes from tests.utils.constants import ( DEFAULT_MANIFEST_URL, ESCROW_ADDRESS, + EXCHANGE_ORACLE_ADDRESS, FACTORY_ADDRESS, JOB_LAUNCHER_ADDRESS, PGP_PASSPHRASE, @@ -24,6 +25,7 @@ PGP_PUBLIC_KEY1, PGP_PUBLIC_KEY2, RECORDING_ORACLE_ADDRESS, + REPUTATION_ORACLE_ADDRESS, TOKEN_ADDRESS, ) @@ -48,6 +50,8 @@ def setUp(self): created_at="", manifest_url=DEFAULT_MANIFEST_URL, recording_oracle=RECORDING_ORACLE_ADDRESS, + exchange_oracle=EXCHANGE_ORACLE_ADDRESS, + reputation_oracle=REPUTATION_ORACLE_ADDRESS, ) def test_validate_escrow(self): @@ -118,45 +122,26 @@ def test_get_escrow_manifest_invalid_address(self): with pytest.raises(EscrowClientError, match="Invalid escrow address: invalid_address"): get_escrow_manifest(chain_id, "invalid_address") - def test_get_job_launcher_address(self): + def test_get_available_webhook_types(self): with patch("src.chain.escrow.EscrowUtils.get_escrow") as mock_function: mock_function.return_value = self.escrow_data - job_launcher_address = get_job_launcher_address(chain_id, escrow_address) - assert isinstance(job_launcher_address, str) - assert job_launcher_address == JOB_LAUNCHER_ADDRESS - - def test_get_job_launcher_address_invalid_address(self): - with pytest.raises(EscrowClientError, match="Invalid escrow address: invalid_address"): - get_job_launcher_address(chain_id, "invalid_address") - - def test_get_job_launcher_address_invalid_chain_id(self): - with pytest.raises(ValueError, match="123 is not a valid ChainId"): - get_job_launcher_address(123, escrow_address) - - def test_get_job_launcher_address_empty_escrow(self): - with patch("src.chain.escrow.EscrowUtils.get_escrow") as mock_function: - mock_function.return_value = None - with pytest.raises(Exception, match=f"Can't find escrow {ESCROW_ADDRESS}"): - get_job_launcher_address(chain_id, escrow_address) - - def test_get_recording_oracle_address(self): - with patch("src.chain.escrow.EscrowUtils.get_escrow") as mock_function: - self.escrow_data.recording_oracle = RECORDING_ORACLE_ADDRESS - mock_function.return_value = self.escrow_data - recording_oracle_address = get_recording_oracle_address(chain_id, escrow_address) - assert isinstance(recording_oracle_address, str) - assert recording_oracle_address == RECORDING_ORACLE_ADDRESS + webhook_types = get_available_webhook_types(chain_id, escrow_address) + assert webhook_types == { + JOB_LAUNCHER_ADDRESS.lower(): OracleWebhookTypes.job_launcher, + REPUTATION_ORACLE_ADDRESS.lower(): OracleWebhookTypes.reputation_oracle, + RECORDING_ORACLE_ADDRESS.lower(): OracleWebhookTypes.recording_oracle, + } - def test_get_recording_oracle_address_invalid_address(self): + def test_get_available_webhook_types_invalid_address(self): with pytest.raises(EscrowClientError, match="Invalid escrow address: invalid_address"): - get_recording_oracle_address(chain_id, "invalid_address") + get_available_webhook_types(chain_id, "invalid_address") - def test_get_recording_oracle_address_invalid_chain_id(self): + def test_get_available_webhook_types_invalid_chain_id(self): with pytest.raises(ValueError, match="123 is not a valid ChainId"): - get_recording_oracle_address(123, escrow_address) + get_available_webhook_types(123, escrow_address) - def test_get_recording_oracle_address_empty_escrow(self): + def test_get_available_webhook_types_empty_escrow(self): with patch("src.chain.escrow.EscrowUtils.get_escrow") as mock_function: mock_function.return_value = None with pytest.raises(Exception, match=f"Can't find escrow {ESCROW_ADDRESS}"): - get_recording_oracle_address(chain_id, escrow_address) + get_available_webhook_types(chain_id, escrow_address) diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_assignments.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_assignments.py index 2a0008ce37..2acae81df7 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_assignments.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_assignments.py @@ -10,7 +10,7 @@ AssignmentStatuses, ProjectStatuses, ) -from src.crons.state_trackers import track_assignments +from src.crons.cvat.state_trackers import track_assignments from src.db import SessionLocal from src.models.cvat import Assignment, Project, User @@ -66,7 +66,7 @@ def test_track_expired_assignments(self): assert db_assignments[0].status == AssignmentStatuses.created.value assert db_assignments[1].status == AssignmentStatuses.created.value - with patch("src.crons.state_trackers.cvat_api.update_job_assignee") as mock_cvat_api: + with patch("src.crons.cvat.state_trackers.cvat_api.update_job_assignee") as mock_cvat_api: track_assignments() mock_cvat_api.assert_called_once_with(assignment_2.cvat_job_id, assignee_id=None) @@ -81,7 +81,7 @@ def test_track_expired_assignments(self): @pytest.mark.xfail( strict=True, reason=""" -Fix src/crons/state_trackers.py +Fix src.crons.cvat.state_trackers.py Where in `cvat_service.get_active_assignments()` return value will be empty because it actually looking for the expired assignments """, @@ -138,7 +138,7 @@ def test_track_canceled_assignments(self): assert db_assignments[0].status == AssignmentStatuses.created.value assert db_assignments[1].status == AssignmentStatuses.created.value - with patch("src.crons.state_trackers.cvat_api.update_job_assignee") as mock_cvat_api: + with patch("src.crons.cvat.state_trackers.cvat_api.update_job_assignee") as mock_cvat_api: track_assignments() mock_cvat_api.assert_called_once_with(assignment_2.cvat_job_id, assignee_id=None) diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_escrows.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_escrows.py index dfcf5025f0..95ed109d73 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_escrows.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_escrows.py @@ -20,7 +20,7 @@ TaskStatuses, TaskTypes, ) -from src.crons.state_trackers import track_completed_escrows +from src.crons.cvat.state_trackers import track_completed_escrows from src.db import SessionLocal from src.models.cvat import Assignment, Image, Job, Project, Task, User from src.models.webhook import Webhook diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_projects.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_projects.py index 36de489508..9adbf904dd 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_projects.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_projects.py @@ -4,7 +4,7 @@ from sqlalchemy.sql import select from src.core.types import Networks, ProjectStatuses, TaskStatuses, TaskTypes -from src.crons.state_trackers import track_completed_projects +from src.crons.cvat.state_trackers import track_completed_projects from src.db import SessionLocal from src.models.cvat import Project, Task diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_tasks.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_tasks.py index 614afd20d6..e96233a2bc 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_tasks.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_completed_tasks.py @@ -4,7 +4,7 @@ from sqlalchemy.sql import select from src.core.types import JobStatuses, Networks, ProjectStatuses, TaskStatuses, TaskTypes -from src.crons.state_trackers import track_completed_tasks +from src.crons.cvat.state_trackers import track_completed_tasks from src.db import SessionLocal from src.models.cvat import Job, Project, Task diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_escrow_creation.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_escrow_creation.py index b1cc99f939..708a36c9d1 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_escrow_creation.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_escrow_creation.py @@ -2,7 +2,7 @@ import uuid from src.core.types import ProjectStatuses -from src.crons.state_trackers import track_escrow_creation +from src.crons.cvat.state_trackers import track_escrow_creation from src.db import SessionLocal from src.models.cvat import EscrowCreation, Project diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_task_creation.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_task_creation.py index c04c1181a5..fbfeb84a59 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_task_creation.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/state_trackers/test_track_task_creation.py @@ -4,7 +4,7 @@ import src.cvat.api_calls as cvat_api from src.core.types import ExchangeOracleEventTypes, JobStatuses -from src.crons.state_trackers import track_task_creation +from src.crons.cvat.state_trackers import track_task_creation from src.db import SessionLocal from src.models.cvat import DataUpload, Job from src.models.webhook import Webhook @@ -31,7 +31,7 @@ def test_track_track_failed_task_creation(self): self.session.commit() with patch( - "src.crons.state_trackers.cvat_api.get_task_upload_status" + "src.crons.cvat.state_trackers.cvat_api.get_task_upload_status" ) as mock_get_task_upload_status: mock_get_task_upload_status.return_value = (cvat_api.UploadStatus.FAILED, "Failed") @@ -56,9 +56,9 @@ def test_track_track_completed_task_creation(self): new_cvat_job_id = 2 with ( patch( - "src.crons.state_trackers.cvat_api.get_task_upload_status" + "src.crons.cvat.state_trackers.cvat_api.get_task_upload_status" ) as mock_get_task_upload_status, - patch("src.crons.state_trackers.cvat_api.fetch_task_jobs") as mock_fetch_task_jobs, + patch("src.crons.cvat.state_trackers.cvat_api.fetch_task_jobs") as mock_fetch_task_jobs, ): mock_get_task_upload_status.return_value = (cvat_api.UploadStatus.FINISHED, None) mock_cvat_job_1 = Mock() @@ -93,10 +93,10 @@ def test_track_track_completed_task_creation_error(self): with ( patch( - "src.crons.state_trackers.cvat_api.get_task_upload_status" + "src.crons.cvat.state_trackers.cvat_api.get_task_upload_status" ) as mock_get_task_upload_status, patch( - "src.crons.state_trackers.cvat_api.fetch_task_jobs", + "src.crons.cvat.state_trackers.cvat_api.fetch_task_jobs", side_effect=cvat_api.exceptions.ApiException("Error"), ), ): diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py index 1fbb1b1c35..c31261e0aa 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_job_launcher_webhooks.py @@ -1,11 +1,12 @@ import json import unittest import uuid -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock, call, patch from human_protocol_sdk.constants import ChainId, Status from sqlalchemy.sql import select +from src.core.storage import compose_data_bucket_prefix, compose_results_bucket_prefix from src.core.types import ( ExchangeOracleEventTypes, JobLauncherEventTypes, @@ -16,13 +17,14 @@ TaskStatuses, TaskTypes, ) -from src.crons.process_job_launcher_webhooks import ( +from src.crons.webhooks.job_launcher import ( process_incoming_job_launcher_webhooks, process_outgoing_job_launcher_webhooks, ) from src.db import SessionLocal from src.models.cvat import EscrowCreation, Project from src.models.webhook import Webhook +from src.services.cloud import StorageClient from src.services.webhook import OracleWebhookDirectionTags from tests.utils.constants import DEFAULT_MANIFEST_URL, JOB_LAUNCHER_ADDRESS @@ -159,7 +161,14 @@ def test_process_incoming_job_launcher_webhooks_escrow_created_type_exceed_max_r self.session.add(webhook) self.session.commit() - with patch("src.chain.escrow.get_escrow") as mock_escrow: + + mock_storage_client = MagicMock(spec=StorageClient) + with ( + patch("src.chain.escrow.get_escrow") as mock_escrow, + patch("src.services.cloud.make_client", return_value=mock_storage_client), + patch("src.cvat.api_calls.delete_project") as delete_project_mock, + patch("src.cvat.api_calls.delete_cloudstorage") as delete_cloudstorage_mock, + ): mock_escrow_data = Mock() mock_escrow_data.status = Status.Pending.name mock_escrow.return_value = mock_escrow_data @@ -186,6 +195,32 @@ def test_process_incoming_job_launcher_webhooks_escrow_created_type_exceed_max_r assert new_webhook.status == OracleWebhookStatuses.pending.value assert new_webhook.event_type == ExchangeOracleEventTypes.task_creation_failed assert new_webhook.attempts == 0 + assert mock_storage_client.remove_files.mock_calls == [ + call(prefix=compose_data_bucket_prefix(escrow_address, chain_id)), + call(prefix=compose_results_bucket_prefix(escrow_address, chain_id)), + ] + + assert delete_project_mock.mock_calls == [] + assert delete_cloudstorage_mock.mock_calls == [] + + outgoing_webhooks: list[Webhook] = list( + self.session.scalars( + select(Webhook).where(Webhook.direction == OracleWebhookDirectionTags.outgoing) + ) + ) + assert len(outgoing_webhooks) == 1 + outgoing_webhook = outgoing_webhooks[0] + + assert outgoing_webhook.type == OracleWebhookTypes.job_launcher + assert outgoing_webhook.event_type == ExchangeOracleEventTypes.task_creation_failed + + assert mock_storage_client.remove_files.mock_calls == [ + call(prefix=compose_data_bucket_prefix(escrow_address, chain_id)), + call(prefix=compose_results_bucket_prefix(escrow_address, chain_id)), + ] + + assert delete_project_mock.mock_calls == [] + assert delete_cloudstorage_mock.mock_calls == [] def test_process_incoming_job_launcher_webhooks_escrow_created_type_remove_when_error( self, @@ -278,7 +313,14 @@ def test_process_incoming_job_launcher_webhooks_escrow_canceled_type(self): self.session.add(webhook) self.session.commit() - with patch("src.chain.escrow.get_escrow") as mock_escrow: + + mock_storage_client = MagicMock(spec=StorageClient) + with ( + patch("src.chain.escrow.get_escrow") as mock_escrow, + patch("src.services.cloud.make_client", return_value=mock_storage_client), + patch("src.cvat.api_calls.delete_project") as delete_project_mock, + patch("src.cvat.api_calls.delete_cloudstorage") as delete_cloudstorage_mock, + ): mock_escrow_data = Mock() mock_escrow_data.status = Status.Pending.name mock_escrow_data.balance = 1 @@ -300,6 +342,27 @@ def test_process_incoming_job_launcher_webhooks_escrow_canceled_type(self): assert db_project.status == ProjectStatuses.canceled.value + assert mock_storage_client.remove_files.mock_calls == [ + call(prefix=compose_data_bucket_prefix(escrow_address, chain_id)), + call(prefix=compose_results_bucket_prefix(escrow_address, chain_id)), + ] + + assert delete_project_mock.mock_calls == [ + call(1), + ] + assert delete_cloudstorage_mock.mock_calls == [call(1)] + + outgoing_webhooks: list[Webhook] = list( + self.session.scalars( + select(Webhook).where(Webhook.direction == OracleWebhookDirectionTags.outgoing) + ) + ) + assert len(outgoing_webhooks) == 1 + outgoing_webhook = outgoing_webhooks[0] + + assert outgoing_webhook.type == OracleWebhookTypes.recording_oracle + assert outgoing_webhook.event_type == ExchangeOracleEventTypes.escrow_cleaned + def test_process_incoming_job_launcher_webhooks_escrow_canceled_type_with_multiple_creating_projects( # noqa: E501 self, ): @@ -341,7 +404,14 @@ def test_process_incoming_job_launcher_webhooks_escrow_canceled_type_with_multip self.session.commit() - with patch("src.chain.escrow.get_escrow") as mock_escrow: + mock_storage_client = MagicMock(spec=StorageClient) + + with ( + patch("src.chain.escrow.get_escrow") as mock_escrow, + patch("src.services.cloud.make_client", return_value=mock_storage_client), + patch("src.cvat.api_calls.delete_project") as delete_project_mock, + patch("src.cvat.api_calls.delete_cloudstorage") as delete_cloudstorage_mock, + ): mock_escrow_data = Mock() mock_escrow_data.status = Status.Pending.name mock_escrow_data.balance = 1 @@ -370,6 +440,27 @@ def test_process_incoming_job_launcher_webhooks_escrow_canceled_type_with_multip ) assert bool(db_escrow_creation_tracker.finished_at) + assert db_project.status == ProjectStatuses.canceled.value + + assert mock_storage_client.remove_files.mock_calls == [ + call(prefix=compose_data_bucket_prefix(escrow_address, chain_id)), + call(prefix=compose_results_bucket_prefix(escrow_address, chain_id)), + ] + + assert delete_project_mock.mock_calls == [call(0), call(1), call(2)] + assert delete_cloudstorage_mock.mock_calls == [call(0), call(1), call(2)] + + outgoing_webhooks: list[Webhook] = list( + self.session.scalars( + select(Webhook).where(Webhook.direction == OracleWebhookDirectionTags.outgoing) + ) + ) + assert len(outgoing_webhooks) == 1 + outgoing_webhook = outgoing_webhooks[0] + + assert outgoing_webhook.type == OracleWebhookTypes.recording_oracle + assert outgoing_webhook.event_type == ExchangeOracleEventTypes.escrow_cleaned + def test_process_incoming_job_launcher_webhooks_escrow_canceled_type_invalid_status( self, ): diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_recording_oracle_webhooks.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_recording_oracle_webhooks.py index af027d81a9..2d07069132 100644 --- a/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_recording_oracle_webhooks.py +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_recording_oracle_webhooks.py @@ -16,7 +16,7 @@ TaskStatuses, TaskTypes, ) -from src.crons.process_recording_oracle_webhooks import ( +from src.crons.webhooks.recording_oracle import ( process_incoming_recording_oracle_webhooks, process_outgoing_recording_oracle_webhooks, ) diff --git a/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_reputation_oracle_webhooks.py b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_reputation_oracle_webhooks.py new file mode 100644 index 0000000000..57306f02c1 --- /dev/null +++ b/packages/examples/cvat/exchange-oracle/tests/integration/cron/test_process_reputation_oracle_webhooks.py @@ -0,0 +1,242 @@ +import uuid + +import pytest +from cvat_sdk.api_client.exceptions import NotFoundException +from sqlalchemy import select + +from src.core.storage import compose_data_bucket_prefix, compose_results_bucket_prefix +from src.core.types import ( + ExchangeOracleEventTypes, + JobStatuses, + Networks, + OracleWebhookStatuses, + OracleWebhookTypes, + ProjectStatuses, + ReputationOracleEventTypes, + TaskStatuses, + TaskTypes, +) +from src.crons.webhooks.reputation_oracle import process_incoming_reputation_oracle_webhooks +from src.cvat import api_calls +from src.db import SessionLocal +from src.models.cvat import Job, Project, Task +from src.models.webhook import Webhook +from src.services import cloud +from src.services.cloud import StorageClient +from src.services.webhook import OracleWebhookDirectionTags +from src.utils.time import utcnow + +escrow_address = "0x86e83d346041E8806e352681f3F14549C0d2BC67" +chain_id = Networks.localhost + + +@pytest.fixture +def session(): + db_session = SessionLocal() + yield db_session + db_session.close() + + +@pytest.fixture +def create_project(session): + cvat_id = 0 + + def _create_project(status: ProjectStatuses) -> Project: + nonlocal cvat_id + cvat_id += 1 + project = Project( + id=str(uuid.uuid4()), + cvat_id=cvat_id, + cvat_cloudstorage_id=1, + status=status, + job_type=TaskTypes.image_label_binary, + escrow_address=escrow_address, + chain_id=chain_id, + bucket_url="https://test.storage.googleapis.com/", + ) + session.add(project) + session.commit() + return project + + return _create_project + + +@pytest.fixture +def create_webhook(session): + def _create_webhook( + event_type: ReputationOracleEventTypes, + direction: OracleWebhookDirectionTags, + event_data: dict | None = None, + ) -> Webhook: + webhook = Webhook( + id=str(uuid.uuid4()), + signature="signature", + escrow_address=escrow_address, + chain_id=chain_id, + type=OracleWebhookTypes.reputation_oracle, + status=OracleWebhookStatuses.pending, + event_type=event_type, + event_data=event_data or {}, + direction=direction, + ) + session.add(webhook) + session.commit() + return webhook + + return _create_webhook + + +def mock_cloud_services(mocker): + mock_storage_client = mocker.MagicMock(spec=StorageClient) + mocker.patch.object(cloud, cloud.make_client.__name__, return_value=mock_storage_client) + delete_project_mock = mocker.patch.object(api_calls, api_calls.delete_project.__name__) + delete_cloudstorage_mock = mocker.patch.object( + api_calls, api_calls.delete_cloudstorage.__name__ + ) + return mock_storage_client, delete_project_mock, delete_cloudstorage_mock + + +def add_cvat_entities(session, task_status: JobStatuses, job_status: TaskStatuses): + task = Task(id=str(uuid.uuid4()), cvat_id=1, cvat_project_id=1, status=task_status) + job = Job( + id=str(uuid.uuid4()), + cvat_id=1, + cvat_project_id=1, + cvat_task_id=task.cvat_id, + status=job_status, + ) + session.add_all([task, job]) + + +@pytest.mark.parametrize( + ("project_status", "expected_project_status"), + [ + (ProjectStatuses.completed, ProjectStatuses.deleted), + (ProjectStatuses.deleted, ProjectStatuses.deleted), + (ProjectStatuses.annotation, ProjectStatuses.deleted), + ], +) +def test_process_incoming_reputation_oracle_webhook_escrow_completed( + session, create_project, create_webhook, mocker, project_status, expected_project_status +) -> None: + project1 = create_project(project_status) + project2 = create_project(project_status) + + add_cvat_entities(session, JobStatuses.completed, TaskStatuses.completed) + + webhook = create_webhook( + ReputationOracleEventTypes.escrow_completed, OracleWebhookDirectionTags.incoming + ) + mock_storage_client, delete_project_mock, delete_cloudstorage_mock = mock_cloud_services(mocker) + + process_incoming_reputation_oracle_webhooks() + + session.refresh(project1) + session.refresh(project2) + session.refresh(webhook) + + assert webhook.status == OracleWebhookStatuses.completed + assert webhook.attempts == 1 + assert project1.status == expected_project_status + assert mock_storage_client.remove_files.mock_calls == [ + mocker.call(prefix=compose_data_bucket_prefix(escrow_address, chain_id)), + mocker.call(prefix=compose_results_bucket_prefix(escrow_address, chain_id)), + ] + assert delete_project_mock.mock_calls == [ + mocker.call(project1.cvat_id), + mocker.call(project2.cvat_id), + ] + assert delete_cloudstorage_mock.mock_calls == [mocker.call(1)] + + outgoing_webhooks = list( + session.scalars( + select(Webhook).where(Webhook.direction == OracleWebhookDirectionTags.outgoing) + ) + ) + assert len(outgoing_webhooks) == 1 + assert outgoing_webhooks[0].type == OracleWebhookTypes.recording_oracle + assert outgoing_webhooks[0].event_type == ExchangeOracleEventTypes.escrow_cleaned + + +@pytest.mark.parametrize( + ("project_status", "expected_project_status"), + [ + (ProjectStatuses.completed, ProjectStatuses.deleted), + (ProjectStatuses.deleted, ProjectStatuses.deleted), + (ProjectStatuses.annotation, ProjectStatuses.deleted), + ], +) +def test_process_incoming_reputation_oracle_webhooks_escrow_completed_exceptions( + session, create_project, create_webhook, mocker, project_status, expected_project_status +) -> None: + project1 = create_project(project_status) + project2 = create_project(project_status) + + add_cvat_entities(session, JobStatuses.completed, TaskStatuses.completed) + + webhook = create_webhook( + ReputationOracleEventTypes.escrow_completed, OracleWebhookDirectionTags.incoming + ) + mock_storage_client, delete_project_mock, delete_cloudstorage_mock = mock_cloud_services(mocker) + + delete_project_mock.side_effect = [Exception, None] + + process_incoming_reputation_oracle_webhooks() + + session.refresh(project1) + session.refresh(project2) + session.refresh(webhook) + + assert webhook.status == OracleWebhookStatuses.pending + assert webhook.attempts == 1 + assert project1.status == project_status + + assert mock_storage_client.remove_files.mock_calls == [ + mocker.call(prefix=compose_data_bucket_prefix(escrow_address, chain_id)), + mocker.call(prefix=compose_results_bucket_prefix(escrow_address, chain_id)), + ] + assert delete_project_mock.mock_calls == [ + mocker.call(project1.cvat_id), + mocker.call(project2.cvat_id), + ] + assert delete_cloudstorage_mock.mock_calls == [mocker.call(1)] + + outgoing_webhooks = list( + session.scalars( + select(Webhook).where(Webhook.direction == OracleWebhookDirectionTags.outgoing) + ) + ) + assert len(outgoing_webhooks) == 0 + + delete_project_mock.reset_mock() + mock_storage_client.reset_mock() + delete_project_mock.side_effect = [None, NotFoundException] + delete_cloudstorage_mock.side_effect = NotFoundException + webhook.wait_until = utcnow() + session.commit() + + process_incoming_reputation_oracle_webhooks() + + assert webhook.status == OracleWebhookStatuses.completed + assert webhook.attempts == 2 + + assert project1.status == expected_project_status + assert project2.status == expected_project_status + + assert mock_storage_client.remove_files.mock_calls == [ + mocker.call(prefix=compose_data_bucket_prefix(escrow_address, chain_id)), + mocker.call(prefix=compose_results_bucket_prefix(escrow_address, chain_id)), + ] + assert delete_project_mock.mock_calls == [ + mocker.call(project1.cvat_id), + mocker.call(project2.cvat_id), + ] + + outgoing_webhooks = list( + session.scalars( + select(Webhook).where(Webhook.direction == OracleWebhookDirectionTags.outgoing) + ) + ) + assert len(outgoing_webhooks) == 1 + assert outgoing_webhooks[0].type == OracleWebhookTypes.recording_oracle + assert outgoing_webhooks[0].event_type == ExchangeOracleEventTypes.escrow_cleaned