diff --git a/iptables/iptables.go b/iptables/iptables.go index 6a6d380..6c5bbd7 100644 --- a/iptables/iptables.go +++ b/iptables/iptables.go @@ -112,8 +112,20 @@ func Timeout(timeout int) option { } } -// New creates a new IPTables configured with the options passed as parameter. -// For backwards compatibility, by default always uses IPv4 and timeout 0. +func Path(path string) option { + return func(ipt *IPTables) { + ipt.path = path + } +} + +// New creates a new IPTables configured with the options passed as parameters. +// Supported parameters are: +// +// IPFamily(Protocol) +// Timeout(int) +// Path(string) +// +// For backwards compatibility, by default New uses IPv4 and timeout 0. // i.e. you can create an IPv6 IPTables using a timeout of 5 seconds passing // the IPFamily and Timeout options as follow: // @@ -123,13 +135,21 @@ func New(opts ...option) (*IPTables, error) { ipt := &IPTables{ proto: ProtocolIPv4, timeout: 0, + path: "", } for _, opt := range opts { opt(ipt) } - path, err := exec.LookPath(getIptablesCommand(ipt.proto)) + // if path wasn't preset through New(Path()), autodiscover it + cmd := "" + if ipt.path == "" { + cmd = getIptablesCommand(ipt.proto) + } else { + cmd = ipt.path + } + path, err := exec.LookPath(cmd) if err != nil { return nil, err } diff --git a/iptables/iptables_test.go b/iptables/iptables_test.go index cc2de33..b6e1c39 100644 --- a/iptables/iptables_test.go +++ b/iptables/iptables_test.go @@ -70,6 +70,54 @@ func TestTimeout(t *testing.T) { } +// force usage of -legacy or -nft commands and check that they're detected correctly +func TestLegacyDetection(t *testing.T) { + testCases := []struct { + in string + mode string + err bool + }{ + { + "iptables-legacy", + "legacy", + false, + }, + { + "ip6tables-legacy", + "legacy", + false, + }, + { + "iptables-nft", + "nf_tables", + false, + }, + { + "ip6tables-nft", + "nf_tables", + false, + }, + } + + for i, tt := range testCases { + t.Run(fmt.Sprint(i), func(t *testing.T) { + ipt, err := New(Path(tt.in)) + if err == nil && tt.err { + t.Fatal("expected err, got none") + } else if err != nil && !tt.err { + t.Fatalf("unexpected err %s", err) + } + + if !strings.Contains(ipt.path, tt.in) { + t.Fatalf("Expected path %s in %s", tt.in, ipt.path) + } + if ipt.mode != tt.mode { + t.Fatalf("Expected %s iptables, but got %s", tt.mode, ipt.mode) + } + }) + } +} + func randChain(t *testing.T) string { n, err := rand.Int(rand.Reader, big.NewInt(1000000)) if err != nil {