diff --git a/script/get-ml-model-rgat/_cm.yaml b/script/get-ml-model-rgat/_cm.yaml new file mode 100644 index 0000000000..1aa9f3c1f8 --- /dev/null +++ b/script/get-ml-model-rgat/_cm.yaml @@ -0,0 +1,63 @@ +alias: get-ml-model-rgat +automation_alias: script +automation_uid: 5b4e0237da074764 +cache: true +category: AI/ML models +env: + CM_ML_MODEL: RGAT + CM_ML_MODEL_DATASET: ICBH +input_mapping: + checkpoint: RGAT_CHECKPOINT_PATH + download_path: CM_DOWNLOAD_PATH + to: CM_DOWNLOAD_PATH +new_env_keys: +- CM_ML_MODEL_* +- RGAT_CHECKPOINT_PATH +prehook_deps: +- enable_if_env: + CM_DOWNLOAD_TOOL: + - rclone + CM_TMP_REQUIRE_DOWNLOAD: + - 'yes' + env: + CM_DOWNLOAD_FINAL_ENV_NAME: CM_ML_MODEL_PATH + extra_cache_tags: rgat,gnn,model + force_cache: true + names: + - dae + tags: download-and-extract + update_tags_from_env_with_prefix: + _url.: + - CM_DOWNLOAD_URL +print_env_at_the_end: + RGAT_CHECKPOINT_PATH: R-GAT checkpoint path +tags: +- get +- raw +- ml-model +- rgat +uid: b409fd66c5ad4ed5 +variations: + fp32: + default: true + env: + CM_ML_MODEL_INPUT_DATA_TYPES: fp32 + CM_ML_MODEL_PRECISION: fp32 + CM_ML_MODEL_WEIGHT_DATA_TYPES: fp32 + group: precision + mlcommons: + default: true + default_variations: + download-tool: rclone + group: download-source + rclone: + adr: + dae: + tags: _rclone + env: + CM_DOWNLOAD_TOOL: rclone + CM_RCLONE_CONFIG_NAME: mlc-inference + group: download-tool + rclone,fp32: + env: + CM_DOWNLOAD_URL: mlc-inference:mlcommons-inference-wg-public/R-GAT/RGAT.pt diff --git a/script/get-ml-model-rgat/customize.py b/script/get-ml-model-rgat/customize.py new file mode 100644 index 0000000000..3f2c6c0af6 --- /dev/null +++ b/script/get-ml-model-rgat/customize.py @@ -0,0 +1,29 @@ +from cmind import utils +import os + + +def preprocess(i): + + os_info = i['os_info'] + env = i['env'] + + path = env.get('RGAT_CHECKPOINT_PATH', '').strip() + + if path == '' or not os.path.exists(path): + env['CM_TMP_REQUIRE_DOWNLOAD'] = 'yes' + + return {'return': 0} + + +def postprocess(i): + + env = i['env'] + + if env.get('RGAT_CHECKPOINT_PATH', '') == '': + env['RGAT_CHECKPOINT_PATH'] = env['CM_ML_MODEL_PATH'] + elif env.get('CM_ML_MODEL_PATH', '') == '': + env['CM_ML_MODEL_PATH'] = env['RGAT_CHECKPOINT_PATH'] + + env['CM_GET_DEPENDENT_CACHED_PATH'] = env['RGAT_CHECKPOINT_PATH'] + + return {'return': 0}