Skip to content

Commit

Permalink
Fix credentials intialization revealed by mypy version of google auth (
Browse files Browse the repository at this point in the history
  • Loading branch information
potiuk authored and romsharon98 committed Jul 26, 2024
1 parent 9822618 commit 33be1a7
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def get_custom_job_object(
base_output_dir=base_output_dir,
project=project,
location=location,
credentials=self.get_credentials,
credentials=self.get_credentials(),
labels=labels,
encryption_spec_key_name=encryption_spec_key_name,
staging_bucket=staging_bucket,
Expand Down
19 changes: 13 additions & 6 deletions airflow/providers/google/cloud/utils/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,8 @@ def get_credentials_and_project(self) -> tuple[Credentials, str]:
:return: Google Auth Credentials
"""
if self.is_anonymous:
credentials, project_id = AnonymousCredentials(), ""
credentials: Credentials = AnonymousCredentials()
project_id = ""
else:
if self.key_path:
credentials, project_id = self._get_credentials_using_key_path()
Expand Down Expand Up @@ -273,18 +274,22 @@ def get_credentials_and_project(self) -> tuple[Credentials, str]:

return credentials, project_id

def _get_credentials_using_keyfile_dict(self):
def _get_credentials_using_keyfile_dict(self) -> tuple[Credentials, str]:
self._log_debug("Getting connection using JSON Dict")
# Depending on how the JSON was formatted, it may contain
# escaped newlines. Convert those to actual newlines.
if self.keyfile_dict is None:
raise ValueError("The keyfile_dict field is None, and we need it for keyfile_dict auth.")
self.keyfile_dict["private_key"] = self.keyfile_dict["private_key"].replace("\\n", "\n")
credentials = google.oauth2.service_account.Credentials.from_service_account_info(
self.keyfile_dict, scopes=self.scopes
)
project_id = credentials.project_id
return credentials, project_id

def _get_credentials_using_key_path(self):
def _get_credentials_using_key_path(self) -> tuple[Credentials, str]:
if self.key_path is None:
raise ValueError("The ky_path field is None, and we need it for keyfile_dict auth.")
if self.key_path.endswith(".p12"):
raise AirflowException("Legacy P12 key file are not supported, use a JSON key file.")

Expand All @@ -298,13 +303,15 @@ def _get_credentials_using_key_path(self):
project_id = credentials.project_id
return credentials, project_id

def _get_credentials_using_key_secret_name(self):
def _get_credentials_using_key_secret_name(self) -> tuple[Credentials, str]:
self._log_debug("Getting connection using JSON key data from GCP secret: %s", self.key_secret_name)

# Use ADC to access GCP Secret Manager.
adc_credentials, adc_project_id = google.auth.default(scopes=self.scopes)
secret_manager_client = _SecretManagerClient(credentials=adc_credentials)

if self.key_secret_name is None:
raise ValueError("The key_secret_name field is None, and we need it for keyfile_dict auth.")
if not secret_manager_client.is_valid_secret_name(self.key_secret_name):
raise AirflowException("Invalid secret name specified for fetching JSON key data.")

Expand All @@ -326,7 +333,7 @@ def _get_credentials_using_key_secret_name(self):
project_id = credentials.project_id
return credentials, project_id

def _get_credentials_using_credential_config_file(self):
def _get_credentials_using_credential_config_file(self) -> tuple[Credentials, str]:
if isinstance(self.credential_config_file, str) and os.path.exists(self.credential_config_file):
self._log_info(
f"Getting connection using credential configuration file: `{self.credential_config_file}`"
Expand All @@ -350,7 +357,7 @@ def _get_credentials_using_credential_config_file(self):

return credentials, project_id

def _get_credentials_using_adc(self):
def _get_credentials_using_adc(self) -> tuple[Credentials, str]:
self._log_info(
"Getting connection using `google.auth.default()` since no explicit credentials are provided."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def _get_gce_credentials(


def get_default_id_token_credentials(
target_audience: str | None, request: google.auth.transport.Request = None
target_audience: str | None, request: google.auth.transport.Request | None = None
) -> google_auth_credentials.Credentials:
"""Get the default ID Token credentials for the current environment.
Expand Down

0 comments on commit 33be1a7

Please sign in to comment.