Skip to content

Commit

Permalink
chore: clean and improve code
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangda committed Nov 3, 2024
1 parent 527a1ef commit 69c5a49
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 72 deletions.
5 changes: 4 additions & 1 deletion internal/scim/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
// patchGroupOperations assembles the operations for patch groups
// bases in the limits of operations we can execute in a single request.
func patchGroupOperations(op, path string, pvs []patchValue, gms *model.GroupMembers) []*aws.PatchGroupRequest {
patchOperations := []*aws.PatchGroupRequest{}
patchOperations := make([]*aws.PatchGroupRequest, 0)

if len(pvs) > MaxPatchGroupMembersPerRequest {
for i := 0; i < len(pvs); i += MaxPatchGroupMembersPerRequest {
Expand All @@ -33,9 +33,11 @@ func patchGroupOperations(op, path string, pvs []patchValue, gms *model.GroupMem
},
},
}

patchOperations = append(patchOperations, patchGroupRequest)
}
} else {

patchGroupRequest := &aws.PatchGroupRequest{
Group: aws.Group{
ID: gms.Group.SCIMID,
Expand All @@ -52,6 +54,7 @@ func patchGroupOperations(op, path string, pvs []patchValue, gms *model.GroupMem
},
},
}

patchOperations = append(patchOperations, patchGroupRequest)
}

Expand Down
11 changes: 11 additions & 0 deletions internal/scim/operations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,14 @@ func Test_patchGroupOperations(t *testing.T) {
})
}
}

func Benchmark_patchGroupOperations(b *testing.B) {
for i := 0; i < b.N; i++ {
patchGroupOperations("add", "members", patchValueGenerator(1, 350), &model.GroupMembers{
Group: &model.Group{
SCIMID: "016722b2be-ee23ed58-6e4e-4b2f-a94a-3ace8456a36e",
Name: "group 1",
},
})
}
}
78 changes: 38 additions & 40 deletions internal/scim/scim.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ type AWSSCIMProvider interface {
// ListUsers lists users in SCIM Provider
ListUsers(ctx context.Context, filter string) (*aws.ListUsersResponse, error)

// CreateUser creates a user in SCIM Provider
CreateUser(ctx context.Context, u *aws.CreateUserRequest) (*aws.CreateUserResponse, error)

// CreateOrGetUser creates a user in SCIM Provider
CreateOrGetUser(ctx context.Context, u *aws.CreateUserRequest) (*aws.CreateUserResponse, error)

Expand All @@ -39,9 +36,6 @@ type AWSSCIMProvider interface {
// ListGroups lists groups in SCIM Provider
ListGroups(ctx context.Context, filter string) (*aws.ListGroupsResponse, error)

// CreateGroup creates a group in SCIM Provider
CreateGroup(ctx context.Context, g *aws.CreateGroupRequest) (*aws.CreateGroupResponse, error)

// CreateOrGetGroup creates a group in SCIM Provider
CreateOrGetGroup(ctx context.Context, g *aws.CreateGroupRequest) (*aws.CreateGroupResponse, error)

Expand Down Expand Up @@ -81,24 +75,28 @@ func (s *Provider) GetGroups(ctx context.Context) (*model.GroupsResult, error) {

groups := make([]*model.Group, len(groupsResponse.Resources))
for i, group := range groupsResponse.Resources {
e := model.GroupBuilder().
g := model.GroupBuilder().
WithSCIMID(group.ID).
WithName(group.DisplayName).
WithIPID(group.ExternalID).
Build()

groups[i] = e
groups[i] = g

}

groupsResult := model.GroupsResultBuilder().WithResources(groups).Build()

slog.Debug("scim: GetGroups()", "groups", len(groups))

return groupsResult, nil
}

// CreateGroups creates groups in SCIM Provider
func (s *Provider) CreateGroups(ctx context.Context, gr *model.GroupsResult) (*model.GroupsResult, error) {
if gr == nil {
return nil, fmt.Errorf("scim: error creating groups, groups result is nil")
}

groups := make([]*model.Group, len(gr.Resources))

for i, group := range gr.Resources {
Expand All @@ -114,18 +112,17 @@ func (s *Provider) CreateGroups(ctx context.Context, gr *model.GroupsResult) (*m
return nil, fmt.Errorf("scim: error creating group: %w", err)
}

e := model.GroupBuilder().
g := model.GroupBuilder().
WithSCIMID(r.ID).
WithName(group.Name).
WithIPID(group.IPID).
WithEmail(group.Email).
Build()

groups[i] = e
groups[i] = g
}

groupsResult := model.GroupsResultBuilder().WithResources(groups).Build()

slog.Debug("scim: CreateGroups()", "groups", len(groups))

return groupsResult, nil
Expand Down Expand Up @@ -162,14 +159,14 @@ func (s *Provider) UpdateGroups(ctx context.Context, gr *model.GroupsResult) (*m
}

// return the same group
e := model.GroupBuilder().
g := model.GroupBuilder().
WithSCIMID(group.SCIMID).
WithName(group.Name).
WithIPID(group.IPID).
WithEmail(group.Email).
Build()

groups[i] = e
groups[i] = g
}

groupsResult := model.GroupsResultBuilder().WithResources(groups).Build()
Expand Down Expand Up @@ -200,8 +197,8 @@ func (s *Provider) GetUsers(ctx context.Context) (*model.UsersResult, error) {

users := make([]*model.User, len(usersResponse.Resources))
for i, user := range usersResponse.Resources {
e := buildUser(user)
users[i] = e
u := buildUser(user)
users[i] = u
}

usersResult := model.UsersResultBuilder().WithResources(users).Build()
Expand Down Expand Up @@ -286,13 +283,13 @@ type patchValue struct {

// CreateGroupsMembers creates groups members in SCIM Provider given a list of groups members
func (s *Provider) CreateGroupsMembers(ctx context.Context, gmr *model.GroupsMembersResult) (*model.GroupsMembersResult, error) {
groupsMembers := make([]*model.GroupMembers, 0)
groupsMembers := make([]*model.GroupMembers, len(gmr.Resources))

for _, groupMembers := range gmr.Resources {
members := make([]*model.Member, 0)
membersIDValue := []patchValue{}
for i, groupMembers := range gmr.Resources {
members := make([]*model.Member, len(groupMembers.Resources))
membersIDValue := make([]patchValue, len(groupMembers.Resources))

for _, member := range groupMembers.Resources {
for j, member := range groupMembers.Resources {
if member.SCIMID == "" {
u, err := s.scim.GetUserByUserName(ctx, member.Email)
if err != nil {
Expand All @@ -301,28 +298,27 @@ func (s *Provider) CreateGroupsMembers(ctx context.Context, gmr *model.GroupsMem
member.SCIMID = u.ID
}

membersIDValue = append(membersIDValue, patchValue{
membersIDValue[j] = patchValue{
Value: member.SCIMID,
})
}

e := model.MemberBuilder().
m := model.MemberBuilder().
WithIPID(member.IPID).
WithSCIMID(member.SCIMID).
WithEmail(member.Email).
WithStatus(member.Status).
Build()

slog.Warn("adding member to group", "group", groupMembers.Group.Name, "email", member.Email)
members = append(members, e)

members[j] = m
}

e := model.GroupMembersBuilder().
gm := model.GroupMembersBuilder().
WithGroup(groupMembers.Group).
WithResources(members).
Build()

groupsMembers = append(groupsMembers, e)
groupsMembers[i] = gm

patchOperations := patchGroupOperations("add", "members", membersIDValue, groupMembers)

Expand Down Expand Up @@ -397,9 +393,9 @@ func (s *Provider) GetGroupsMembers(ctx context.Context, gr *model.GroupsResult)
}

for _, gr := range lgr.Resources {
members := make([]*model.Member, 0)
members := make([]*model.Member, len(gr.Members))

for _, member := range gr.Members {
for j, member := range gr.Members {
u, err := s.scim.GetUser(ctx, member.Value)
if err != nil {
return nil, fmt.Errorf("scim: error getting user: %s, error %w", member.Value, err)
Expand All @@ -410,15 +406,15 @@ func (s *Provider) GetGroupsMembers(ctx context.Context, gr *model.GroupsResult)
WithEmail(u.Emails[0].Value).
Build()

members = append(members, m)
members[j] = m
}

e := model.GroupMembersBuilder().
gms := model.GroupMembersBuilder().
WithGroup(group).
WithResources(members).
Build()

groupMembers = append(groupMembers, e)
groupMembers = append(groupMembers, gms)
}
}

Expand All @@ -431,22 +427,23 @@ func (s *Provider) GetGroupsMembers(ctx context.Context, gr *model.GroupsResult)
// GetGroupsMembersBruteForce returns a list of groups and their members from the SCIM Provider
// NOTE: this is an bad alternative to the method GetGroupsMembers, because read the note in the method.
func (s *Provider) GetGroupsMembersBruteForce(ctx context.Context, gr *model.GroupsResult, ur *model.UsersResult) (*model.GroupsMembersResult, error) {
groupMembers := make([]*model.GroupMembers, 0)
groupMembers := make([]*model.GroupMembers, len(gr.Resources))

// brute force implemented here thanks to the fxxckin' aws sso scim api
for _, group := range gr.Resources {
for i, group := range gr.Resources {
members := make([]*model.Member, 0)

for _, user := range ur.Resources {

// https://docs.aws.amazon.com/singlesignon/latest/developerguide/listgroups.html
f := fmt.Sprintf("id eq %q and members eq %q", group.SCIMID, user.SCIMID)
lgr, err := s.scim.ListGroups(ctx, f)
filter := fmt.Sprintf("id eq %q and members eq %q", group.SCIMID, user.SCIMID)
lgr, err := s.scim.ListGroups(ctx, filter)
if err != nil {
return nil, fmt.Errorf("scim: error listing groups: %w", err)
}

if lgr.TotalResults > 0 { // crazy thing of the AWS SSO SCIM API, it doesn't return the member into the Resources array
// AWS SSO SCIM API, it doesn't return the member into the Resources array
if lgr.TotalResults > 0 {
m := model.MemberBuilder().
WithIPID(user.IPID).
WithSCIMID(user.SCIMID).
Expand All @@ -460,12 +457,13 @@ func (s *Provider) GetGroupsMembersBruteForce(ctx context.Context, gr *model.Gro
members = append(members, m)
}
}
e := model.GroupMembersBuilder().

gms := model.GroupMembersBuilder().
WithGroup(group).
WithResources(members).
Build()

groupMembers = append(groupMembers, e)
groupMembers[i] = gms
}

slog.Debug("scim: GetGroupsMembersBruteForce()", "groups_members", len(groupMembers))
Expand Down
Loading

0 comments on commit 69c5a49

Please sign in to comment.