Skip to content

Commit

Permalink
Merge pull request #2142 from FedML-AI/dimitris/fail_fast_policy_merge
Browse files Browse the repository at this point in the history
Fast Fail and Timeout Enforcement Policy for Model Deploy Endpoints
  • Loading branch information
fedml-dimitris authored Jun 5, 2024
2 parents 9a8f307 + b0a55ad commit 10c5e17
Show file tree
Hide file tree
Showing 11 changed files with 308 additions and 162 deletions.
10 changes: 10 additions & 0 deletions python/examples/deploy/debug/inference_timeout/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
workspace: "./src"
entry_point: "serve_main.py"
bootstrap: |
echo "Bootstrap start..."
sleep 5
echo "Bootstrap finished"
auto_detect_public_ip: true
use_gpu: true

request_timeout_sec: 10
32 changes: 32 additions & 0 deletions python/examples/deploy/debug/inference_timeout/src/serve_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from fedml.serving import FedMLPredictor
from fedml.serving import FedMLInferenceRunner
import uuid
import torch

# Calculate the number of elements
num_elements = 1_073_741_824 // 4 # using integer division for whole elements


class DummyPredictor(FedMLPredictor):
def __init__(self):
super().__init__()
# Create a tensor with these many elements
tensor = torch.empty(num_elements, dtype=torch.float32)

# Move the tensor to GPU
tensor_gpu = tensor.cuda()

# for debug
with open("/tmp/dummy_gpu_occupier.txt", "w") as f:
f.write("GPU is occupied")

self.worker_id = uuid.uuid4()

def predict(self, request):
return {f"AlohaV0From{self.worker_id}": request}


if __name__ == "__main__":
predictor = DummyPredictor()
fedml_inference_runner = FedMLInferenceRunner(predictor)
fedml_inference_runner.run()
2 changes: 0 additions & 2 deletions python/fedml/computing/scheduler/comm_utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ class SchedulerConstants:
ENDPOINT_INFERENCE_READY_TIMEOUT = 15
ENDPOINT_STATUS_CHECK_TIMEOUT = 60 * 3

MQTT_INFERENCE_TIMEOUT = 60 * 6

TRAIN_PROVISIONING_TIMEOUT = 60 * 25
TRAIN_STARTING_TIMEOUT = 60 * 15
TRAIN_STOPPING_TIMEOUT = 60 * 5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ class ClientConstants(object):
INFERENCE_ENGINE_TYPE_INT_DEFAULT = 2
INFERENCE_MODEL_VERSION = "1"
INFERENCE_INFERENCE_SERVER_VERSION = "v2"
INFERENCE_REQUEST_TIMEOUT = 30

MSG_MODELOPS_DEPLOYMENT_STATUS_INITIALIZING = "INITIALIZING"
MSG_MODELOPS_DEPLOYMENT_STATUS_DEPLOYING = "DEPLOYING"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import traceback
from typing import Mapping
from urllib.parse import urlparse

import httpx
import traceback

from .device_client_constants import ClientConstants
import requests

from fastapi.responses import Response
from fastapi.responses import StreamingResponse
from urllib.parse import urlparse
from typing import Mapping


class FedMLHttpInference:
Expand Down
109 changes: 67 additions & 42 deletions python/fedml/computing/scheduler/model_scheduler/device_model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class FedMLModelCache(Singleton):

FEDML_KEY_COUNT_PER_SCAN = 1000

FEDML_PENDING_REQUESTS_COUNTER = "FEDML_PENDING_REQUESTS_COUNTER"

def __init__(self):
if not hasattr(self, "redis_pool"):
self.redis_pool = None
Expand Down Expand Up @@ -110,7 +112,7 @@ def set_user_setting_replica_num(self, end_point_id,
replica_num: int, enable_auto_scaling: bool = False,
scale_min: int = 0, scale_max: int = 0, state: str = "UNKNOWN",
target_queries_per_replica: int = 60, aggregation_window_size_seconds: int = 60,
scale_down_delay_seconds: int = 120
scale_down_delay_seconds: int = 120, timeout_s: int = 30
) -> bool:
"""
Key: FEDML_MODEL_ENDPOINT_REPLICA_USER_SETTING_TAG--<end_point_id>
Expand All @@ -136,7 +138,8 @@ def set_user_setting_replica_num(self, end_point_id,
"scale_min": scale_min, "scale_max": scale_max, "state": state,
"target_queries_per_replica": target_queries_per_replica,
"aggregation_window_size_seconds": aggregation_window_size_seconds,
"scale_down_delay_seconds": scale_down_delay_seconds
"scale_down_delay_seconds": scale_down_delay_seconds,
"request_timeout_sec": timeout_s
}
try:
self.redis_connection.set(self.get_user_setting_replica_num_key(end_point_id), json.dumps(replica_num_dict))
Expand Down Expand Up @@ -362,7 +365,7 @@ def get_idle_device(self, end_point_id, end_point_name,
if "model_status" in result_payload and result_payload["model_status"] == "DEPLOYED":
idle_device_list.append({"device_id": device_id, "end_point_id": end_point_id})

logging.info(f"{len(idle_device_list)} devices has this model on it: {idle_device_list}")
logging.info(f"{len(idle_device_list)} devices this model has on it: {idle_device_list}")

if len(idle_device_list) <= 0:
return None, None
Expand Down Expand Up @@ -824,38 +827,37 @@ def get_monitor_metrics_key(self, end_point_id, end_point_name, model_name, mode
end_point_id, end_point_name, model_name, model_version)

def get_endpoint_metrics(self,
endpoint_id,
end_point_id,
k_recent=None) -> List[Any]:
model_deployment_monitor_metrics = list()
try:
key_pattern = "{}*{}*".format(
self.FEDML_MODEL_DEPLOYMENT_MONITOR_TAG,
endpoint_id)
model_deployment_monitor_endpoint_keys = \
end_point_id)
model_deployment_monitor_endpoint_key = \
self.redis_connection.keys(pattern=key_pattern)
# Since the reply is a list, we need to make sure the list
# is non-empty otherwise the index will raise an error.
if model_deployment_monitor_endpoint_keys:
if model_deployment_monitor_endpoint_key:
model_deployment_monitor_endpoint_key = \
model_deployment_monitor_endpoint_keys[0]
else:
raise Exception("Function `get_endpoint_metrics` Key {} does not exist."
.format(key_pattern))
# Set start and end index depending on the size of the
# list and the requested number of most recent records.
num_records = self.redis_connection.llen(name=model_deployment_monitor_endpoint_key)
# if k_most_recent is None, then fetch all by default.
start, end = 0, -1
# if k_most_recent is positive then fetch [-k_most_recent:]
if k_recent and k_recent > 0:
start = num_records - k_recent
model_deployment_monitor_metrics = \
self.redis_connection.lrange(
name=model_deployment_monitor_endpoint_key,
start=start,
end=end)
model_deployment_monitor_metrics = [
json.loads(m) for m in model_deployment_monitor_metrics]
model_deployment_monitor_endpoint_key[0]

# Set start and end index depending on the size of the
# list and the requested number of most recent records.
num_records = self.redis_connection.llen(
name=model_deployment_monitor_endpoint_key)
# if k_most_recent is None, then fetch all by default.
start, end = 0, -1
# if k_most_recent is positive then fetch [-k_most_recent:]
if k_recent and k_recent > 0:
start = num_records - k_recent
model_deployment_monitor_metrics = \
self.redis_connection.lrange(
name=model_deployment_monitor_endpoint_key,
start=start,
end=end)
model_deployment_monitor_metrics = [
json.loads(m) for m in model_deployment_monitor_metrics]

except Exception as e:
logging.error(e)
Expand All @@ -868,24 +870,24 @@ def get_endpoint_replicas_results(self, endpoint_id) -> List[Any]:
key_pattern = "{}*{}*".format(
self.FEDML_MODEL_DEPLOYMENT_RESULT_TAG,
endpoint_id)
model_deployment_result_key = \
model_deployment_result_keys = \
self.redis_connection.keys(pattern=key_pattern)
if model_deployment_result_key:
if model_deployment_result_keys:
model_deployment_result_key = \
model_deployment_result_key[0]
model_deployment_result_keys[0]
replicas_results = \
self.redis_connection.lrange(
name=model_deployment_result_key,
start=0,
end=-1)
# Format the result value to a properly formatted json.
for replica_idx, replica in enumerate(replicas_results):
replicas_results[replica_idx] = json.loads(replica)
replicas_results[replica_idx]["result"] = \
json.loads(replicas_results[replica_idx]["result"])
else:
raise Exception("Function `get_endpoint_replicas_results` Key {} does not exist."
.format(key_pattern))
replicas_results = \
self.redis_connection.lrange(
name=model_deployment_result_key,
start=0,
end=-1)

# Format the result value to a properly formatted json.
for replica_idx, replica in enumerate(replicas_results):
replicas_results[replica_idx] = json.loads(replica)
replicas_results[replica_idx]["result"] = json.loads(replicas_results[replica_idx]["result"])

except Exception as e:
logging.error(e)
Expand All @@ -898,11 +900,16 @@ def get_endpoint_settings(self, endpoint_id) -> Dict:
key_pattern = "{}*{}*".format(
self.FEDML_MODEL_ENDPOINT_REPLICA_USER_SETTING_TAG,
endpoint_id)
endpoint_settings = \

endpoint_settings_keys = \
self.redis_connection.keys(pattern=key_pattern)
if endpoint_settings:

if len(endpoint_settings_keys) > 0:
endpoint_settings = \
json.load(endpoint_settings[0])
self.redis_connection.get(endpoint_settings_keys[0])

if not isinstance(endpoint_settings, dict):
endpoint_settings = json.loads(endpoint_settings)
else:
raise Exception("Function `get_endpoint_settings` Key {} does not exist."
.format(key_pattern))
Expand Down Expand Up @@ -966,3 +973,21 @@ def delete_endpoint_scaling_down_decision_time(self, end_point_id) -> bool:
return bool(self.redis_connection.hdel(
self.FEDML_MODEL_ENDPOINT_SCALING_DOWN_DECISION_TIME_TAG,
end_point_id))

def get_pending_requests_counter(self) -> int:
if not self.redis_connection.exists(self.FEDML_PENDING_REQUESTS_COUNTER):
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
return int(self.redis_connection.get(self.FEDML_PENDING_REQUESTS_COUNTER))

def update_pending_requests_counter(self, increase=False, decrease=False) -> int:
if not self.redis_connection.exists(self.FEDML_PENDING_REQUESTS_COUNTER):
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
if increase:
self.redis_connection.incr(self.FEDML_PENDING_REQUESTS_COUNTER)
if decrease:
# Making sure the counter never becomes negative!
if self.get_pending_requests_counter() < 0:
self.redis_connection.set(self.FEDML_PENDING_REQUESTS_COUNTER, 0)
else:
self.redis_connection.decr(self.FEDML_PENDING_REQUESTS_COUNTER)
return self.get_pending_requests_counter()
Loading

0 comments on commit 10c5e17

Please sign in to comment.