Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serve] Avoid looping over all snapshot ids for each long poll request #45881

52 changes: 34 additions & 18 deletions python/ray/serve/_private/long_poll.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ def __init__(
self.key_listeners = key_listeners
self.event_loop = call_in_event_loop
self.snapshot_ids: Dict[KeyType, int] = {
key: -1 for key in self.key_listeners.keys()
# The initial snapshot id for each key is < 0,
# but real snapshot keys in the long poll host are always >= 0,
# so this will always trigger an initial update.
shrekris-anyscale marked this conversation as resolved.
Show resolved Hide resolved
key: -1
for key in self.key_listeners.keys()
}
self.is_running = True

Expand Down Expand Up @@ -191,11 +195,9 @@ def __init__(
] = LISTEN_FOR_CHANGE_REQUEST_TIMEOUT_S,
):
# Map object_key -> int
self.snapshot_ids: DefaultDict[KeyType, int] = defaultdict(
lambda: random.randint(0, 1_000_000)
)
self.snapshot_ids: Dict[KeyType, int] = {}
# Map object_key -> object
self.object_snapshots: Dict[KeyType, Any] = dict()
self.object_snapshots: Dict[KeyType, Any] = {}
# Map object_key -> set(asyncio.Event waiting for updates)
self.notifier_events: DefaultDict[KeyType, Set[asyncio.Event]] = defaultdict(
set
Expand Down Expand Up @@ -247,24 +249,32 @@ async def listen_for_change(
immediately if the snapshot_ids are outdated, otherwise it will block
until there's an update.
"""
watched_keys = keys_to_snapshot_ids.keys()
existent_keys = set(watched_keys).intersection(set(self.snapshot_ids.keys()))
shrekris-anyscale marked this conversation as resolved.
Show resolved Hide resolved

# If there are any keys with outdated snapshot ids,
# return their updated values immediately.
updated_objects = {
key: UpdatedObject(self.object_snapshots[key], self.snapshot_ids[key])
for key in existent_keys
if self.snapshot_ids[key] != keys_to_snapshot_ids[key]
}
updated_objects = {}
for key, client_snapshot_id in keys_to_snapshot_ids.items():
try:
existing_id = self.snapshot_ids[key]
except KeyError:
# The caller may ask for keys that we don't know about (yet),
# just ignore them.
shrekris-anyscale marked this conversation as resolved.
Show resolved Hide resolved
# This can happen when, for example,
# a deployment handle is manually created for an app
# that hasn't been deployed yet (by bypassing the safety checks).
continue

if existing_id != client_snapshot_id:
updated_objects[key] = UpdatedObject(
self.object_snapshots[key], existing_id
)
if len(updated_objects) > 0:
self._count_send(updated_objects)
return updated_objects

# Otherwise, register asyncio events to be waited.
async_task_to_events = {}
async_task_to_watched_keys = {}
for key in watched_keys:
for key in keys_to_snapshot_ids.keys():
# Create a new asyncio event for this key.
event = asyncio.Event()

Expand Down Expand Up @@ -398,10 +408,16 @@ def notify_changed(
object_key: KeyType,
updated_object: Any,
):
self.snapshot_ids[object_key] += 1
try:
self.snapshot_ids[object_key] += 1
Comment on lines -401 to +412
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a defaultdict anymore :(

except KeyError:
# Initial snapshot id must be >= 0, so that the long poll client
# can send a negative initial snapshot id to get a fast update.
# They should also be randomized;
# see https://github.com/ray-project/ray/pull/45881#discussion_r1645243485
self.snapshot_ids[object_key] = random.randint(0, 1_000_000)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that this isn't a defaultdict anymore, could you double check whether any other places we access self.snapshot_ids needs this KeyError check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see four uses:

The first two are in

try:
self.snapshot_ids[object_key] += 1
except KeyError:
# Initial snapshot id must be >= 0, so that the long poll client
# can send a negative initial snapshot id to get a fast update.
# They should also be randomized to try to avoid situations where,
# if the controller restarts and a client has a now-invalid snapshot id
# that happens to match what the controller restarts with,
, which are this block that's protected by the KeyError check.

The next is

try:
existing_id = self.snapshot_ids[key]
except KeyError:
# The caller may ask for keys that we don't know about (yet),
# just ignore them.
# This can happen when, for example,
# a deployment handle is manually created for an app
# that hasn't been deployed yet (by bypassing the safety checks).
continue
, which is also protected by a KeyError check.

The last is

else:
updated_object_key: str = async_task_to_watched_keys[done.pop()]
updated_object = {
updated_object_key: UpdatedObject(
self.object_snapshots[updated_object_key],
self.snapshot_ids[updated_object_key],
)
}
self._count_send(updated_object)
return updated_object
. That block runs when a key changes, which is triggered by
logger.debug(f"LongPollHost: Notify change for key {object_key}.")
, which is below that first block which ensure the snapshot id is present in the mapping, so I think we're good on that one too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks for being thorough here!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My pleasure!

self.object_snapshots[object_key] = updated_object
logger.debug(f"LongPollHost: Notify change for key {object_key}.")

if object_key in self.notifier_events:
for event in self.notifier_events.pop(object_key):
event.set()
for event in self.notifier_events.pop(object_key, set()):
event.set()
Loading