Skip to content

Commit

Permalink
Merge pull request #7511 from hashicorp/pr-7319
Browse files Browse the repository at this point in the history
provider/aws: AWS prefix lists to enable security group egress to a VPC Endpoint (supersedes #7319)
  • Loading branch information
catsby authored Jul 7, 2016
2 parents 21e2173 + 1d488bd commit 17931c7
Show file tree
Hide file tree
Showing 10 changed files with 520 additions and 37 deletions.
163 changes: 127 additions & 36 deletions builtin/providers/aws/resource_aws_security_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,12 @@ func resourceAwsSecurityGroup() *schema.Resource {
Elem: &schema.Schema{Type: schema.TypeString},
},

"prefix_list_ids": &schema.Schema{
Type: schema.TypeList,
Optional: true,
Elem: &schema.Schema{Type: schema.TypeString},
},

"security_groups": &schema.Schema{
Type: schema.TypeSet,
Optional: true,
Expand Down Expand Up @@ -397,6 +403,18 @@ func resourceAwsSecurityGroupRuleHash(v interface{}) int {
buf.WriteString(fmt.Sprintf("%s-", v))
}
}
if v, ok := m["prefix_list_ids"]; ok {
vs := v.([]interface{})
s := make([]string, len(vs))
for i, raw := range vs {
s[i] = raw.(string)
}
sort.Strings(s)

for _, v := range s {
buf.WriteString(fmt.Sprintf("%s-", v))
}
}
if v, ok := m["security_groups"]; ok {
vs := v.(*schema.Set).List()
s := make([]string, len(vs))
Expand Down Expand Up @@ -449,6 +467,20 @@ func resourceAwsSecurityGroupIPPermGather(groupId string, permissions []*ec2.IpP
m["cidr_blocks"] = list
}

if len(perm.PrefixListIds) > 0 {
raw, ok := m["prefix_list_ids"]
if !ok {
raw = make([]string, 0, len(perm.PrefixListIds))
}
list := raw.([]string)

for _, pl := range perm.PrefixListIds {
list = append(list, *pl.PrefixListId)
}

m["prefix_list_ids"] = list
}

groups := flattenSecurityGroups(perm.UserIdGroupPairs, ownerId)
for i, g := range groups {
if *g.GroupId == groupId {
Expand Down Expand Up @@ -658,16 +690,21 @@ func matchRules(rType string, local []interface{}, remote []map[string]interface
// local rule we're examining
rHash := idHash(rType, r["protocol"].(string), r["to_port"].(int64), r["from_port"].(int64), remoteSelfVal)
if rHash == localHash {
var numExpectedCidrs, numExpectedSGs, numRemoteCidrs, numRemoteSGs int
var numExpectedCidrs, numExpectedPrefixLists, numExpectedSGs, numRemoteCidrs, numRemotePrefixLists, numRemoteSGs int
var matchingCidrs []string
var matchingSGs []string
var matchingPrefixLists []string

// grab the local/remote cidr and sg groups, capturing the expected and
// actual counts
lcRaw, ok := l["cidr_blocks"]
if ok {
numExpectedCidrs = len(l["cidr_blocks"].([]interface{}))
}
lpRaw, ok := l["prefix_list_ids"]
if ok {
numExpectedPrefixLists = len(l["prefix_list_ids"].([]interface{}))
}
lsRaw, ok := l["security_groups"]
if ok {
numExpectedSGs = len(l["security_groups"].(*schema.Set).List())
Expand All @@ -677,6 +714,10 @@ func matchRules(rType string, local []interface{}, remote []map[string]interface
if ok {
numRemoteCidrs = len(r["cidr_blocks"].([]string))
}
rpRaw, ok := r["prefix_list_ids"]
if ok {
numRemotePrefixLists = len(r["prefix_list_ids"].([]string))
}

rsRaw, ok := r["security_groups"]
if ok {
Expand All @@ -688,6 +729,10 @@ func matchRules(rType string, local []interface{}, remote []map[string]interface
log.Printf("[DEBUG] Local rule has more CIDR blocks, continuing (%d/%d)", numExpectedCidrs, numRemoteCidrs)
continue
}
if numExpectedPrefixLists > numRemotePrefixLists {
log.Printf("[DEBUG] Local rule has more prefix lists, continuing (%d/%d)", numExpectedPrefixLists, numRemotePrefixLists)
continue
}
if numExpectedSGs > numRemoteSGs {
log.Printf("[DEBUG] Local rule has more Security Groups, continuing (%d/%d)", numExpectedSGs, numRemoteSGs)
continue
Expand Down Expand Up @@ -721,6 +766,34 @@ func matchRules(rType string, local []interface{}, remote []map[string]interface
}
}

// match prefix lists by converting both to sets, and using Set methods
var localPrefixLists []interface{}
if lpRaw != nil {
localPrefixLists = lpRaw.([]interface{})
}
localPrefixListsSet := schema.NewSet(schema.HashString, localPrefixLists)

// remote prefix lists are presented as a slice of strings, so we need to
// reformat them into a slice of interfaces to be used in creating the
// remote prefix list set
var remotePrefixLists []string
if rpRaw != nil {
remotePrefixLists = rpRaw.([]string)
}
// convert remote prefix lists to a set, for easy comparison
list = nil
for _, s := range remotePrefixLists {
list = append(list, s)
}
remotePrefixListsSet := schema.NewSet(schema.HashString, list)

// Build up a list of local prefix lists that are found in the remote set
for _, s := range localPrefixListsSet.List() {
if remotePrefixListsSet.Contains(s) {
matchingPrefixLists = append(matchingPrefixLists, s.(string))
}
}

// match SGs. Both local and remote are already sets
var localSGSet *schema.Set
if lsRaw == nil {
Expand Down Expand Up @@ -748,41 +821,57 @@ func matchRules(rType string, local []interface{}, remote []map[string]interface
// match, and then remove those elements from the remote rule, so that
// this remote rule can still be considered by other local rules
if numExpectedCidrs == len(matchingCidrs) {
if numExpectedSGs == len(matchingSGs) {
// confirm that self references match
var lSelf bool
var rSelf bool
if _, ok := l["self"]; ok {
lSelf = l["self"].(bool)
}
if _, ok := r["self"]; ok {
rSelf = r["self"].(bool)
}
if rSelf == lSelf {
delete(r, "self")
// pop local cidrs from remote
diffCidr := remoteCidrSet.Difference(localCidrSet)
var newCidr []string
for _, cRaw := range diffCidr.List() {
newCidr = append(newCidr, cRaw.(string))
if numExpectedPrefixLists == len(matchingPrefixLists) {
if numExpectedSGs == len(matchingSGs) {
// confirm that self references match
var lSelf bool
var rSelf bool
if _, ok := l["self"]; ok {
lSelf = l["self"].(bool)
}

// reassigning
if len(newCidr) > 0 {
r["cidr_blocks"] = newCidr
} else {
delete(r, "cidr_blocks")
if _, ok := r["self"]; ok {
rSelf = r["self"].(bool)
}

// pop local sgs from remote
diffSGs := remoteSGSet.Difference(localSGSet)
if len(diffSGs.List()) > 0 {
r["security_groups"] = diffSGs
} else {
delete(r, "security_groups")
if rSelf == lSelf {
delete(r, "self")
// pop local cidrs from remote
diffCidr := remoteCidrSet.Difference(localCidrSet)
var newCidr []string
for _, cRaw := range diffCidr.List() {
newCidr = append(newCidr, cRaw.(string))
}

// reassigning
if len(newCidr) > 0 {
r["cidr_blocks"] = newCidr
} else {
delete(r, "cidr_blocks")
}

// pop local prefix lists from remote
diffPrefixLists := remotePrefixListsSet.Difference(localPrefixListsSet)
var newPrefixLists []string
for _, pRaw := range diffPrefixLists.List() {
newPrefixLists = append(newPrefixLists, pRaw.(string))
}

// reassigning
if len(newPrefixLists) > 0 {
r["prefix_list_ids"] = newPrefixLists
} else {
delete(r, "prefix_list_ids")
}

// pop local sgs from remote
diffSGs := remoteSGSet.Difference(localSGSet)
if len(diffSGs.List()) > 0 {
r["security_groups"] = diffSGs
} else {
delete(r, "security_groups")
}

saves = append(saves, l)
}

saves = append(saves, l)
}
}
}
Expand All @@ -795,11 +884,13 @@ func matchRules(rType string, local []interface{}, remote []map[string]interface
// matched locally, and let the graph sort things out. This will happen when
// rules are added externally to Terraform
for _, r := range remote {
var lenCidr, lenSGs int
var lenCidr, lenPrefixLists, lenSGs int
if rCidrs, ok := r["cidr_blocks"]; ok {
lenCidr = len(rCidrs.([]string))
}

if rPrefixLists, ok := r["prefix_list_ids"]; ok {
lenPrefixLists = len(rPrefixLists.([]string))
}
if rawSGs, ok := r["security_groups"]; ok {
lenSGs = len(rawSGs.(*schema.Set).List())
}
Expand All @@ -810,7 +901,7 @@ func matchRules(rType string, local []interface{}, remote []map[string]interface
}
}

if lenSGs+lenCidr > 0 {
if lenSGs+lenCidr+lenPrefixLists > 0 {
log.Printf("[DEBUG] Found a remote Rule that wasn't empty: (%#v)", r)
saves = append(saves, r)
}
Expand Down
51 changes: 51 additions & 0 deletions builtin/providers/aws/resource_aws_security_group_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ func resourceAwsSecurityGroupRule() *schema.Resource {
Elem: &schema.Schema{Type: schema.TypeString},
},

"prefix_list_ids": &schema.Schema{
Type: schema.TypeList,
Optional: true,
ForceNew: true,
Elem: &schema.Schema{Type: schema.TypeString},
},

"security_group_id": &schema.Schema{
Type: schema.TypeString,
Required: true,
Expand Down Expand Up @@ -363,6 +370,19 @@ func findRuleMatch(p *ec2.IpPermission, rules []*ec2.IpPermission, isVPC bool) *
continue
}

remaining = len(p.PrefixListIds)
for _, pl := range p.PrefixListIds {
for _, rpl := range r.PrefixListIds {
if *pl.PrefixListId == *rpl.PrefixListId {
remaining--
}
}
}

if remaining > 0 {
continue
}

remaining = len(p.UserIdGroupPairs)
for _, ip := range p.UserIdGroupPairs {
for _, rip := range r.UserIdGroupPairs {
Expand Down Expand Up @@ -413,6 +433,18 @@ func ipPermissionIDHash(sg_id, ruleType string, ip *ec2.IpPermission) string {
}
}

if len(ip.PrefixListIds) > 0 {
s := make([]string, len(ip.PrefixListIds))
for i, pl := range ip.PrefixListIds {
s[i] = *pl.PrefixListId
}
sort.Strings(s)

for _, v := range s {
buf.WriteString(fmt.Sprintf("%s-", v))
}
}

if len(ip.UserIdGroupPairs) > 0 {
sort.Sort(ByGroupPair(ip.UserIdGroupPairs))
for _, pair := range ip.UserIdGroupPairs {
Expand Down Expand Up @@ -494,6 +526,18 @@ func expandIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup) (*ec2.IpPermiss
}
}

if raw, ok := d.GetOk("prefix_list_ids"); ok {
list := raw.([]interface{})
perm.PrefixListIds = make([]*ec2.PrefixListId, len(list))
for i, v := range list {
prefixListID, ok := v.(string)
if !ok {
return nil, fmt.Errorf("empty element found in prefix_list_ids - consider using the compact function")
}
perm.PrefixListIds[i] = &ec2.PrefixListId{PrefixListId: aws.String(prefixListID)}
}
}

return &perm, nil
}

Expand All @@ -514,6 +558,13 @@ func setFromIPPerm(d *schema.ResourceData, sg *ec2.SecurityGroup, rule *ec2.IpPe
// 'self' is false by default. Below, we range over the group ids and set true
// if the parent sg id is found
d.Set("self", false)

var pl []string
for _, p := range rule.PrefixListIds {
pl = append(pl, *p.PrefixListId)
}
d.Set("prefix_list_ids", pl)

if len(rule.UserIdGroupPairs) > 0 {
s := rule.UserIdGroupPairs[0]

Expand Down
Loading

0 comments on commit 17931c7

Please sign in to comment.