From c37be351f66ead11426af70225954642918a2999 Mon Sep 17 00:00:00 2001 From: turekt <32360115+turekt@users.noreply.github.com> Date: Wed, 17 Apr 2024 13:04:40 +0000 Subject: [PATCH] Objects implementation refactor Refactored obj.go to a more generic approach Added object support for already implemented expressions Added test for limit object Small lint changes Fixes https://github.com/google/nftables/issues/253 --- counter.go | 17 ++- expr/bitwise.go | 5 +- expr/bitwise_test.go | 2 +- expr/byteorder.go | 5 +- expr/connlimit.go | 5 +- expr/counter.go | 5 +- expr/ct.go | 7 +- expr/dup.go | 6 +- expr/dynset.go | 7 +- expr/expr.go | 167 +++++++++++++++--------- expr/exthdr.go | 5 +- expr/exthdr_test.go | 2 +- expr/fib.go | 6 +- expr/flow_offload.go | 5 +- expr/hash.go | 5 +- expr/immediate.go | 5 +- expr/limit.go | 5 +- expr/log.go | 5 +- expr/lookup.go | 5 +- expr/match.go | 6 +- expr/match_test.go | 2 +- expr/meta_test.go | 2 +- expr/nat.go | 5 +- expr/notrack.go | 2 +- expr/numgen.go | 5 +- expr/objref.go | 5 +- expr/payload.go | 7 +- expr/queue.go | 5 +- expr/quota.go | 5 +- expr/range.go | 6 +- expr/redirect.go | 5 +- expr/reject.go | 5 +- expr/rt.go | 5 +- expr/socket.go | 5 +- expr/socket_test.go | 2 +- expr/target.go | 6 +- expr/target_test.go | 2 +- expr/tproxy.go | 5 +- expr/verdict.go | 5 +- flowtable.go | 2 +- internal/parseexprfunc/parseexprfunc.go | 4 +- nftables_test.go | 137 +++++++++++++++---- obj.go | 146 +++++++++++++++------ quota.go | 25 +++- rule.go | 2 +- set.go | 10 +- 46 files changed, 495 insertions(+), 190 deletions(-) diff --git a/counter.go b/counter.go index 25d37d8..34c36aa 100644 --- a/counter.go +++ b/counter.go @@ -16,11 +16,12 @@ package nftables import ( "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" ) -// CounterObj implements Obj. +// Deprecated: Use ObjAttr instead type CounterObj struct { Table *Table Name string // e.g. “fwded” @@ -41,6 +42,20 @@ func (c *CounterObj) unmarshal(ad *netlink.AttributeDecoder) error { return ad.Err() } +func (c *CounterObj) data() expr.Any { + return &expr.Counter{ + Bytes: c.Bytes, + Packets: c.Packets, + } +} + +func (c *CounterObj) name() string { + return c.Name +} +func (c *CounterObj) objType() ObjType { + return ObjTypeCounter +} + func (c *CounterObj) table() *Table { return c.Table } diff --git a/expr/bitwise.go b/expr/bitwise.go index 62f7f9b..d66500d 100644 --- a/expr/bitwise.go +++ b/expr/bitwise.go @@ -30,7 +30,7 @@ type Bitwise struct { Xor []byte } -func (e *Bitwise) marshal(fam byte) ([]byte, error) { +func (e *Bitwise) marshal(fam byte, dataOnly bool) ([]byte, error) { mask, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_DATA_VALUE, Data: e.Mask}, }) @@ -54,6 +54,9 @@ func (e *Bitwise) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("bitwise\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, diff --git a/expr/bitwise_test.go b/expr/bitwise_test.go index 35fc3b3..1777d18 100644 --- a/expr/bitwise_test.go +++ b/expr/bitwise_test.go @@ -32,7 +32,7 @@ func TestBitwise(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { nbw := Bitwise{} - data, err := tt.bw.marshal(0 /* don't care in this test */) + data, err := tt.bw.marshal(0 /* don't care in this test */, false) if err != nil { t.Fatalf("marshal error: %+v", err) diff --git a/expr/byteorder.go b/expr/byteorder.go index 2450e8f..9875bc6 100644 --- a/expr/byteorder.go +++ b/expr/byteorder.go @@ -37,7 +37,7 @@ type Byteorder struct { Size uint32 } -func (e *Byteorder) marshal(fam byte) ([]byte, error) { +func (e *Byteorder) marshal(fam byte, dataOnly bool) ([]byte, error) { data, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_BYTEORDER_SREG, Data: binaryutil.BigEndian.PutUint32(e.SourceRegister)}, {Type: unix.NFTA_BYTEORDER_DREG, Data: binaryutil.BigEndian.PutUint32(e.DestRegister)}, @@ -48,6 +48,9 @@ func (e *Byteorder) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("byteorder\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, diff --git a/expr/connlimit.go b/expr/connlimit.go index b712967..3480574 100644 --- a/expr/connlimit.go +++ b/expr/connlimit.go @@ -36,7 +36,7 @@ type Connlimit struct { Flags uint32 } -func (e *Connlimit) marshal(fam byte) ([]byte, error) { +func (e *Connlimit) marshal(fam byte, dataOnly bool) ([]byte, error) { data, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: NFTA_CONNLIMIT_COUNT, Data: binaryutil.BigEndian.PutUint32(e.Count)}, {Type: NFTA_CONNLIMIT_FLAGS, Data: binaryutil.BigEndian.PutUint32(e.Flags)}, @@ -44,6 +44,9 @@ func (e *Connlimit) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("connlimit\x00")}, diff --git a/expr/counter.go b/expr/counter.go index dd6eab3..27760da 100644 --- a/expr/counter.go +++ b/expr/counter.go @@ -27,7 +27,7 @@ type Counter struct { Packets uint64 } -func (e *Counter) marshal(fam byte) ([]byte, error) { +func (e *Counter) marshal(fam byte, dataOnly bool) ([]byte, error) { data, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_COUNTER_BYTES, Data: binaryutil.BigEndian.PutUint64(e.Bytes)}, {Type: unix.NFTA_COUNTER_PACKETS, Data: binaryutil.BigEndian.PutUint64(e.Packets)}, @@ -35,6 +35,9 @@ func (e *Counter) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("counter\x00")}, diff --git a/expr/ct.go b/expr/ct.go index 1a0ee68..7511b13 100644 --- a/expr/ct.go +++ b/expr/ct.go @@ -63,8 +63,8 @@ type Ct struct { Key CtKey } -func (e *Ct) marshal(fam byte) ([]byte, error) { - regData := []byte{} +func (e *Ct) marshal(fam byte, dataOnly bool) ([]byte, error) { + var regData []byte exprData, err := netlink.MarshalAttributes( []netlink.Attribute{ {Type: unix.NFTA_CT_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, @@ -90,6 +90,9 @@ func (e *Ct) marshal(fam byte) ([]byte, error) { return nil, err } exprData = append(exprData, regData...) + if dataOnly { + return exprData, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("ct\x00")}, diff --git a/expr/dup.go b/expr/dup.go index 0114fa7..e2aa709 100644 --- a/expr/dup.go +++ b/expr/dup.go @@ -28,7 +28,7 @@ type Dup struct { IsRegDevSet bool } -func (e *Dup) marshal(fam byte) ([]byte, error) { +func (e *Dup) marshal(fam byte, dataOnly bool) ([]byte, error) { attrs := []netlink.Attribute{ {Type: unix.NFTA_DUP_SREG_ADDR, Data: binaryutil.BigEndian.PutUint32(e.RegAddr)}, } @@ -38,10 +38,12 @@ func (e *Dup) marshal(fam byte) ([]byte, error) { } data, err := netlink.MarshalAttributes(attrs) - if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("dup\x00")}, diff --git a/expr/dynset.go b/expr/dynset.go index e44f772..cd5c711 100644 --- a/expr/dynset.go +++ b/expr/dynset.go @@ -43,7 +43,7 @@ type Dynset struct { Exprs []Any } -func (e *Dynset) marshal(fam byte) ([]byte, error) { +func (e *Dynset) marshal(fam byte, dataOnly bool) ([]byte, error) { // See: https://git.netfilter.org/libnftnl/tree/src/expr/dynset.c var opAttrs []netlink.Attribute opAttrs = append(opAttrs, netlink.Attribute{Type: unix.NFTA_DYNSET_SREG_KEY, Data: binaryutil.BigEndian.PutUint32(e.SrcRegKey)}) @@ -95,6 +95,9 @@ func (e *Dynset) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return opData, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("dynset\x00")}, @@ -125,7 +128,7 @@ func (e *Dynset) unmarshal(fam byte, data []byte) error { case unix.NFTA_DYNSET_FLAGS: e.Invert = (ad.Uint32() & unix.NFT_DYNSET_F_INV) != 0 case unix.NFTA_DYNSET_EXPR: - exprs, err := parseexprfunc.ParseExprBytesFunc(fam, ad, ad.Bytes()) + exprs, err := parseexprfunc.ParseExprBytesFunc(fam, ad) if err != nil { return err } diff --git a/expr/expr.go b/expr/expr.go index a4d970f..8b06a37 100644 --- a/expr/expr.go +++ b/expr/expr.go @@ -25,8 +25,8 @@ import ( ) func init() { - parseexprfunc.ParseExprBytesFunc = func(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]interface{}, error) { - exprs, err := exprsFromBytes(fam, ad, b) + parseexprfunc.ParseExprBytesFunc = func(fam byte, ad *netlink.AttributeDecoder, args ...string) ([]interface{}, error) { + exprs, err := exprsFromBytes(fam, ad, args...) if err != nil { return nil, err } @@ -36,7 +36,7 @@ func init() { } return result, nil } - parseexprfunc.ParseExprMsgFunc = func(fam byte, b []byte) ([]interface{}, error) { + parseexprfunc.ParseExprMsgFunc = func(fam byte, b []byte, args ...string) ([]interface{}, error) { ad, err := netlink.NewAttributeDecoder(b) if err != nil { return nil, err @@ -44,7 +44,7 @@ func init() { ad.ByteOrder = binary.BigEndian var exprs []interface{} for ad.Next() { - e, err := parseexprfunc.ParseExprBytesFunc(fam, ad, b) + e, err := parseexprfunc.ParseExprBytesFunc(fam, ad, args...) if err != nil { return e, err } @@ -56,7 +56,11 @@ func init() { // Marshal serializes the specified expression into a byte slice. func Marshal(fam byte, e Any) ([]byte, error) { - return e.marshal(fam) + return e.marshal(fam, false) +} + +func MarshalExprData(fam byte, e Any) ([]byte, error) { + return e.marshal(fam, true) } // Unmarshal fills an expression from the specified byte slice. @@ -66,8 +70,20 @@ func Unmarshal(fam byte, data []byte, e Any) error { // exprsFromBytes parses nested raw expressions bytes // to construct nftables expressions -func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]Any, error) { +func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, args ...string) ([]Any, error) { var exprs []Any + if len(args) > 0 { + e := exprFromName(args[0]) + ad.Do(func(b []byte) error { + if err := Unmarshal(fam, b, e); err != nil { + return err + } + exprs = append(exprs, e) + return nil + }) + return exprs, ad.Err() + } + ad.Do(func(b []byte) error { ad, err := netlink.NewAttributeDecoder(b) if err != nil { @@ -75,74 +91,29 @@ func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]Any, er } ad.ByteOrder = binary.BigEndian var name string + if len(args) > 0 { + name = args[0] + } for ad.Next() { switch ad.Type() { case unix.NFTA_EXPR_NAME: + if name != "" { + continue + } name = ad.String() if name == "notrack" { e := &Notrack{} exprs = append(exprs, e) } + case unix.NFTA_OBJ_DATA: + fallthrough case unix.NFTA_EXPR_DATA: - var e Any - switch name { - case "ct": - e = &Ct{} - case "range": - e = &Range{} - case "meta": - e = &Meta{} - case "cmp": - e = &Cmp{} - case "counter": - e = &Counter{} - case "objref": - e = &Objref{} - case "payload": - e = &Payload{} - case "lookup": - e = &Lookup{} - case "immediate": - e = &Immediate{} - case "bitwise": - e = &Bitwise{} - case "redir": - e = &Redir{} - case "nat": - e = &NAT{} - case "limit": - e = &Limit{} - case "quota": - e = &Quota{} - case "dynset": - e = &Dynset{} - case "log": - e = &Log{} - case "exthdr": - e = &Exthdr{} - case "match": - e = &Match{} - case "target": - e = &Target{} - case "connlimit": - e = &Connlimit{} - case "queue": - e = &Queue{} - case "flow_offload": - e = &FlowOffload{} - case "reject": - e = &Reject{} - case "masq": - e = &Masq{} - case "hash": - e = &Hash{} - } + e := exprFromName(name) if e == nil { // TODO: introduce an opaque expression type so that users know // something is here. continue // unsupported expression type } - ad.Do(func(b []byte) error { if err := Unmarshal(fam, b, e); err != nil { return err @@ -166,9 +137,66 @@ func exprsFromBytes(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]Any, er return exprs, ad.Err() } +func exprFromName(name string) Any { + var e Any + switch name { + case "ct": + e = &Ct{} + case "range": + e = &Range{} + case "meta": + e = &Meta{} + case "cmp": + e = &Cmp{} + case "counter": + e = &Counter{} + case "objref": + e = &Objref{} + case "payload": + e = &Payload{} + case "lookup": + e = &Lookup{} + case "immediate": + e = &Immediate{} + case "bitwise": + e = &Bitwise{} + case "redir": + e = &Redir{} + case "nat": + e = &NAT{} + case "limit": + e = &Limit{} + case "quota": + e = &Quota{} + case "dynset": + e = &Dynset{} + case "log": + e = &Log{} + case "exthdr": + e = &Exthdr{} + case "match": + e = &Match{} + case "target": + e = &Target{} + case "connlimit": + e = &Connlimit{} + case "queue": + e = &Queue{} + case "flow_offload": + e = &FlowOffload{} + case "reject": + e = &Reject{} + case "masq": + e = &Masq{} + case "hash": + e = &Hash{} + } + return e +} + // Any is an interface implemented by any expression type. type Any interface { - marshal(fam byte) ([]byte, error) + marshal(fam byte, dataOnly bool) ([]byte, error) unmarshal(fam byte, data []byte) error } @@ -213,8 +241,8 @@ type Meta struct { Register uint32 } -func (e *Meta) marshal(fam byte) ([]byte, error) { - regData := []byte{} +func (e *Meta) marshal(fam byte, dataOnly bool) ([]byte, error) { + var regData []byte exprData, err := netlink.MarshalAttributes( []netlink.Attribute{ {Type: unix.NFTA_META_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, @@ -240,6 +268,9 @@ func (e *Meta) marshal(fam byte) ([]byte, error) { return nil, err } exprData = append(exprData, regData...) + if dataOnly { + return exprData, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("meta\x00")}, @@ -290,7 +321,7 @@ const ( NF_NAT_RANGE_PREFIX = unix.NF_NAT_RANGE_NETMAP ) -func (e *Masq) marshal(fam byte) ([]byte, error) { +func (e *Masq) marshal(fam byte, dataOnly bool) ([]byte, error) { msgData := []byte{} if !e.ToPorts { flags := uint32(0) @@ -327,6 +358,9 @@ func (e *Masq) marshal(fam byte) ([]byte, error) { msgData = append(msgData, regsData...) } } + if dataOnly { + return msgData, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("masq\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: msgData}, @@ -376,7 +410,7 @@ type Cmp struct { Data []byte } -func (e *Cmp) marshal(fam byte) ([]byte, error) { +func (e *Cmp) marshal(fam byte, dataOnly bool) ([]byte, error) { cmpData, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_DATA_VALUE, Data: e.Data}, }) @@ -391,6 +425,9 @@ func (e *Cmp) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return cmpData, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("cmp\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, diff --git a/expr/exthdr.go b/expr/exthdr.go index df0c7db..e3683db 100644 --- a/expr/exthdr.go +++ b/expr/exthdr.go @@ -39,7 +39,7 @@ type Exthdr struct { SourceRegister uint32 } -func (e *Exthdr) marshal(fam byte) ([]byte, error) { +func (e *Exthdr) marshal(fam byte, dataOnly bool) ([]byte, error) { var attr []netlink.Attribute // Operations are differentiated by the Op and whether the SourceRegister @@ -68,6 +68,9 @@ func (e *Exthdr) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("exthdr\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, diff --git a/expr/exthdr_test.go b/expr/exthdr_test.go index b211818..5c92437 100644 --- a/expr/exthdr_test.go +++ b/expr/exthdr_test.go @@ -44,7 +44,7 @@ func TestExthdr(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { neh := Exthdr{} - data, err := tt.eh.marshal(0 /* don't care in this test */) + data, err := tt.eh.marshal(0 /* don't care in this test */, false) if err != nil { t.Fatalf("marshal error: %+v", err) diff --git a/expr/fib.go b/expr/fib.go index f7ee704..aba30a3 100644 --- a/expr/fib.go +++ b/expr/fib.go @@ -36,7 +36,7 @@ type Fib struct { FlagPRESENT bool } -func (e *Fib) marshal(fam byte) ([]byte, error) { +func (e *Fib) marshal(fam byte, dataOnly bool) ([]byte, error) { data := []byte{} reg, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_FIB_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, @@ -92,7 +92,9 @@ func (e *Fib) marshal(fam byte) ([]byte, error) { } data = append(data, rslt...) } - + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("fib\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, diff --git a/expr/flow_offload.go b/expr/flow_offload.go index 54f956f..5bed4b4 100644 --- a/expr/flow_offload.go +++ b/expr/flow_offload.go @@ -27,13 +27,16 @@ type FlowOffload struct { Name string } -func (e *FlowOffload) marshal(fam byte) ([]byte, error) { +func (e *FlowOffload) marshal(fam byte, dataOnly bool) ([]byte, error) { data, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: NFTNL_EXPR_FLOW_TABLE_NAME, Data: []byte(e.Name)}, }) if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("flow_offload\x00")}, diff --git a/expr/hash.go b/expr/hash.go index e8506b9..8994f8d 100644 --- a/expr/hash.go +++ b/expr/hash.go @@ -40,7 +40,7 @@ type Hash struct { Type HashType } -func (e *Hash) marshal(fam byte) ([]byte, error) { +func (e *Hash) marshal(fam byte, dataOnly bool) ([]byte, error) { hashAttrs := []netlink.Attribute{ {Type: unix.NFTA_HASH_SREG, Data: binaryutil.BigEndian.PutUint32(uint32(e.SourceRegister))}, {Type: unix.NFTA_HASH_DREG, Data: binaryutil.BigEndian.PutUint32(uint32(e.DestRegister))}, @@ -60,6 +60,9 @@ func (e *Hash) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("hash\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, diff --git a/expr/immediate.go b/expr/immediate.go index 99531f8..0b491d5 100644 --- a/expr/immediate.go +++ b/expr/immediate.go @@ -28,7 +28,7 @@ type Immediate struct { Data []byte } -func (e *Immediate) marshal(fam byte) ([]byte, error) { +func (e *Immediate) marshal(fam byte, dataOnly bool) ([]byte, error) { immData, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_DATA_VALUE, Data: e.Data}, }) @@ -43,6 +43,9 @@ func (e *Immediate) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("immediate\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, diff --git a/expr/limit.go b/expr/limit.go index 9ecb41f..0bbdcb9 100644 --- a/expr/limit.go +++ b/expr/limit.go @@ -71,7 +71,7 @@ type Limit struct { Burst uint32 } -func (l *Limit) marshal(fam byte) ([]byte, error) { +func (l *Limit) marshal(fam byte, dataOnly bool) ([]byte, error) { var flags uint32 if l.Over { flags = unix.NFT_LIMIT_F_INV @@ -88,6 +88,9 @@ func (l *Limit) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("limit\x00")}, diff --git a/expr/log.go b/expr/log.go index a712b99..76edaa2 100644 --- a/expr/log.go +++ b/expr/log.go @@ -68,7 +68,7 @@ type Log struct { Data []byte } -func (e *Log) marshal(fam byte) ([]byte, error) { +func (e *Log) marshal(fam byte, dataOnly bool) ([]byte, error) { // Per https://git.netfilter.org/libnftnl/tree/src/expr/log.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n129 attrs := make([]netlink.Attribute, 0) if e.Key&(1< 0 { attrs = append(attrs, netlink.Attribute{Type: unix.NFTA_REDIR_REG_PROTO_MIN, Data: binaryutil.BigEndian.PutUint32(e.RegisterProtoMin)}) @@ -44,6 +44,9 @@ func (e *Redir) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("redir\x00")}, diff --git a/expr/reject.go b/expr/reject.go index a742626..5988ebc 100644 --- a/expr/reject.go +++ b/expr/reject.go @@ -27,7 +27,7 @@ type Reject struct { Code uint8 } -func (e *Reject) marshal(fam byte) ([]byte, error) { +func (e *Reject) marshal(fam byte, dataOnly bool) ([]byte, error) { data, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_REJECT_TYPE, Data: binaryutil.BigEndian.PutUint32(e.Type)}, {Type: unix.NFTA_REJECT_ICMP_CODE, Data: []byte{e.Code}}, @@ -35,6 +35,9 @@ func (e *Reject) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("reject\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, diff --git a/expr/rt.go b/expr/rt.go index c3be7ff..f0f1af3 100644 --- a/expr/rt.go +++ b/expr/rt.go @@ -36,7 +36,7 @@ type Rt struct { Key RtKey } -func (e *Rt) marshal(fam byte) ([]byte, error) { +func (e *Rt) marshal(fam byte, dataOnly bool) ([]byte, error) { data, err := netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_RT_KEY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Key))}, {Type: unix.NFTA_RT_DREG, Data: binaryutil.BigEndian.PutUint32(e.Register)}, @@ -44,6 +44,9 @@ func (e *Rt) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("rt\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, diff --git a/expr/socket.go b/expr/socket.go index 1b6bc24..82a4fcf 100644 --- a/expr/socket.go +++ b/expr/socket.go @@ -48,7 +48,7 @@ const ( SocketKeyCgroupv2 SocketKey = NFT_SOCKET_CGROUPV2 ) -func (e *Socket) marshal(fam byte) ([]byte, error) { +func (e *Socket) marshal(fam byte, dataOnly bool) ([]byte, error) { // NOTE: Socket.Level is only used when Socket.Key == SocketKeyCgroupv2. But `nft` always encoding it. Check link below: // http://git.netfilter.org/nftables/tree/src/netlink_linearize.c?id=0583bac241ea18c9d7f61cb20ca04faa1e043b78#n319 exprData, err := netlink.MarshalAttributes( @@ -62,6 +62,9 @@ func (e *Socket) marshal(fam byte) ([]byte, error) { return nil, err } + if dataOnly { + return exprData, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("socket\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: exprData}, diff --git a/expr/socket_test.go b/expr/socket_test.go index 25eddb2..007f98e 100644 --- a/expr/socket_test.go +++ b/expr/socket_test.go @@ -74,7 +74,7 @@ func TestSocket(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { nSocket := Socket{} - data, err := tt.socket.marshal(0 /* don't care in this test */) + data, err := tt.socket.marshal(0 /* don't care in this test */, false) if err != nil { t.Fatalf("marshal error: %+v", err) diff --git a/expr/target.go b/expr/target.go index e531a9f..ed89cbc 100644 --- a/expr/target.go +++ b/expr/target.go @@ -20,7 +20,7 @@ type Target struct { Info xt.InfoAny } -func (e *Target) marshal(fam byte) ([]byte, error) { +func (e *Target) marshal(fam byte, dataOnly bool) ([]byte, error) { // Per https://git.netfilter.org/libnftnl/tree/src/expr/target.c?id=09456c720e9c00eecc08e41ac6b7c291b3821ee5#n38 name := e.Name // limit the extension name as (some) user-space tools do and leave room for @@ -44,7 +44,9 @@ func (e *Target) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } - + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("target\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, diff --git a/expr/target_test.go b/expr/target_test.go index e630e86..2529746 100644 --- a/expr/target_test.go +++ b/expr/target_test.go @@ -30,7 +30,7 @@ func TestTarget(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ntgt := Target{} - data, err := tt.tgt.marshal(0 /* don't care in this test */) + data, err := tt.tgt.marshal(0 /* don't care in this test */, false) if err != nil { t.Fatalf("marshal error: %+v", err) diff --git a/expr/tproxy.go b/expr/tproxy.go index 2846aab..f52be9c 100644 --- a/expr/tproxy.go +++ b/expr/tproxy.go @@ -39,7 +39,7 @@ type TProxy struct { RegPort uint32 } -func (e *TProxy) marshal(fam byte) ([]byte, error) { +func (e *TProxy) marshal(fam byte, dataOnly bool) ([]byte, error) { attrs := []netlink.Attribute{ {Type: NFTA_TPROXY_FAMILY, Data: binaryutil.BigEndian.PutUint32(uint32(e.Family))}, {Type: NFTA_TPROXY_REG_PORT, Data: binaryutil.BigEndian.PutUint32(e.RegPort)}, @@ -56,6 +56,9 @@ func (e *TProxy) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("tproxy\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, diff --git a/expr/verdict.go b/expr/verdict.go index 421fa06..d3b7d4b 100644 --- a/expr/verdict.go +++ b/expr/verdict.go @@ -53,7 +53,7 @@ const ( VerdictStop ) -func (e *Verdict) marshal(fam byte) ([]byte, error) { +func (e *Verdict) marshal(fam byte, dataOnly bool) ([]byte, error) { // A verdict is a tree of netlink attributes structured as follows: // NFTA_LIST_ELEM | NLA_F_NESTED { // NFTA_EXPR_NAME { "immediate\x00" } @@ -90,6 +90,9 @@ func (e *Verdict) marshal(fam byte) ([]byte, error) { if err != nil { return nil, err } + if dataOnly { + return data, nil + } return netlink.MarshalAttributes([]netlink.Attribute{ {Type: unix.NFTA_EXPR_NAME, Data: []byte("immediate\x00")}, {Type: unix.NLA_F_NESTED | unix.NFTA_EXPR_DATA, Data: data}, diff --git a/flowtable.go b/flowtable.go index 01df08e..93dbcb5 100644 --- a/flowtable.go +++ b/flowtable.go @@ -219,7 +219,7 @@ func (cc *Conn) getFlowtables(t *Table) ([]netlink.Message, error) { reply, err := receiveAckAware(conn, message.Header.Flags) if err != nil { - return nil, fmt.Errorf("Receive: %v", err) + return nil, fmt.Errorf("receiveAckAware: %v", err) } return reply, nil diff --git a/internal/parseexprfunc/parseexprfunc.go b/internal/parseexprfunc/parseexprfunc.go index 523859d..ae840b4 100644 --- a/internal/parseexprfunc/parseexprfunc.go +++ b/internal/parseexprfunc/parseexprfunc.go @@ -5,6 +5,6 @@ import ( ) var ( - ParseExprBytesFunc func(fam byte, ad *netlink.AttributeDecoder, b []byte) ([]interface{}, error) - ParseExprMsgFunc func(fam byte, b []byte) ([]interface{}, error) + ParseExprBytesFunc func(fam byte, ad *netlink.AttributeDecoder, args ...string) ([]interface{}, error) + ParseExprMsgFunc func(fam byte, b []byte, args ...string) ([]interface{}, error) ) diff --git a/nftables_test.go b/nftables_test.go index be8b83b..5ac8ab1 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -1783,7 +1783,7 @@ func TestListChainByName(t *testing.T) { } func TestListChainByNameUsingLasting(t *testing.T) { - conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + _, newNS := nftest.OpenSystemConn(t, *enableSysTests) conn, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting()) if err != nil { t.Fatalf("nftables.New() failed: %v", err) @@ -1882,8 +1882,7 @@ func TestListTableByName(t *testing.T) { } // not specifying correct family should return err since no table in ipv4 - tr, err = conn.ListTable(table2.Name) - if err == nil { + if _, err = conn.ListTable(table2.Name); err == nil { t.Fatalf("conn.ListTable() should have failed") } @@ -2114,9 +2113,9 @@ func TestGetObjReset(t *testing.T) { t.Fatal(err) } - co, ok := obj.(*nftables.CounterObj) + co, ok := obj.(*nftables.ObjAttr) if !ok { - t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj) + t.Fatalf("unexpected type: got %T, want *nftables.ObjAttr", obj) } if got, want := co.Table.Name, filter.Name; got != want { t.Errorf("unexpected table name: got %q, want %q", got, want) @@ -2124,10 +2123,14 @@ func TestGetObjReset(t *testing.T) { if got, want := co.Table.Family, filter.Family; got != want { t.Errorf("unexpected table family: got %d, want %d", got, want) } - if got, want := co.Packets, uint64(9); got != want { + o, ok := co.Obj.(*expr.Counter) + if !ok { + t.Fatalf("unexpected type: got %T, want *expr.Counter", o) + } + if got, want := o.Packets, uint64(9); got != want { t.Errorf("unexpected number of packets: got %d, want %d", got, want) } - if got, want := co.Bytes, uint64(1121); got != want { + if got, want := o.Bytes, uint64(1121); got != want { t.Errorf("unexpected number of bytes: got %d, want %d", got, want) } } @@ -2223,10 +2226,9 @@ func TestObjAPI(t *testing.T) { t.Errorf("c.GetObject(counter1) failed: %v failed", err) } - rcounter1, ok := obj1.(*nftables.CounterObj) - + rcounter1, ok := obj1.(*nftables.ObjAttr) if !ok { - t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", rcounter1) + t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj1) } if rcounter1.Name != "fwded1" { @@ -2238,10 +2240,9 @@ func TestObjAPI(t *testing.T) { t.Errorf("c.GetObject(counter2) failed: %v failed", err) } - rcounter2, ok := obj2.(*nftables.CounterObj) - + rcounter2, ok := obj2.(*nftables.ObjAttr) if !ok { - t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", rcounter2) + t.Fatalf("unexpected type: got %T, want *nftables.CounterObj", obj2) } if rcounter2.Name != "fwded2" { @@ -2260,7 +2261,7 @@ func TestObjAPI(t *testing.T) { t.Errorf("c.GetObject(counter1) failed: %v failed", err) } - if counter1 := obj1.(*nftables.CounterObj); counter1.Packets > 0 { + if counter1 := obj1.(*nftables.ObjAttr).Obj.(*expr.Counter); counter1.Packets > 0 { t.Errorf("unexpected packets number: got %d, want %d", counter1.Packets, 0) } @@ -2270,7 +2271,7 @@ func TestObjAPI(t *testing.T) { t.Errorf("c.GetObject(counter2) failed: %v failed", err) } - if counter2 := obj2.(*nftables.CounterObj); counter2.Packets != 1 { + if counter2 := obj2.(*nftables.ObjAttr).Obj.(*expr.Counter); counter2.Packets != 1 { t.Errorf("unexpected packets number: got %d, want %d", counter2.Packets, 1) } @@ -2767,7 +2768,7 @@ func TestCreateUseAnonymousSet(t *testing.T) { } func TestCappedErrMsgOnSets(t *testing.T) { - c, newNS := nftest.OpenSystemConn(t, *enableSysTests) + _, newNS := nftest.OpenSystemConn(t, *enableSysTests) c, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting()) if err != nil { t.Fatalf("nftables.New() failed: %v", err) @@ -6285,6 +6286,84 @@ func TestGetRulesObjref(t *testing.T) { } } +func TestAddLimitObj(t *testing.T) { + conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) + defer nftest.CleanupSystemConn(t, newNS) + conn.FlushRuleset() + defer conn.FlushRuleset() + + table := &nftables.Table{ + Name: "limit_demo", + Family: nftables.TableFamilyIPv4, + } + tr := conn.AddTable(table) + + c := &nftables.Chain{ + Name: "filter", + Table: table, + } + conn.AddChain(c) + + l := &expr.Limit{ + Type: expr.LimitTypePkts, + Rate: 400, + Unit: expr.LimitTimeMinute, + Burst: 5, + Over: false, + } + o := &nftables.ObjAttr{ + Table: tr, + Name: "limit_test", + Type: nftables.ObjTypeLimit, + Obj: l, + } + conn.AddObj(o) + + if err := conn.Flush(); err != nil { + t.Errorf("conn.Flush() failed: %v", err) + } + + obj, err := conn.GetObj(&nftables.ObjAttr{ + Table: table, + Name: "limit_test", + Type: nftables.ObjTypeLimit, + }) + if err != nil { + t.Fatalf("conn.GetObj() failed: %v", err) + } + + if got, want := len(obj), 1; got != want { + t.Fatalf("unexpected object list length: got %d, want %d", got, want) + } + + o1, ok := obj[0].(*nftables.ObjAttr) + if !ok { + t.Fatalf("unexpected type: got %T, want *ObjAttr", obj[0]) + } + if got, want := o1.Name, o.Name; got != want { + t.Fatalf("limit name mismatch: got %s, want %s", got, want) + } + q, ok := o1.Obj.(*expr.Limit) + if !ok { + t.Fatalf("unexpected type: got %T, want *expr.Quota", o1.Obj) + } + if got, want := q.Burst, l.Burst; got != want { + t.Fatalf("limit burst mismatch: got %d, want %d", got, want) + } + if got, want := q.Unit, l.Unit; got != want { + t.Fatalf("limit unit mismatch: got %d, want %d", got, want) + } + if got, want := q.Rate, l.Rate; got != want { + t.Fatalf("limit rate mismatch: got %v, want %v", got, want) + } + if got, want := q.Over, l.Over; got != want { + t.Fatalf("limit over mismatch: got %v, want %v", got, want) + } + if got, want := q.Type, l.Type; got != want { + t.Fatalf("limit type mismatch: got %v, want %v", got, want) + } +} + func TestAddQuotaObj(t *testing.T) { conn, newNS := nftest.OpenSystemConn(t, *enableSysTests) defer nftest.CleanupSystemConn(t, newNS) @@ -6328,20 +6407,24 @@ func TestAddQuotaObj(t *testing.T) { t.Fatalf("unexpected object list length: got %d, want %d", got, want) } - o1, ok := obj[0].(*nftables.QuotaObj) + o1, ok := obj[0].(*nftables.ObjAttr) if !ok { - t.Fatalf("unexpected type: got %T, want *QuotaObj", obj[0]) + t.Fatalf("unexpected type: got %T, want *ObjAttr", obj[0]) } if got, want := o1.Name, o.Name; got != want { t.Fatalf("quota name mismatch: got %s, want %s", got, want) } - if got, want := o1.Bytes, o.Bytes; got != want { + q, ok := o1.Obj.(*expr.Quota) + if !ok { + t.Fatalf("unexpected type: got %T, want *expr.Quota", o1.Obj) + } + if got, want := q.Bytes, o.Bytes; got != want { t.Fatalf("quota bytes mismatch: got %d, want %d", got, want) } - if got, want := o1.Consumed, o.Consumed; got != want { + if got, want := q.Consumed, o.Consumed; got != want { t.Fatalf("quota consumed mismatch: got %d, want %d", got, want) } - if got, want := o1.Over, o.Over; got != want { + if got, want := q.Over, o.Over; got != want { t.Fatalf("quota over mismatch: got %v, want %v", got, want) } } @@ -6452,7 +6535,17 @@ func TestDeleteQuotaObj(t *testing.T) { t.Fatalf("unexpected number of objects: got %d, want %d", got, want) } - if got, want := obj[0], o; !reflect.DeepEqual(got, want) { + want := &nftables.ObjAttr{ + Table: tr, + Name: "q_test", + Type: nftables.ObjTypeQuota, + Obj: &expr.Quota{ + Bytes: o.Bytes, + Consumed: o.Consumed, + Over: o.Over, + }, + } + if got, want := obj[0], want; !reflect.DeepEqual(got, want) { t.Errorf("got = %+v, want = %+v", got, want) } diff --git a/obj.go b/obj.go index c468a63..fdbeb08 100644 --- a/obj.go +++ b/obj.go @@ -18,6 +18,9 @@ import ( "encoding/binary" "fmt" + "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" + "github.com/google/nftables/internal/parseexprfunc" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" ) @@ -27,13 +30,70 @@ var ( delObjHeaderType = netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELOBJ) ) +type ObjType uint32 + +// https://git.netfilter.org/libnftnl/tree/include/linux/netfilter/nf_tables.h?id=be0bae0ad31b0adb506f96de083f52a2bd0d4fbf#n1612 +const ( + ObjTypeCounter ObjType = unix.NFT_OBJECT_COUNTER + ObjTypeQuota ObjType = unix.NFT_OBJECT_QUOTA + ObjTypeCtHelper ObjType = unix.NFT_OBJECT_CT_HELPER + ObjTypeLimit ObjType = unix.NFT_OBJECT_LIMIT + ObjTypeConnLimit ObjType = unix.NFT_OBJECT_CONNLIMIT + ObjTypeTunnel ObjType = unix.NFT_OBJECT_TUNNEL + ObjTypeCtTimeout ObjType = unix.NFT_OBJECT_CT_TIMEOUT + ObjTypeSecMark ObjType = unix.NFT_OBJECT_SECMARK + ObjTypeCtExpect ObjType = unix.NFT_OBJECT_CT_EXPECT + ObjTypeSynProxy ObjType = unix.NFT_OBJECT_SYNPROXY +) + +var objByObjTypeMagic = map[ObjType]string{ + ObjTypeCounter: "counter", + ObjTypeQuota: "quota", + ObjTypeLimit: "limit", + ObjTypeConnLimit: "connlimit", + ObjTypeCtHelper: "cthelper", // not implemented in expr + ObjTypeTunnel: "tunnel", // not implemented in expr + ObjTypeCtTimeout: "cttimeout", // not implemented in expr + ObjTypeSecMark: "secmark", // not implemented in expr + ObjTypeCtExpect: "ctexpect", // not implemented in expr + ObjTypeSynProxy: "synproxy", // not implemented in expr +} + // Obj represents a netfilter stateful object. See also // https://wiki.nftables.org/wiki-nftables/index.php/Stateful_objects type Obj interface { table() *Table family() TableFamily - unmarshal(*netlink.AttributeDecoder) error - marshal(data bool) ([]byte, error) + data() expr.Any + name() string + objType() ObjType +} + +type ObjAttr struct { + Table *Table + Name string + Type ObjType + Obj expr.Any +} + +func (o *ObjAttr) table() *Table { + return o.Table +} + +func (o *ObjAttr) family() TableFamily { + return o.Table.Family +} + +func (o *ObjAttr) data() expr.Any { + return o.Obj +} + +func (o *ObjAttr) name() string { + return o.Name +} + +func (o *ObjAttr) objType() ObjType { + return o.Type } // AddObject adds the specified Obj. Alias of AddObj. @@ -46,18 +106,27 @@ func (cc *Conn) AddObject(o Obj) Obj { func (cc *Conn) AddObj(o Obj) Obj { cc.mu.Lock() defer cc.mu.Unlock() - data, err := o.marshal(true) + data, err := expr.MarshalExprData(byte(o.family()), o.data()) if err != nil { cc.setErr(err) return nil } + attrs := []netlink.Attribute{ + {Type: unix.NFTA_OBJ_TABLE, Data: []byte(o.table().Name + "\x00")}, + {Type: unix.NFTA_OBJ_NAME, Data: []byte(o.name() + "\x00")}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(o.objType()))}, + } + if len(data) > 0 { + attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: data}) + } + cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, }, - Data: append(extraHeader(uint8(o.family()), 0), data...), + Data: append(extraHeader(uint8(o.family()), 0), cc.marshalAttr(attrs)...), }) return o } @@ -66,12 +135,12 @@ func (cc *Conn) AddObj(o Obj) Obj { func (cc *Conn) DeleteObject(o Obj) { cc.mu.Lock() defer cc.mu.Unlock() - data, err := o.marshal(false) - if err != nil { - cc.setErr(err) - return + attrs := []netlink.Attribute{ + {Type: unix.NFTA_OBJ_TABLE, Data: []byte(o.table().Name + "\x00")}, + {Type: unix.NFTA_OBJ_NAME, Data: []byte(o.name() + "\x00")}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(o.objType()))}, } - + data := cc.marshalAttr(attrs) data = append(data, cc.marshalAttr([]netlink.Attribute{{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA}})...) cc.messages = append(cc.messages, netlink.Message{ @@ -150,38 +219,26 @@ func objFromMsg(msg netlink.Message) (Obj, error) { case unix.NFTA_OBJ_TYPE: objectType = ad.Uint32() case unix.NFTA_OBJ_DATA: - switch objectType { - case unix.NFT_OBJECT_COUNTER: - o := CounterObj{ - Table: table, - Name: name, - } - - ad.Do(func(b []byte) error { - ad, err := netlink.NewAttributeDecoder(b) - if err != nil { - return err - } - ad.ByteOrder = binary.BigEndian - return o.unmarshal(ad) - }) - return &o, ad.Err() - case NFT_OBJECT_QUOTA: - o := QuotaObj{ - Table: table, - Name: name, - } - - ad.Do(func(b []byte) error { - ad, err := netlink.NewAttributeDecoder(b) - if err != nil { - return err - } - ad.ByteOrder = binary.BigEndian - return o.unmarshal(ad) - }) - return &o, ad.Err() + o := ObjAttr{ + Table: table, + Name: name, + Type: ObjType(objectType), + } + + objs, err := parseexprfunc.ParseExprBytesFunc(byte(o.family()), ad, objByObjTypeMagic[o.Type]) + if err != nil { + return nil, err + } + exprs := make([]expr.Any, len(objs)) + for i := range exprs { + exprs[i] = objs[i].(expr.Any) } + if len(exprs) == 0 { + return nil, fmt.Errorf("objFromMsg: exprs is empty for obj %v", o) + } + + o.Obj = exprs[0] + return &o, ad.Err() } } if err := ad.Err(); err != nil { @@ -201,7 +258,12 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { var flags netlink.HeaderFlags if o != nil { - data, err = o.marshal(false) + attrs := []netlink.Attribute{ + {Type: unix.NFTA_OBJ_TABLE, Data: []byte(o.table().Name + "\x00")}, + {Type: unix.NFTA_OBJ_NAME, Data: []byte(o.name() + "\x00")}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(uint32(o.objType()))}, + } + data = cc.marshalAttr(attrs) } else { flags = netlink.Dump data, err = netlink.MarshalAttributes([]netlink.Attribute{ @@ -226,7 +288,7 @@ func (cc *Conn) getObj(o Obj, t *Table, msgType uint16) ([]Obj, error) { reply, err := receiveAckAware(conn, message.Header.Flags) if err != nil { - return nil, fmt.Errorf("Receive: %v", err) + return nil, fmt.Errorf("receiveAckAware: %v", err) } var objs []Obj for _, msg := range reply { diff --git a/quota.go b/quota.go index 71cb9bb..e3c71b1 100644 --- a/quota.go +++ b/quota.go @@ -16,15 +16,12 @@ package nftables import ( "github.com/google/nftables/binaryutil" + "github.com/google/nftables/expr" "github.com/mdlayher/netlink" "golang.org/x/sys/unix" ) -const ( - NFTA_OBJ_USERDATA = 8 - NFT_OBJECT_QUOTA = 2 -) - +// Deprecated: Use ObjAttr instead type QuotaObj struct { Table *Table Name string @@ -63,7 +60,7 @@ func (q *QuotaObj) marshal(data bool) ([]byte, error) { attrs := []netlink.Attribute{ {Type: unix.NFTA_OBJ_TABLE, Data: []byte(q.Table.Name + "\x00")}, {Type: unix.NFTA_OBJ_NAME, Data: []byte(q.Name + "\x00")}, - {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(NFT_OBJECT_QUOTA)}, + {Type: unix.NFTA_OBJ_TYPE, Data: binaryutil.BigEndian.PutUint32(unix.NFT_OBJECT_QUOTA)}, } if data { attrs = append(attrs, netlink.Attribute{Type: unix.NLA_F_NESTED | unix.NFTA_OBJ_DATA, Data: obj}) @@ -78,3 +75,19 @@ func (q *QuotaObj) table() *Table { func (q *QuotaObj) family() TableFamily { return q.Table.Family } + +func (q *QuotaObj) data() expr.Any { + return &expr.Quota{ + Bytes: q.Bytes, + Consumed: q.Consumed, + Over: q.Over, + } +} + +func (q *QuotaObj) name() string { + return q.Name +} + +func (q *QuotaObj) objType() ObjType { + return ObjTypeQuota +} diff --git a/rule.go b/rule.go index 8bcfda1..0706834 100644 --- a/rule.go +++ b/rule.go @@ -92,7 +92,7 @@ func (cc *Conn) GetRules(t *Table, c *Chain) ([]*Rule, error) { reply, err := receiveAckAware(conn, message.Header.Flags) if err != nil { - return nil, fmt.Errorf("Receive: %v", err) + return nil, fmt.Errorf("receiveAckAware: %v", err) } var rules []*Rule for _, msg := range reply { diff --git a/set.go b/set.go index 192c619..401b9b4 100644 --- a/set.go +++ b/set.go @@ -321,7 +321,7 @@ func (s *SetElement) decode(fam byte) func(b []byte) error { case unix.NFTA_SET_ELEM_EXPIRATION: s.Expires = time.Millisecond * time.Duration(ad.Uint64()) case unix.NFTA_SET_ELEM_EXPR: - elems, err := parseexprfunc.ParseExprBytesFunc(fam, ad, ad.Bytes()) + elems, err := parseexprfunc.ParseExprBytesFunc(fam, ad) if err != nil { return err } @@ -832,7 +832,7 @@ func (cc *Conn) GetSets(t *Table) ([]*Set, error) { reply, err := receiveAckAware(conn, message.Header.Flags) if err != nil { - return nil, fmt.Errorf("Receive: %v", err) + return nil, fmt.Errorf("receiveAckAware: %v", err) } var sets []*Set for _, msg := range reply { @@ -877,11 +877,11 @@ func (cc *Conn) GetSetByName(t *Table, name string) (*Set, error) { reply, err := receiveAckAware(conn, message.Header.Flags) if err != nil { - return nil, fmt.Errorf("Receive: %w", err) + return nil, fmt.Errorf("receiveAckAware: %w", err) } if len(reply) != 1 { - return nil, fmt.Errorf("Receive: expected to receive 1 message but got %d", len(reply)) + return nil, fmt.Errorf("receiveAckAware: expected to receive 1 message but got %d", len(reply)) } rs, err := setsFromMsg(reply[0]) if err != nil { @@ -922,7 +922,7 @@ func (cc *Conn) GetSetElements(s *Set) ([]SetElement, error) { reply, err := receiveAckAware(conn, message.Header.Flags) if err != nil { - return nil, fmt.Errorf("Receive: %v", err) + return nil, fmt.Errorf("receiveAckAware: %v", err) } var elems []SetElement for _, msg := range reply {