Skip to content

Commit

Permalink
Implement WAL rollback mechanism for Role Assignments (#110) (#114)
Browse files Browse the repository at this point in the history
* Implement Role Assignment WAL and rollback

* Improve error handling around unassignment of non-existent role assignment ID

* Better error handling in test, and guarding against nil or empty values

* Add clarity to rollback log message, and check if there were no Azure Roles associated with Role

* Further improve error handling, fix failing test, add guard against size mismatch between number of roles and assignmentIDs, parameterize Resource Group in test

* Fix rollback test, and clean up left over debug line

* Add missing error check for spRevoke during test, use errors.New instead of Errorf for AzureRoles and assignmentIDs check

* Add warning about resources potentially still existing if WAL has expired

Co-authored-by: davidadeleon <[email protected]>
  • Loading branch information
jasonodonnell and davidadeleon authored Nov 22, 2022
1 parent 2083997 commit d8b573c
Show file tree
Hide file tree
Showing 4 changed files with 378 additions and 39 deletions.
16 changes: 9 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,17 +126,19 @@ func (c *client) deleteApp(ctx context.Context, appObjectID string) error {
}

// assignRoles assigns Azure roles to a service principal.
func (c *client) assignRoles(ctx context.Context, spID string, roles []*AzureRole) ([]string, error) {
func (c *client) assignRoles(ctx context.Context, spID string, roles []*AzureRole, assignmentIDs []string) ([]string, error) {
var ids []string

for _, role := range roles {
assignmentID, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
if len(roles) != len(assignmentIDs) {
return nil, errors.New("number of Azure Roles and assignment IDs do not match")
}

for i, role := range roles {
resultRaw, err := retry(ctx, func() (interface{}, bool, error) {
ra, err := c.provider.CreateRoleAssignment(ctx, role.Scope, assignmentID,
if assignmentIDs[i] == "" {
return nil, true, fmt.Errorf("assignmentID at index %d was empty", i)
}
ra, err := c.provider.CreateRoleAssignment(ctx, role.Scope, assignmentIDs[i],
authorization.RoleAssignmentCreateParameters{
RoleAssignmentProperties: &authorization.RoleAssignmentProperties{
RoleDefinitionID: &role.RoleID,
Expand Down
31 changes: 29 additions & 2 deletions path_service_principal.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/Azure/go-autorest/autorest/to"
uuid "github.com/hashicorp/go-uuid"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/helper/locksutil"
"github.com/hashicorp/vault/sdk/logical"
Expand Down Expand Up @@ -122,8 +123,30 @@ func (b *azureSecretBackend) createSPSecret(ctx context.Context, s logical.Stora
return nil, err
}

// Pre-generate UUIDs to be provided to assignRoles so we can rollback if we need to
var assignmentIDs []string

for i := 0; i < len(role.AzureRoles); i++ {
assignmentID, err := uuid.GenerateUUID()
if err != nil {
return nil, err
}
assignmentIDs = append(assignmentIDs, assignmentID)
}

// Write a second WAL entry in case the Role assignments don't complete
rWALID, err := framework.PutWAL(ctx, s, walAppRoleAssignment, &walAppRoleAssign{
SpID: spID,
AssignmentIDs: assignmentIDs,
AzureRoles: role.AzureRoles,
Expiration: time.Now().Add(maxWALAge),
})
if err != nil {
return nil, fmt.Errorf("error writing WAL: %w", err)
}

// Assign Azure roles to the new SP
raIDs, err := c.assignRoles(ctx, spID, role.AzureRoles)
raIDs, err := c.assignRoles(ctx, spID, role.AzureRoles, assignmentIDs)
if err != nil {
return nil, err
}
Expand All @@ -133,11 +156,15 @@ func (b *azureSecretBackend) createSPSecret(ctx context.Context, s logical.Stora
return nil, err
}

// SP is fully created so delete the WAL
// SP is fully created so delete the WALs
if err := framework.DeleteWAL(ctx, s, walID); err != nil {
return nil, fmt.Errorf("error deleting WAL: %w", err)
}

if err := framework.DeleteWAL(ctx, s, rWALID); err != nil {
return nil, fmt.Errorf("error deleting role assignment WAL: %w", err)
}

data := map[string]interface{}{
"client_id": appID,
"client_secret": password,
Expand Down
294 changes: 266 additions & 28 deletions path_service_principal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,37 +113,63 @@ func assertEmptyWAL(t *testing.T, b *azureSecretBackend, emp api.AzureProvider,
t.Fatal(err)
}

// Decode the WAL data
var app walApp
d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.StringToTimeHookFunc(time.RFC3339),
Result: &app,
})
if err != nil {
t.Fatal(err)
}
err = d.Decode(entry.Data)
if err != nil {
t.Fatal(err)
}
switch entry.Kind {
case walAppKey:
// Decode the WAL data
var app walApp
d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.StringToTimeHookFunc(time.RFC3339),
Result: &app,
})
if err != nil {
t.Fatal(err)
}
err = d.Decode(entry.Data)
if err != nil {
t.Fatal(err)
}

_, err = emp.GetApplication(context.Background(), app.AppObjID)
if err != nil {
t.Fatalf("expected to find application (%s), but wasn't found", app.AppObjID)
}
_, err = emp.GetApplication(context.Background(), app.AppObjID)
if err != nil {
t.Fatalf("expected to find application (%s), but wasn't found", app.AppObjID)
}

err = b.walRollback(ctx, req, entry.Kind, entry.Data)
if err != nil {
t.Fatal(err)
}
if err := framework.DeleteWAL(ctx, s, v); err != nil {
t.Fatal(err)
}
err = b.walRollback(ctx, req, entry.Kind, entry.Data)
if err != nil {
t.Fatal(err)
}
if err := framework.DeleteWAL(ctx, s, v); err != nil {
t.Fatal(err)
}

_, err = emp.GetApplication(context.Background(), app.AppObjID)
if err == nil {
t.Fatalf("expected error getting application")
_, err = emp.GetApplication(context.Background(), app.AppObjID)
if err == nil {
t.Fatalf("expected error getting application")
}
case walAppRoleAssignment:
// Decode the WAL data
d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.StringToTimeHookFunc(time.RFC3339),
Result: &entry,
})
if err != nil {
t.Fatal(err)
}
err = d.Decode(entry.Data)
if err != nil {
t.Fatal(err)
}

err = b.walRollback(ctx, req, entry.Kind, entry.Data)
if err != nil {
t.Fatal(err)
}

if err := framework.DeleteWAL(ctx, s, v); err != nil {
t.Fatal(err)
}
}

}
}

Expand Down Expand Up @@ -483,7 +509,219 @@ func TestCredentialReadProviderError(t *testing.T) {
}
}

// TestCredentialInteg is an integration test against the live Azure service. It requires
// TestRoleAssignmentWALRollback tests rolling back any
// role assignments that may have taken place prior to
// a subsequent failure resulting in the need to rollback
// an App or SP. This test requires valid, sufficiently-privileged
// Azure credentials in env variables.
func TestRoleAssignmentWALRollback(t *testing.T) {
if os.Getenv("VAULT_ACC") != "1" {
t.SkipNow()
}

if os.Getenv("AZURE_CLIENT_SECRET") == "" {
t.Skip("Azure Secrets: Azure environment variables not set. Skipping.")
}

t.Run("service principals", func(t *testing.T) {
t.Parallel()

skipIfMissingEnvVars(t,
"AZURE_SUBSCRIPTION_ID",
"AZURE_CLIENT_ID",
"AZURE_CLIENT_SECRET",
"AZURE_TENANT_ID",
"AZURE_TEST_RESOURCE_GROUP",
)

b := backend()
s := new(logical.InmemStorage)
subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID")
clientID := os.Getenv("AZURE_CLIENT_ID")
clientSecret := os.Getenv("AZURE_CLIENT_SECRET")
tenantID := os.Getenv("AZURE_TENANT_ID")
resourceGroup := os.Getenv("AZURE_TEST_RESOURCE_GROUP")

config := &logical.BackendConfig{
Logger: logging.NewVaultLogger(log.Trace),
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLHr,
MaxLeaseTTLVal: maxLeaseTTLHr,
},
StorageView: s,
}
err := b.Setup(context.Background(), config)
assertErrorIsNil(t, err)

configData := map[string]interface{}{
"subscription_id": subscriptionID,
"client_id": clientID,
"client_secret": clientSecret,
"tenant_id": tenantID,
}

configResp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: "config",
Data: configData,
Storage: s,
})
assertRespNoError(t, configResp, err)

roleName := "test_role_rawalrollback"

roleData := map[string]interface{}{
"azure_roles": fmt.Sprintf(`[
{
"role_name": "Storage Blob Data Owner",
"scope": "/subscriptions/%s/resourceGroups/%s"
},
{
"role_name": "Reader",
"scope": "/subscriptions/%s/resourceGroups/%s"
}]`, subscriptionID, resourceGroup, subscriptionID, resourceGroup),
}

roleResp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.CreateOperation,
Path: fmt.Sprintf("roles/%s", roleName),
Data: roleData,
Storage: s,
})
assertRespNoError(t, roleResp, err)

credsResp, err := b.HandleRequest(context.Background(), &logical.Request{
Operation: logical.ReadOperation,
Path: fmt.Sprintf("creds/%s", roleName),
Storage: s,
})
assertRespNoError(t, credsResp, err)

appID := credsResp.Data["client_id"].(string)

// Use the underlying provider to access clients directly for testing
client, err := b.getClient(context.Background(), s)
assertErrorIsNil(t, err)
provider := client.provider.(*provider)
spObjID := findServicePrincipalID(t, provider.spClient, appID)

assertServicePrincipalExists(t, provider.spClient, spObjID)

// Verify that the role assignments were created. Get the assignment
// info from Azure and verify it matches the Reader role.
raIDs := credsResp.Secret.InternalData["role_assignment_ids"].([]string)
equal(t, 2, len(raIDs))

ra, err := provider.raClient.GetByID(context.Background(), raIDs[0])
assertErrorIsNil(t, err)

roleDefs, err := provider.ListRoleDefinitions(context.Background(), fmt.Sprintf("subscriptions/%s", subscriptionID), "")
assertErrorIsNil(t, err)

defID := *ra.RoleAssignmentPropertiesWithScope.RoleDefinitionID
found := false
for _, def := range roleDefs {
if *def.ID == defID && *def.RoleName == "Storage Blob Data Owner" {
found = true
break
}
}

if !found {
t.Fatal("'Storage Blob Data Owner' role assignment not found")
}

// Parse the assignment IDs
var assignmentIDs []string
for _, raID := range raIDs {
t := strings.Split(raID, "/")
tRa := t[len(t)-1]
assignmentIDs = append(assignmentIDs, strings.Replace(tRa, " ", "", -1))
}

// Remove one of the RA IDs to simulate a failure to assign a role
if err := client.unassignRoles(context.Background(), []string{raIDs[0]}); err != nil {
t.Fatalf("error unassigning Role: %s", err.Error())
}

rEntry, err := s.Get(context.Background(), fmt.Sprintf("%s/%s", "roles", roleName))
if err != nil {
t.Fatalf("error getting role from storage: %s", err.Error())
}

if rEntry == nil {
t.Fatalf("role entry was nil: %s", err.Error())
}

// Decode returned Role Entry
role := new(roleEntry)
if err := rEntry.DecodeJSON(role); err != nil {
t.Fatalf("unable to decode role entry: %s", err.Error())
}

// Manually Create Role Assignment WAL
rWALID, err := framework.PutWAL(context.Background(), s, walAppRoleAssignment, &walAppRoleAssign{
SpID: spObjID,
AssignmentIDs: assignmentIDs,
AzureRoles: role.AzureRoles,
Expiration: time.Now().Add(maxWALAge),
})
if err != nil {
t.Fatalf("error creating role assignment WAL: %s", err.Error())
}

// Retrieve WAL
entry, err := framework.GetWAL(context.Background(), s, rWALID)
if err != nil {
t.Fatalf("error retrieving role assignment WAL: %s", err.Error())
}

// Decode the WAL data
var appRoleAssign walAppRoleAssign
d, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.StringToTimeHookFunc(time.RFC3339),
Result: &appRoleAssign,
})
if err != nil {
t.Fatalf("error decoding WAL data: %s", err.Error())
}
err = d.Decode(entry.Data)
if err != nil {
t.Fatalf("error decoding WAL data: %s", err.Error())
}

req := &logical.Request{
Storage: s,
}

// Initiate Role Assignment Rollback
err = b.walRollback(context.Background(), req, entry.Kind, entry.Data)
if err != nil {
t.Fatalf("error rolling back WAL: %s", err.Error())
}

// Serialize and deserialize the secret to remove typing, as will really happen.
fakeSaveLoad(credsResp.Secret)

// Revoke the Service Principal by sending back the secret we just received
req = &logical.Request{
Secret: credsResp.Secret,
Storage: s,
}

_, err = b.spRevoke(context.Background(), req, nil)
if err != nil {
t.Fatalf("error revoking service principal: %s", err.Error())
}

// Verify that SP get is an error after delete. Expected there
// to be a delay and that this step would take some time/retries,
// but that seems not to be the case.
assertServicePrincipalDoesNotExist(t, provider.spClient, spObjID)
})
}

// This is an integration test against the live Azure service. It requires
// valid, sufficiently-privileged Azure credentials in env variables.
func TestCredentialInteg_aad(t *testing.T) {
if os.Getenv("VAULT_ACC") != "1" {
Expand Down
Loading

0 comments on commit d8b573c

Please sign in to comment.