Skip to content

Commit

Permalink
Better
Browse files Browse the repository at this point in the history
  • Loading branch information
majst01 committed Feb 16, 2024
1 parent 287f439 commit b13c9a8
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 185 deletions.
173 changes: 66 additions & 107 deletions cmd/firewall.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ package cmd
import (
"encoding/base64"
"fmt"
"net/netip"
"strconv"
"strings"
"os"
"time"

"github.com/metal-stack/metal-go/api/client/firewall"
Expand All @@ -16,6 +14,7 @@ import (
"github.com/metal-stack/metalctl/cmd/sorters"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"gopkg.in/yaml.v3"
)

type firewallCmd struct {
Expand Down Expand Up @@ -47,10 +46,53 @@ func newFirewallCmd(c *config) *cobra.Command {
CreateCmdMutateFn: func(cmd *cobra.Command) {
c.addMachineCreateFlags(cmd, "firewall")
cmd.Aliases = []string{"allocate"}
cmd.Flags().StringSlice("egress", nil, "egress firewall rule to deploy on creation format: tcp|udp@cidr#cidr@port#port@comment")
cmd.Flags().StringSlice("ingress", nil, "ingress firewall rule to deploy on creation format: tcp|udp@cidr#cidr@port#port@comment")
must(cmd.RegisterFlagCompletionFunc("egress", c.comp.FirewallEgressCompletion))
must(cmd.RegisterFlagCompletionFunc("ingress", c.comp.FirewallIngressCompletion))
cmd.Flags().String("firewall-rules-file", "", `firewall rules specified in a yaml file
Example:
$ metalctl firewall create ..mandatory args.. --firewall-rules-file rules.yaml
rules.yaml
---
egress:
- comment: allow outgoing https
ports:
- 443
protocol: TCP
to:
- 0.0.0.0/0
- comment: allow outgoing dns via tcp
ports:
- 53
protocol: TCP
to:
- 0.0.0.0/0
- comment: allow outgoing dns and ntp via udp
ports:
- 53
- 123
protocol: UDP
to:
- 0.0.0.0/0
ingress:
- comment: allow incoming ssh only to one ip
ports:
- 22
protocol: TCP
from:
- 0.0.0.0/0
- 1.2.3.4/32
to:
- 212.34.83.19/32
- comment: allow incoming https to all targets
ports:
- 80
- 433
protocol: TCP
from:
- 0.0.0.0/0
`)
},
ListCmdMutateFn: func(cmd *cobra.Command) {
cmd.Flags().String("id", "", "ID to filter [optional]")
Expand Down Expand Up @@ -134,10 +176,10 @@ func (c firewallCmd) Update(rq any) (*models.V1FirewallResponse, error) {
}

func (c firewallCmd) Convert(r *models.V1FirewallResponse) (string, *models.V1FirewallCreateRequest, any, error) {
// if r.ID == nil {
// return "", nil, nil, fmt.Errorf("id is nil")
// }
return "", firewallResponseToCreate(r), nil, nil
if r.ID == nil {
return "", nil, nil, fmt.Errorf("id is nil")
}
return *r.ID, firewallResponseToCreate(r), nil, nil
}

func firewallResponseToCreate(r *models.V1FirewallResponse) *models.V1FirewallCreateRequest {
Expand Down Expand Up @@ -179,23 +221,11 @@ func (c *firewallCmd) createRequestFromCLI() (*models.V1FirewallCreateRequest, e
return nil, fmt.Errorf("firewall create error:%w", err)
}

egress, err := parseEgressFlags(viper.GetStringSlice("egress"))
if err != nil {
return nil, fmt.Errorf("firewall create error:%w", err)
}
ingress, err := parseIngressFlags(viper.GetStringSlice("ingress"))
firewallRules, err := parseFirewallRulesFile()
if err != nil {
return nil, fmt.Errorf("firewall create error:%w", err)
}

var firewallRules *models.V1FirewallRules
if len(egress) > 0 || len(ingress) > 0 {
firewallRules = &models.V1FirewallRules{
Egress: egress,
Ingress: ingress,
}
}

return &models.V1FirewallCreateRequest{
Description: mcr.Description,
Filesystemlayoutid: mcr.Filesystemlayoutid,
Expand All @@ -214,95 +244,24 @@ func (c *firewallCmd) createRequestFromCLI() (*models.V1FirewallCreateRequest, e
FirewallRules: firewallRules,
}, nil
}

// parseEgressFlags input must be in the form of
// proto@cidr#cidr@port#port#port@comment
// [email protected]/24#2.3.4.1/32@80#443#8080#8443@"Allow apt update"
func parseEgressFlags(inputs []string) ([]*models.V1FirewallEgressRule, error) {
var rules []*models.V1FirewallEgressRule

for _, input := range inputs {
r, err := parseRuleSpec(input)
if err != nil {
return nil, err
}

rule := &models.V1FirewallEgressRule{
Protocol: r.protocol,
To: r.cidrs,
Ports: r.ports,
Comment: r.comment,
}
rules = append(rules, rule)
func parseFirewallRulesFile() (*models.V1FirewallRules, error) {
if !viper.IsSet("firewall-rules-file") {
return nil, nil
}

return rules, nil
}

func parseIngressFlags(inputs []string) ([]*models.V1FirewallIngressRule, error) {
var rules []*models.V1FirewallIngressRule

for _, input := range inputs {
r, err := parseRuleSpec(input)
if err != nil {
return nil, err
}

rule := &models.V1FirewallIngressRule{
Protocol: r.protocol,
From: r.cidrs,
Ports: r.ports,
Comment: r.comment,
}
rules = append(rules, rule)
}

return rules, nil
}

type rule struct {
comment string
protocol string
cidrs []string
ports []int32
}

func parseRuleSpec(spec string) (*rule, error) {
parts := strings.Split(spec, "@")
if len(parts) < 3 {
return nil, fmt.Errorf("at least proto, cidrs and ports must be given, spec:%q parts:%q", spec, parts)
}
if len(parts) > 4 {
return nil, fmt.Errorf("malformed rule spec:%q", spec)
firewallRulesFile := viper.GetString("firewall-rules-file")
if firewallRulesFile == "" {
return nil, nil
}

r := &rule{}
comment := ""
if len(parts) == 4 {
comment = parts[3]
}
r.comment = comment
r.protocol = parts[0]

cidrs := strings.Split(parts[1], "#")
ports := strings.Split(parts[2], "#")

for _, cidr := range cidrs {
p, err := netip.ParsePrefix(cidr)
if err != nil {
return nil, err
}
r.cidrs = append(r.cidrs, p.String())
firewallRules, err := os.ReadFile(firewallRulesFile)
if err != nil {
return nil, err
}

for _, port := range ports {
p, err := strconv.ParseInt(port, 10, 32)
if err != nil {
return nil, err
}
r.ports = append(r.ports, int32(p))
}
return r, nil
var fwrules models.V1FirewallRules
err = yaml.Unmarshal([]byte(firewallRules), &fwrules)
return &fwrules, err
}

func (c *firewallCmd) firewallSSH(args []string) (err error) {
Expand Down
37 changes: 1 addition & 36 deletions cmd/firewall_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package cmd

import (
"reflect"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -299,8 +298,7 @@ ID AGE HOSTNAME PROJECT NETWORKS IPS PARTITION
"--sshpublickey", pointer.FirstOrZero(want.Allocation.SSHPubKeys),
"--tags", strings.Join(want.Tags, ","),
"--userdata", want.Allocation.UserData,
"--egress", "",
"--ingress", "",
"--firewall-rules-file", "",
}
assertExhaustiveArgs(t, args, commonExcludedFileArgs()...)
return args
Expand All @@ -319,36 +317,3 @@ ID AGE HOSTNAME PROJECT NETWORKS IPS PARTITION
tt.testCmd(t)
}
}

func Test_parseRuleSpec(t *testing.T) {
tests := []struct {
name string
spec string
want *rule
wantErr bool
}{
{
name: "simple egress",
spec: "[email protected]/24#0.0.0.0/0@80#443#8080#8443@apt update",
want: &rule{
protocol: "tcp",
comment: "apt update",
ports: []int32{80, 443, 8080, 8443},
cidrs: []string{"1.2.3.0/24", "0.0.0.0/0"},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseRuleSpec(tt.spec)
if (err != nil) != tt.wantErr {
t.Errorf("parseRuleSpec() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseRuleSpec() = %v, want %v", got, tt.want)
}
})
}
}
Loading

0 comments on commit b13c9a8

Please sign in to comment.