From fe57eb26badb596fd9bd8a0b8b65f00f060b009d Mon Sep 17 00:00:00 2001 From: Dov Shlachter Date: Fri, 21 Jan 2022 12:54:48 -0800 Subject: [PATCH] feat: add interceptor-like functionality to REST transport (#1142) Interceptors are a gRPC feature that wraps rpcs in continuation-passing-style pre and post method custom functions. These can be used e.g. for logging, local caching, and tweaking metadata. This PR adds interceptor like functionality to the REST transport in generated GAPICs. The REST transport interceptors differ in a few ways: 1) They are not continuations. For each method there is a slot for a "pre"function, and for each method with a non-empty return there is a slot for a "post" function. 2) There is always an interceptor for each method. The default simply does nothing. 3) Existing gRPC interceptors and the new REST interceptors are not composable or interoperable. --- .../%service/transports/__init__.py.j2 | 2 + .../services/%service/transports/rest.py.j2 | 73 +++++++++++++++++-- .../%name_%version/%sub/test_%service.py.j2 | 50 +++++++++++++ .../%service/transports/__init__.py.j2 | 2 + .../services/%service/transports/rest.py.j2 | 73 +++++++++++++++++-- .../%name_%version/%sub/test_%service.py.j2 | 58 ++++++++++++++- noxfile.py | 17 +++-- .../unit/gapic/asset_v1/test_asset_service.py | 1 + .../unit/gapic/redis_v1/test_cloud_redis.py | 1 + 9 files changed, 255 insertions(+), 22 deletions(-) diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/__init__.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/__init__.py.j2 index 1241886b63..88d196a7c2 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/__init__.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/__init__.py.j2 @@ -11,6 +11,7 @@ from .grpc import {{ service.name }}GrpcTransport {% endif %} {% if 'rest' in opts.transport %} from .rest import {{ service.name }}RestTransport +from .rest import {{ service.name }}RestInterceptor {% endif %} # Compile a registry of transports. @@ -29,6 +30,7 @@ __all__ = ( {% endif %} {% if 'rest' in opts.transport %} '{{ service.name }}RestTransport', + '{{ service.name }}RestInterceptor', {% endif %} ) {% endblock %} diff --git a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 index 488646be12..a4dc7c61e3 100644 --- a/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/ads-templates/%namespace/%name/%version/%sub/services/%service/transports/rest.py.j2 @@ -49,10 +49,67 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( rest_version=requests_version, ) + +class {{ service.name }}RestInterceptor: + """Interceptor for {{ service.name }}. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the {{ service.name }}RestTransport. + + .. code-block: + class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor): +{% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %} + def pre_{{ method.name|snake_case }}(request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + {% if not method.void %} + def post_{{ method.name|snake_case }}(response): + logging.log(f"Received response: {response}") + {% endif %} + +{% endfor %} + transport = {{ service.name }}RestTransport(interceptor=MyCustom{{ service.name }}Interceptor()) + client = {{ service.client_name }}(transport=transport) + + + """ + {% for method in service.methods.values()|sort(attribute="name") if not(method.server_streaming or method.client_streaming) %} + def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for {{ method.name|snake_case }} + + Override in a subclass to manipulate the request or metadata + before they are sent to the {{ service.name }} server. + """ + return request, metadata + + {% if not method.void %} + def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}: + """Post-rpc interceptor for {{ method.name|snake_case }} + + Override in a subclass to manipulate the response + after it is returned by the {{ service.name }} server but before + it is returned to user code. + """ + return response + {% endif %} + + {% endfor %} + + @dataclasses.dataclass class {{service.name}}RestStub: _session: AuthorizedSession _host: str + _interceptor: {{ service.name }}RestInterceptor + class {{service.name}}RestTransport({{service.name}}Transport): """REST backend transport for {{ service.name }}. @@ -80,6 +137,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO, always_use_jwt_access: Optional[bool]=False, url_scheme: str='https', + interceptor: Optional[{{ service.name }}RestInterceptor] = None, ) -> None: """Instantiate the transport. @@ -130,6 +188,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): {% endif %} if client_cert_source_for_mtls: self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or {{ service.name }}RestInterceptor() self._prep_wrapped_messages(client_info) {% if service.has_lro %} @@ -233,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): }, {% endfor %}{# rule in method.http_options #} ] - + request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata) request_kwargs = {{method.input.ident}}.to_dict(request) transcoded_request = path_template.transcode( http_options, **request_kwargs) @@ -288,16 +347,16 @@ class {{service.name}}RestTransport({{service.name}}Transport): {% if not method.void %} # Return the response {% if method.lro %} - return_op = operations_pb2.Operation() - json_format.Parse(response.content, return_op, ignore_unknown_fields=True) - return return_op + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) {% else %} - return {{method.output.ident}}.from_json( + resp = {{method.output.ident}}.from_json( response.content, ignore_unknown_fields=True ) - {% endif %}{# method.lro #} + resp = self._interceptor.post_{{ method.name|snake_case }}(resp) + return resp {% endif %}{# method.void #} {% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #} {% if not method.http_options %} @@ -323,7 +382,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): {{method.output.ident}}]: stub = self._STUBS.get("{{method.name | snake_case}}") if not stub: - stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host) + stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor) return stub diff --git a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index af7d28335d..a7934d84e8 100644 --- a/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/ads-templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -35,6 +35,7 @@ from google.api_core import grpc_helpers from google.api_core import path_template {% if service.has_lro %} from google.api_core import future +from google.api_core import operation from google.api_core import operations_v1 from google.longrunning import operations_pb2 {% if "rest" in opts.transport %} @@ -1113,6 +1114,55 @@ def test_{{ method_name }}_rest_unset_required_fields(): {% endif %}{# required_fields #} +{% if not (method.server_streaming or method.client_streaming) %} +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_{{ method_name }}_rest_interceptors(null_interceptor): + transport = transports.{{ service.name }}RestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.{{ service.name}}RestInterceptor(), + ) + client = {{ service.client_name }}(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + {% if method.lro %} + mock.patch.object(operation.Operation, "_set_result_from_operation"), \ + {% endif %} + {% if not method.void %} + mock.patch.object(transports.{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}") as post, \ + {% endif %} + mock.patch.object(transports.{{ service.name }}RestInterceptor, "pre_{{ method.name|snake_case }}") as pre: + pre.assert_not_called() + {% if not method.void %} + post.assert_not_called() + {% endif %} + + transcode.return_value = {"method": "post", "uri": "my_uri", "body": None, "query_params": {},} + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + {% if not method.void %} + req.return_value._content = {% if method.output.ident.package == method.ident.package %}{{ method.output.ident }}.to_json({{ method.output.ident }}()){% else %}json_format.MessageToJson({{ method.output.ident }}()){% endif %} + {% endif %} + + request = {{ method.input.ident }}() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + {% if not method.void %} + post.return_value = {{ method.output.ident }} + {% endif %} + + client.{{ method_name }}(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + {% if not method.void %} + post.assert_called_once() + {% endif %} +{% endif %}{# streaming #} + def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}): client = {{ service.client_name }}( diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 index 107e2bd4e8..66be2e5c29 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/__init__.py.j2 @@ -12,6 +12,7 @@ from .grpc_asyncio import {{ service.name }}GrpcAsyncIOTransport {% endif %} {% if 'rest' in opts.transport %} from .rest import {{ service.name }}RestTransport +from .rest import {{ service.name }}RestInterceptor {% endif %} @@ -34,6 +35,7 @@ __all__ = ( {% endif %} {% if 'rest' in opts.transport %} '{{ service.name }}RestTransport', + '{{ service.name }}RestInterceptor', {% endif %} ) {% endblock %} diff --git a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 index 488646be12..b208c0940f 100644 --- a/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 +++ b/gapic/templates/%namespace/%name_%version/%sub/services/%service/transports/rest.py.j2 @@ -49,10 +49,67 @@ DEFAULT_CLIENT_INFO = gapic_v1.client_info.ClientInfo( rest_version=requests_version, ) + +class {{ service.name }}RestInterceptor: + """Interceptor for {{ service.name }}. + + Interceptors are used to manipulate requests, request metadata, and responses + in arbitrary ways. + Example use cases include: + * Logging + * Verifying requests according to service or custom semantics + * Stripping extraneous information from responses + + These use cases and more can be enabled by injecting an + instance of a custom subclass when constructing the {{ service.name }}RestTransport. + + .. code-block: + class MyCustom{{ service.name }}Interceptor({{ service.name }}RestInterceptor): + {% for _, method in service.methods|dictsort if not (method.server_streaming or method.client_streaming) %} + def pre_{{ method.name|snake_case }}(request, metadata): + logging.log(f"Received request: {request}") + return request, metadata + + {% if not method.void %} + def post_{{ method.name|snake_case }}(response): + logging.log(f"Received response: {response}") + {% endif %} + +{% endfor %} + transport = {{ service.name }}RestTransport(interceptor=MyCustom{{ service.name }}Interceptor()) + client = {{ service.client_name }}(transport=transport) + + + """ + {% for method in service.methods.values()|sort(attribute="name") if not (method.server_streaming or method.client_streaming) %} + def pre_{{ method.name|snake_case }}(self, request: {{method.input.ident}}, metadata: Sequence[Tuple[str, str]]) -> Tuple[{{method.input.ident}}, Sequence[Tuple[str, str]]]: + """Pre-rpc interceptor for {{ method.name|snake_case }} + + Override in a subclass to manipulate the request or metadata + before they are sent to the {{ service.name }} server. + """ + return request, metadata + + {% if not method.void %} + def post_{{ method.name|snake_case }}(self, response: {{method.output.ident}}) -> {{method.output.ident}}: + """Post-rpc interceptor for {{ method.name|snake_case }} + + Override in a subclass to manipulate the response + after it is returned by the {{ service.name }} server but before + it is returned to user code. + """ + return response + {% endif %} + + {% endfor %} + + @dataclasses.dataclass class {{service.name}}RestStub: _session: AuthorizedSession _host: str + _interceptor: {{ service.name }}RestInterceptor + class {{service.name}}RestTransport({{service.name}}Transport): """REST backend transport for {{ service.name }}. @@ -80,6 +137,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): client_info: gapic_v1.client_info.ClientInfo=DEFAULT_CLIENT_INFO, always_use_jwt_access: Optional[bool]=False, url_scheme: str='https', + interceptor: Optional[{{ service.name }}RestInterceptor] = None, ) -> None: """Instantiate the transport. @@ -130,6 +188,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): {% endif %} if client_cert_source_for_mtls: self._session.configure_mtls_channel(client_cert_source_for_mtls) + self._interceptor = interceptor or {{ service.name }}RestInterceptor() self._prep_wrapped_messages(client_info) {% if service.has_lro %} @@ -233,7 +292,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): }, {% endfor %}{# rule in method.http_options #} ] - + request, metadata = self._interceptor.pre_{{ method.name|snake_case }}(request, metadata) request_kwargs = {{method.input.ident}}.to_dict(request) transcoded_request = path_template.transcode( http_options, **request_kwargs) @@ -288,16 +347,16 @@ class {{service.name}}RestTransport({{service.name}}Transport): {% if not method.void %} # Return the response {% if method.lro %} - return_op = operations_pb2.Operation() - json_format.Parse(response.content, return_op, ignore_unknown_fields=True) - return return_op + resp = operations_pb2.Operation() + json_format.Parse(response.content, resp, ignore_unknown_fields=True) {% else %} - return {{method.output.ident}}.from_json( + resp = {{method.output.ident}}.from_json( response.content, ignore_unknown_fields=True ) - {% endif %}{# method.lro #} + resp = self._interceptor.post_{{ method.name|snake_case }}(resp) + return resp {% endif %}{# method.void #} {% else %}{# method.http_options and not (method.server_streaming or method.client_streaming) #} {% if not method.http_options %} @@ -323,7 +382,7 @@ class {{service.name}}RestTransport({{service.name}}Transport): {{method.output.ident}}]: stub = self._STUBS.get("{{method.name | snake_case}}") if not stub: - stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host) + stub = self._STUBS["{{method.name | snake_case}}"] = self._{{method.name}}(self._session, self._host, self._interceptor) return stub diff --git a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 index 56cdbc6287..cdee5b7697 100644 --- a/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 +++ b/gapic/templates/tests/unit/gapic/%name_%version/%sub/test_%service.py.j2 @@ -39,6 +39,7 @@ from google.api_core import grpc_helpers_async from google.api_core import path_template {% if service.has_lro %} from google.api_core import future +from google.api_core import operation from google.api_core import operations_v1 from google.longrunning import operations_pb2 {% if "rest" in opts.transport %} @@ -1515,6 +1516,57 @@ def test_{{ method_name }}_rest_unset_required_fields(): {% endif %}{# required_fields #} +{% if not (method.server_streaming or method.client_streaming) %} +@pytest.mark.parametrize("null_interceptor", [True, False]) +def test_{{ method_name }}_rest_interceptors(null_interceptor): + transport = transports.{{ service.name }}RestTransport( + credentials=ga_credentials.AnonymousCredentials(), + interceptor=None if null_interceptor else transports.{{ service.name}}RestInterceptor(), + ) + client = {{ service.client_name }}(transport=transport) + with mock.patch.object(type(client.transport._session), "request") as req, \ + mock.patch.object(path_template, "transcode") as transcode, \ + {% if method.lro %} + mock.patch.object(operation.Operation, "_set_result_from_operation"), \ + {% endif %} + {% if not method.void %} + mock.patch.object(transports.{{ service.name }}RestInterceptor, "post_{{method.name|snake_case}}") as post, \ + {% endif %} + mock.patch.object(transports.{{ service.name }}RestInterceptor, "pre_{{ method.name|snake_case }}") as pre: + pre.assert_not_called() + {% if not method.void %} + post.assert_not_called() + {% endif %} + + transcode.return_value = {"method": "post", "uri": "my_uri", "body": None, "query_params": {},} + + req.return_value = Response() + req.return_value.status_code = 200 + req.return_value.request = PreparedRequest() + {% if not method.void %} + req.return_value._content = {% if method.output.ident.package == method.ident.package %}{{ method.output.ident }}.to_json({{ method.output.ident }}()){% else %}json_format.MessageToJson({{ method.output.ident }}()){% endif %} + {% endif %} + + request = {{ method.input.ident }}() + metadata =[ + ("key", "val"), + ("cephalopod", "squid"), + ] + pre.return_value = request, metadata + {% if not method.void %} + post.return_value = {{ method.output.ident }} + {% endif %} + + client.{{ method_name }}(request, metadata=[("key", "val"), ("cephalopod", "squid"),]) + + pre.assert_called_once() + {% if not method.void %} + post.assert_called_once() + {% endif %} + +{% endif %}{# streaming #} + + def test_{{ method_name }}_rest_bad_request(transport: str = 'rest', request_type={{ method.input.ident }}): client = {{ service.client_name }}( credentials=ga_credentials.AnonymousCredentials(), @@ -1829,7 +1881,7 @@ def test_credentials_transport_error(): client_options={"credentials_file": "credentials.json"}, transport=transport, ) - + # It is an error to provide an api_key and a transport instance. transport = transports.{{ service.name }}{{ opts.transport[0].capitalize() }}Transport( credentials=ga_credentials.AnonymousCredentials(), @@ -1841,7 +1893,7 @@ def test_credentials_transport_error(): client_options=options, transport=transport, ) - + # It is an error to provide an api_key and a credential. options = mock.Mock() options.api_key = "api_key" @@ -2141,6 +2193,8 @@ def test_{{ service.name|snake_case }}_rest_lro_client(): # Ensure that subsequent calls to the property send the exact same object. assert transport.operations_client is transport.operations_client {%- endif %} + + {% endif %} {# rest #} def test_{{ service.name|snake_case }}_host_no_port(): diff --git a/noxfile.py b/noxfile.py index 6154ea94bb..a9df7d65d4 100644 --- a/noxfile.py +++ b/noxfile.py @@ -310,12 +310,17 @@ def run_showcase_unit_tests(session, fail_under=100): # Run the tests. session.run( "py.test", - "-n=auto", - "--quiet", - "--cov=google", - "--cov-append", - f"--cov-fail-under={str(fail_under)}", - *(session.posargs or [path.join("tests", "unit")]), + *( + session.posargs + or [ + "-n=auto", + "--quiet", + "--cov=google", + "--cov-append", + f"--cov-fail-under={str(fail_under)}", + path.join("tests", "unit"), + ] + ), ) diff --git a/tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py b/tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py index a2747714b8..dd4f527b9c 100644 --- a/tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py +++ b/tests/integration/goldens/asset/tests/unit/gapic/asset_v1/test_asset_service.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template diff --git a/tests/integration/goldens/redis/tests/unit/gapic/redis_v1/test_cloud_redis.py b/tests/integration/goldens/redis/tests/unit/gapic/redis_v1/test_cloud_redis.py index e89e2e73fd..b189511ab7 100644 --- a/tests/integration/goldens/redis/tests/unit/gapic/redis_v1/test_cloud_redis.py +++ b/tests/integration/goldens/redis/tests/unit/gapic/redis_v1/test_cloud_redis.py @@ -29,6 +29,7 @@ from google.api_core import gapic_v1 from google.api_core import grpc_helpers from google.api_core import grpc_helpers_async +from google.api_core import operation from google.api_core import operation_async # type: ignore from google.api_core import operations_v1 from google.api_core import path_template