Skip to content

Commit

Permalink
Merge pull request #51 from squeed/iptables-nft
Browse files Browse the repository at this point in the history
Add support for iptables in nftables mode.
  • Loading branch information
Casey Callendrello authored Aug 3, 2018
2 parents 25d087f + 5c15b20 commit 47f22b0
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 31 deletions.
97 changes: 76 additions & 21 deletions iptables/iptables.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ import (
// Adds the output of stderr to exec.ExitError
type Error struct {
exec.ExitError
cmd exec.Cmd
msg string
cmd exec.Cmd
msg string
exitStatus *int //for overriding
}

func (e *Error) ExitStatus() int {
if e.exitStatus != nil {
return *e.exitStatus
}
return e.Sys().(syscall.WaitStatus).ExitStatus()
}

Expand Down Expand Up @@ -65,6 +69,7 @@ type IPTables struct {
v1 int
v2 int
v3 int
mode string // the underlying iptables operating mode, e.g. nf_tables
}

// New creates a new IPTables.
Expand All @@ -81,12 +86,10 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
return nil, err
}
vstring, err := getIptablesVersionString(path)
v1, v2, v3, err := extractIptablesVersion(vstring)
v1, v2, v3, mode, err := extractIptablesVersion(vstring)

checkPresent, waitPresent, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)

checkPresent, waitPresent, randomFullyPresent, err := getIptablesCommandSupport(v1, v2, v3)
if err != nil {
return nil, fmt.Errorf("error checking iptables version: %v", err)
}
ipt := IPTables{
path: path,
proto: proto,
Expand All @@ -96,6 +99,7 @@ func NewWithProtocol(proto Protocol) (*IPTables, error) {
v1: v1,
v2: v2,
v3: v3,
mode: mode,
}
return &ipt, nil
}
Expand Down Expand Up @@ -266,10 +270,27 @@ func (ipt *IPTables) executeList(args []string) ([]string, error) {
}

rules := strings.Split(stdout.String(), "\n")

// strip trailing newline
if len(rules) > 0 && rules[len(rules)-1] == "" {
rules = rules[:len(rules)-1]
}

// nftables mode doesn't return an error code when listing a non-existent
// chain. Patch that up.
if len(rules) == 0 && ipt.mode == "nf_tables" {
v := 1
return nil, &Error{
cmd: exec.Cmd{Args: args},
msg: "iptables: No chain/target/match by that name.",
exitStatus: &v,
}
}

for i, rule := range rules {
rules[i] = filterRuleOutput(rule)
}

return rules, nil
}

Expand All @@ -284,11 +305,18 @@ func (ipt *IPTables) NewChain(table, chain string) error {
func (ipt *IPTables) ClearChain(table, chain string) error {
err := ipt.NewChain(table, chain)

// the exit code for "this table already exists" is different for
// different iptables modes
existsErr := 1
if ipt.mode == "nf_tables" {
existsErr = 4
}

eerr, eok := err.(*Error)
switch {
case err == nil:
return nil
case eok && eerr.ExitStatus() == 1:
case eok && eerr.ExitStatus() == existsErr:
// chain already exists. Flush (clear) it.
return ipt.run("-t", table, "-F", chain)
default:
Expand Down Expand Up @@ -357,7 +385,7 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
if err := cmd.Run(); err != nil {
switch e := err.(type) {
case *exec.ExitError:
return &Error{*e, cmd, stderr.String()}
return &Error{*e, cmd, stderr.String(), nil}
default:
return err
}
Expand All @@ -376,36 +404,40 @@ func getIptablesCommand(proto Protocol) string {
}

// Checks if iptables has the "-C" and "--wait" flag
func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool, error) {

return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3), nil
func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool) {
return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3)
}

// getIptablesVersion returns the first three components of the iptables version.
// e.g. "iptables v1.3.66" would return (1, 3, 66, nil)
func extractIptablesVersion(str string) (int, int, int, error) {
versionMatcher := regexp.MustCompile("v([0-9]+)\\.([0-9]+)\\.([0-9]+)")
// getIptablesVersion returns the first three components of the iptables version
// and the operating mode (e.g. nf_tables or legacy)
// e.g. "iptables v1.3.66" would return (1, 3, 66, legacy, nil)
func extractIptablesVersion(str string) (int, int, int, string, error) {
versionMatcher := regexp.MustCompile(`v([0-9]+)\.([0-9]+)\.([0-9]+)(?:\s+\((\w+))?`)
result := versionMatcher.FindStringSubmatch(str)
if result == nil {
return 0, 0, 0, fmt.Errorf("no iptables version found in string: %s", str)
return 0, 0, 0, "", fmt.Errorf("no iptables version found in string: %s", str)
}

v1, err := strconv.Atoi(result[1])
if err != nil {
return 0, 0, 0, err
return 0, 0, 0, "", err
}

v2, err := strconv.Atoi(result[2])
if err != nil {
return 0, 0, 0, err
return 0, 0, 0, "", err
}

v3, err := strconv.Atoi(result[3])
if err != nil {
return 0, 0, 0, err
return 0, 0, 0, "", err
}

return v1, v2, v3, nil
mode := "legacy"
if result[4] != "" {
mode = result[4]
}
return v1, v2, v3, mode, nil
}

// Runs "iptables --version" to get the version string
Expand Down Expand Up @@ -473,3 +505,26 @@ func (ipt *IPTables) existsForOldIptables(table, chain string, rulespec []string
}
return strings.Contains(stdout.String(), rs), nil
}

// counterRegex is the regex used to detect nftables counter format
var counterRegex = regexp.MustCompile(`^\[([0-9]+):([0-9]+)\] `)

// filterRuleOutput works around some inconsistencies in output.
// For example, when iptables is in legacy vs. nftables mode, it produces
// different results.
func filterRuleOutput(rule string) string {
out := rule

// work around an output difference in nftables mode where counters
// are output in iptables-save format, rather than iptables -S format
// The string begins with "[0:0]"
//
// Fixes #49
if groups := counterRegex.FindStringSubmatch(out); groups != nil {
// drop the brackets
out = out[len(groups[0]):]
out = fmt.Sprintf("%s -c %s %s", out, groups[1], groups[2])
}

return out
}
98 changes: 90 additions & 8 deletions iptables/iptables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,10 @@ func mustTestableIptables() []*IPTables {
}

func TestChain(t *testing.T) {
for _, ipt := range mustTestableIptables() {
runChainTests(t, ipt)
for i, ipt := range mustTestableIptables() {
t.Run(fmt.Sprint(i), func(t *testing.T) {
runChainTests(t, ipt)
})
}
}

Expand Down Expand Up @@ -179,8 +181,10 @@ func runChainTests(t *testing.T, ipt *IPTables) {
}

func TestRules(t *testing.T) {
for _, ipt := range mustTestableIptables() {
runRulesTests(t, ipt)
for i, ipt := range mustTestableIptables() {
t.Run(fmt.Sprint(i), func(t *testing.T) {
runRulesTests(t, ipt)
})
}
}

Expand Down Expand Up @@ -265,12 +269,17 @@ func runRulesTests(t *testing.T, ipt *IPTables) {
t.Fatalf("ListWithCounters failed: %v", err)
}

suffix := " -c 0 0 -j ACCEPT"
if ipt.mode == "nf_tables" {
suffix = " -j ACCEPT -c 0 0"
}

expected = []string{
"-N " + chain,
"-A " + chain + " -s " + subnet1 + " -d " + address1 + " -c 0 0 -j ACCEPT",
"-A " + chain + " -s " + subnet2 + " -d " + address2 + " -c 0 0 -j ACCEPT",
"-A " + chain + " -s " + subnet2 + " -d " + address1 + " -c 0 0 -j ACCEPT",
"-A " + chain + " -s " + address1 + " -d " + subnet2 + " -c 0 0 -j ACCEPT",
"-A " + chain + " -s " + subnet1 + " -d " + address1 + suffix,
"-A " + chain + " -s " + subnet2 + " -d " + address2 + suffix,
"-A " + chain + " -s " + subnet2 + " -d " + address1 + suffix,
"-A " + chain + " -s " + address1 + " -d " + subnet2 + suffix,
}

if !reflect.DeepEqual(rules, expected) {
Expand Down Expand Up @@ -408,3 +417,76 @@ func TestIsNotExist(t *testing.T) {
t.Fatal("IsNotExist returned false, expected true")
}
}

func TestFilterRuleOutput(t *testing.T) {
testCases := []struct {
name string
in string
out string
}{
{
"legacy output",
"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
},
{
"nft output",
"[99:42] -A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT",
"-A foo1 -p tcp -m tcp --dport 1337 -j ACCEPT -c 99 42",
},
}

for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
actual := filterRuleOutput(tt.in)
if actual != tt.out {
t.Fatalf("expect %s actual %s", tt.out, actual)
}
})
}
}

func TestExtractIptablesVersion(t *testing.T) {
testCases := []struct {
in string
v1, v2, v3 int
mode string
err bool
}{
{
"iptables v1.8.0 (nf_tables)",
1, 8, 0,
"nf_tables",
false,
},
{
"iptables v1.8.0 (legacy)",
1, 8, 0,
"legacy",
false,
},
{
"iptables v1.6.2",
1, 6, 2,
"legacy",
false,
},
}

for i, tt := range testCases {
t.Run(fmt.Sprint(i), func(t *testing.T) {
v1, v2, v3, mode, err := extractIptablesVersion(tt.in)
if err == nil && tt.err {
t.Fatal("expected err, got none")
} else if err != nil && !tt.err {
t.Fatalf("unexpected err %s", err)
}

if v1 != tt.v1 || v2 != tt.v2 || v3 != tt.v3 || mode != tt.mode {
t.Fatalf("expected %d %d %d %s, got %d %d %d %s",
tt.v1, tt.v2, tt.v3, tt.mode,
v1, v2, v3, mode)
}
})
}
}
7 changes: 5 additions & 2 deletions test
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,14 @@ split=(${TEST// / })
TEST=${split[@]/#/${REPO_PATH}/}

echo "Running tests..."
go test -i ${TEST}
bin=$(mktemp)

go test -c -o ${bin} ${COVER} -i ${TEST}
if [[ -z "$SUDO_PERMITTED" ]]; then
echo "Test aborted for safety reasons. Please set the SUDO_PERMITTED variable."
exit 1
fi

sudo -E bash -c "PATH=\$GOROOT/bin:\$PATH go test ${COVER} $@ ${TEST}"
sudo -E bash -c "${bin} $@ ${TEST}"
echo "Success"
rm "${bin}"

0 comments on commit 47f22b0

Please sign in to comment.