From 4ed7a631d635ed793760f73f76917c6012f44af5 Mon Sep 17 00:00:00 2001 From: Ankit Singhal <30610298+singankit@users.noreply.github.com> Date: Mon, 3 Oct 2022 14:18:56 -0700 Subject: [PATCH] Workspace using unified identity classes (#26588) * Workspace using unified credential classes * Updating tests * Adding missing import * Updating test to use constants for identity type * Updating params for Identity config class * Updating changelog.md file --- sdk/ml/azure-ai-ml/CHANGELOG.md | 1 + .../azure/ai/ml/_schema/workspace/identity.py | 6 +- .../azure/ai/ml/entities/__init__.py | 4 - .../azure/ai/ml/entities/_credentials.py | 56 ++++++++++++- .../ai/ml/entities/_workspace/identity.py | 83 ------------------- .../ai/ml/entities/_workspace/workspace.py | 8 +- .../ai/ml/operations/_workspace_operations.py | 12 ++- .../workspace/e2etests/test_workspace.py | 2 +- .../unittests/test_workspace_operations.py | 16 ++-- 9 files changed, 79 insertions(+), 109 deletions(-) delete mode 100644 sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/identity.py diff --git a/sdk/ml/azure-ai-ml/CHANGELOG.md b/sdk/ml/azure-ai-ml/CHANGELOG.md index abe2d3d49f7b..35e4b24c0263 100644 --- a/sdk/ml/azure-ai-ml/CHANGELOG.md +++ b/sdk/ml/azure-ai-ml/CHANGELOG.md @@ -17,6 +17,7 @@ - OnlineDeploymentOperations.delete has been renamed to begin_attach. - Datastore credentials are switched to use unified credential configuration classes. - UserAssignedIdentity is replaced by ManagedIdentityConfiguration +- Workspace ManagedServiceIdentity has been replaced by IdentityConfiguration. ### Bugs Fixed - Fix identity passthrough job with single file code diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/identity.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/identity.py index ac0ecdba416f..0c4235a9f6b7 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/identity.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/identity.py @@ -11,7 +11,7 @@ from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel from azure.ai.ml.constants._workspace import ManagedServiceIdentityType -from azure.ai.ml.entities._workspace.identity import ManagedServiceIdentity, UserAssignedIdentity +from azure.ai.ml.entities._credentials import IdentityConfiguration, ManagedIdentityConfiguration class UserAssignedIdentitySchema(metaclass=PatchedSchemaMeta): @@ -20,7 +20,7 @@ class UserAssignedIdentitySchema(metaclass=PatchedSchemaMeta): @post_load def make(self, data, **kwargs): - return UserAssignedIdentity(**data) + return ManagedIdentityConfiguration(**data) class IdentitySchema(metaclass=PatchedSchemaMeta): @@ -43,4 +43,4 @@ class IdentitySchema(metaclass=PatchedSchemaMeta): @post_load def make(self, data, **kwargs): data["type"] = snake_to_camel(data.pop("type")) - return ManagedServiceIdentity(**data) + return IdentityConfiguration(**data) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py index d199f41952c3..5e35f88d985b 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py @@ -84,8 +84,6 @@ from ._validation import ValidationResult from ._workspace.connections.workspace_connection import WorkspaceConnection from ._workspace.customer_managed_key import CustomerManagedKey -from ._workspace.identity import ManagedServiceIdentity -from ._workspace.identity import UserAssignedIdentity as WorkspaceUserAssignedIdentity from ._workspace.private_endpoint import EndpointConnection, PrivateEndpoint from ._workspace.workspace import Workspace from ._workspace.workspace_keys import WorkspaceKeys, NotebookAccessKeys, ContainerRegistryCredential @@ -145,8 +143,6 @@ "Workspace", "WorkspaceKeys", "WorkspaceConnection", - "ManagedServiceIdentity", - "WorkspaceUserAssignedIdentity", # pylint: disable=naming-mismatch "PrivateEndpoint", "EndpointConnection", "CustomerManagedKey", diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py index c3e7e8f6c145..a929f171cce8 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py @@ -44,6 +44,8 @@ Identity as RestIdentityConfiguration ) +from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestWorkspaceIdentityConfiguration +from azure.ai.ml._restclient.v2022_05_01.models import UserAssignedIdentity as RestWorkspaceUserAssignedIdentity from azure.ai.ml._restclient.v2022_10_01_preview.models import ( ManagedServiceIdentity as RestRegistryManagedIdentity ) @@ -319,12 +321,14 @@ def __init__( client_id: str = None, resource_id: str = None, object_id: str = None, + principal_id: str = None ): self.type = camel_to_snake(ConnectionAuthType.MANAGED_IDENTITY) self.client_id = client_id # TODO: Check if both client_id and resource_id are required self.resource_id = resource_id self.object_id = object_id + self.principal_id = principal_id def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionManagedIdentity: return RestWorkspaceConnectionManagedIdentity(client_id=self.client_id, resource_id=self.resource_id) @@ -363,6 +367,19 @@ def _from_identity_configuration_rest_object(cls, rest_obj: RestUserAssignedIden result.__dict__.update(rest_obj.as_dict()) return result + def _to_workspace_rest_object(self) -> RestWorkspaceUserAssignedIdentity: + return RestWorkspaceUserAssignedIdentity( + principal_id=self.principal_id, + client_id=self.client_id, + ) + + @classmethod + def _from_workspace_rest_object(cls, obj: RestWorkspaceUserAssignedIdentity) -> "ManagedIdentityConfiguration": + return cls( + principal_id=obj.principal_id, + client_id=obj.client_id, + ) + def __eq__(self, other: object) -> bool: if not isinstance(other, ManagedIdentityConfiguration): return NotImplemented @@ -405,7 +422,7 @@ def _from_job_rest_object(cls, obj: RestAmlToken) -> "AmlTokenConfiguration": class IdentityConfiguration(RestTranslatableMixin): """Managed identity specification.""" - def __init__(self, *, type: str, user_assigned_identities: List[ManagedIdentityConfiguration] = None): + def __init__(self, *, type: str, user_assigned_identities: List[ManagedIdentityConfiguration] = None, **kwargs): """Managed identity specification. :param type: Managed identity type, defaults to None @@ -416,8 +433,8 @@ def __init__(self, *, type: str, user_assigned_identities: List[ManagedIdentityC self.type = type self.user_assigned_identities = user_assigned_identities - self.principal_id = None - self.tenant_id = None + self.principal_id = kwargs.pop("principal_id", None) + self.tenant_id = kwargs.pop("tenant_id", None) def _to_compute_rest_object(self) -> RestIdentityConfiguration: rest_user_assigned_identities = ( @@ -446,6 +463,38 @@ def _from_compute_rest_object(cls, obj: RestIdentityConfiguration) -> "IdentityC result.tenant_id = obj.tenant_id return result + @classmethod + def _from_workspace_rest_object(cls, obj: RestWorkspaceIdentityConfiguration) -> "IdentityConfiguration": + user_assigned_identities = None + if obj.user_assigned_identities: + user_assigned_identities = {} + for k, v in obj.user_assigned_identities.items(): + metadata = None + if v and isinstance(v, RestUserAssignedIdentity): + metadata = ManagedIdentityConfiguration._from_workspace_rest_object(v) # pylint: disable=protected-access + user_assigned_identities[k] = metadata + return cls( + type=obj.type, + principal_id=obj.principal_id, + tenant_id=obj.tenant_id, + user_assigned_identities=user_assigned_identities, + ) + + def _to_workspace_rest_object(self) -> RestWorkspaceIdentityConfiguration: + + user_assigned_identities = ( + {uai.resource_id: uai._to_workspace_rest_object() for uai in self.user_assigned_identities} + if self.user_assigned_identities + else None + ) + + return RestWorkspaceIdentityConfiguration( + type=snake_to_pascal(self.type), + principal_id=self.principal_id, + tenant_id=self.tenant_id, + user_assigned_identities=user_assigned_identities, + ) + def _to_rest_object(self) -> RestRegistryManagedIdentity: return RestRegistryManagedIdentity( type=self.type, @@ -462,6 +511,7 @@ def _from_rest_object(cls, obj: RestRegistryManagedIdentity) -> "IdentityConfigu result.principal_id = obj.principal_id result.tenant_id = obj.tenant_id return result + class NoneCredentialConfiguration(RestTranslatableMixin): """None Credential Configuration.""" diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/identity.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/identity.py deleted file mode 100644 index ba5cd6527a30..000000000000 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/identity.py +++ /dev/null @@ -1,83 +0,0 @@ -# --------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# --------------------------------------------------------- - -from typing import Dict, Optional, Union - -from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestManagedServiceIdentity -from azure.ai.ml._restclient.v2022_05_01.models import UserAssignedIdentity as RestUserAssignedIdentity -from azure.ai.ml.constants._workspace import ManagedServiceIdentityType - - -class ManagedServiceIdentity: - """Managed service identity (system assigned and/or user assigned identities).""" - - def __init__( - self, - *, - type: Union[str, ManagedServiceIdentityType], # pylint: disable=redefined-builtin - principal_id: str = None, - tenant_id: str = None, - user_assigned_identities: Optional[Dict[str, "UserAssignedIdentity"]] = None, - ): - self.type = type - self.principal_id = principal_id - self.tenant_id = tenant_id - self.user_assigned_identities = user_assigned_identities - - def _to_rest_object(self) -> RestManagedServiceIdentity: - user_assigned_identities = None - if self.user_assigned_identities: - user_assigned_identities = {} - for k, v in self.user_assigned_identities.items(): - user_assigned_identities[k] = v._to_rest_object() if v else None # pylint: disable=protected-access - - return RestManagedServiceIdentity( - type=self.type, - principal_id=self.principal_id, - tenant_id=self.tenant_id, - user_assigned_identities=user_assigned_identities, - ) - - @classmethod - def _from_rest_object(cls, obj: RestManagedServiceIdentity) -> "ManagedServiceIdentity": - user_assigned_identities = None - if obj.user_assigned_identities: - user_assigned_identities = {} - for k, v in obj.user_assigned_identities.items(): - metadata = None - if v and isinstance(v, RestUserAssignedIdentity): - metadata = UserAssignedIdentity._from_rest_object(v) # pylint: disable=protected-access - user_assigned_identities[k] = metadata - return cls( - type=obj.type, - principal_id=obj.principal_id, - tenant_id=obj.tenant_id, - user_assigned_identities=user_assigned_identities, - ) - - -class UserAssignedIdentity: - """User assigned identity properties.""" - - def __init__( - self, - *, - principal_id: str = None, - client_id: str = None, - ): - self.principal_id = principal_id - self.client_id = client_id - - def _to_rest_object(self) -> RestUserAssignedIdentity: - return RestUserAssignedIdentity( - principal_id=self.principal_id, - client_id=self.client_id, - ) - - @classmethod - def _from_rest_object(cls, obj: RestUserAssignedIdentity) -> "UserAssignedIdentity": - return cls( - principal_id=obj.principal_id, - client_id=obj.client_id, - ) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/workspace.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/workspace.py index 8f9983613981..0555a8c8fbc4 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/workspace.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/workspace.py @@ -13,9 +13,9 @@ from azure.ai.ml._schema.workspace.workspace import WorkspaceSchema from azure.ai.ml._utils.utils import dump_yaml_to_file from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, WorkspaceResourceConstants +from azure.ai.ml.entities._credentials import IdentityConfiguration from azure.ai.ml.entities._resource import Resource from azure.ai.ml.entities._util import load_from_dict -from azure.ai.ml.entities._workspace.identity import ManagedServiceIdentity from .customer_managed_key import CustomerManagedKey @@ -38,7 +38,7 @@ def __init__( customer_managed_key: CustomerManagedKey = None, image_build_compute: str = None, public_network_access: str = None, - identity: ManagedServiceIdentity = None, + identity: IdentityConfiguration = None, primary_user_assigned_identity: str = None, **kwargs, ): @@ -83,7 +83,7 @@ def __init__( when a workspace is private link enabled. :type public_network_access: str :param identity: workspace's Managed Identity (user assigned, or system assigned) - :type identity: ManagedServiceIdentity + :type identity: IdentityConfiguration :param primary_user_assigned_identity: The workspace's primary user assigned identity :type primary_user_assigned_identity: str :param kwargs: A dictionary of additional configuration parameters. @@ -185,7 +185,7 @@ def _from_rest_object(cls, rest_obj: RestWorkspace) -> "Workspace": group = None if len(armid_parts) < 4 else armid_parts[4] identity = None if rest_obj.identity and isinstance(rest_obj.identity, RestManagedServiceIdentity): - identity = ManagedServiceIdentity._from_rest_object(rest_obj.identity) # pylint: disable=protected-access + identity = IdentityConfiguration._from_workspace_rest_object(rest_obj.identity) # pylint: disable=protected-access return Workspace( name=rest_obj.name, id=rest_obj.id, diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_operations.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_operations.py index 4ee9d448bb51..a01c5511318d 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_operations.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_workspace_operations.py @@ -25,10 +25,12 @@ get_resource_and_group_name, get_resource_group_location, ) +from azure.ai.ml._utils.utils import camel_to_snake from azure.ai.ml._version import VERSION from azure.ai.ml.constants import ManagedServiceIdentityType from azure.ai.ml.constants._common import ArmConstants, LROConfigurations, WorkspaceResourceConstants -from azure.ai.ml.entities import ManagedServiceIdentity, Workspace, WorkspaceKeys +from azure.ai.ml.entities import Workspace, WorkspaceKeys +from azure.ai.ml.entities._credentials import IdentityConfiguration from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException from azure.core.credentials import TokenCredential from azure.core.polling import LROPoller @@ -226,7 +228,7 @@ def begin_update( """ identity = kwargs.get("identity", workspace.identity) if identity: - identity = identity._to_rest_object() + identity = identity._to_workspace_rest_object() update_param = WorkspaceUpdateParameters( tags=workspace.tags, description=kwargs.get("description", workspace.description), @@ -494,9 +496,11 @@ def _populate_arm_paramaters(self, workspace: Workspace) -> Tuple[dict, dict, di identity = None if workspace.identity: - identity = workspace.identity._to_rest_object() + identity = workspace.identity._to_workspace_rest_object() else: - identity = ManagedServiceIdentity(type=ManagedServiceIdentityType.SYSTEM_ASSIGNED)._to_rest_object() + # pylint: disable=protected-access + identity = IdentityConfiguration( + type=camel_to_snake(ManagedServiceIdentityType.SYSTEM_ASSIGNED))._to_workspace_rest_object() _set_val(param["identity"], identity) if workspace.primary_user_assigned_identity: diff --git a/sdk/ml/azure-ai-ml/tests/workspace/e2etests/test_workspace.py b/sdk/ml/azure-ai-ml/tests/workspace/e2etests/test_workspace.py index beb4d9e40816..fc15b89a42b6 100644 --- a/sdk/ml/azure-ai-ml/tests/workspace/e2etests/test_workspace.py +++ b/sdk/ml/azure-ai-ml/tests/workspace/e2etests/test_workspace.py @@ -8,7 +8,7 @@ from azure.ai.ml import MLClient, load_workspace from azure.ai.ml.constants._common import PublicNetworkAccess -from azure.ai.ml.entities._workspace.identity import ManagedServiceIdentityType +from azure.ai.ml.constants._workspace import ManagedServiceIdentityType from azure.core.paging import ItemPaged from azure.mgmt.msi._managed_service_identity_client import ManagedServiceIdentityClient diff --git a/sdk/ml/azure-ai-ml/tests/workspace/unittests/test_workspace_operations.py b/sdk/ml/azure-ai-ml/tests/workspace/unittests/test_workspace_operations.py index 1fc15eb8eeea..35f83337b707 100644 --- a/sdk/ml/azure-ai-ml/tests/workspace/unittests/test_workspace_operations.py +++ b/sdk/ml/azure-ai-ml/tests/workspace/unittests/test_workspace_operations.py @@ -2,11 +2,13 @@ from unittest.mock import DEFAULT, Mock, call, patch import pytest +from azure.ai.ml._utils.utils import camel_to_snake from pytest_mock import MockFixture from azure.ai.ml._scope_dependent_operations import OperationScope from azure.ai.ml.constants import ManagedServiceIdentityType -from azure.ai.ml.entities import CustomerManagedKey, ManagedServiceIdentity, Workspace, WorkspaceUserAssignedIdentity +from azure.ai.ml.entities import CustomerManagedKey, Workspace, \ + IdentityConfiguration, ManagedIdentityConfiguration from azure.ai.ml.operations import WorkspaceOperations from azure.core.polling import LROPoller @@ -119,12 +121,12 @@ def test_update(self, mock_workspace_operation: WorkspaceOperations) -> None: public_network_access="Enabled", container_registry="foo_conntainer_registry", application_insights="foo_application_insights", - identity=ManagedServiceIdentity( - type=ManagedServiceIdentityType.USER_ASSIGNED, - user_assigned_identities={ - "resource1": WorkspaceUserAssignedIdentity(), - "resource2": WorkspaceUserAssignedIdentity(), - }, + identity=IdentityConfiguration( + type=camel_to_snake(ManagedServiceIdentityType.USER_ASSIGNED), + user_assigned_identities=[ + ManagedIdentityConfiguration(resource_id="resource1"), + ManagedIdentityConfiguration(resource_id="resource2") + ], ), primary_user_assigned_identity="resource2", )