Skip to content

Commit

Permalink
feat: support self-signed JWT flow for servie accounts
Browse files Browse the repository at this point in the history
  • Loading branch information
busunkim96 committed Feb 10, 2021
1 parent 9bde5c7 commit d696ce3
Show file tree
Hide file tree
Showing 5 changed files with 307 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ class {{ service.name }}Transport(abc.ABC):
{%- endfor %}
)

DEFAULT_HOST = {% if service.host %}'{{ service.host }}'{% else %}{{ None }}{% endif %}

def __init__(
self, *,
host: str{% if service.host %} = '{{ service.host }}'{% endif %},
host: str = DEFAULT_HOST,
credentials: credentials.Credentials = None,
credentials_file: typing.Optional[str] = None,
scopes: typing.Optional[typing.Sequence[str]] = AUTH_SCOPES,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from typing import Callable, Dict, Optional, Sequence, Tuple

import google.api_core
from google.api_core import grpc_helpers # type: ignore
{%- if service.has_lro %}
from google.api_core import operations_v1 # type: ignore
Expand All @@ -12,6 +13,8 @@ from google.api_core import gapic_v1 # type: ignore
from google import auth # type: ignore
from google.auth import credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
import packaging.version
import pkg_resources

import grpc # type: ignore

Expand All @@ -27,6 +30,17 @@ from google.iam.v1 import policy_pb2 as policy # type: ignore
{% endfilter %}
from .base import {{ service.name }}Transport, DEFAULT_CLIENT_INFO

try:
# google.auth.__version__ was added in 1.26.0
_GOOGLE_AUTH_VERSION = auth.__version__
except AttributeError:
try: # try pkg_resources if it is available
_GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version
except pkg_resources.DistributionNotFound: # pragma: NO COVER
_GOOGLE_AUTH_VERSION = None

_API_CORE_VERSION = google.api_core.__version__


class {{ service.name }}GrpcTransport({{ service.name }}Transport):
"""gRPC backend transport for {{ service.name }}.
Expand Down Expand Up @@ -101,6 +115,22 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""

# If a custom API endpoint is set, set scopes to ensure the auth
# library does not used the self-signed JWT flow for service
# accounts
if host.split(":")[0] != self.DEFAULT_HOST and not scopes:
scopes = self.AUTH_SCOPES

# TODO(busunkim): Remove this if/else once google-auth >= 1.25.0 is required
if _GOOGLE_AUTH_VERSION and (
packaging.version.parse(_GOOGLE_AUTH_VERSION)
>= packaging.version.parse("1.25.0")
):
scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES}
else:
scopes_kwargs = {"scopes": scopes or self.AUTH_SCOPES}

self._ssl_channel_credentials = ssl_channel_credentials

if api_mtls_endpoint:
Expand All @@ -120,7 +150,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)
credentials, _ = auth.default(**scopes_kwargs, quota_project_id=quota_project_id)

# Create SSL credentials with client_cert_source or application
# default SSL credentials.
Expand All @@ -138,7 +168,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
credentials=credentials,
credentials_file=credentials_file,
ssl_credentials=ssl_credentials,
scopes=scopes or self.AUTH_SCOPES,
scopes=scopes,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
Expand All @@ -150,7 +180,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
host = host if ":" in host else host + ":443"

if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)
credentials, _ = auth.default(**scopes_kwargs, quota_project_id=quota_project_id)

if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
Expand All @@ -164,7 +194,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
credentials=credentials,
credentials_file=credentials_file,
ssl_credentials=self._ssl_channel_credentials,
scopes=scopes or self.AUTH_SCOPES,
scopes=scopes,
quota_project_id=quota_project_id,
options=[
("grpc.max_send_message_length", -1),
Expand All @@ -182,7 +212,7 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
host=host,
credentials=credentials,
credentials_file=credentials_file,
scopes=scopes or self.AUTH_SCOPES,
scopes=scopes,
quota_project_id=quota_project_id,
client_info=client_info,
)
Expand Down Expand Up @@ -220,7 +250,19 @@ class {{ service.name }}GrpcTransport({{ service.name }}Transport):
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
scopes = scopes or cls.AUTH_SCOPES
self_signed_jwt_kwargs = {}

# TODO(busunkim): Remove this if/else once google-api-core >= 1.26.0 is required
if _API_CORE_VERSION and (
packaging.version.parse(_API_CORE_VERSION)
>= packaging.version.parse("1.26.0")
):
self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES
self_signed_jwt_kwargs["scopes"] = scopes
self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST
else:
self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES

return grpc_helpers.create_channel(
host,
credentials=credentials,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from google.api_core import operations_v1 # type: ignore
from google import auth # type: ignore
from google.auth import credentials # type: ignore
from google.auth.transport.grpc import SslCredentials # type: ignore
import packaging.version

import grpc # type: ignore
from grpc.experimental import aio # type: ignore
Expand All @@ -28,6 +29,8 @@ from google.iam.v1 import policy_pb2 as policy # type: ignore
{% endfilter %}
from .base import {{ service.name }}Transport, DEFAULT_CLIENT_INFO
from .grpc import {{ service.name }}GrpcTransport
from .grpc import _API_CORE_VERSION
from .grpc import _GOOGLE_AUTH_VERSION


class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
Expand Down Expand Up @@ -75,7 +78,19 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
Returns:
aio.Channel: A gRPC AsyncIO channel object.
"""
scopes = scopes or cls.AUTH_SCOPES
self_signed_jwt_kwargs = {}

# TODO(busunkim): Remove this if/else once google-api-core >= 1.26.0 is required
if _API_CORE_VERSION and (
packaging.version.parse(_API_CORE_VERSION)
>= packaging.version.parse("1.26.0")
):
self_signed_jwt_kwargs["default_scopes"] = cls.AUTH_SCOPES
self_signed_jwt_kwargs["scopes"] = scopes
self_signed_jwt_kwargs["default_host"] = cls.DEFAULT_HOST
else:
self_signed_jwt_kwargs["scopes"] = scopes or cls.AUTH_SCOPES

return grpc_helpers_async.create_channel(
host,
credentials=credentials,
Expand Down Expand Up @@ -145,6 +160,21 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
google.api_core.exceptions.DuplicateCredentialArgs: If both ``credentials``
and ``credentials_file`` are passed.
"""
# If a custom API endpoint is set, set scopes to ensure the auth
# library does not used the self-signed JWT flow for service
# accounts
if host.split(":")[0] != self.DEFAULT_HOST and not scopes:
scopes = self.AUTH_SCOPES

# TODO: Remove this if/else once google-auth >= 1.25.0 is required
if _GOOGLE_AUTH_VERSION and packaging.version.parse(
_GOOGLE_AUTH_VERSION
) >= packaging.version.parse("1.25.0"):
scopes_kwargs = {"scopes": scopes, "default_scopes": self.AUTH_SCOPES}
else:
scopes_kwargs = {"scopes": scopes or self.AUTH_SCOPES}


self._ssl_channel_credentials = ssl_channel_credentials

if api_mtls_endpoint:
Expand All @@ -164,7 +194,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
host = api_mtls_endpoint if ":" in api_mtls_endpoint else api_mtls_endpoint + ":443"

if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)
credentials, _ = auth.default(**scopes_kwargs, quota_project_id=quota_project_id)

# Create SSL credentials with client_cert_source or application
# default SSL credentials.
Expand Down Expand Up @@ -194,7 +224,7 @@ class {{ service.grpc_asyncio_transport_name }}({{ service.name }}Transport):
host = host if ":" in host else host + ":443"

if credentials is None:
credentials, _ = auth.default(scopes=self.AUTH_SCOPES, quota_project_id=quota_project_id)
credentials, _ = auth.default(**scopes_kwargs, quota_project_id=quota_project_id)

if client_cert_source_for_mtls and not ssl_channel_credentials:
cert, key = client_cert_source_for_mtls()
Expand Down
1 change: 1 addition & 0 deletions gapic/templates/setup.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ setuptools.setup(
'google-api-core[grpc] >= 1.22.2, < 2.0.0dev',
'libcst >= 0.2.5',
'proto-plus >= 1.4.0',
'packaging >= 14.3',
{%- if api.requires_package(('google', 'iam', 'v1')) or opts.add_iam_methods %}
'grpc-google-iam-v1',
{%- endif %}
Expand Down
Loading

0 comments on commit d696ce3

Please sign in to comment.