From 0ee695d5ddb9d8f91730afd71d0565d1dc74bba5 Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Thu, 14 Sep 2023 15:39:08 +0530 Subject: [PATCH 01/17] feat: sagemaker custom container with executors --- .pre-commit-config.yaml | 2 +- jina/serve/runtimes/servers/__init__.py | 95 ++++++----- jina/serve/runtimes/servers/http.py | 37 ++++- .../serve/runtimes/worker/request_handling.py | 153 ++++++++++-------- 4 files changed, 178 insertions(+), 109 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5b4dd3438b911..85e1cc39ad07f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,7 +43,7 @@ repos: args: - -S - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort exclude: ^(jina/helloworld/|jina/proto/pb/jina_pb2.py|jina/proto/pb/jina_pb2_grpc.py|jina/proto/pb2/jina_pb2.py|jina/proto/pb2/jina_pb2_grpc.py|docs/|jina/resources/|jina/proto/docarray_v1|jina/proto/docarray_v2) diff --git a/jina/serve/runtimes/servers/__init__.py b/jina/serve/runtimes/servers/__init__.py index 195a08ed5e4f6..4a03ef71651c8 100644 --- a/jina/serve/runtimes/servers/__init__.py +++ b/jina/serve/runtimes/servers/__init__.py @@ -1,4 +1,5 @@ import abc +import threading import time from types import SimpleNamespace from typing import TYPE_CHECKING, Dict, Optional, Union @@ -6,13 +7,15 @@ from jina.logging.logger import JinaLogger from jina.serve.instrumentation import InstrumentationMixin from jina.serve.runtimes.monitoring import MonitoringMixin -import threading __all__ = ['BaseServer'] if TYPE_CHECKING: import multiprocessing + from jina.serve.runtimes.gateway.request_handling import GatewayRequestHandler + from jina.serve.runtimes.worker.request_handling import WorkerRequestHandler + class BaseServer(MonitoringMixin, InstrumentationMixin): """ @@ -20,20 +23,22 @@ class BaseServer(MonitoringMixin, InstrumentationMixin): """ def __init__( - self, - name: Optional[str] = 'gateway', - runtime_args: Optional[Dict] = None, - req_handler_cls=None, - req_handler=None, - is_cancel=None, - **kwargs, + self, + name: Optional[str] = 'gateway', + runtime_args: Optional[Dict] = None, + req_handler_cls=None, + req_handler=None, + is_cancel=None, + **kwargs, ): self.name = name or '' self.runtime_args = runtime_args self.works_as_load_balancer = False self.is_cancel = is_cancel or threading.Event() if isinstance(runtime_args, Dict): - self.works_as_load_balancer = runtime_args.get('gateway_load_balancer', False) + self.works_as_load_balancer = runtime_args.get( + 'gateway_load_balancer', False + ) if isinstance(self.runtime_args, dict): self.logger = JinaLogger(self.name, **self.runtime_args) else: @@ -53,7 +58,9 @@ def __init__( metrics_exporter_host=self.runtime_args.metrics_exporter_host, metrics_exporter_port=self.runtime_args.metrics_exporter_port, ) - self._request_handler = req_handler or self._get_request_handler() + self._request_handler: Union[ + 'GatewayRequestHandler', 'WorkerRequestHandler' + ] = (req_handler or self._get_request_handler()) if hasattr(self._request_handler, 'streamer'): self.streamer = self._request_handler.streamer # backward compatibility self.executor = self._request_handler.executor # backward compatibility @@ -90,7 +97,7 @@ def _get_request_handler(self): aio_tracing_client_interceptors=self.aio_tracing_client_interceptors(), tracing_client_interceptor=self.tracing_client_interceptor(), deployment_name=self.name.split('/')[0], - works_as_load_balancer=self.works_as_load_balancer + works_as_load_balancer=self.works_as_load_balancer, ) def _add_gateway_args(self): @@ -122,21 +129,33 @@ def port(self): """Gets the first port of the port list argument. To be used in the regular case where a Gateway exposes a single port :return: The first port to be exposed """ - return self.runtime_args.port[0] if isinstance(self.runtime_args.port, list) else self.runtime_args.port + return ( + self.runtime_args.port[0] + if isinstance(self.runtime_args.port, list) + else self.runtime_args.port + ) @property def ports(self): """Gets all the list of ports from the runtime_args as a list. :return: The lists of ports to be exposed """ - return self.runtime_args.port if isinstance(self.runtime_args.port, list) else [self.runtime_args.port] + return ( + self.runtime_args.port + if isinstance(self.runtime_args.port, list) + else [self.runtime_args.port] + ) @property def protocols(self): """Gets all the list of protocols from the runtime_args as a list. :return: The lists of protocols to be exposed """ - return self.runtime_args.protocol if isinstance(self.runtime_args.protocol, list) else [self.runtime_args.protocol] + return ( + self.runtime_args.protocol + if isinstance(self.runtime_args.protocol, list) + else [self.runtime_args.protocol] + ) @property def host(self): @@ -168,11 +187,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): @staticmethod def is_ready( - ctrl_address: str, - protocol: Optional[str] = 'grpc', - timeout: float = 1.0, - logger=None, - **kwargs, + ctrl_address: str, + protocol: Optional[str] = 'grpc', + timeout: float = 1.0, + logger=None, + **kwargs, ) -> bool: """ Check if status is ready. @@ -183,15 +202,11 @@ def is_ready( :param kwargs: extra keyword arguments :return: True if status is ready else False. """ + from jina.enums import ProtocolType from jina.serve.runtimes.servers.grpc import GRPCServer from jina.serve.runtimes.servers.http import FastAPIBaseServer - from jina.enums import ProtocolType - if ( - protocol is None - or protocol == ProtocolType.GRPC - or protocol == 'grpc' - ): + if protocol is None or protocol == ProtocolType.GRPC or protocol == 'grpc': res = GRPCServer.is_ready(ctrl_address) else: res = FastAPIBaseServer.is_ready(ctrl_address) @@ -199,11 +214,11 @@ def is_ready( @staticmethod async def async_is_ready( - ctrl_address: str, - protocol: Optional[str] = 'grpc', - timeout: float = 1.0, - logger=None, - **kwargs, + ctrl_address: str, + protocol: Optional[str] = 'grpc', + timeout: float = 1.0, + logger=None, + **kwargs, ) -> bool: """ Check if status is ready. @@ -214,15 +229,11 @@ async def async_is_ready( :param kwargs: extra keyword arguments :return: True if status is ready else False. """ + from jina.enums import ProtocolType from jina.serve.runtimes.servers.grpc import GRPCServer from jina.serve.runtimes.servers.http import FastAPIBaseServer - from jina.enums import ProtocolType - if ( - protocol is None - or protocol == ProtocolType.GRPC - or protocol == 'grpc' - ): + if protocol is None or protocol == ProtocolType.GRPC or protocol == 'grpc': res = await GRPCServer.async_is_ready(ctrl_address, logger=logger) else: res = await FastAPIBaseServer.async_is_ready(ctrl_address, logger=logger) @@ -230,12 +241,12 @@ async def async_is_ready( @classmethod def wait_for_ready_or_shutdown( - cls, - timeout: Optional[float], - ready_or_shutdown_event: Union['multiprocessing.Event', 'threading.Event'], - ctrl_address: str, - health_check: bool = False, - **kwargs, + cls, + timeout: Optional[float], + ready_or_shutdown_event: Union['multiprocessing.Event', 'threading.Event'], + ctrl_address: str, + health_check: bool = False, + **kwargs, ): """ Check if the runtime has successfully started diff --git a/jina/serve/runtimes/servers/http.py b/jina/serve/runtimes/servers/http.py index a7ea54ce5da76..88a847abc41b8 100644 --- a/jina/serve/runtimes/servers/http.py +++ b/jina/serve/runtimes/servers/http.py @@ -122,7 +122,9 @@ def filter(self, record: logging.LogRecord) -> bool: ) if isinstance(self._request_handler, GatewayRequestHandler): - await self._request_handler.streamer._get_endpoints_input_output_models(is_cancel=self.is_cancel) + await self._request_handler.streamer._get_endpoints_input_output_models( + is_cancel=self.is_cancel + ) self._request_handler.streamer._validate_flow_docarray_compatibility() # app property will generate a new fastapi app each time called @@ -170,7 +172,9 @@ def should_exit(self): return self._should_exit @staticmethod - def is_ready(ctrl_address: str, timeout: float = 1.0, logger=None, **kwargs) -> bool: + def is_ready( + ctrl_address: str, timeout: float = 1.0, logger=None, **kwargs + ) -> bool: """ Check if status is ready. :param ctrl_address: the address where the control request needs to be sent @@ -192,7 +196,9 @@ def is_ready(ctrl_address: str, timeout: float = 1.0, logger=None, **kwargs) -> return False @staticmethod - async def async_is_ready(ctrl_address: str, timeout: float = 1.0, logger=None, **kwargs) -> bool: + async def async_is_ready( + ctrl_address: str, timeout: float = 1.0, logger=None, **kwargs + ) -> bool: """ Async Check if status is ready. :param ctrl_address: the address where the control request needs to be sent @@ -254,3 +260,28 @@ def app(self): cors=self.cors, logger=self.logger, ) + + +class SagemakerHTTPServer(FastAPIBaseServer): + """ + :class:`SagemakerHTTPServer` is a FastAPIBaseServer that uses a custom FastAPI app for sagemaker endpoints + + """ + + @property + def app(self): + """Get the sagemaker fastapi app + :return: Return a FastAPI app for the sagemaker container + """ + return self._request_handler._http_fastapi_sagemaker_app( + title=self.title, + description=self.description, + no_crud_endpoints=self.no_crud_endpoints, + no_debug_endpoints=self.no_debug_endpoints, + expose_endpoints=self.expose_endpoints, + expose_graphql_endpoint=self.expose_graphql_endpoint, + tracing=self.tracing, + tracer_provider=self.tracer_provider, + cors=self.cors, + logger=self.logger, + ) diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index d136361f50bd5..cc464a8e79031 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -51,16 +51,16 @@ class WorkerRequestHandler: _KEY_RESULT = '__results__' def __init__( - self, - args: 'argparse.Namespace', - logger: 'JinaLogger', - metrics_registry: Optional['CollectorRegistry'] = None, - tracer_provider: Optional['trace.TracerProvider'] = None, - meter_provider: Optional['metrics.MeterProvider'] = None, - meter=None, - tracer=None, - deployment_name: str = '', - **kwargs, + self, + args: 'argparse.Namespace', + logger: 'JinaLogger', + metrics_registry: Optional['CollectorRegistry'] = None, + tracer_provider: Optional['trace.TracerProvider'] = None, + meter_provider: Optional['metrics.MeterProvider'] = None, + meter=None, + tracer=None, + deployment_name: str = '', + **kwargs, ): """Initialize private parameters and execute private loading functions. @@ -83,8 +83,8 @@ def __init__( self._is_closed = False if self.metrics_registry: with ImportExtensions( - required=True, - help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', + required=True, + help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Counter, Summary @@ -189,6 +189,28 @@ async def _shutdown(): return app + def _http_fastapi_sagemaker_app(self, **kwargs): + from jina.serve.runtimes.worker.http_fastapi_app import get_fastapi_app + + request_models_map = self._executor._get_endpoint_models_dict() + + def call_handle(request): + is_generator = request_models_map[request.header.exec_endpoint][ + 'is_generator' + ] + + return self.process_single_data(request, None, is_generator=is_generator) + + app = get_fastapi_app( + request_models_map=request_models_map, caller=call_handle, **kwargs + ) + + @app.on_event('shutdown') + async def _shutdown(): + await self.close() + + return app + async def _hot_reload(self): import inspect @@ -205,9 +227,9 @@ async def _hot_reload(self): watched_files.add(extra_python_file) with ImportExtensions( - required=True, - logger=self.logger, - help_text='''hot reload requires watchfiles dependency to be installed. You can do `pip install + required=True, + logger=self.logger, + help_text='''hot reload requires watchfiles dependency to be installed. You can do `pip install watchfiles''', ): from watchfiles import awatch @@ -274,16 +296,14 @@ def _init_batchqueue_dict(self): } def _init_monitoring( - self, - metrics_registry: Optional['CollectorRegistry'] = None, - meter: Optional['metrics.Meter'] = None, + self, + metrics_registry: Optional['CollectorRegistry'] = None, + meter: Optional['metrics.Meter'] = None, ): - if metrics_registry: - with ImportExtensions( - required=True, - help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', + required=True, + help_text='You need to install the `prometheus_client` to use the montitoring functionality of jina', ): from prometheus_client import Counter, Summary @@ -339,10 +359,10 @@ def _init_monitoring( self._sent_response_size_histogram = None def _load_executor( - self, - metrics_registry: Optional['CollectorRegistry'] = None, - tracer_provider: Optional['trace.TracerProvider'] = None, - meter_provider: Optional['metrics.MeterProvider'] = None, + self, + metrics_registry: Optional['CollectorRegistry'] = None, + tracer_provider: Optional['trace.TracerProvider'] = None, + meter_provider: Optional['metrics.MeterProvider'] = None, ): """ Load the executor to this runtime, specified by ``uses`` CLI argument. @@ -554,19 +574,28 @@ def _setup_req_doc_array_cls(self, requests, exec_endpoint, is_response=False): if not docarray_v2: req.document_array_cls = DocumentArray else: - if not endpoint_info.is_generator and not endpoint_info.is_singleton_doc: - req.document_array_cls = endpoint_info.request_schema if not is_response else endpoint_info.response_schema - else: - req.document_array_cls = DocList[ + if ( + not endpoint_info.is_generator + and not endpoint_info.is_singleton_doc + ): + req.document_array_cls = ( endpoint_info.request_schema - ] if not is_response else DocList[endpoint_info.response_schema] + if not is_response + else endpoint_info.response_schema + ) + else: + req.document_array_cls = ( + DocList[endpoint_info.request_schema] + if not is_response + else DocList[endpoint_info.response_schema] + ) except AttributeError: pass def _setup_requests( - self, - requests: List['DataRequest'], - exec_endpoint: str, + self, + requests: List['DataRequest'], + exec_endpoint: str, ): """Execute a request using the executor. @@ -582,7 +611,7 @@ def _setup_requests( return requests, params async def handle_generator( - self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None + self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None ) -> Generator: """Prepares and executes a request for generator endpoints. @@ -600,9 +629,7 @@ async def handle_generator( f'Request endpoint must match one of the available endpoints.' ) - requests, params = self._setup_requests( - requests, exec_endpoint - ) + requests, params = self._setup_requests(requests, exec_endpoint) if exec_endpoint in self._batchqueue_config: warnings.warn( 'Batching is not supported for generator executors endpoints. Ignoring batch size.' @@ -619,7 +646,7 @@ async def handle_generator( ) async def handle( - self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None + self, requests: List['DataRequest'], tracing_context: Optional['Context'] = None ) -> DataRequest: """Initialize private parameters and execute private loading functions. @@ -641,9 +668,7 @@ async def handle( ) return requests[0] - requests, params = self._setup_requests( - requests, exec_endpoint - ) + requests, params = self._setup_requests(requests, exec_endpoint) len_docs = len(requests[0].docs) # TODO we can optimize here and access the if exec_endpoint in self._batchqueue_config: assert len(requests) == 1, 'dynamic batching does not support no_reduce' @@ -691,7 +716,7 @@ async def handle( @staticmethod def replace_docs( - request: List['DataRequest'], docs: 'DocumentArray', ndarray_type: str = None + request: List['DataRequest'], docs: 'DocumentArray', ndarray_type: str = None ) -> None: """Replaces the docs in a message with new Documents. @@ -739,7 +764,7 @@ async def close(self): @staticmethod def _get_docs_matrix_from_request( - requests: List['DataRequest'], + requests: List['DataRequest'], ) -> Tuple[Optional[List['DocumentArray']], Optional[Dict[str, 'DocumentArray']]]: """ Returns a docs matrix from a list of DataRequest objects. @@ -763,7 +788,7 @@ def _get_docs_matrix_from_request( @staticmethod def get_parameters_dict_from_request( - requests: List['DataRequest'], + requests: List['DataRequest'], ) -> 'Dict': """ Returns a parameters dict from a list of DataRequest objects. @@ -783,7 +808,7 @@ def get_parameters_dict_from_request( @staticmethod def get_docs_from_request( - requests: List['DataRequest'], + requests: List['DataRequest'], ) -> 'DocumentArray': """ Gets a field from the message @@ -863,7 +888,7 @@ def reduce_requests(requests: List['DataRequest']) -> 'DataRequest': # serving part async def process_single_data( - self, request: DataRequest, context, is_generator: bool = False + self, request: DataRequest, context, is_generator: bool = False ) -> DataRequest: """ Process the received requests and return the result as a new request @@ -877,7 +902,7 @@ async def process_single_data( return await self.process_data([request], context, is_generator=is_generator) async def stream_doc( - self, request: SingleDocumentRequest, context: 'grpc.aio.ServicerContext' + self, request: SingleDocumentRequest, context: 'grpc.aio.ServicerContext' ) -> SingleDocumentRequest: """ Process the received requests and return the result as a new request, used for streaming behavior, one doc IN, several out @@ -990,7 +1015,9 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: ).schema() if inner_dict['parameters']['model'] is not None: - inner_dict['parameters']['model'] = inner_dict['parameters']['model'].schema() + inner_dict['parameters']['model'] = inner_dict['parameters'][ + 'model' + ].schema() else: for endpoint_name, inner_dict in schemas.items(): inner_dict['input']['model'] = inner_dict['input']['model'].schema() @@ -1000,7 +1027,7 @@ async def endpoint_discovery(self, empty, context) -> jina_pb2.EndpointsProto: return endpoints_proto def _extract_tracing_context( - self, metadata: 'grpc.aio.Metadata' + self, metadata: 'grpc.aio.Metadata' ) -> Optional['Context']: if self.tracer: from opentelemetry.propagate import extract @@ -1016,7 +1043,7 @@ def _log_data_request(self, request: DataRequest): ) async def process_data( - self, requests: List[DataRequest], context, is_generator: bool = False + self, requests: List[DataRequest], context, is_generator: bool = False ) -> DataRequest: """ Process the received requests and return the result as a new request @@ -1028,7 +1055,7 @@ async def process_data( """ self.logger.debug('recv a process_data request') with MetricsTimer( - self._summary, self._receiving_request_seconds, self._metric_attributes + self._summary, self._receiving_request_seconds, self._metric_attributes ): try: if self.logger.debug_enabled: @@ -1077,8 +1104,8 @@ async def process_data( ) if ( - self.args.exit_on_exceptions - and type(ex).__name__ in self.args.exit_on_exceptions + self.args.exit_on_exceptions + and type(ex).__name__ in self.args.exit_on_exceptions ): self.logger.info('Exiting because of "--exit-on-exceptions".') raise RuntimeTerminated @@ -1103,7 +1130,7 @@ async def _status(self, empty, context) -> jina_pb2.JinaInfoProto: return info_proto async def stream( - self, request_iterator, context=None, *args, **kwargs + self, request_iterator, context=None, *args, **kwargs ) -> AsyncIterator['Request']: """ stream requests from client iterator and stream responses back. @@ -1121,8 +1148,8 @@ async def stream( Call = stream def _create_snapshot_status( - self, - snapshot_directory: str, + self, + snapshot_directory: str, ) -> 'jina_pb2.SnapshotStatusProto': _id = str(uuid.uuid4()) self.logger.debug(f'Generated snapshot id: {_id}') @@ -1135,7 +1162,7 @@ def _create_snapshot_status( ) def _create_restore_status( - self, + self, ) -> 'jina_pb2.SnapshotStatusProto': _id = str(uuid.uuid4()) self.logger.debug(f'Generated restore id: {_id}') @@ -1154,9 +1181,9 @@ async def snapshot(self, request, context) -> 'jina_pb2.SnapshotStatusProto': """ self.logger.debug(f' Calling snapshot') if ( - self._snapshot - and self._snapshot_thread - and self._snapshot_thread.is_alive() + self._snapshot + and self._snapshot_thread + and self._snapshot_thread.is_alive() ): raise RuntimeError( f'A snapshot with id {self._snapshot.id.value} is currently in progress. Cannot start another.' @@ -1174,7 +1201,7 @@ async def snapshot(self, request, context) -> 'jina_pb2.SnapshotStatusProto': return self._snapshot async def snapshot_status( - self, request: 'jina_pb2.SnapshotId', context + self, request: 'jina_pb2.SnapshotId', context ) -> 'jina_pb2.SnapshotStatusProto': """ method to start a snapshot process of the Executor @@ -1235,7 +1262,7 @@ async def restore(self, request: 'jina_pb2.RestoreSnapshotCommand', context): return self._restore async def restore_status( - self, request, context + self, request, context ) -> 'jina_pb2.RestoreSnapshotStatusProto': """ method to start a snapshot process of the Executor From 3c62f55ad8edc5b6b444a7e4b5ba16e11667d7f2 Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Thu, 14 Sep 2023 15:42:13 +0530 Subject: [PATCH 02/17] feat: sagemaker provider --- jina/enums.py | 7 ++++++ jina/parsers/orchestrate/pod.py | 38 +++++++++++++++++++++------------ 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/jina/enums.py b/jina/enums.py index 8668e08f2f4d3..f85d26bdc8db0 100644 --- a/jina/enums.py +++ b/jina/enums.py @@ -264,6 +264,13 @@ class WebsocketSubProtocols(str, Enum): BYTES = 'bytes' +class ProviderType(BetterEnum): + """Provider type.""" + + NONE = 0 #: no provider + SAGEMAKER = 1 #: AWS SageMaker + + def replace_enum_to_str(obj): """ Transform BetterEnum type into string. diff --git a/jina/parsers/orchestrate/pod.py b/jina/parsers/orchestrate/pod.py index 519fef6bb868e..f9927172b0245 100644 --- a/jina/parsers/orchestrate/pod.py +++ b/jina/parsers/orchestrate/pod.py @@ -4,12 +4,12 @@ from dataclasses import dataclass from typing import Dict -from jina.enums import PodRoleType, ProtocolType +from jina.enums import PodRoleType, ProtocolType, ProviderType from jina.helper import random_port from jina.parsers.helper import ( _SHOW_ALL_ARGS, - CastToIntAction, CastPeerPorts, + CastToIntAction, KVAppendAction, add_arg_group, ) @@ -52,7 +52,7 @@ def mixin_pod_parser(parser, pod_type: str = 'worker'): type=int, default=600000, help='The timeout in milliseconds of a Pod waits for the runtime to be ready, -1 for waiting ' - 'forever', + 'forever', ) gp.add_argument( @@ -68,7 +68,8 @@ def mixin_pod_parser(parser, pod_type: str = 'worker'): action=KVAppendAction, metavar='KEY: VALUE', nargs='*', - help='The map of environment variables that are read from kubernetes cluster secrets' if _SHOW_ALL_ARGS + help='The map of environment variables that are read from kubernetes cluster secrets' + if _SHOW_ALL_ARGS else argparse.SUPPRESS, ) gp.add_argument( @@ -76,7 +77,8 @@ def mixin_pod_parser(parser, pod_type: str = 'worker'): type=str, nargs='+', default=None, - help='List of ImagePullSecrets that the Kubernetes Pods need to have access to in order to pull the image. Used in `to_kubernetes_yaml`' if _SHOW_ALL_ARGS + help='List of ImagePullSecrets that the Kubernetes Pods need to have access to in order to pull the image. Used in `to_kubernetes_yaml`' + if _SHOW_ALL_ARGS else argparse.SUPPRESS, ) @@ -106,7 +108,7 @@ def mixin_pod_parser(parser, pod_type: str = 'worker'): action='store_true', default=False, help='If set, starting a Pod/Deployment does not block the thread/process. It then relies on ' - '`wait_start_success` at outer function for the postpone check.' + '`wait_start_success` at outer function for the postpone check.' if _SHOW_ALL_ARGS else argparse.SUPPRESS, ) @@ -116,7 +118,7 @@ def mixin_pod_parser(parser, pod_type: str = 'worker'): action='store_true', default=False, help='If set, the current Pod/Deployment can not be further chained, ' - 'and the next `.add()` will chain after the last Pod/Deployment not this current one.', + 'and the next `.add()` will chain after the last Pod/Deployment not this current one.', ) gp.add_argument( @@ -134,9 +136,9 @@ def mixin_pod_parser(parser, pod_type: str = 'worker'): action='store_true', default=False, help='If set, the Executor will restart while serving if YAML configuration source or Executor modules ' - 'are changed. If YAML configuration is changed, the whole deployment is reloaded and new ' - 'processes will be restarted. If only Python modules of the Executor have changed, they will be ' - 'reloaded to the interpreter without restarting process.', + 'are changed. If YAML configuration is changed, the whole deployment is reloaded and new ' + 'processes will be restarted. If only Python modules of the Executor have changed, they will be ' + 'reloaded to the interpreter without restarting process.', ) gp.add_argument( '--install-requirements', @@ -195,6 +197,14 @@ def mixin_pod_runtime_args_parser(arg_group, pod_type='worker'): help=f'Communication protocol of the server exposed by the {server_name}. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: {[protocol.to_string() for protocol in list(ProtocolType)]}.', ) + arg_group.add_argument( + '--provider', + type=ProviderType.from_string, + choices=list(ProviderType), + default=[ProviderType.NONE], + help=f'If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: {[provider.to_string() for provider in list(ProviderType)]}.', + ) + arg_group.add_argument( '--monitoring', action='store_true', @@ -225,7 +235,7 @@ def mixin_pod_runtime_args_parser(arg_group, pod_type='worker'): action='store_true', default=False, help='If set, the sdk implementation of the OpenTelemetry tracer will be available and will be enabled for automatic tracing of requests and customer span creation. ' - 'Otherwise a no-op implementation will be provided.', + 'Otherwise a no-op implementation will be provided.', ) arg_group.add_argument( @@ -247,7 +257,7 @@ def mixin_pod_runtime_args_parser(arg_group, pod_type='worker'): action='store_true', default=False, help='If set, the sdk implementation of the OpenTelemetry metrics will be available for default monitoring and custom measurements. ' - 'Otherwise a no-op implementation will be provided.', + 'Otherwise a no-op implementation will be provided.', ) arg_group.add_argument( @@ -283,8 +293,8 @@ def mixin_stateful_parser(parser): type=str, default=None, help='When using --stateful option, it is required to tell the cluster what are the cluster configuration. This is important' - 'when the Deployment is restarted. It indicates the ports to which each replica of the cluster binds.' - ' It is expected to be a single list if shards == 1 or a dictionary if shards > 1.', + 'when the Deployment is restarted. It indicates the ports to which each replica of the cluster binds.' + ' It is expected to be a single list if shards == 1 or a dictionary if shards > 1.', action=CastPeerPorts, nargs='+', ) From 1583bbf8550c6e0d7366bcdaec0bc1f7421e7421 Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Thu, 14 Sep 2023 15:43:42 +0530 Subject: [PATCH 03/17] feat: sagemaker provider --- jina/serve/runtimes/asyncio.py | 135 +++++++----- .../runtimes/worker/http_sagemaker_app.py | 197 ++++++++++++++++++ .../serve/runtimes/worker/request_handling.py | 2 +- 3 files changed, 285 insertions(+), 49 deletions(-) create mode 100644 jina/serve/runtimes/worker/http_sagemaker_app.py diff --git a/jina/serve/runtimes/asyncio.py b/jina/serve/runtimes/asyncio.py index d07f869d19ddb..756d2daae3fa8 100644 --- a/jina/serve/runtimes/asyncio.py +++ b/jina/serve/runtimes/asyncio.py @@ -37,17 +37,17 @@ class AsyncNewLoopRuntime: """ def __init__( - self, - args: 'argparse.Namespace', - cancel_event: Optional[ - Union['asyncio.Event', 'multiprocessing.Event', 'threading.Event'] - ] = None, - signal_handlers_installed_event: Optional[ - Union['asyncio.Event', 'multiprocessing.Event', 'threading.Event'] - ] = None, - req_handler_cls=None, - gateway_load_balancer: bool = False, - **kwargs, + self, + args: 'argparse.Namespace', + cancel_event: Optional[ + Union['asyncio.Event', 'multiprocessing.Event', 'threading.Event'] + ] = None, + signal_handlers_installed_event: Optional[ + Union['asyncio.Event', 'multiprocessing.Event', 'threading.Event'] + ] = None, + req_handler_cls=None, + gateway_load_balancer: bool = False, + **kwargs, ): self.req_handler_cls = req_handler_cls self.gateway_load_balancer = gateway_load_balancer @@ -60,7 +60,9 @@ def __init__( self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) self.is_cancel = cancel_event or asyncio.Event() - self.is_signal_handlers_installed = signal_handlers_installed_event or asyncio.Event() + self.is_signal_handlers_installed = ( + signal_handlers_installed_event or asyncio.Event() + ) self.logger.debug(f'Setting signal handlers') @@ -113,12 +115,12 @@ async def _wait_for_cancel(self): """Do NOT override this method when inheriting from :class:`GatewayPod`""" # threads are not using asyncio.Event, but threading.Event if isinstance(self.is_cancel, asyncio.Event) and not hasattr( - self.server, '_should_exit' + self.server, '_should_exit' ): await self.is_cancel.wait() else: while not self.is_cancel.is_set() and not getattr( - self.server, '_should_exit', False + self.server, '_should_exit', False ): await asyncio.sleep(0.1) @@ -139,14 +141,17 @@ def _cancel(self): def _get_server(self): # construct server type based on protocol (and potentially req handler class to keep backwards compatibility) - from jina.enums import ProtocolType + from jina.enums import ProtocolType, ProviderType + if self.req_handler_cls.__name__ == 'GatewayRequestHandler': self.timeout_send = self.args.timeout_send if self.timeout_send: self.timeout_send /= 1e3 # convert ms to seconds if not self.args.port: self.args.port = random_ports(len(self.args.protocol)) - _set_gateway_uses(self.args, gateway_load_balancer=self.gateway_load_balancer) + _set_gateway_uses( + self.args, gateway_load_balancer=self.gateway_load_balancer + ) uses_with = self.args.uses_with or {} non_defaults = ArgNamespace.get_non_defaults_args( self.args, set_gateway_parser() @@ -184,8 +189,25 @@ def _get_server(self): if isinstance(server, BaseServer): server.is_cancel = self.is_cancel return server + elif ( + hasattr(self.args, 'provider') + and self.args.provider == ProviderType.SAGEMAKER + ): + from jina.serve.runtimes.servers.http import SagemakerHTTPServer + + return SagemakerHTTPServer( + name=self.args.name, + runtime_args=self.args, + req_handler_cls=self.req_handler_cls, + proxy=getattr(self.args, 'proxy', None), + uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None), + ssl_keyfile=getattr(self.args, 'ssl_keyfile', None), + ssl_certfile=getattr(self.args, 'ssl_certfile', None), + cors=getattr(self.args, 'cors', None), + is_cancel=self.is_cancel, + ) elif not hasattr(self.args, 'protocol') or ( - len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.GRPC + len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.GRPC ): from jina.serve.runtimes.servers.grpc import GRPCServer @@ -199,38 +221,55 @@ def _get_server(self): proxy=getattr(self.args, 'proxy', None), ) - elif len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.HTTP: - from jina.serve.runtimes.servers.http import HTTPServer # we need a concrete implementation of this - return HTTPServer(name=self.args.name, - runtime_args=self.args, - req_handler_cls=self.req_handler_cls, - proxy=getattr(self.args, 'proxy', None), - uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None), - ssl_keyfile=getattr(self.args, 'ssl_keyfile', None), - ssl_certfile=getattr(self.args, 'ssl_certfile', None), - cors=getattr(self.args, 'cors', None), - is_cancel=self.is_cancel, - ) - elif len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.WEBSOCKET: - from jina.serve.runtimes.servers.websocket import \ - WebSocketServer # we need a concrete implementation of this - return WebSocketServer(name=self.args.name, - runtime_args=self.args, - req_handler_cls=self.req_handler_cls, - proxy=getattr(self.args, 'proxy', None), - uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None), - ssl_keyfile=getattr(self.args, 'ssl_keyfile', None), - ssl_certfile=getattr(self.args, 'ssl_certfile', None), - is_cancel=self.is_cancel) + elif ( + len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.HTTP + ): + from jina.serve.runtimes.servers.http import ( + HTTPServer, # we need a concrete implementation of this + ) + + return HTTPServer( + name=self.args.name, + runtime_args=self.args, + req_handler_cls=self.req_handler_cls, + proxy=getattr(self.args, 'proxy', None), + uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None), + ssl_keyfile=getattr(self.args, 'ssl_keyfile', None), + ssl_certfile=getattr(self.args, 'ssl_certfile', None), + cors=getattr(self.args, 'cors', None), + is_cancel=self.is_cancel, + ) + elif ( + len(self.args.protocol) == 1 + and self.args.protocol[0] == ProtocolType.WEBSOCKET + ): + from jina.serve.runtimes.servers.websocket import ( + WebSocketServer, # we need a concrete implementation of this + ) + + return WebSocketServer( + name=self.args.name, + runtime_args=self.args, + req_handler_cls=self.req_handler_cls, + proxy=getattr(self.args, 'proxy', None), + uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None), + ssl_keyfile=getattr(self.args, 'ssl_keyfile', None), + ssl_certfile=getattr(self.args, 'ssl_certfile', None), + is_cancel=self.is_cancel, + ) elif len(self.args.protocol) > 1: - from jina.serve.runtimes.servers.composite import \ - CompositeServer # we need a concrete implementation of this - return CompositeServer(name=self.args.name, - runtime_args=self.args, - req_handler_cls=self.req_handler_cls, - ssl_keyfile=getattr(self.args, 'ssl_keyfile', None), - ssl_certfile=getattr(self.args, 'ssl_certfile', None), - is_cancel=self.is_cancel) + from jina.serve.runtimes.servers.composite import ( + CompositeServer, # we need a concrete implementation of this + ) + + return CompositeServer( + name=self.args.name, + runtime_args=self.args, + req_handler_cls=self.req_handler_cls, + ssl_keyfile=getattr(self.args, 'ssl_keyfile', None), + ssl_certfile=getattr(self.args, 'ssl_certfile', None), + is_cancel=self.is_cancel, + ) def _send_telemetry_event(self, event, extra_kwargs=None): gateway_kwargs = {} diff --git a/jina/serve/runtimes/worker/http_sagemaker_app.py b/jina/serve/runtimes/worker/http_sagemaker_app.py new file mode 100644 index 0000000000000..0dffcca4ab778 --- /dev/null +++ b/jina/serve/runtimes/worker/http_sagemaker_app.py @@ -0,0 +1,197 @@ +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union + +from jina import Document, DocumentArray +from jina._docarray import docarray_v2 +from jina.importer import ImportExtensions +from jina.types.request.data import DataRequest + +if TYPE_CHECKING: + from jina.logging.logger import JinaLogger + +if docarray_v2: + from docarray import BaseDoc, DocList + + +def get_fastapi_app( + request_models_map: Dict, + caller: Callable, + logger: 'JinaLogger', + cors: bool = False, + **kwargs, +): + """ + Get the app from FastAPI as the REST interface. + + :param request_models_map: Map describing the endpoints and its Pydantic models + :param caller: Callable to be handled by the endpoints of the returned FastAPI app + :param logger: Logger object + :param cors: If set, a CORS middleware is added to FastAPI frontend to allow cross-origin access. + :param kwargs: Extra kwargs to make it compatible with other methods + :return: fastapi app + """ + with ImportExtensions(required=True): + from fastapi import FastAPI, Response, HTTPException + import pydantic + from fastapi.middleware.cors import CORSMiddleware + + import os + + from pydantic import BaseModel, Field + from pydantic.config import BaseConfig, inherit_config + + from jina.proto import jina_pb2 + from jina.serve.runtimes.gateway.models import _to_camel_case + + class Header(BaseModel): + request_id: Optional[str] = Field( + description='Request ID', example=os.urandom(16).hex() + ) + + class Config(BaseConfig): + alias_generator = _to_camel_case + allow_population_by_field_name = True + + class InnerConfig(BaseConfig): + alias_generator = _to_camel_case + allow_population_by_field_name = True + + app = FastAPI() + + if cors: + app.add_middleware( + CORSMiddleware, + allow_origins=['*'], + allow_credentials=True, + allow_methods=['*'], + allow_headers=['*'], + ) + logger.warning('CORS is enabled. This service is accessible from any website!') + + def validate_invocations_route(): + nonlocal request_models_map + + if '/invocations' in request_models_map: + return + + if len(request_models_map) == 2: # one is _jina_dry_run_ + for route in request_models_map: + if route != '_jina_dry_run_': + logger.warning( + f'No "/invocations" route found. Using "{route}" as "/invocations" route' + ) + request_models_map['/invocations'] = request_models_map[route] + return + + raise ValueError( + 'The request_models_map must contain the key "/invocations" to be able to serve the model' + ) + + def add_post_route( + endpoint_path, + input_model, + output_model, + input_doc_list_model=None, + output_doc_list_model=None, + ): + app_kwargs = dict( + path=f'/{endpoint_path.strip("/")}', + methods=['POST'], + summary=f'Endpoint {endpoint_path}', + response_model=output_model, + ) + if docarray_v2: + from docarray.base_doc.docarray_response import DocArrayResponse + + app_kwargs['response_class'] = DocArrayResponse + + @app.api_route(**app_kwargs) + async def post(body: input_model, response: Response): + req = DataRequest() + if body.header is not None: + req.header.request_id = body.header.request_id + + if body.parameters is not None: + req.parameters = body.parameters + req.header.exec_endpoint = endpoint_path + data = body.data + if isinstance(data, list): + if not docarray_v2: + req.data.docs = DocumentArray.from_pydantic_model(data) + else: + req.document_array_cls = DocList[input_doc_model] + req.data.docs = DocList[input_doc_list_model](data) + else: + if not docarray_v2: + req.data.docs = DocumentArray([Document.from_pydantic_model(data)]) + else: + req.document_array_cls = DocList[input_doc_model] + req.data.docs = DocList[input_doc_list_model]([data]) + if body.header is None: + req.header.request_id = req.docs[0].id + + resp = await caller(req) + status = resp.header.status + + if status.code == jina_pb2.StatusProto.ERROR: + raise HTTPException(status_code=499, detail=status.description) + else: + if not docarray_v2: + docs_response = resp.docs.to_dict() + else: + docs_response = resp.docs + return output_model(data=docs_response, parameters=resp.parameters) + + validate_invocations_route() + for endpoint, input_output_map in request_models_map.items(): + if endpoint != '_jina_dry_run_': + input_doc_model = input_output_map['input']['model'] + output_doc_model = input_output_map['output']['model'] + parameters_model = input_output_map['parameters']['model'] or Optional[Dict] + default_parameters = ( + ... if input_output_map['parameters']['model'] else None + ) + + if docarray_v2: + _config = inherit_config(InnerConfig, BaseDoc.__config__) + else: + _config = input_doc_model.__config__ + + endpoint_input_model = pydantic.create_model( + f'{endpoint.strip("/")}_input_model', + data=(Union[List[input_doc_model], input_doc_model], ...), + parameters=(parameters_model, default_parameters), + header=(Optional[Header], None), + __config__=_config, + ) + + endpoint_output_model = pydantic.create_model( + f'{endpoint.strip("/")}_output_model', + data=(Union[List[output_doc_model], output_doc_model], ...), + parameters=(Optional[Dict], None), + __config__=_config, + ) + + add_post_route( + endpoint, + input_model=endpoint_input_model, + output_model=endpoint_output_model, + input_doc_list_model=input_doc_model, + output_doc_list_model=output_doc_model, + ) + + from jina.serve.runtimes.gateway.health_model import JinaHealthModel + + @app.get( + path='/ping', + summary='Get the health of Jina Executor service', + response_model=JinaHealthModel, + ) + async def _executor_health(): + """ + Get the health of this Gateway service. + .. # noqa: DAR201 + + """ + return {} + + return app diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index cc464a8e79031..775158b2972d2 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -190,7 +190,7 @@ async def _shutdown(): return app def _http_fastapi_sagemaker_app(self, **kwargs): - from jina.serve.runtimes.worker.http_fastapi_app import get_fastapi_app + from jina.serve.runtimes.worker.http_sagemaker_app import get_fastapi_app request_models_map = self._executor._get_endpoint_models_dict() From c2db154866068bb15d0935b4a80bbdb54bae57dd Mon Sep 17 00:00:00 2001 From: Jina Dev Bot Date: Thu, 14 Sep 2023 10:16:48 +0000 Subject: [PATCH 04/17] style: fix overload and cli autocomplete --- jina/orchestrate/deployments/__init__.py | 2 ++ jina/orchestrate/flow/base.py | 9 +++++++++ jina/serve/executors/__init__.py | 2 ++ jina_cli/autocomplete.py | 4 ++++ 4 files changed, 17 insertions(+) diff --git a/jina/orchestrate/deployments/__init__.py b/jina/orchestrate/deployments/__init__.py index 8553ef51f43d1..8a68e061ddc8c 100644 --- a/jina/orchestrate/deployments/__init__.py +++ b/jina/orchestrate/deployments/__init__.py @@ -281,6 +281,7 @@ def __init__( port_monitoring: Optional[int] = None, prefer_platform: Optional[str] = None, protocol: Optional[Union[str, List[str]]] = ['GRPC'], + provider: Optional[str] = ['NONE'], py_modules: Optional[List[str]] = None, quiet: Optional[bool] = False, quiet_error: Optional[bool] = False, @@ -380,6 +381,7 @@ def __init__( :param port_monitoring: The port on which the prometheus server is exposed, default is a random port between [49152, 65535] :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. :param py_modules: The customized python modules need to be imported before loading the executor Note that the recommended way is to only import a single module - a simple python file, if your diff --git a/jina/orchestrate/flow/base.py b/jina/orchestrate/flow/base.py index 3fc3af7d1548e..dc53951b130bd 100644 --- a/jina/orchestrate/flow/base.py +++ b/jina/orchestrate/flow/base.py @@ -202,6 +202,7 @@ def __init__( port_monitoring: Optional[int] = None, prefetch: Optional[int] = 1000, protocol: Optional[Union[str, List[str]]] = ['GRPC'], + provider: Optional[str] = ['NONE'], proxy: Optional[bool] = False, py_modules: Optional[List[str]] = None, quiet: Optional[bool] = False, @@ -272,6 +273,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway @@ -462,6 +464,7 @@ def __init__( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway @@ -866,6 +869,7 @@ def add( port_monitoring: Optional[int] = None, prefer_platform: Optional[str] = None, protocol: Optional[Union[str, List[str]]] = ['GRPC'], + provider: Optional[str] = ['NONE'], py_modules: Optional[List[str]] = None, quiet: Optional[bool] = False, quiet_error: Optional[bool] = False, @@ -965,6 +969,7 @@ def add( :param port_monitoring: The port on which the prometheus server is exposed, default is a random port between [49152, 65535] :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. :param py_modules: The customized python modules need to be imported before loading the executor Note that the recommended way is to only import a single module - a simple python file, if your @@ -1127,6 +1132,7 @@ def add( :param port_monitoring: The port on which the prometheus server is exposed, default is a random port between [49152, 65535] :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. :param py_modules: The customized python modules need to be imported before loading the executor Note that the recommended way is to only import a single module - a simple python file, if your @@ -1319,6 +1325,7 @@ def config_gateway( port_monitoring: Optional[int] = None, prefetch: Optional[int] = 1000, protocol: Optional[Union[str, List[str]]] = ['GRPC'], + provider: Optional[str] = ['NONE'], proxy: Optional[bool] = False, py_modules: Optional[List[str]] = None, quiet: Optional[bool] = False, @@ -1389,6 +1396,7 @@ def config_gateway( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway @@ -1488,6 +1496,7 @@ def config_gateway( Used to control the speed of data input into a Flow. 0 disables prefetch (1000 requests is the default) :param protocol: Communication protocol of the server exposed by the Gateway. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. :param proxy: If set, respect the http_proxy and https_proxy environment variables. otherwise, it will unset these proxy variables before start. gRPC seems to prefer no proxy :param py_modules: The customized python modules need to be imported before loading the gateway diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index 2f4117c9e772d..12b6dd947eccd 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -956,6 +956,7 @@ def serve( port_monitoring: Optional[int] = None, prefer_platform: Optional[str] = None, protocol: Optional[Union[str, List[str]]] = ['GRPC'], + provider: Optional[str] = ['NONE'], py_modules: Optional[List[str]] = None, quiet: Optional[bool] = False, quiet_error: Optional[bool] = False, @@ -1055,6 +1056,7 @@ def serve( :param port_monitoring: The port on which the prometheus server is exposed, default is a random port between [49152, 65535] :param prefer_platform: The preferred target Docker platform. (e.g. "linux/amd64", "linux/arm64") :param protocol: Communication protocol of the server exposed by the Executor. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: ['GRPC', 'HTTP', 'WEBSOCKET']. + :param provider: If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: ['NONE', 'SAGEMAKER']. :param py_modules: The customized python modules need to be imported before loading the executor Note that the recommended way is to only import a single module - a simple python file, if your diff --git a/jina_cli/autocomplete.py b/jina_cli/autocomplete.py index 9db1f702c63a0..c101d1244b868 100644 --- a/jina_cli/autocomplete.py +++ b/jina_cli/autocomplete.py @@ -71,6 +71,7 @@ '--ports', '--protocol', '--protocols', + '--provider', '--monitoring', '--port-monitoring', '--retries', @@ -178,6 +179,7 @@ '--port-in', '--protocol', '--protocols', + '--provider', '--monitoring', '--port-monitoring', '--retries', @@ -383,6 +385,7 @@ '--ports', '--protocol', '--protocols', + '--provider', '--monitoring', '--port-monitoring', '--retries', @@ -456,6 +459,7 @@ '--ports', '--protocol', '--protocols', + '--provider', '--monitoring', '--port-monitoring', '--retries', From 2ab578fcef0f06b91de6d4f3a1655759cfe17e98 Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Thu, 14 Sep 2023 17:29:57 +0530 Subject: [PATCH 05/17] feat: sagemaker provider --- .../runtimes/worker/http_sagemaker_app.py | 21 +------------------ .../serve/runtimes/worker/request_handling.py | 21 +++++++++++++++++++ 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/jina/serve/runtimes/worker/http_sagemaker_app.py b/jina/serve/runtimes/worker/http_sagemaker_app.py index 0dffcca4ab778..8530fece5149f 100644 --- a/jina/serve/runtimes/worker/http_sagemaker_app.py +++ b/jina/serve/runtimes/worker/http_sagemaker_app.py @@ -67,25 +67,6 @@ class InnerConfig(BaseConfig): ) logger.warning('CORS is enabled. This service is accessible from any website!') - def validate_invocations_route(): - nonlocal request_models_map - - if '/invocations' in request_models_map: - return - - if len(request_models_map) == 2: # one is _jina_dry_run_ - for route in request_models_map: - if route != '_jina_dry_run_': - logger.warning( - f'No "/invocations" route found. Using "{route}" as "/invocations" route' - ) - request_models_map['/invocations'] = request_models_map[route] - return - - raise ValueError( - 'The request_models_map must contain the key "/invocations" to be able to serve the model' - ) - def add_post_route( endpoint_path, input_model, @@ -141,7 +122,6 @@ async def post(body: input_model, response: Response): docs_response = resp.docs return output_model(data=docs_response, parameters=resp.parameters) - validate_invocations_route() for endpoint, input_output_map in request_models_map.items(): if endpoint != '_jina_dry_run_': input_doc_model = input_output_map['input']['model'] @@ -181,6 +161,7 @@ async def post(body: input_model, response: Response): from jina.serve.runtimes.gateway.health_model import JinaHealthModel + # `/ping` route is required by AWS Sagemaker @app.get( path='/ping', summary='Get the health of Jina Executor service', diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 775158b2972d2..705c7c1080f84 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -194,6 +194,26 @@ def _http_fastapi_sagemaker_app(self, **kwargs): request_models_map = self._executor._get_endpoint_models_dict() + def validate_invocations_route(): + # validate if the request_models_map contains the key '/invocations' + nonlocal request_models_map + + if '/invocations' in request_models_map: + return + + if len(request_models_map) == 2: # one is _jina_dry_run_ + for route in request_models_map: + if route != '_jina_dry_run_': + self.logger.warning( + f'No "/invocations" route found. Using "{route}" as "/invocations" route' + ) + request_models_map['/invocations'] = request_models_map[route] + return + + raise ValueError( + 'The request_models_map must contain the key "/invocations" to be able to serve the model' + ) + def call_handle(request): is_generator = request_models_map[request.header.exec_endpoint][ 'is_generator' @@ -201,6 +221,7 @@ def call_handle(request): return self.process_single_data(request, None, is_generator=is_generator) + validate_invocations_route() app = get_fastapi_app( request_models_map=request_models_map, caller=call_handle, **kwargs ) From a393b2ce21ffd2a6a61aef164adea51ac0928f81 Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Thu, 14 Sep 2023 17:36:52 +0530 Subject: [PATCH 06/17] feat: sagemaker provider --- jina/serve/runtimes/servers/http.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jina/serve/runtimes/servers/http.py b/jina/serve/runtimes/servers/http.py index 88a847abc41b8..cadc7e6aeeda9 100644 --- a/jina/serve/runtimes/servers/http.py +++ b/jina/serve/runtimes/servers/http.py @@ -268,6 +268,12 @@ class SagemakerHTTPServer(FastAPIBaseServer): """ + @property + def port(self): + """Get the port for the sagemaker server + :return: Return the port for the sagemaker server, always 8080""" + return 8080 + @property def app(self): """Get the sagemaker fastapi app From 1d4e361c47fd18a2e9e950632b29670ca364b409 Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Thu, 14 Sep 2023 17:43:47 +0530 Subject: [PATCH 07/17] feat: sagemaker provider --- jina/serve/runtimes/servers/http.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jina/serve/runtimes/servers/http.py b/jina/serve/runtimes/servers/http.py index cadc7e6aeeda9..a60063ac6adb6 100644 --- a/jina/serve/runtimes/servers/http.py +++ b/jina/serve/runtimes/servers/http.py @@ -139,7 +139,7 @@ def filter(self, record: logging.LogRecord) -> bool: **self.uvicorn_kwargs, ) ) - self.logger.debug(f'UviServer server setup') + self.logger.debug(f'UviServer server setup on port {self.port}') await self.server.setup() self.logger.debug(f'HTTP server setup successful') From b32f26be7e80e86fe5ad6d593690c26057a76b84 Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Thu, 14 Sep 2023 17:48:54 +0530 Subject: [PATCH 08/17] feat: sagemaker provider --- jina/serve/runtimes/servers/http.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/jina/serve/runtimes/servers/http.py b/jina/serve/runtimes/servers/http.py index a60063ac6adb6..4ca3685c735ec 100644 --- a/jina/serve/runtimes/servers/http.py +++ b/jina/serve/runtimes/servers/http.py @@ -274,6 +274,12 @@ def port(self): :return: Return the port for the sagemaker server, always 8080""" return 8080 + @property + def ports(self): + """Get the port for the sagemaker server + :return: Return the port for the sagemaker server, always 8080""" + return [8080] + @property def app(self): """Get the sagemaker fastapi app From ee42d6a205b2d7fd4767e05ae2a6d7aecf904bae Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Fri, 15 Sep 2023 09:36:39 +0530 Subject: [PATCH 09/17] fix: move sagemaker route to executor --- jina/serve/executors/__init__.py | 37 ++++++++++++++++--- .../serve/runtimes/worker/request_handling.py | 22 +---------- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index 12b6dd947eccd..6188dedade5ce 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -68,9 +68,10 @@ def is_pydantic_model(annotation: Type) -> bool: :param annotation: The annotation from which to extract PydantiModel. :return: boolean indicating if a Pydantic model is inside the annotation """ - from pydantic import BaseModel from typing import get_args, get_origin + from pydantic import BaseModel + origin = get_origin(annotation) or annotation args = get_args(annotation) @@ -92,8 +93,9 @@ def get_inner_pydantic_model(annotation: Type) -> bool: :return: The inner Pydantic model expected """ try: + from typing import Optional, Type, Union, get_args, get_origin + from pydantic import BaseModel - from typing import Type, Optional, get_args, get_origin, Union origin = get_origin(annotation) or annotation args = get_args(annotation) @@ -179,7 +181,7 @@ def validate(self): self.is_generator and self.is_batch_docs ), f'Cannot specify the `docs` parameter if the endpoint {self.fn.__name__} is a generator' if docarray_v2: - from docarray import DocList, BaseDoc + from docarray import BaseDoc, DocList if not self.is_generator: if self.is_batch_docs and ( @@ -390,10 +392,11 @@ def __init__( self._add_requests(requests) self._add_dynamic_batching(dynamic_batching) self._add_runtime_args(runtime_args) + self.logger = JinaLogger(self.__class__.__name__, **vars(self.runtime_args)) + self._validate_sagemaker() self._init_instrumentation(runtime_args) self._init_monitoring() self._init_workspace = workspace - self.logger = JinaLogger(self.__class__.__name__, **vars(self.runtime_args)) if __dry_run_endpoint__ not in self.requests: self.requests[ __dry_run_endpoint__ @@ -596,6 +599,31 @@ def _add_requests(self, _requests: Optional[Dict]): f'expect {typename(self)}.{func} to be a function, but receiving {typename(_func)}' ) + def _validate_sagemaker(self): + # sagemaker expects the POST /invocations endpoint to be defined. + # if it is not defined, we check if there is only one endpoint defined, + # and if so, we use it as the POST /invocations endpoint, or raise an error + if ( + not hasattr(self, 'runtime_args') + or not self.runtime_args.provider != 'sagemaker' + ): + return + + if '/invocations' in self.requests: + return + + if len(self.requests) == 1: + route = list(self.requests.keys())[0] + self.logger.warning( + f'No "/invocations" route found. Using "{route}" as "/invocations" route' + ) + self.requests['/invocations'] = self.requests[route] + return + + raise ValueError( + 'No "/invocations" route found. Please define a "/invocations" route' + ) + def _add_dynamic_batching(self, _dynamic_batching: Optional[Dict]): if _dynamic_batching: self.dynamic_batching = getattr(self, 'dynamic_batching', {}) @@ -695,7 +723,6 @@ async def __acall__(self, req_endpoint: str, **kwargs): async def __acall_endpoint__( self, req_endpoint, tracing_context: Optional['Context'], **kwargs ): - # Decorator to make sure that `parameters` are passed as PydanticModels if needed def parameters_as_pydantic_models_decorator(func, parameters_pydantic_model): @functools.wraps(func) # Step 2: Use functools.wraps to preserve metadata diff --git a/jina/serve/runtimes/worker/request_handling.py b/jina/serve/runtimes/worker/request_handling.py index 705c7c1080f84..c8584d19baa40 100644 --- a/jina/serve/runtimes/worker/request_handling.py +++ b/jina/serve/runtimes/worker/request_handling.py @@ -194,26 +194,6 @@ def _http_fastapi_sagemaker_app(self, **kwargs): request_models_map = self._executor._get_endpoint_models_dict() - def validate_invocations_route(): - # validate if the request_models_map contains the key '/invocations' - nonlocal request_models_map - - if '/invocations' in request_models_map: - return - - if len(request_models_map) == 2: # one is _jina_dry_run_ - for route in request_models_map: - if route != '_jina_dry_run_': - self.logger.warning( - f'No "/invocations" route found. Using "{route}" as "/invocations" route' - ) - request_models_map['/invocations'] = request_models_map[route] - return - - raise ValueError( - 'The request_models_map must contain the key "/invocations" to be able to serve the model' - ) - def call_handle(request): is_generator = request_models_map[request.header.exec_endpoint][ 'is_generator' @@ -221,7 +201,6 @@ def call_handle(request): return self.process_single_data(request, None, is_generator=is_generator) - validate_invocations_route() app = get_fastapi_app( request_models_map=request_models_map, caller=call_handle, **kwargs ) @@ -404,6 +383,7 @@ def _load_executor( 'shards': self.args.shards, 'replicas': self.args.replicas, 'name': self.args.name, + 'provider': self.args.provider, 'metrics_registry': metrics_registry, 'tracer_provider': tracer_provider, 'meter_provider': meter_provider, From 5801568ea52c1882c6069574e7f2cf413be3bc04 Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Fri, 15 Sep 2023 09:45:44 +0530 Subject: [PATCH 10/17] fix: only support docarray v2 --- .../runtimes/worker/http_sagemaker_app.py | 39 +++++++------------ 1 file changed, 14 insertions(+), 25 deletions(-) diff --git a/jina/serve/runtimes/worker/http_sagemaker_app.py b/jina/serve/runtimes/worker/http_sagemaker_app.py index 8530fece5149f..f095b8cfc9000 100644 --- a/jina/serve/runtimes/worker/http_sagemaker_app.py +++ b/jina/serve/runtimes/worker/http_sagemaker_app.py @@ -1,6 +1,5 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union -from jina import Document, DocumentArray from jina._docarray import docarray_v2 from jina.importer import ImportExtensions from jina.types.request.data import DataRequest @@ -42,6 +41,10 @@ def get_fastapi_app( from jina.proto import jina_pb2 from jina.serve.runtimes.gateway.models import _to_camel_case + if not docarray_v2: + logger.warning('Only docarray v2 is supported with Sagemaker. ') + return + class Header(BaseModel): request_id: Optional[str] = Field( description='Request ID', example=os.urandom(16).hex() @@ -74,16 +77,15 @@ def add_post_route( input_doc_list_model=None, output_doc_list_model=None, ): + from docarray.base_doc.docarray_response import DocArrayResponse + app_kwargs = dict( path=f'/{endpoint_path.strip("/")}', methods=['POST'], summary=f'Endpoint {endpoint_path}', response_model=output_model, + response_class=DocArrayResponse, ) - if docarray_v2: - from docarray.base_doc.docarray_response import DocArrayResponse - - app_kwargs['response_class'] = DocArrayResponse @app.api_route(**app_kwargs) async def post(body: input_model, response: Response): @@ -94,19 +96,14 @@ async def post(body: input_model, response: Response): if body.parameters is not None: req.parameters = body.parameters req.header.exec_endpoint = endpoint_path + data = body.data if isinstance(data, list): - if not docarray_v2: - req.data.docs = DocumentArray.from_pydantic_model(data) - else: - req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_list_model](data) + req.document_array_cls = DocList[input_doc_model] + req.data.docs = DocList[input_doc_list_model](data) else: - if not docarray_v2: - req.data.docs = DocumentArray([Document.from_pydantic_model(data)]) - else: - req.document_array_cls = DocList[input_doc_model] - req.data.docs = DocList[input_doc_list_model]([data]) + req.document_array_cls = DocList[input_doc_model] + req.data.docs = DocList[input_doc_list_model]([data]) if body.header is None: req.header.request_id = req.docs[0].id @@ -116,11 +113,7 @@ async def post(body: input_model, response: Response): if status.code == jina_pb2.StatusProto.ERROR: raise HTTPException(status_code=499, detail=status.description) else: - if not docarray_v2: - docs_response = resp.docs.to_dict() - else: - docs_response = resp.docs - return output_model(data=docs_response, parameters=resp.parameters) + return output_model(data=resp.docs, parameters=resp.parameters) for endpoint, input_output_map in request_models_map.items(): if endpoint != '_jina_dry_run_': @@ -131,11 +124,7 @@ async def post(body: input_model, response: Response): ... if input_output_map['parameters']['model'] else None ) - if docarray_v2: - _config = inherit_config(InnerConfig, BaseDoc.__config__) - else: - _config = input_doc_model.__config__ - + _config = inherit_config(InnerConfig, BaseDoc.__config__) endpoint_input_model = pydantic.create_model( f'{endpoint.strip("/")}_input_model', data=(Union[List[input_doc_model], input_doc_model], ...), From 9c00771c8a53604334899003c37bdc9287806697 Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Fri, 15 Sep 2023 14:23:14 +0530 Subject: [PATCH 11/17] fix: provider in runtime_args --- jina/serve/executors/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index 6188dedade5ce..503d9da649a0b 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -605,6 +605,7 @@ def _validate_sagemaker(self): # and if so, we use it as the POST /invocations endpoint, or raise an error if ( not hasattr(self, 'runtime_args') + or not hasattr(self.runtime_args, 'provider') or not self.runtime_args.provider != 'sagemaker' ): return From bd91cda768b73b741b9eaa06c02f0e5f094543c0 Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Fri, 15 Sep 2023 14:34:04 +0530 Subject: [PATCH 12/17] fix: provider in runtime_args --- jina/serve/executors/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jina/serve/executors/__init__.py b/jina/serve/executors/__init__.py index 503d9da649a0b..8daa6616110e3 100644 --- a/jina/serve/executors/__init__.py +++ b/jina/serve/executors/__init__.py @@ -28,7 +28,7 @@ from jina._docarray import DocumentArray, docarray_v2 from jina.constants import __args_executor_init__, __cache_path__, __default_endpoint__ -from jina.enums import BetterEnum +from jina.enums import BetterEnum, ProviderType from jina.helper import ( ArgNamespace, T, @@ -606,7 +606,7 @@ def _validate_sagemaker(self): if ( not hasattr(self, 'runtime_args') or not hasattr(self.runtime_args, 'provider') - or not self.runtime_args.provider != 'sagemaker' + or self.runtime_args.provider != ProviderType.SAGEMAKER.value ): return From fb71c13b837c3e7b5e4271dc4d531e7af117710c Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Mon, 18 Sep 2023 12:09:04 +0530 Subject: [PATCH 13/17] test: sagemaker integration tests --- .../sagemaker/SampleExecutor/README.md | 2 + .../sagemaker/SampleExecutor/config.yml | 8 +++ .../sagemaker/SampleExecutor/executor.py | 28 ++++++++++ .../sagemaker/SampleExecutor/requirements.txt | 0 tests/integration/sagemaker/__init__.py | 0 tests/integration/sagemaker/test_sagemaker.py | 55 +++++++++++++++++++ 6 files changed, 93 insertions(+) create mode 100644 tests/integration/sagemaker/SampleExecutor/README.md create mode 100644 tests/integration/sagemaker/SampleExecutor/config.yml create mode 100644 tests/integration/sagemaker/SampleExecutor/executor.py create mode 100644 tests/integration/sagemaker/SampleExecutor/requirements.txt create mode 100644 tests/integration/sagemaker/__init__.py create mode 100644 tests/integration/sagemaker/test_sagemaker.py diff --git a/tests/integration/sagemaker/SampleExecutor/README.md b/tests/integration/sagemaker/SampleExecutor/README.md new file mode 100644 index 0000000000000..49da1225f4487 --- /dev/null +++ b/tests/integration/sagemaker/SampleExecutor/README.md @@ -0,0 +1,2 @@ +# SampleExecutor + diff --git a/tests/integration/sagemaker/SampleExecutor/config.yml b/tests/integration/sagemaker/SampleExecutor/config.yml new file mode 100644 index 0000000000000..6b819858f2fc8 --- /dev/null +++ b/tests/integration/sagemaker/SampleExecutor/config.yml @@ -0,0 +1,8 @@ +jtype: SampleExecutor +py_modules: + - executor.py +metas: + name: SampleExecutor + description: + url: + keywords: [] \ No newline at end of file diff --git a/tests/integration/sagemaker/SampleExecutor/executor.py b/tests/integration/sagemaker/SampleExecutor/executor.py new file mode 100644 index 0000000000000..4bccb70388561 --- /dev/null +++ b/tests/integration/sagemaker/SampleExecutor/executor.py @@ -0,0 +1,28 @@ +import numpy as np +from docarray import BaseDoc, DocList +from docarray.typing import NdArray +from pydantic import Field + +from jina import Executor, requests + + +class TextDoc(BaseDoc): + text: str + + +class EmbeddingResponseModel(BaseDoc): + embeddings: NdArray = Field(description="The embedding of the texts", default=[]) + + class Config(BaseDoc.Config): + allow_population_by_field_name = True + arbitrary_types_allowed = True + json_encoders = {NdArray: lambda v: v.tolist()} + + +class SampleExecutor(Executor): + @requests(on="/encode") + def foo(self, docs: DocList[TextDoc], **kwargs) -> DocList[EmbeddingResponseModel]: + ret = [] + for doc in docs: + ret.append(EmbeddingResponseModel(embeddings=np.random.random((1, 64)))) + return DocList[EmbeddingResponseModel](ret) diff --git a/tests/integration/sagemaker/SampleExecutor/requirements.txt b/tests/integration/sagemaker/SampleExecutor/requirements.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/integration/sagemaker/__init__.py b/tests/integration/sagemaker/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/integration/sagemaker/test_sagemaker.py b/tests/integration/sagemaker/test_sagemaker.py new file mode 100644 index 0000000000000..6ef1b6994b40b --- /dev/null +++ b/tests/integration/sagemaker/test_sagemaker.py @@ -0,0 +1,55 @@ +import os +from contextlib import AbstractContextManager + +import requests + +from jina.orchestrate.pods import Pod +from jina.parsers import set_pod_parser + + +class chdir(AbstractContextManager): + def __init__(self, path): + self.path = path + self._old_cwd = [] + + def __enter__(self): + self._old_cwd.append(os.getcwd()) + os.chdir(self.path) + + def __exit__(self, *excinfo): + os.chdir(self._old_cwd.pop()) + + +def test_provider_sagemaker(): + with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): + args, _ = set_pod_parser().parse_known_args( + [ + '--uses', + 'config.yml', + '--provider', + 'sagemaker', + 'serve', # This is added by sagemaker + ] + ) + with Pod(args): + # provider=sagemaker would set the port to 8080 + port = 8080 + # Test the `GET /ping` endpoint (added by jina for sagemaker) + rsp = requests.get(f'http://localhost:{port}/ping') + assert rsp.status_code == 200 + assert rsp.json() == {} + + # Test the `POST /invocations` endpoint + # Note: this endpoint is not implemented in the sample executor + rsp = requests.post( + f'http://localhost:{port}/invocations', + json={ + 'data': [ + {'text': 'hello world'}, + ] + }, + ) + assert rsp.status_code == 200 + resp_json = rsp.json() + assert len(resp_json['data']) == 1 + assert len(resp_json['data'][0]['embeddings'][0]) == 64 From 396fc0858772e370bd22096491899bbedddb36c3 Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Mon, 18 Sep 2023 13:54:54 +0530 Subject: [PATCH 14/17] test: sagemaker integration tests --- tests/integration/sagemaker/test_sagemaker.py | 38 ++++++++++++++++--- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/tests/integration/sagemaker/test_sagemaker.py b/tests/integration/sagemaker/test_sagemaker.py index 6ef1b6994b40b..362e79e319d7f 100644 --- a/tests/integration/sagemaker/test_sagemaker.py +++ b/tests/integration/sagemaker/test_sagemaker.py @@ -3,6 +3,7 @@ import requests +from jina import Deployment from jina.orchestrate.pods import Pod from jina.parsers import set_pod_parser @@ -20,7 +21,7 @@ def __exit__(self, *excinfo): os.chdir(self._old_cwd.pop()) -def test_provider_sagemaker(): +def test_provider_sagemaker_pod(): with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): args, _ = set_pod_parser().parse_known_args( [ @@ -35,14 +36,39 @@ def test_provider_sagemaker(): # provider=sagemaker would set the port to 8080 port = 8080 # Test the `GET /ping` endpoint (added by jina for sagemaker) - rsp = requests.get(f'http://localhost:{port}/ping') + resp = requests.get(f'http://localhost:{port}/ping') + assert resp.status_code == 200 + assert resp.json() == {} + + # Test the `POST /invocations` endpoint + # Note: this endpoint is not implemented in the sample executor + resp = requests.post( + f'http://localhost:{port}/invocations', + json={ + 'data': [ + {'text': 'hello world'}, + ] + }, + ) + assert resp.status_code == 200 + resp_json = resp.json() + assert len(resp_json['data']) == 1 + assert len(resp_json['data'][0]['embeddings'][0]) == 64 + + +def test_provider_sagemaker_deployment(): + with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): + dep_port = 12345 + with Deployment(uses='config.yml', provider='sagemaker', port=dep_port) as dep: + # Test the `GET /ping` endpoint (added by jina for sagemaker) + rsp = requests.get(f'http://localhost:{dep_port}/ping') assert rsp.status_code == 200 assert rsp.json() == {} # Test the `POST /invocations` endpoint # Note: this endpoint is not implemented in the sample executor rsp = requests.post( - f'http://localhost:{port}/invocations', + f'http://localhost:{dep_port}/invocations', json={ 'data': [ {'text': 'hello world'}, @@ -50,6 +76,6 @@ def test_provider_sagemaker(): }, ) assert rsp.status_code == 200 - resp_json = rsp.json() - assert len(resp_json['data']) == 1 - assert len(resp_json['data'][0]['embeddings'][0]) == 64 + rsp_json = rsp.json() + assert len(rsp_json['data']) == 1 + assert len(rsp_json['data'][0]['embeddings'][0]) == 64 From 791b3ee98ea8eed39d1aca793f9b21cb20894793 Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Mon, 18 Sep 2023 13:55:24 +0530 Subject: [PATCH 15/17] fix: deployment with sagemaker provider --- jina/orchestrate/deployments/__init__.py | 26 +++++++++++++------ tests/integration/sagemaker/test_sagemaker.py | 6 ++--- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/jina/orchestrate/deployments/__init__.py b/jina/orchestrate/deployments/__init__.py index 24af1d87286e7..8cffaa15f646f 100644 --- a/jina/orchestrate/deployments/__init__.py +++ b/jina/orchestrate/deployments/__init__.py @@ -1,14 +1,15 @@ import asyncio import copy import json +import multiprocessing import os +import platform import re import subprocess -import threading -import multiprocessing -import platform import sys +import threading import time +import warnings from argparse import Namespace from collections import defaultdict from contextlib import ExitStack @@ -29,7 +30,13 @@ __docker_host__, __windows__, ) -from jina.enums import DeploymentRoleType, PodRoleType, PollingType, ProtocolType +from jina.enums import ( + DeploymentRoleType, + PodRoleType, + PollingType, + ProtocolType, + ProviderType, +) from jina.helper import ( ArgNamespace, parse_host_scheme, @@ -471,6 +478,13 @@ def __init__( args = ArgNamespace.kwargs2namespace(kwargs, parser, True) self.args = args self._gateway_load_balancer = False + if self.args.provider == ProviderType.SAGEMAKER: + if self.args.port != 8080: + warnings.warn( + f'Port is changed to 8080 for Sagemaker deployment. Port {self.args.port} is ignored' + ) + self.args.protocol = [ProtocolType.HTTP] + self.args.port = [8080] if self._include_gateway and ProtocolType.HTTP in self.args.protocol: self._gateway_load_balancer = True log_config = kwargs.get('log_config') @@ -1306,7 +1320,6 @@ def _roundrobin_cuda_device(device_str: Optional[str], replicas: int): selected_devices = [] if device_str[2:]: - for device in Deployment._parse_devices(device_str[2:], num_devices): selected_devices.append(device) else: @@ -1448,7 +1461,6 @@ def _set_pod_args(self) -> Dict[int, List[Namespace]]: @staticmethod def _set_uses_before_after_args(args: Namespace, entity_type: str) -> Namespace: - _args = copy.deepcopy(args) _args.pod_role = PodRoleType.WORKER _args.host = _args.host[0] or __default_host__ @@ -1650,7 +1662,6 @@ def _reload_deployment(changed_file): watch_changes = self.args.reload if watch_changes and self._is_executor_from_yaml: - with ImportExtensions( required=True, help_text='''reload requires watchfiles dependency to be installed. You can run `pip install @@ -1694,7 +1705,6 @@ def _get_summary_table(self, all_panels: List[Panel]): swagger_ui_link = None redoc_link = None for _port, _protocol in zip(_ports, _protocols): - address_table.add_row(':chains:', 'Protocol', _protocol) _protocol = _protocol.lower() diff --git a/tests/integration/sagemaker/test_sagemaker.py b/tests/integration/sagemaker/test_sagemaker.py index 362e79e319d7f..4e9d3b6f3ba07 100644 --- a/tests/integration/sagemaker/test_sagemaker.py +++ b/tests/integration/sagemaker/test_sagemaker.py @@ -76,6 +76,6 @@ def test_provider_sagemaker_deployment(): }, ) assert rsp.status_code == 200 - rsp_json = rsp.json() - assert len(rsp_json['data']) == 1 - assert len(rsp_json['data'][0]['embeddings'][0]) == 64 + resp_json = rsp.json() + assert len(resp_json['data']) == 1 + assert len(resp_json['data'][0]['embeddings'][0]) == 64 From eee67f77656de357d127ed6bd95bd10f9f57434a Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Mon, 18 Sep 2023 14:53:08 +0530 Subject: [PATCH 16/17] ci: move sagemaker to docarray v2 dir --- .github/workflows/cd.yml | 1 + .github/workflows/ci.yml | 1 + .../{ => docarray_v2}/sagemaker/SampleExecutor/README.md | 0 .../{ => docarray_v2}/sagemaker/SampleExecutor/config.yml | 0 .../{ => docarray_v2}/sagemaker/SampleExecutor/executor.py | 0 .../{ => docarray_v2}/sagemaker/SampleExecutor/requirements.txt | 0 tests/integration/{ => docarray_v2}/sagemaker/__init__.py | 0 tests/integration/{ => docarray_v2}/sagemaker/test_sagemaker.py | 0 8 files changed, 2 insertions(+) rename tests/integration/{ => docarray_v2}/sagemaker/SampleExecutor/README.md (100%) rename tests/integration/{ => docarray_v2}/sagemaker/SampleExecutor/config.yml (100%) rename tests/integration/{ => docarray_v2}/sagemaker/SampleExecutor/executor.py (100%) rename tests/integration/{ => docarray_v2}/sagemaker/SampleExecutor/requirements.txt (100%) rename tests/integration/{ => docarray_v2}/sagemaker/__init__.py (100%) rename tests/integration/{ => docarray_v2}/sagemaker/test_sagemaker.py (100%) diff --git a/.github/workflows/cd.yml b/.github/workflows/cd.yml index d4511b7f90a39..ffd6e2be468a3 100644 --- a/.github/workflows/cd.yml +++ b/.github/workflows/cd.yml @@ -146,6 +146,7 @@ jobs: pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_singleton.py pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_parameters_as_pydantic.py pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_streaming.py + pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/sagemaker/test_sagemaker.py echo "flag it as jina for codeoverage" echo "codecov_flag=jina" >> $GITHUB_OUTPUT timeout-minutes: 45 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1dbb56b23c7e5..e19af8a54d120 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -481,6 +481,7 @@ jobs: pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_singleton.py pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_parameters_as_pydantic.py pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/test_streaming.py + pytest --suppress-no-test-exit-code --force-flaky --min-passes 1 --max-runs 5 --cov=jina --cov-report=xml --timeout=600 -v -s --ignore-glob='tests/integration/hub_usage/dummyhub*' tests/integration/docarray_v2/sagemaker/test_sagemaker.py echo "flag it as jina for codeoverage" echo "codecov_flag=jina" >> $GITHUB_OUTPUT timeout-minutes: 45 diff --git a/tests/integration/sagemaker/SampleExecutor/README.md b/tests/integration/docarray_v2/sagemaker/SampleExecutor/README.md similarity index 100% rename from tests/integration/sagemaker/SampleExecutor/README.md rename to tests/integration/docarray_v2/sagemaker/SampleExecutor/README.md diff --git a/tests/integration/sagemaker/SampleExecutor/config.yml b/tests/integration/docarray_v2/sagemaker/SampleExecutor/config.yml similarity index 100% rename from tests/integration/sagemaker/SampleExecutor/config.yml rename to tests/integration/docarray_v2/sagemaker/SampleExecutor/config.yml diff --git a/tests/integration/sagemaker/SampleExecutor/executor.py b/tests/integration/docarray_v2/sagemaker/SampleExecutor/executor.py similarity index 100% rename from tests/integration/sagemaker/SampleExecutor/executor.py rename to tests/integration/docarray_v2/sagemaker/SampleExecutor/executor.py diff --git a/tests/integration/sagemaker/SampleExecutor/requirements.txt b/tests/integration/docarray_v2/sagemaker/SampleExecutor/requirements.txt similarity index 100% rename from tests/integration/sagemaker/SampleExecutor/requirements.txt rename to tests/integration/docarray_v2/sagemaker/SampleExecutor/requirements.txt diff --git a/tests/integration/sagemaker/__init__.py b/tests/integration/docarray_v2/sagemaker/__init__.py similarity index 100% rename from tests/integration/sagemaker/__init__.py rename to tests/integration/docarray_v2/sagemaker/__init__.py diff --git a/tests/integration/sagemaker/test_sagemaker.py b/tests/integration/docarray_v2/sagemaker/test_sagemaker.py similarity index 100% rename from tests/integration/sagemaker/test_sagemaker.py rename to tests/integration/docarray_v2/sagemaker/test_sagemaker.py From 812667b5e96de97307d1011d476af8d0540baa5a Mon Sep 17 00:00:00 2001 From: Deepankar Mahapatro Date: Mon, 18 Sep 2023 16:17:41 +0530 Subject: [PATCH 17/17] fix: sagemaker args check --- jina/orchestrate/deployments/__init__.py | 14 +++++++++++--- .../docarray_v2/sagemaker/test_sagemaker.py | 8 ++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/jina/orchestrate/deployments/__init__.py b/jina/orchestrate/deployments/__init__.py index 8cffaa15f646f..27327402ebdcf 100644 --- a/jina/orchestrate/deployments/__init__.py +++ b/jina/orchestrate/deployments/__init__.py @@ -479,12 +479,20 @@ def __init__( self.args = args self._gateway_load_balancer = False if self.args.provider == ProviderType.SAGEMAKER: - if self.args.port != 8080: + if self._gateway_kwargs.get('port', 0) == 8080: + raise ValueError( + f'Port 8080 is reserved for Sagemaker deployment. Please use another port' + ) + if self.args.port != [8080]: warnings.warn( f'Port is changed to 8080 for Sagemaker deployment. Port {self.args.port} is ignored' ) - self.args.protocol = [ProtocolType.HTTP] - self.args.port = [8080] + self.args.port = [8080] + if self.args.protocol != [ProtocolType.HTTP]: + warnings.warn( + f'Protocol is changed to HTTP for Sagemaker deployment. Protocol {self.args.protocol} is ignored' + ) + self.args.protocol = [ProtocolType.HTTP] if self._include_gateway and ProtocolType.HTTP in self.args.protocol: self._gateway_load_balancer = True log_config = kwargs.get('log_config') diff --git a/tests/integration/docarray_v2/sagemaker/test_sagemaker.py b/tests/integration/docarray_v2/sagemaker/test_sagemaker.py index 4e9d3b6f3ba07..2d1d0c6d88cb3 100644 --- a/tests/integration/docarray_v2/sagemaker/test_sagemaker.py +++ b/tests/integration/docarray_v2/sagemaker/test_sagemaker.py @@ -1,6 +1,7 @@ import os from contextlib import AbstractContextManager +import pytest import requests from jina import Deployment @@ -79,3 +80,10 @@ def test_provider_sagemaker_deployment(): resp_json = rsp.json() assert len(resp_json['data']) == 1 assert len(resp_json['data'][0]['embeddings'][0]) == 64 + + +def test_provider_sagemaker_deployment_wrong_port(): + with chdir(os.path.join(os.path.dirname(__file__), 'SampleExecutor')): + with pytest.raises(ValueError): + with Deployment(uses='config.yml', provider='sagemaker', port=8080) as dep: + pass