Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workspace using unified identity classes #26588

Merged
merged 12 commits into from
Oct 3, 2022
1 change: 1 addition & 0 deletions sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- Workspace.list_keys renamed to Workspace.get_keys.
- Datastore credentials are switched to use unified credential configuration classes.
- UserAssignedIdentity is replaced by ManagedIdentityConfiguration
- Workspace ManagedServiceIdentity has been replaced by IdentityConfiguration.

### Bugs Fixed
- Fix identity passthrough job with single file code
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from azure.ai.ml._schema.core.schema_meta import PatchedSchemaMeta
from azure.ai.ml._utils.utils import camel_to_snake, snake_to_camel
from azure.ai.ml.constants._workspace import ManagedServiceIdentityType
from azure.ai.ml.entities._workspace.identity import ManagedServiceIdentity, UserAssignedIdentity
from azure.ai.ml.entities._credentials import IdentityConfiguration, ManagedIdentityConfiguration


class UserAssignedIdentitySchema(metaclass=PatchedSchemaMeta):
Expand All @@ -20,7 +20,7 @@ class UserAssignedIdentitySchema(metaclass=PatchedSchemaMeta):

@post_load
def make(self, data, **kwargs):
return UserAssignedIdentity(**data)
return ManagedIdentityConfiguration(**data)


class IdentitySchema(metaclass=PatchedSchemaMeta):
Expand All @@ -43,4 +43,4 @@ class IdentitySchema(metaclass=PatchedSchemaMeta):
@post_load
def make(self, data, **kwargs):
data["type"] = snake_to_camel(data.pop("type"))
return ManagedServiceIdentity(**data)
return IdentityConfiguration(**data)
4 changes: 0 additions & 4 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@
from ._validation import ValidationResult
from ._workspace.connections.workspace_connection import WorkspaceConnection
from ._workspace.customer_managed_key import CustomerManagedKey
from ._workspace.identity import ManagedServiceIdentity
from ._workspace.identity import UserAssignedIdentity as WorkspaceUserAssignedIdentity
from ._workspace.private_endpoint import EndpointConnection, PrivateEndpoint
from ._workspace.workspace import Workspace
from ._workspace.workspace_keys import WorkspaceKeys
Expand Down Expand Up @@ -145,8 +143,6 @@
"Workspace",
"WorkspaceKeys",
"WorkspaceConnection",
"ManagedServiceIdentity",
"WorkspaceUserAssignedIdentity", # pylint: disable=naming-mismatch
"PrivateEndpoint",
"EndpointConnection",
"CustomerManagedKey",
Expand Down
56 changes: 53 additions & 3 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
Identity as RestIdentityConfiguration
)

from azure.ai.ml._restclient.v2022_05_01.models import ManagedServiceIdentity as RestWorkspaceIdentityConfiguration
from azure.ai.ml._restclient.v2022_05_01.models import UserAssignedIdentity as RestWorkspaceUserAssignedIdentity
from azure.ai.ml._restclient.v2022_10_01_preview.models import (
ManagedServiceIdentity as RestRegistryManagedIdentity
)
Expand Down Expand Up @@ -319,12 +321,14 @@ def __init__(
client_id: str = None,
resource_id: str = None,
object_id: str = None,
principal_id: str = None
):
self.type = camel_to_snake(ConnectionAuthType.MANAGED_IDENTITY)
self.client_id = client_id
# TODO: Check if both client_id and resource_id are required
self.resource_id = resource_id
self.object_id = object_id
self.principal_id = principal_id

def _to_workspace_connection_rest_object(self) -> RestWorkspaceConnectionManagedIdentity:
return RestWorkspaceConnectionManagedIdentity(client_id=self.client_id, resource_id=self.resource_id)
Expand Down Expand Up @@ -363,6 +367,19 @@ def _from_identity_configuration_rest_object(cls, rest_obj: RestUserAssignedIden
result.__dict__.update(rest_obj.as_dict())
return result

def _to_workspace_rest_object(self) -> RestWorkspaceUserAssignedIdentity:
return RestWorkspaceUserAssignedIdentity(
principal_id=self.principal_id,
client_id=self.client_id,
)

@classmethod
def _from_workspace_rest_object(cls, obj: RestWorkspaceUserAssignedIdentity) -> "ManagedIdentityConfiguration":
return cls(
principal_id=obj.principal_id,
client_id=obj.client_id,
)

def __eq__(self, other: object) -> bool:
if not isinstance(other, ManagedIdentityConfiguration):
return NotImplemented
Expand Down Expand Up @@ -405,7 +422,7 @@ def _from_job_rest_object(cls, obj: RestAmlToken) -> "AmlTokenConfiguration":
class IdentityConfiguration(RestTranslatableMixin):
"""Managed identity specification."""

def __init__(self, *, type: str, user_assigned_identities: List[ManagedIdentityConfiguration] = None):
def __init__(self, *, type: str, user_assigned_identities: List[ManagedIdentityConfiguration] = None, **kwargs):
"""Managed identity specification.

:param type: Managed identity type, defaults to None
Expand All @@ -416,8 +433,8 @@ def __init__(self, *, type: str, user_assigned_identities: List[ManagedIdentityC

self.type = type
self.user_assigned_identities = user_assigned_identities
self.principal_id = None
self.tenant_id = None
self.principal_id = kwargs.pop("principal_id", None)
self.tenant_id = kwargs.pop("tenant_id", None)

def _to_compute_rest_object(self) -> RestIdentityConfiguration:
rest_user_assigned_identities = (
Expand Down Expand Up @@ -446,6 +463,38 @@ def _from_compute_rest_object(cls, obj: RestIdentityConfiguration) -> "IdentityC
result.tenant_id = obj.tenant_id
return result

@classmethod
def _from_workspace_rest_object(cls, obj: RestWorkspaceIdentityConfiguration) -> "IdentityConfiguration":
user_assigned_identities = None
if obj.user_assigned_identities:
user_assigned_identities = {}
for k, v in obj.user_assigned_identities.items():
metadata = None
if v and isinstance(v, RestUserAssignedIdentity):
metadata = ManagedIdentityConfiguration._from_workspace_rest_object(v) # pylint: disable=protected-access
user_assigned_identities[k] = metadata
return cls(
type=obj.type,
principal_id=obj.principal_id,
tenant_id=obj.tenant_id,
user_assigned_identities=user_assigned_identities,
)

def _to_workspace_rest_object(self) -> RestWorkspaceIdentityConfiguration:

user_assigned_identities = (
{uai.resource_id: uai._to_workspace_rest_object() for uai in self.user_assigned_identities}
if self.user_assigned_identities
else None
)

return RestWorkspaceIdentityConfiguration(
type=snake_to_pascal(self.type),
principal_id=self.principal_id,
tenant_id=self.tenant_id,
user_assigned_identities=user_assigned_identities,
)

def _to_rest_object(self) -> RestRegistryManagedIdentity:
return RestRegistryManagedIdentity(
type=self.type,
Expand All @@ -462,6 +511,7 @@ def _from_rest_object(cls, obj: RestRegistryManagedIdentity) -> "IdentityConfigu
result.principal_id = obj.principal_id
result.tenant_id = obj.tenant_id
return result

class NoneCredentialConfiguration(RestTranslatableMixin):
"""None Credential Configuration."""

Expand Down
83 changes: 0 additions & 83 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_workspace/identity.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
from azure.ai.ml._schema.workspace.workspace import WorkspaceSchema
from azure.ai.ml._utils.utils import dump_yaml_to_file
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, PARAMS_OVERRIDE_KEY, WorkspaceResourceConstants
from azure.ai.ml.entities._credentials import IdentityConfiguration
from azure.ai.ml.entities._resource import Resource
from azure.ai.ml.entities._util import load_from_dict
from azure.ai.ml.entities._workspace.identity import ManagedServiceIdentity

from .customer_managed_key import CustomerManagedKey

Expand All @@ -38,7 +38,7 @@ def __init__(
customer_managed_key: CustomerManagedKey = None,
image_build_compute: str = None,
public_network_access: str = None,
identity: ManagedServiceIdentity = None,
identity: IdentityConfiguration = None,
primary_user_assigned_identity: str = None,
**kwargs,
):
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(
when a workspace is private link enabled.
:type public_network_access: str
:param identity: workspace's Managed Identity (user assigned, or system assigned)
:type identity: ManagedServiceIdentity
:type identity: IdentityConfiguration
:param primary_user_assigned_identity: The workspace's primary user assigned identity
:type primary_user_assigned_identity: str
:param kwargs: A dictionary of additional configuration parameters.
Expand Down Expand Up @@ -185,7 +185,7 @@ def _from_rest_object(cls, rest_obj: RestWorkspace) -> "Workspace":
group = None if len(armid_parts) < 4 else armid_parts[4]
identity = None
if rest_obj.identity and isinstance(rest_obj.identity, RestManagedServiceIdentity):
identity = ManagedServiceIdentity._from_rest_object(rest_obj.identity) # pylint: disable=protected-access
identity = IdentityConfiguration._from_workspace_rest_object(rest_obj.identity) # pylint: disable=protected-access
return Workspace(
name=rest_obj.name,
id=rest_obj.id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
get_resource_and_group_name,
get_resource_group_location,
)
from azure.ai.ml._utils.utils import camel_to_snake
from azure.ai.ml._version import VERSION
from azure.ai.ml.constants import ManagedServiceIdentityType
from azure.ai.ml.constants._common import ArmConstants, LROConfigurations, WorkspaceResourceConstants
from azure.ai.ml.entities import ManagedServiceIdentity, Workspace, WorkspaceKeys
from azure.ai.ml.entities import Workspace, WorkspaceKeys
from azure.ai.ml.entities._credentials import IdentityConfiguration
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationException
from azure.core.credentials import TokenCredential
from azure.core.polling import LROPoller
Expand Down Expand Up @@ -226,7 +228,7 @@ def begin_update(
"""
identity = kwargs.get("identity", workspace.identity)
if identity:
identity = identity._to_rest_object()
identity = identity._to_workspace_rest_object()
update_param = WorkspaceUpdateParameters(
tags=workspace.tags,
description=kwargs.get("description", workspace.description),
Expand Down Expand Up @@ -494,9 +496,11 @@ def _populate_arm_paramaters(self, workspace: Workspace) -> Tuple[dict, dict, di

identity = None
if workspace.identity:
identity = workspace.identity._to_rest_object()
identity = workspace.identity._to_workspace_rest_object()
else:
identity = ManagedServiceIdentity(type=ManagedServiceIdentityType.SYSTEM_ASSIGNED)._to_rest_object()
# pylint: disable=protected-access
identity = IdentityConfiguration(
type=camel_to_snake(ManagedServiceIdentityType.SYSTEM_ASSIGNED))._to_workspace_rest_object()
_set_val(param["identity"], identity)

if workspace.primary_user_assigned_identity:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from azure.ai.ml import MLClient, load_workspace
from azure.ai.ml.constants._common import PublicNetworkAccess
from azure.ai.ml.entities._workspace.identity import ManagedServiceIdentityType
from azure.ai.ml.constants._workspace import ManagedServiceIdentityType
from azure.core.paging import ItemPaged
from azure.mgmt.msi._managed_service_identity_client import ManagedServiceIdentityClient

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
from unittest.mock import DEFAULT, Mock, call, patch

import pytest
from azure.ai.ml._utils.utils import camel_to_snake
from pytest_mock import MockFixture

from azure.ai.ml._scope_dependent_operations import OperationScope
from azure.ai.ml.constants import ManagedServiceIdentityType
from azure.ai.ml.entities import CustomerManagedKey, ManagedServiceIdentity, Workspace, WorkspaceUserAssignedIdentity
from azure.ai.ml.entities import CustomerManagedKey, Workspace, \
IdentityConfiguration, ManagedIdentityConfiguration
from azure.ai.ml.operations import WorkspaceOperations
from azure.core.polling import LROPoller

Expand Down Expand Up @@ -119,12 +121,12 @@ def test_update(self, mock_workspace_operation: WorkspaceOperations) -> None:
public_network_access="Enabled",
container_registry="foo_conntainer_registry",
application_insights="foo_application_insights",
identity=ManagedServiceIdentity(
type=ManagedServiceIdentityType.USER_ASSIGNED,
user_assigned_identities={
"resource1": WorkspaceUserAssignedIdentity(),
"resource2": WorkspaceUserAssignedIdentity(),
},
identity=IdentityConfiguration(
type=camel_to_snake(ManagedServiceIdentityType.USER_ASSIGNED),
user_assigned_identities=[
ManagedIdentityConfiguration(resource_id="resource1"),
ManagedIdentityConfiguration(resource_id="resource2")
],
),
primary_user_assigned_identity="resource2",
)
Expand Down