Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: executor to sagemaker custom container #6046

Merged
merged 19 commits into from
Sep 18, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ repos:
args:
- -S
- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.12.0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this update, I keep getting an error during commit. SO link

RuntimeError: The Poetry configuration is invalid:
            - [extras.pipfile_deprecated_finder.2] 'pip-shims<=0.3.4' does not match '^[a-zA-Z-_.0-9]+$'

hooks:
- id: isort
exclude: ^(jina/helloworld/|jina/proto/pb/jina_pb2.py|jina/proto/pb/jina_pb2_grpc.py|jina/proto/pb2/jina_pb2.py|jina/proto/pb2/jina_pb2_grpc.py|docs/|jina/resources/|jina/proto/docarray_v1|jina/proto/docarray_v2)
Expand Down
7 changes: 7 additions & 0 deletions jina/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,13 @@ class WebsocketSubProtocols(str, Enum):
BYTES = 'bytes'


class ProviderType(BetterEnum):
"""Provider type."""

NONE = 0 #: no provider
SAGEMAKER = 1 #: AWS SageMaker


def replace_enum_to_str(obj):
"""
Transform BetterEnum type into string.
Expand Down
38 changes: 24 additions & 14 deletions jina/parsers/orchestrate/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from dataclasses import dataclass
from typing import Dict

from jina.enums import PodRoleType, ProtocolType
from jina.enums import PodRoleType, ProtocolType, ProviderType
from jina.helper import random_port
from jina.parsers.helper import (
_SHOW_ALL_ARGS,
CastToIntAction,
CastPeerPorts,
CastToIntAction,
KVAppendAction,
add_arg_group,
)
Expand Down Expand Up @@ -52,7 +52,7 @@ def mixin_pod_parser(parser, pod_type: str = 'worker'):
type=int,
default=600000,
help='The timeout in milliseconds of a Pod waits for the runtime to be ready, -1 for waiting '
'forever',
'forever',
)

gp.add_argument(
Expand All @@ -68,15 +68,17 @@ def mixin_pod_parser(parser, pod_type: str = 'worker'):
action=KVAppendAction,
metavar='KEY: VALUE',
nargs='*',
help='The map of environment variables that are read from kubernetes cluster secrets' if _SHOW_ALL_ARGS
help='The map of environment variables that are read from kubernetes cluster secrets'
if _SHOW_ALL_ARGS
else argparse.SUPPRESS,
)
gp.add_argument(
'--image-pull-secrets',
type=str,
nargs='+',
default=None,
help='List of ImagePullSecrets that the Kubernetes Pods need to have access to in order to pull the image. Used in `to_kubernetes_yaml`' if _SHOW_ALL_ARGS
help='List of ImagePullSecrets that the Kubernetes Pods need to have access to in order to pull the image. Used in `to_kubernetes_yaml`'
if _SHOW_ALL_ARGS
else argparse.SUPPRESS,
)

Expand Down Expand Up @@ -106,7 +108,7 @@ def mixin_pod_parser(parser, pod_type: str = 'worker'):
action='store_true',
default=False,
help='If set, starting a Pod/Deployment does not block the thread/process. It then relies on '
'`wait_start_success` at outer function for the postpone check.'
'`wait_start_success` at outer function for the postpone check.'
if _SHOW_ALL_ARGS
else argparse.SUPPRESS,
)
Expand All @@ -116,7 +118,7 @@ def mixin_pod_parser(parser, pod_type: str = 'worker'):
action='store_true',
default=False,
help='If set, the current Pod/Deployment can not be further chained, '
'and the next `.add()` will chain after the last Pod/Deployment not this current one.',
'and the next `.add()` will chain after the last Pod/Deployment not this current one.',
)

gp.add_argument(
Expand All @@ -134,9 +136,9 @@ def mixin_pod_parser(parser, pod_type: str = 'worker'):
action='store_true',
default=False,
help='If set, the Executor will restart while serving if YAML configuration source or Executor modules '
'are changed. If YAML configuration is changed, the whole deployment is reloaded and new '
'processes will be restarted. If only Python modules of the Executor have changed, they will be '
'reloaded to the interpreter without restarting process.',
'are changed. If YAML configuration is changed, the whole deployment is reloaded and new '
'processes will be restarted. If only Python modules of the Executor have changed, they will be '
'reloaded to the interpreter without restarting process.',
)
gp.add_argument(
'--install-requirements',
Expand Down Expand Up @@ -195,6 +197,14 @@ def mixin_pod_runtime_args_parser(arg_group, pod_type='worker'):
help=f'Communication protocol of the server exposed by the {server_name}. This can be a single value or a list of protocols, depending on your chosen Gateway. Choose the convenient protocols from: {[protocol.to_string() for protocol in list(ProtocolType)]}.',
)

arg_group.add_argument(
'--provider',
type=ProviderType.from_string,
choices=list(ProviderType),
default=[ProviderType.NONE],
help=f'If set, Executor is translated to a custom container compatible with the chosen provider. Choose the convenient providers from: {[provider.to_string() for provider in list(ProviderType)]}.',
)

arg_group.add_argument(
'--monitoring',
action='store_true',
Expand Down Expand Up @@ -225,7 +235,7 @@ def mixin_pod_runtime_args_parser(arg_group, pod_type='worker'):
action='store_true',
default=False,
help='If set, the sdk implementation of the OpenTelemetry tracer will be available and will be enabled for automatic tracing of requests and customer span creation. '
'Otherwise a no-op implementation will be provided.',
'Otherwise a no-op implementation will be provided.',
)

arg_group.add_argument(
Expand All @@ -247,7 +257,7 @@ def mixin_pod_runtime_args_parser(arg_group, pod_type='worker'):
action='store_true',
default=False,
help='If set, the sdk implementation of the OpenTelemetry metrics will be available for default monitoring and custom measurements. '
'Otherwise a no-op implementation will be provided.',
'Otherwise a no-op implementation will be provided.',
)

arg_group.add_argument(
Expand Down Expand Up @@ -283,8 +293,8 @@ def mixin_stateful_parser(parser):
type=str,
default=None,
help='When using --stateful option, it is required to tell the cluster what are the cluster configuration. This is important'
'when the Deployment is restarted. It indicates the ports to which each replica of the cluster binds.'
' It is expected to be a single list if shards == 1 or a dictionary if shards > 1.',
'when the Deployment is restarted. It indicates the ports to which each replica of the cluster binds.'
' It is expected to be a single list if shards == 1 or a dictionary if shards > 1.',
action=CastPeerPorts,
nargs='+',
)
135 changes: 87 additions & 48 deletions jina/serve/runtimes/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,17 @@ class AsyncNewLoopRuntime:
"""

def __init__(
self,
args: 'argparse.Namespace',
cancel_event: Optional[
Union['asyncio.Event', 'multiprocessing.Event', 'threading.Event']
] = None,
signal_handlers_installed_event: Optional[
Union['asyncio.Event', 'multiprocessing.Event', 'threading.Event']
] = None,
req_handler_cls=None,
gateway_load_balancer: bool = False,
**kwargs,
self,
args: 'argparse.Namespace',
cancel_event: Optional[
Union['asyncio.Event', 'multiprocessing.Event', 'threading.Event']
] = None,
signal_handlers_installed_event: Optional[
Union['asyncio.Event', 'multiprocessing.Event', 'threading.Event']
] = None,
req_handler_cls=None,
gateway_load_balancer: bool = False,
**kwargs,
):
self.req_handler_cls = req_handler_cls
self.gateway_load_balancer = gateway_load_balancer
Expand All @@ -60,7 +60,9 @@ def __init__(
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self.is_cancel = cancel_event or asyncio.Event()
self.is_signal_handlers_installed = signal_handlers_installed_event or asyncio.Event()
self.is_signal_handlers_installed = (
signal_handlers_installed_event or asyncio.Event()
)

self.logger.debug(f'Setting signal handlers')

Expand Down Expand Up @@ -113,12 +115,12 @@ async def _wait_for_cancel(self):
"""Do NOT override this method when inheriting from :class:`GatewayPod`"""
# threads are not using asyncio.Event, but threading.Event
if isinstance(self.is_cancel, asyncio.Event) and not hasattr(
self.server, '_should_exit'
self.server, '_should_exit'
):
await self.is_cancel.wait()
else:
while not self.is_cancel.is_set() and not getattr(
self.server, '_should_exit', False
self.server, '_should_exit', False
):
await asyncio.sleep(0.1)

Expand All @@ -139,14 +141,17 @@ def _cancel(self):

def _get_server(self):
# construct server type based on protocol (and potentially req handler class to keep backwards compatibility)
from jina.enums import ProtocolType
from jina.enums import ProtocolType, ProviderType

if self.req_handler_cls.__name__ == 'GatewayRequestHandler':
self.timeout_send = self.args.timeout_send
if self.timeout_send:
self.timeout_send /= 1e3 # convert ms to seconds
if not self.args.port:
self.args.port = random_ports(len(self.args.protocol))
_set_gateway_uses(self.args, gateway_load_balancer=self.gateway_load_balancer)
_set_gateway_uses(
self.args, gateway_load_balancer=self.gateway_load_balancer
)
uses_with = self.args.uses_with or {}
non_defaults = ArgNamespace.get_non_defaults_args(
self.args, set_gateway_parser()
Expand Down Expand Up @@ -184,8 +189,25 @@ def _get_server(self):
if isinstance(server, BaseServer):
server.is_cancel = self.is_cancel
return server
elif (
hasattr(self.args, 'provider')
and self.args.provider == ProviderType.SAGEMAKER
):
from jina.serve.runtimes.servers.http import SagemakerHTTPServer

return SagemakerHTTPServer(
name=self.args.name,
runtime_args=self.args,
req_handler_cls=self.req_handler_cls,
proxy=getattr(self.args, 'proxy', None),
uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None),
ssl_keyfile=getattr(self.args, 'ssl_keyfile', None),
ssl_certfile=getattr(self.args, 'ssl_certfile', None),
cors=getattr(self.args, 'cors', None),
is_cancel=self.is_cancel,
)
elif not hasattr(self.args, 'protocol') or (
len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.GRPC
len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.GRPC
):
from jina.serve.runtimes.servers.grpc import GRPCServer

Expand All @@ -199,38 +221,55 @@ def _get_server(self):
proxy=getattr(self.args, 'proxy', None),
)

elif len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.HTTP:
from jina.serve.runtimes.servers.http import HTTPServer # we need a concrete implementation of this
return HTTPServer(name=self.args.name,
runtime_args=self.args,
req_handler_cls=self.req_handler_cls,
proxy=getattr(self.args, 'proxy', None),
uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None),
ssl_keyfile=getattr(self.args, 'ssl_keyfile', None),
ssl_certfile=getattr(self.args, 'ssl_certfile', None),
cors=getattr(self.args, 'cors', None),
is_cancel=self.is_cancel,
)
elif len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.WEBSOCKET:
from jina.serve.runtimes.servers.websocket import \
WebSocketServer # we need a concrete implementation of this
return WebSocketServer(name=self.args.name,
runtime_args=self.args,
req_handler_cls=self.req_handler_cls,
proxy=getattr(self.args, 'proxy', None),
uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None),
ssl_keyfile=getattr(self.args, 'ssl_keyfile', None),
ssl_certfile=getattr(self.args, 'ssl_certfile', None),
is_cancel=self.is_cancel)
elif (
len(self.args.protocol) == 1 and self.args.protocol[0] == ProtocolType.HTTP
):
from jina.serve.runtimes.servers.http import (
HTTPServer, # we need a concrete implementation of this
)

return HTTPServer(
name=self.args.name,
runtime_args=self.args,
req_handler_cls=self.req_handler_cls,
proxy=getattr(self.args, 'proxy', None),
uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None),
ssl_keyfile=getattr(self.args, 'ssl_keyfile', None),
ssl_certfile=getattr(self.args, 'ssl_certfile', None),
cors=getattr(self.args, 'cors', None),
is_cancel=self.is_cancel,
)
elif (
len(self.args.protocol) == 1
and self.args.protocol[0] == ProtocolType.WEBSOCKET
):
from jina.serve.runtimes.servers.websocket import (
WebSocketServer, # we need a concrete implementation of this
)

return WebSocketServer(
name=self.args.name,
runtime_args=self.args,
req_handler_cls=self.req_handler_cls,
proxy=getattr(self.args, 'proxy', None),
uvicorn_kwargs=getattr(self.args, 'uvicorn_kwargs', None),
ssl_keyfile=getattr(self.args, 'ssl_keyfile', None),
ssl_certfile=getattr(self.args, 'ssl_certfile', None),
is_cancel=self.is_cancel,
)
elif len(self.args.protocol) > 1:
from jina.serve.runtimes.servers.composite import \
CompositeServer # we need a concrete implementation of this
return CompositeServer(name=self.args.name,
runtime_args=self.args,
req_handler_cls=self.req_handler_cls,
ssl_keyfile=getattr(self.args, 'ssl_keyfile', None),
ssl_certfile=getattr(self.args, 'ssl_certfile', None),
is_cancel=self.is_cancel)
from jina.serve.runtimes.servers.composite import (
CompositeServer, # we need a concrete implementation of this
)

return CompositeServer(
name=self.args.name,
runtime_args=self.args,
req_handler_cls=self.req_handler_cls,
ssl_keyfile=getattr(self.args, 'ssl_keyfile', None),
ssl_certfile=getattr(self.args, 'ssl_certfile', None),
is_cancel=self.is_cancel,
)

def _send_telemetry_event(self, event, extra_kwargs=None):
gateway_kwargs = {}
Expand Down
Loading