Skip to content

Commit

Permalink
I forgot what I did
Browse files Browse the repository at this point in the history
  • Loading branch information
gongy committed Nov 27, 2024
1 parent 6f734a7 commit 7fb598f
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 21 deletions.
6 changes: 4 additions & 2 deletions modal/_runtime/asgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,11 @@ async def messages():
await disconnect_app()
return

await messages_to_app.put(first_message)
for m in first_message:
await messages_to_app.put(m)
async for message in message_gen:
await messages_to_app.put(message)
for m in message:
await messages_to_app.put(m)

async def send(msg):
# Automatically split body chunks that are greater than the output size limit, to
Expand Down
27 changes: 24 additions & 3 deletions modal/_runtime/container_io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,10 +505,31 @@ async def format_blob_data(self, data: bytes) -> Dict[str, Any]:
else {"data": data}
)

async def get_data_in(self, function_call_id: str) -> AsyncIterator[Any]:
@synchronizer.no_io_translation
async def get_data_in(self, function_call_id: str) -> AsyncIterator[List[Any]]:
"""Read from the `data_in` stream of a function call."""
async for data in _stream_function_call_data(self._client, function_call_id, "data_in"):
yield data

if True:
async for data in _stream_function_call_data(self._client, function_call_id, "data_in_asgi"):
yield [data]
else:
q = asyncio.Queue(1024)

async def fill_queue():
async for data in _stream_function_call_data(self._client, function_call_id, "data_in"):
await q.put(data)
await q.put(None)

t = asyncio.create_task(fill_queue())
try:
while data := await q.get():
data_array = [data]
while q.qsize() > 0:
data_array.append(q.get_nowait())
print(f"yielding {len(data_array)} items")
yield data_array
finally:
t.cancel()

async def put_data_out_request(
self,
Expand Down
5 changes: 1 addition & 4 deletions modal/_runtime/user_code_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@ class ImportedFunction(Service):
_user_defined_callable: Callable[..., Any]

def get_finalized_functions(
self,
fun_def: api_pb2.Function,
container_io_manager: "modal._runtime.container_io_manager.ContainerIOManager",
client,
self, fun_def: api_pb2.Function, container_io_manager: "modal._runtime.container_io_manager.ContainerIOManager"
) -> Dict[str, "FinalizedFunction"]:
# Check this property before we turn it into a method (overriden by webhooks)
is_async = get_is_async(self._user_defined_callable)
Expand Down
11 changes: 8 additions & 3 deletions modal/_utils/function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import modal_proto
from modal_proto import api_pb2

from .._serialization import deserialize, deserialize_data_format, serialize
from .._serialization import _deserialize_asgi, deserialize, deserialize_data_format, serialize
from .._traceback import append_modal_tb
from ..config import config, logger
from ..exception import DeserializationError, ExecutionError, FunctionTimeoutError, InvalidError, RemoteError
Expand Down Expand Up @@ -358,7 +358,7 @@ def callable_has_non_self_non_default_params(f: Callable[..., Any]) -> bool:


async def _stream_function_call_data(
client, function_call_id: str, variant: Literal["data_in", "data_out"]
client, function_call_id: str, variant: Literal["data_in", "data_out", "data_in_asgi"]
) -> AsyncGenerator[Any, None]:
"""Read from the `data_in` or `data_out` stream of a function call."""
import time
Expand All @@ -371,6 +371,8 @@ async def _stream_function_call_data(

if variant == "data_in":
stub_fn = client.container_stub.FunctionCallGetDataIn
elif variant == "data_in_asgi":
stub_fn = client.container_stub.FunctionCallGetDataInAsgi
elif variant == "data_out":
stub_fn = client.container_stub.FunctionCallGetDataOut

Expand Down Expand Up @@ -404,9 +406,12 @@ async def fill_queue():

t = asyncio.create_task(fill_queue())
try:
# await asyncio.sleep(0.05)
t0 = time.time()
while chunk := await q.get():
if variant == "data_in_asgi":
yield _deserialize_asgi(chunk)
continue

if chunk.index <= last_index:
continue
if chunk.data_blob_id:
Expand Down
9 changes: 9 additions & 0 deletions modal/_utils/grpc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TypeVar,
)

import grpc.aio
import grpclib.client
import grpclib.config
import grpclib.events
Expand Down Expand Up @@ -70,6 +71,14 @@ def connected(self):
]


def create_grpcio_channel(server_url: str, metadata: Dict[str, str] = {}) -> grpc.aio.Channel:
options = [
("grpc.max_receive_message_length", 64 * 1024 * 1024), # 64MB
("grpc.max_send_message_length", 64 * 1024 * 1024), # 64MB
]
return grpc.aio.insecure_channel(server_url, options=options)


def create_channel(
server_url: str,
metadata: Dict[str, str] = {},
Expand Down
11 changes: 2 additions & 9 deletions modal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import grpclib.client
from google.protobuf import empty_pb2
from google.protobuf.message import Message
from grpc.aio import insecure_channel
from grpclib import GRPCError, Status
from synchronicity.async_wrap import asynccontextmanager

Expand All @@ -31,7 +30,7 @@

from ._utils import async_utils
from ._utils.async_utils import TaskContext, synchronize_api
from ._utils.grpc_utils import connect_channel, create_channel, retry_transient_errors
from ._utils.grpc_utils import connect_channel, create_channel, create_grpcio_channel, retry_transient_errors
from .config import _check_config, _is_remote, config, logger
from .exception import AuthError, ClientClosed, ConnectionError, DeprecationError, VersionError

Expand Down Expand Up @@ -123,13 +122,7 @@ async def _open(self):
assert self._stub is None
metadata = _get_metadata(self.client_type, self._credentials, self.version)
self._channel = create_channel(self.server_url, metadata=metadata)
self._container_channel = insecure_channel(
self.server_url,
options=[
("grpc.max_receive_message_length", 64 * 1024 * 1024), # 64MB
("grpc.max_send_message_length", 64 * 1024 * 1024), # 64MB
],
)
self._container_channel = create_grpcio_channel(self.server_url, metadata=metadata)
try:
await connect_channel(self._channel)
except OSError as exc:
Expand Down
1 change: 1 addition & 0 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2692,6 +2692,7 @@ service ModalClient {
rpc FunctionBindParams(FunctionBindParamsRequest) returns (FunctionBindParamsResponse);
rpc FunctionCallCancel(FunctionCallCancelRequest) returns (google.protobuf.Empty);
rpc FunctionCallGetDataIn(FunctionCallGetDataRequest) returns (stream DataChunk);
rpc FunctionCallGetDataInAsgi(FunctionCallGetDataRequest) returns (stream Asgi);
rpc FunctionCallGetDataOut(FunctionCallGetDataRequest) returns (stream DataChunk);
rpc FunctionCallList(FunctionCallListRequest) returns (FunctionCallListResponse);
rpc FunctionCallPutDataOut(FunctionCallPutDataRequest) returns (google.protobuf.Empty);
Expand Down

0 comments on commit 7fb598f

Please sign in to comment.