Skip to content

Commit

Permalink
[BEAM-14014] Support impersonation credentials in dataflow runner (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
ryanthompson591 authored May 13, 2022
1 parent 04f4984 commit b774133
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 56 deletions.
39 changes: 39 additions & 0 deletions sdks/python/apache_beam/examples/wordcount_it_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from hamcrest.core.core.allof import all_of

from apache_beam.examples import wordcount
from apache_beam.internal.gcp import auth
from apache_beam.testing.load_tests.load_test_metrics_utils import InfluxDBMetricsPublisherOptions
from apache_beam.testing.load_tests.load_test_metrics_utils import MetricsReader
from apache_beam.testing.pipeline_verifiers import FileChecksumMatcher
Expand All @@ -47,6 +48,44 @@ class WordCountIT(unittest.TestCase):
def test_wordcount_it(self):
self._run_wordcount_it(wordcount.run)

@pytest.mark.it_postcommit
@pytest.mark.sickbay_direct
@pytest.mark.sickbay_spark
@pytest.mark.sickbay_flink
def test_wordcount_impersonation_it(self):
"""Tests impersonation on dataflow.
For testing impersonation, we use three ingredients:
- a principal to impersonate
- a dataflow service account that only that principal is
allowed to launch jobs as
- a temp root that only the above two accounts have access to
Jenkins and Dataflow workers both run as GCE default service account.
So we remove that account from all the above.
"""
# Credentials need to be reset or this test will fail and credentials
# from a previous test will be used.
auth._Credentials._credentials_init = False

ACCOUNT_TO_IMPERSONATE = (
'allows-impersonation@apache-'
'beam-testing.iam.gserviceaccount.com')
RUNNER_ACCOUNT = (
'impersonation-dataflow-worker@'
'apache-beam-testing.iam.gserviceaccount.com')
TEMP_DIR = 'gs://impersonation-test-bucket/temp-it'
STAGING_LOCATION = 'gs://impersonation-test-bucket/staging-it'
extra_options = {
'impersonate_service_account': ACCOUNT_TO_IMPERSONATE,
'service_account_email': RUNNER_ACCOUNT,
'temp_location': TEMP_DIR,
'staging_location': STAGING_LOCATION
}
self._run_wordcount_it(wordcount.run, **extra_options)
# Reset credentials for future tests.
auth._Credentials._credentials_init = False

@pytest.mark.it_postcommit
@pytest.mark.it_validatescontainer
def test_wordcount_fnapi_it(self):
Expand Down
66 changes: 48 additions & 18 deletions sdks/python/apache_beam/internal/gcp/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@
import socket
import threading

from apache_beam.options.pipeline_options import GoogleCloudOptions
from apache_beam.options.pipeline_options import PipelineOptions

# google.auth is only available when Beam is installed with the gcp extra.
try:
from google.auth import impersonated_credentials
import google.auth
import google_auth_httplib2
_GOOGLE_AUTH_AVAILABLE = True
Expand All @@ -40,6 +44,16 @@

_LOGGER = logging.getLogger(__name__)

CLIENT_SCOPES = [
'https://www.googleapis.com/auth/bigquery',
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/devstorage.full_control',
'https://www.googleapis.com/auth/userinfo.email',
'https://www.googleapis.com/auth/datastore',
'https://www.googleapis.com/auth/spanner.admin',
'https://www.googleapis.com/auth/spanner.data'
]


def set_running_in_gce(worker_executing_project):
"""For internal use only; no backwards-compatibility guarantees.
Expand All @@ -59,16 +73,19 @@ def set_running_in_gce(worker_executing_project):
executing_project = worker_executing_project


def get_service_credentials():
def get_service_credentials(pipeline_options):
"""For internal use only; no backwards-compatibility guarantees.
Get credentials to access Google services.
Args:
pipeline_options: Pipeline options, used in creating credentials
like impersonated credentials.
Returns:
A ``google.auth.credentials.Credentials`` object or None if credentials
not found. Returned object is thread-safe.
"""
return _Credentials.get_service_credentials()
return _Credentials.get_service_credentials(pipeline_options)


if _GOOGLE_AUTH_AVAILABLE:
Expand Down Expand Up @@ -108,10 +125,7 @@ class _Credentials(object):
_credentials = None

@classmethod
def get_service_credentials(cls):
if cls._credentials_init:
return cls._credentials

def get_service_credentials(cls, pipeline_options):
with cls._credentials_lock:
if cls._credentials_init:
return cls._credentials
Expand All @@ -124,31 +138,24 @@ def get_service_credentials(cls):
_LOGGER.info(
"socket default timeout is %s seconds.", socket.getdefaulttimeout())

cls._credentials = cls._get_service_credentials()
cls._credentials = cls._get_service_credentials(pipeline_options)
cls._credentials_init = True

return cls._credentials

@staticmethod
def _get_service_credentials():
def _get_service_credentials(pipeline_options):
if not _GOOGLE_AUTH_AVAILABLE:
_LOGGER.warning(
'Unable to find default credentials because the google-auth library '
'is not available. Install the gcp extra (apache_beam[gcp]) to use '
'Google default credentials. Connecting anonymously.')
return None

client_scopes = [
'https://www.googleapis.com/auth/bigquery',
'https://www.googleapis.com/auth/cloud-platform',
'https://www.googleapis.com/auth/devstorage.full_control',
'https://www.googleapis.com/auth/userinfo.email',
'https://www.googleapis.com/auth/datastore',
'https://www.googleapis.com/auth/spanner.admin',
'https://www.googleapis.com/auth/spanner.data'
]
try:
credentials, _ = google.auth.default(scopes=client_scopes) # pylint: disable=c-extension-no-member
credentials, _ = google.auth.default(scopes=CLIENT_SCOPES) # pylint: disable=c-extension-no-member
credentials = _Credentials._add_impersonation_credentials(
credentials, pipeline_options)
credentials = _ApitoolsCredentialsAdapter(credentials)
logging.debug(
'Connecting using Google Application Default '
Expand All @@ -160,3 +167,26 @@ def _get_service_credentials():
'Connecting anonymously.',
e)
return None

@staticmethod
def _add_impersonation_credentials(credentials, pipeline_options):
if isinstance(pipeline_options, PipelineOptions):
gcs_options = pipeline_options.view_as(GoogleCloudOptions)
impersonate_service_account = gcs_options.impersonate_service_account
elif isinstance(pipeline_options, dict):
impersonate_service_account = pipeline_options.get(
'impersonate_service_account')
else:
return credentials
if impersonate_service_account:
_LOGGER.info('Impersonating: %s', impersonate_service_account)
impersonate_accounts = impersonate_service_account.split(',')
target_principal = impersonate_accounts[-1]
delegate_to = impersonate_accounts[0:-1]
credentials = impersonated_credentials.Credentials(
source_credentials=credentials,
target_principal=target_principal,
delegates=delegate_to,
target_scopes=CLIENT_SCOPES,
)
return credentials
4 changes: 2 additions & 2 deletions sdks/python/apache_beam/io/gcp/bigquery_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ class BigQueryWrapper(object):
The wrapper is used to organize all the BigQuery integration points and
offer a common place where retry logic for failures can be controlled.
In addition it offers various functions used both in sources and sinks
In addition, it offers various functions used both in sources and sinks
(e.g., find and create tables, query a table, etc.).
"""

Expand All @@ -328,7 +328,7 @@ class BigQueryWrapper(object):
def __init__(self, client=None, temp_dataset_id=None, temp_table_ref=None):
self.client = client or bigquery.BigqueryV2(
http=get_new_http(),
credentials=auth.get_service_credentials(),
credentials=auth.get_service_credentials(None),
response_encoding='utf8',
additional_http_headers={
"user-agent": "apache-beam-%s" % apache_beam.__version__
Expand Down
31 changes: 19 additions & 12 deletions sdks/python/apache_beam/io/gcp/gcsfilesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class GCSFileSystem(FileSystem):
CHUNK_SIZE = gcsio.MAX_BATCH_OPERATION_SIZE # Chuck size in batch operations
GCS_PREFIX = 'gs://'

def __init__(self, pipeline_options):
super().__init__(pipeline_options)
self._pipeline_options = pipeline_options

@classmethod
def scheme(cls):
"""URI scheme for the FileSystem
Expand Down Expand Up @@ -127,12 +131,15 @@ def _list(self, dir_or_prefix):
``BeamIOError``: if listing fails, but not if no files were found.
"""
try:
for path, (size, updated) in gcsio.GcsIO().list_prefix(
for path, (size, updated) in self._gcsIO().list_prefix(
dir_or_prefix, with_metadata=True).items():
yield FileMetadata(path, size, updated)
except Exception as e: # pylint: disable=broad-except
raise BeamIOError("List operation failed", {dir_or_prefix: e})

def _gcsIO(self):
return gcsio.GcsIO(pipeline_options=self._pipeline_options)

def _path_open(
self,
path,
Expand All @@ -143,7 +150,7 @@ def _path_open(
"""
compression_type = FileSystem._get_compression_type(path, compression_type)
mime_type = CompressionTypes.mime_type(compression_type, mime_type)
raw_file = gcsio.GcsIO().open(path, mode, mime_type=mime_type)
raw_file = self._gcsIO().open(path, mode, mime_type=mime_type)
if compression_type == CompressionTypes.UNCOMPRESSED:
return raw_file
return CompressedFile(raw_file, compression_type=compression_type)
Expand Down Expand Up @@ -206,9 +213,9 @@ def _copy_path(source, destination):
raise ValueError('Destination %r must be GCS path.' % destination)
# Use copy_tree if the path ends with / as it is a directory
if source.endswith('/'):
gcsio.GcsIO().copytree(source, destination)
self._gcsIO().copytree(source, destination)
else:
gcsio.GcsIO().copy(source, destination)
self._gcsIO().copy(source, destination)

exceptions = {}
for source, destination in zip(source_file_names, destination_file_names):
Expand Down Expand Up @@ -249,15 +256,15 @@ def rename(self, source_file_names, destination_file_names):
# Execute GCS renames if any and return exceptions.
exceptions = {}
for batch in gcs_batches:
copy_statuses = gcsio.GcsIO().copy_batch(batch)
copy_statuses = self._gcsIO().copy_batch(batch)
copy_succeeded = []
for src, dest, exception in copy_statuses:
if exception:
exceptions[(src, dest)] = exception
else:
copy_succeeded.append((src, dest))
delete_batch = [src for src, dest in copy_succeeded]
delete_statuses = gcsio.GcsIO().delete_batch(delete_batch)
delete_statuses = self._gcsIO().delete_batch(delete_batch)
for i, (src, exception) in enumerate(delete_statuses):
dest = copy_succeeded[i][1]
if exception:
Expand All @@ -274,7 +281,7 @@ def exists(self, path):
Returns: boolean flag indicating if path exists
"""
return gcsio.GcsIO().exists(path)
return self._gcsIO().exists(path)

def size(self, path):
"""Get size of path on the FileSystem.
Expand All @@ -287,7 +294,7 @@ def size(self, path):
Raises:
``BeamIOError``: if path doesn't exist.
"""
return gcsio.GcsIO().size(path)
return self._gcsIO().size(path)

def last_updated(self, path):
"""Get UNIX Epoch time in seconds on the FileSystem.
Expand All @@ -300,7 +307,7 @@ def last_updated(self, path):
Raises:
``BeamIOError``: if path doesn't exist.
"""
return gcsio.GcsIO().last_updated(path)
return self._gcsIO().last_updated(path)

def checksum(self, path):
"""Fetch checksum metadata of a file on the
Expand All @@ -315,7 +322,7 @@ def checksum(self, path):
``BeamIOError``: if path isn't a file or doesn't exist.
"""
try:
return gcsio.GcsIO().checksum(path)
return self._gcsIO().checksum(path)
except Exception as e: # pylint: disable=broad-except
raise BeamIOError("Checksum operation failed", {path: e})

Expand All @@ -332,7 +339,7 @@ def metadata(self, path):
``BeamIOError``: if path isn't a file or doesn't exist.
"""
try:
file_metadata = gcsio.GcsIO()._status(path)
file_metadata = self._gcsIO()._status(path)
return FileMetadata(
path, file_metadata['size'], file_metadata['last_updated'])
except Exception as e: # pylint: disable=broad-except
Expand All @@ -353,7 +360,7 @@ def _delete_path(path):
else:
path_to_use = path
match_result = self.match([path_to_use])[0]
statuses = gcsio.GcsIO().delete_batch(
statuses = self._gcsIO().delete_batch(
[m.path for m in match_result.metadata_list])
# pylint: disable=used-before-assignment
failures = [e for (_, e) in statuses if e is not None]
Expand Down
Loading

0 comments on commit b774133

Please sign in to comment.