Skip to content

Commit

Permalink
Workspace using unified identity classes (Azure#26588)
Browse files Browse the repository at this point in the history
* Workspace using unified credential classes

* Updating tests

* Adding missing import

* Updating test to use constants for identity type

* Updating params for Identity config class

* Updating changelog.md file
  • Loading branch information
singankit authored Oct 3, 2022
1 parent 1b298f9 commit 4ed7a63
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 109 deletions.
1 change: 1 addition & 0 deletions sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- OnlineDeploymentOperations.delete has been renamed to begin_attach.
- Datastore credentials are switched to use unified credential configuration classes.
- UserAssignedIdentity is replaced by ManagedIdentityConfiguration
- Workspace ManagedServiceIdentity has been replaced by IdentityConfiguration.

### Bugs Fixed
- Fix identity passthrough job with single file code
Expand Down
6 changes: 3 additions & 3 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/workspace/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel
from azure.ai.ml.constants._workspace import ManagedServiceIdentityType
from azure.ai.ml.entities._workspace.identity import ManagedServiceIdentity, UserAssignedIdentity
from azure.ai.ml.entities._credentials import IdentityConfiguration, ManagedIdentityConfiguration


class UserAssignedIdentitySchema(metaclass=PatchedSchemaMeta):
Expand All @@ -20,7 +20,7 @@ class UserAssignedIdentitySchema(metaclass=PatchedSchemaMeta):

@post_load
def make(self, data, **kwargs):
return UserAssignedIdentity(**data)
return ManagedIdentityConfiguration(**data)


class IdentitySchema(metaclass=PatchedSchemaMeta):
Expand All @@ -43,4 +43,4 @@ class IdentitySchema(metaclass=PatchedSchemaMeta):
@post_load
def make(self, data, **kwargs):
data["type"] = snake_to_camel(data.pop("type"))
return ManagedServiceIdentity(**data)
return IdentityConfiguration(**data)
4 changes: 0 additions & 4 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@
from ._validation import ValidationResult
from ._workspace.connections.workspace_connection import WorkspaceConnection
from ._workspace.customer_managed_key import CustomerManagedKey
from ._workspace.identity import ManagedServiceIdentity
from ._workspace.identity import UserAssignedIdentity as WorkspaceUserAssignedIdentity
from ._workspace.private_endpoint import EndpointConnection, PrivateEndpoint
from ._workspace.workspace import Workspace
from ._workspace.workspace_keys import WorkspaceKeys, NotebookAccessKeys, ContainerRegistryCredential
Expand Down Expand Up @@ -145,8 +143,6 @@
"Workspace",
"WorkspaceKeys",
"WorkspaceConnection",
"ManagedServiceIdentity",
"WorkspaceUserAssignedIdentity", # pylint: disable=naming-mismatch
"PrivateEndpoint",
"EndpointConnection",
"CustomerManagedKey",
Expand Down
56 changes: 53 additions & 3 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
Identity as RestIdentityConfiguration
)

from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestWorkspaceIdentityConfiguration
from azure.ai.ml._restclient.v2022_05_01.models import UserAssignedIdentity as RestWorkspaceUserAssignedIdentity
from azure.ai.ml._restclient.v2022_10_01_preview.models import (
ManagedServiceIdentity as RestRegistryManagedIdentity
)
Expand Down Expand Up @@ -319,12 +321,14 @@ def __init__(
client_id: str = None,
resource_id: str = None,
object_id: str = None,
principal_id: str = None
):
self.type = camel_to_snake(ConnectionAuthType.MANAGED_IDENTITY)
self.client_id = client_id
# TODO: Check if both client_id and resource_id are required
self.resource_id = resource_id
self.object_id = object_id
self.principal_id = principal_id

def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionManagedIdentity:
return RestWorkspaceConnectionManagedIdentity(client_id=self.client_id, resource_id=self.resource_id)
Expand Down Expand Up @@ -363,6 +367,19 @@ def _from_identity_configuration_rest_object(cls, rest_obj: RestUserAssignedIden
result.__dict__.update(rest_obj.as_dict())
return result

def _to_workspace_rest_object(self) -> RestWorkspaceUserAssignedIdentity:
return RestWorkspaceUserAssignedIdentity(
principal_id=self.principal_id,
client_id=self.client_id,
)

@classmethod
def _from_workspace_rest_object(cls, obj: RestWorkspaceUserAssignedIdentity) -> "ManagedIdentityConfiguration":
return cls(
principal_id=obj.principal_id,
client_id=obj.client_id,
)

def __eq__(self, other: object) -> bool:
if not isinstance(other, ManagedIdentityConfiguration):
return NotImplemented
Expand Down Expand Up @@ -405,7 +422,7 @@ def _from_job_rest_object(cls, obj: RestAmlToken) -> "AmlTokenConfiguration":
class IdentityConfiguration(RestTranslatableMixin):
"""Managed identity specification."""

def __init__(self, *, type: str, user_assigned_identities: List[ManagedIdentityConfiguration] = None):
def __init__(self, *, type: str, user_assigned_identities: List[ManagedIdentityConfiguration] = None, **kwargs):
"""Managed identity specification.
:param type: Managed identity type, defaults to None
Expand All @@ -416,8 +433,8 @@ def __init__(self, *, type: str, user_assigned_identities: List[ManagedIdentityC

self.type = type
self.user_assigned_identities = user_assigned_identities
self.principal_id = None
self.tenant_id = None
self.principal_id = kwargs.pop("principal_id", None)
self.tenant_id = kwargs.pop("tenant_id", None)

def _to_compute_rest_object(self) -> RestIdentityConfiguration:
rest_user_assigned_identities = (
Expand Down Expand Up @@ -446,6 +463,38 @@ def _from_compute_rest_object(cls, obj: RestIdentityConfiguration) -> "IdentityC
result.tenant_id = obj.tenant_id
return result

@classmethod
def _from_workspace_rest_object(cls, obj: RestWorkspaceIdentityConfiguration) -> "IdentityConfiguration":
user_assigned_identities = None
if obj.user_assigned_identities:
user_assigned_identities = {}
for k, v in obj.user_assigned_identities.items():
metadata = None
if v and isinstance(v, RestUserAssignedIdentity):
metadata = ManagedIdentityConfiguration._from_workspace_rest_object(v) # pylint: disable=protected-access
user_assigned_identities[k] = metadata
return cls(
type=obj.type,
principal_id=obj.principal_id,
tenant_id=obj.tenant_id,
user_assigned_identities=user_assigned_identities,
)

def _to_workspace_rest_object(self) -> RestWorkspaceIdentityConfiguration:

user_assigned_identities = (
{uai.resource_id: uai._to_workspace_rest_object() for uai in self.user_assigned_identities}
if self.user_assigned_identities
else None
)

return RestWorkspaceIdentityConfiguration(
type=snake_to_pascal(self.type),
principal_id=self.principal_id,
tenant_id=self.tenant_id,
user_assigned_identities=user_assigned_identities,
)

def _to_rest_object(self) -> RestRegistryManagedIdentity:
return RestRegistryManagedIdentity(
type=self.type,
Expand All @@ -462,6 +511,7 @@ def _from_rest_object(cls, obj: RestRegistryManagedIdentity) -> "IdentityConfigu
result.principal_id = obj.principal_id
result.tenant_id = obj.tenant_id
return result

class NoneCredentialConfiguration(RestTranslatableMixin):
"""None Credential Configuration."""

Expand Down
83 changes: 0 additions & 83 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/identity.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from azure.ai.ml._schema.workspace.workspace import WorkspaceSchema
from azure.ai.ml._utils.utils import dump_yaml_to_file
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, WorkspaceResourceConstants
from azure.ai.ml.entities._credentials import IdentityConfiguration
from azure.ai.ml.entities._resource import Resource
from azure.ai.ml.entities._util import load_from_dict
from azure.ai.ml.entities._workspace.identity import ManagedServiceIdentity

from .customer_managed_key import CustomerManagedKey

Expand All @@ -38,7 +38,7 @@ def __init__(
customer_managed_key: CustomerManagedKey = None,
image_build_compute: str = None,
public_network_access: str = None,
identity: ManagedServiceIdentity = None,
identity: IdentityConfiguration = None,
primary_user_assigned_identity: str = None,
**kwargs,
):
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(
when a workspace is private link enabled.
:type public_network_access: str
:param identity: workspace's Managed Identity (user assigned, or system assigned)
:type identity: ManagedServiceIdentity
:type identity: IdentityConfiguration
:param primary_user_assigned_identity: The workspace's primary user assigned identity
:type primary_user_assigned_identity: str
:param kwargs: A dictionary of additional configuration parameters.
Expand Down Expand Up @@ -185,7 +185,7 @@ def _from_rest_object(cls, rest_obj: RestWorkspace) -> "Workspace":
group = None if len(armid_parts) < 4 else armid_parts[4]
identity = None
if rest_obj.identity and isinstance(rest_obj.identity, RestManagedServiceIdentity):
identity = ManagedServiceIdentity._from_rest_object(rest_obj.identity) # pylint: disable=protected-access
identity = IdentityConfiguration._from_workspace_rest_object(rest_obj.identity) # pylint: disable=protected-access
return Workspace(
name=rest_obj.name,
id=rest_obj.id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
get_resource_and_group_name,
get_resource_group_location,
)
from azure.ai.ml._utils.utils import camel_to_snake
from azure.ai.ml._version import VERSION
from azure.ai.ml.constants import ManagedServiceIdentityType
from azure.ai.ml.constants._common import ArmConstants, LROConfigurations, WorkspaceResourceConstants
from azure.ai.ml.entities import ManagedServiceIdentity, Workspace, WorkspaceKeys
from azure.ai.ml.entities import Workspace, WorkspaceKeys
from azure.ai.ml.entities._credentials import IdentityConfiguration
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
from azure.core.credentials import TokenCredential
from azure.core.polling import LROPoller
Expand Down Expand Up @@ -226,7 +228,7 @@ def begin_update(
"""
identity = kwargs.get("identity", workspace.identity)
if identity:
identity = identity._to_rest_object()
identity = identity._to_workspace_rest_object()
update_param = WorkspaceUpdateParameters(
tags=workspace.tags,
description=kwargs.get("description", workspace.description),
Expand Down Expand Up @@ -494,9 +496,11 @@ def _populate_arm_paramaters(self, workspace: Workspace) -> Tuple[dict, dict, di

identity = None
if workspace.identity:
identity = workspace.identity._to_rest_object()
identity = workspace.identity._to_workspace_rest_object()
else:
identity = ManagedServiceIdentity(type=ManagedServiceIdentityType.SYSTEM_ASSIGNED)._to_rest_object()
# pylint: disable=protected-access
identity = IdentityConfiguration(
type=camel_to_snake(ManagedServiceIdentityType.SYSTEM_ASSIGNED))._to_workspace_rest_object()
_set_val(param["identity"], identity)

if workspace.primary_user_assigned_identity:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from azure.ai.ml import MLClient, load_workspace
from azure.ai.ml.constants._common import PublicNetworkAccess
from azure.ai.ml.entities._workspace.identity import ManagedServiceIdentityType
from azure.ai.ml.constants._workspace import ManagedServiceIdentityType
from azure.core.paging import ItemPaged
from azure.mgmt.msi._managed_service_identity_client import ManagedServiceIdentityClient

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from unittest.mock import DEFAULT, Mock, call, patch

import pytest
from azure.ai.ml._utils.utils import camel_to_snake
from pytest_mock import MockFixture

from azure.ai.ml._scope_dependent_operations import OperationScope
from azure.ai.ml.constants import ManagedServiceIdentityType
from azure.ai.ml.entities import CustomerManagedKey, ManagedServiceIdentity, Workspace, WorkspaceUserAssignedIdentity
from azure.ai.ml.entities import CustomerManagedKey, Workspace, \
IdentityConfiguration, ManagedIdentityConfiguration
from azure.ai.ml.operations import WorkspaceOperations
from azure.core.polling import LROPoller

Expand Down Expand Up @@ -119,12 +121,12 @@ def test_update(self, mock_workspace_operation: WorkspaceOperations) -> None:
public_network_access="Enabled",
container_registry="foo_conntainer_registry",
application_insights="foo_application_insights",
identity=ManagedServiceIdentity(
type=ManagedServiceIdentityType.USER_ASSIGNED,
user_assigned_identities={
"resource1": WorkspaceUserAssignedIdentity(),
"resource2": WorkspaceUserAssignedIdentity(),
},
identity=IdentityConfiguration(
type=camel_to_snake(ManagedServiceIdentityType.USER_ASSIGNED),
user_assigned_identities=[
ManagedIdentityConfiguration(resource_id="resource1"),
ManagedIdentityConfiguration(resource_id="resource2")
],
),
primary_user_assigned_identity="resource2",
)
Expand Down

0 comments on commit 4ed7a63

Please sign in to comment.