Skip to content

Commit

Permalink
Add proper message size check for UDP (#92)
Browse files Browse the repository at this point in the history
Added a path MTU config parameter for InitExportingProcess function.
For UDP transport, based on path MTU parameter, we make a decision on
the message size. This is optional parameter for TCP transport; TCP
supports max socket buffer size (65535).
  • Loading branch information
srikartati authored Dec 3, 2020
1 parent 499c7d9 commit 150f98e
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 42 deletions.
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ BINDIR ?= $(CURDIR)/bin
codegen:
GO111MODULE=on $(GO) get github.com/golang/mock/[email protected]
PATH=$$PATH:$(GOPATH)/bin $(GO) generate ./...
# Make sure the IPFIX registries are up-to-date
GO111MODULE=on $(GO) run pkg/registry/build_registry/build_registry.go

# Make sure the IPFIX registries are up-to-date.
# Hitting 304 error when getting IANA registry csv file multiple times, so
# skipping this check temporarily.
#GO111MODULE=on $(GO) run pkg/registry/build_registry/build_registry.go

.coverage:
mkdir -p ./.coverage
Expand Down
4 changes: 3 additions & 1 deletion pkg/entities/message.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import (
)

const (
MaxTcpSocketMsgSize uint16 = 65535
MaxTcpSocketMsgSize int = 65535
DefaultUDPMsgSize int = 512
MaxUDPMsgSize int = 1500
)

// Message represents IPFIX message.
Expand Down
11 changes: 7 additions & 4 deletions pkg/entities/set.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ const (
)

type Set interface {
GetBuffLen() uint16
GetBuffLen() int
GetBuffer() *bytes.Buffer
GetSetType() ContentType
UpdateLenInHeader()
Expand Down Expand Up @@ -72,8 +72,8 @@ func NewSet(setType ContentType, templateID uint16, isDecoding bool) Set {
return set
}

func (s *set) GetBuffLen() uint16 {
return uint16(s.buffer.Len())
func (s *set) GetBuffLen() int {
return s.buffer.Len()
}

func (s *set) GetBuffer() *bytes.Buffer {
Expand Down Expand Up @@ -107,10 +107,13 @@ func (s *set) AddRecord(elements []*InfoElementWithValue, templateID uint16) err
// write record to set when encoding
if !s.isDecoding {
recordBytes := record.GetBuffer().Bytes()
_, err := s.buffer.Write(recordBytes)
bytesWritten, err := s.buffer.Write(recordBytes)
if err != nil {
return fmt.Errorf("error in writing the buffer to set: %v", err)
}
if bytesWritten != len(recordBytes) {
return fmt.Errorf("bytes written length is not expected")
}
}
return nil
}
Expand Down
16 changes: 8 additions & 8 deletions pkg/entities/set_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ func TestGetBuffer(t *testing.T) {
}

func TestGetBuffLen(t *testing.T) {
assert.Equal(t, uint16(0), NewSet(Template, uint16(256), true).GetBuffLen())
assert.Equal(t, uint16(4), NewSet(Template, uint16(257), false).GetBuffLen())
assert.Equal(t, uint16(0), NewSet(Data, uint16(258), true).GetBuffLen())
assert.Equal(t, uint16(4), NewSet(Data, uint16(259), false).GetBuffLen())
assert.Equal(t, 0, NewSet(Template, uint16(256), true).GetBuffLen())
assert.Equal(t, 4, NewSet(Template, uint16(257), false).GetBuffLen())
assert.Equal(t, 0, NewSet(Data, uint16(258), true).GetBuffLen())
assert.Equal(t, 4, NewSet(Data, uint16(259), false).GetBuffLen())
}

func TestGetRecords(t *testing.T) {
Expand Down Expand Up @@ -105,12 +105,12 @@ func TestSet_UpdateLenInHeader(t *testing.T) {
setForDecoding := NewSet(Template, uint16(256), true)
setForEncoding := NewSet(Template, uint16(257), false)
setForEncoding.AddRecord(elements, 256)
assert.Equal(t, uint16(0), setForDecoding.GetBuffLen())
assert.Equal(t, uint16(16), setForEncoding.GetBuffLen())
assert.Equal(t, 0, setForDecoding.GetBuffLen())
assert.Equal(t, 16, setForEncoding.GetBuffLen())
setForDecoding.UpdateLenInHeader()
setForEncoding.UpdateLenInHeader()
// Nothing should be written in setForDecoding
assert.Equal(t, uint16(0), setForDecoding.GetBuffLen())
assert.Equal(t, 0, setForDecoding.GetBuffLen())
// Check the bytes in the header for set length
assert.Equal(t, setForEncoding.GetBuffLen(), binary.BigEndian.Uint16(setForEncoding.GetBuffer().Bytes()[2:4]))
assert.Equal(t, uint16(setForEncoding.GetBuffLen()), binary.BigEndian.Uint16(setForEncoding.GetBuffer().Bytes()[2:4]))
}
4 changes: 2 additions & 2 deletions pkg/entities/testing/mock_set.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 29 additions & 12 deletions pkg/exporter/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,22 @@ type ExportingProcess struct {
obsDomainID uint32
seqNumber uint32
templateID uint16
pathMTU int
templatesMap map[uint16]templateValue
templateRefCh chan struct{}
mutex sync.Mutex
}

// InitExportingProcess takes in collector address(net.Addr format), obsID(observation ID) and tempRefTimeout
// (template refresh timeout). tempRefTimeout is applicable only for collectors listening over UDP; unit is seconds. For TCP, you can
// pass any value. For UDP, if 0 is passed, consider 1800s as default.
// TODO: Get obsID, tempRefTimeout as args which can be of dynamic size supporting both TCP and UDP.
func InitExportingProcess(collectorAddr net.Addr, obsID uint32, tempRefTimeout uint32) (*ExportingProcess, error) {
// InitExportingProcess takes in collector address(net.Addr format), obsID(observation ID)
// and tempRefTimeout(template refresh timeout). tempRefTimeout is applicable only
// for collectors listening over UDP; unit is seconds. For TCP, you can pass any
// value. For UDP, if 0 is passed, consider 1800s as default.
//
// PathMTU is recommended for UDP transport. If not given a valid value, i.e., either
// 0 or a value more than 1500, we consider a default value of 512B as per RFC7011.
// PathMTU is optional for TCP as we use max socket buffer size of 65535. It can
// be provided as 0.
func InitExportingProcess(collectorAddr net.Addr, obsID uint32, tempRefTimeout uint32, pathMTU int) (*ExportingProcess, error) {
conn, err := net.Dial(collectorAddr.Network(), collectorAddr.String())
if err != nil {
klog.Errorf("Cannot the create the connection to configured ExportingProcess %s: %v", collectorAddr.String(), err)
Expand All @@ -64,12 +70,16 @@ func InitExportingProcess(collectorAddr net.Addr, obsID uint32, tempRefTimeout u
obsDomainID: obsID,
seqNumber: 0,
templateID: startTemplateID,
pathMTU: pathMTU,
templatesMap: make(map[uint16]templateValue),
templateRefCh: make(chan struct{}),
}

// Template refresh logic is only for UDP transport.
// Template refresh logic and pathMTU check is only required for UDP transport.
if collectorAddr.Network() == "udp" {
if expProc.pathMTU == 0 || expProc.pathMTU > entities.MaxUDPMsgSize {
expProc.pathMTU = entities.DefaultUDPMsgSize
}
if tempRefTimeout == 0 {
// Default value
tempRefTimeout = entities.TemplateRefreshTimeOut
Expand Down Expand Up @@ -109,6 +119,7 @@ func (ep *ExportingProcess) SendSet(set entities.Set) (int, error) {
}
}
}

// Update the length in set header before sending the message.
set.UpdateLenInHeader()
bytesSent, err := ep.createAndSendMsg(set)
Expand Down Expand Up @@ -149,19 +160,25 @@ func (ep *ExportingProcess) createAndSendMsg(set entities.Set) (int, error) {
return 0, fmt.Errorf("error when creating header: %v", err)
}

// Check if message is exceeding the limit with new set
msgLen := uint16(msg.GetMsgBufferLen()) + set.GetBuffLen()
// TODO: Change the limit for UDP transport. This is only valid for TCP transport.
if msgLen > entities.MaxTcpSocketMsgSize {
return 0, fmt.Errorf("set size exceeds max socket size")
// Check if message is exceeding the limit after adding the set. Include message
// header length too.
msgLen := msg.GetMsgBufferLen() + set.GetBuffLen()
if ep.connToCollector.LocalAddr().Network() == "tcp" {
if msgLen > entities.MaxTcpSocketMsgSize {
return 0, fmt.Errorf("TCP transport: message size exceeds max socket buffer size")
}
} else {
if msgLen > ep.pathMTU {
return 0, fmt.Errorf("UDP transport: message size exceeds max pathMTU (set as %v)", ep.pathMTU)
}
}

// Set the fields in the message header.
// IPFIX version number is 10.
// https://www.iana.org/assignments/ipfix/ipfix.xhtml#ipfix-version-numbers
msg.SetVersion(10)
msg.SetObsDomainID(ep.obsDomainID)
msg.SetMessageLen(msgLen)
msg.SetMessageLen(uint16(msgLen))
msg.SetExportTime(uint32(time.Now().Unix()))
if set.GetSetType() == entities.Data {
ep.seqNumber = ep.seqNumber + set.GetNumberOfRecords()
Expand Down
37 changes: 27 additions & 10 deletions pkg/exporter/process_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func TestExportingProcess_SendingTemplateRecordToLocalTCPServer(t *testing.T) {
}()

// Create exporter using local server info
exporter, err := InitExportingProcess(listener.Addr(), 1, 0)
exporter, err := InitExportingProcess(listener.Addr(), 1, 0, 0)
if err != nil {
t.Fatalf("Got error when connecting to local server %s: %v", listener.Addr().String(), err)
}
Expand Down Expand Up @@ -129,7 +129,7 @@ func TestExportingProcess_SendingTemplateRecordToLocalUDPServer(t *testing.T) {
}()

// Create exporter using local server info
exporter, err := InitExportingProcess(conn.LocalAddr(), 1, 1)
exporter, err := InitExportingProcess(conn.LocalAddr(), 1, 1, 0)
if err != nil {
t.Fatalf("Got error when connecting to local server %s: %v", conn.LocalAddr().String(), err)
}
Expand Down Expand Up @@ -203,7 +203,7 @@ func TestExportingProcess_SendingDataRecordToLocalTCPServer(t *testing.T) {
}()

// Create exporter using local server info
exporter, err := InitExportingProcess(listener.Addr(), 1, 0)
exporter, err := InitExportingProcess(listener.Addr(), 1, 0, 0)
if err != nil {
t.Fatalf("Got error when connecting to local server %s: %v", listener.Addr().String(), err)
}
Expand Down Expand Up @@ -247,13 +247,22 @@ func TestExportingProcess_SendingDataRecordToLocalTCPServer(t *testing.T) {
dataRecBytes := dataRecBuff.Bytes()

bytesSent, err := exporter.SendSet(dataSet)
if err != nil {
t.Fatalf("Got error when sending record: %v", err)
}
assert.NoError(t, err)
// 28 is the size of the IPFIX message including all headers (20 bytes)
assert.Equal(t, 28, bytesSent)
assert.Equal(t, dataRecBytes, <-buffCh)
assert.Equal(t, uint32(1), exporter.seqNumber)

// Create data set with multiple data records to test invalid message length
// logic for TCP transport.
dataSet = entities.NewSet(entities.Data, templateID, false)
for i := 0; i < 10000; i++ {
err := dataSet.AddRecord(elements, templateID)
assert.NoError(t, err)
}
bytesSent, err = exporter.SendSet(dataSet)
assert.Error(t, err)

exporter.CloseConnToCollector()
}

Expand Down Expand Up @@ -284,7 +293,7 @@ func TestExportingProcess_SendingDataRecordToLocalUDPServer(t *testing.T) {
}()

// Create exporter using local server info
exporter, err := InitExportingProcess(conn.LocalAddr(), 1, 0)
exporter, err := InitExportingProcess(conn.LocalAddr(), 1, 0, 0)
if err != nil {
t.Fatalf("Got error when connecting to local server %s: %v", conn.LocalAddr().String(), err)
}
Expand Down Expand Up @@ -328,12 +337,20 @@ func TestExportingProcess_SendingDataRecordToLocalUDPServer(t *testing.T) {
dataRecBytes := dataRecBuff.Bytes()

bytesSent, err := exporter.SendSet(dataSet)
if err != nil {
t.Fatalf("Got error when sending record: %v", err)
}
assert.NoError(t, err)
// 28 is the size of the IPFIX message including all headers (20 bytes)
assert.Equal(t, 28, bytesSent)
assert.Equal(t, dataRecBytes, <-buffCh)
assert.Equal(t, uint32(1), exporter.seqNumber)

// Create data set with multiple data records to test invalid message length
// logic for UDP transport.
dataSet = entities.NewSet(entities.Data, templateID, false)
for i := 0; i < 100; i++ {
dataSet.AddRecord(elements, templateID)
}
bytesSent, err = exporter.SendSet(dataSet)
assert.Error(t, err)

exporter.CloseConnToCollector()
}
4 changes: 2 additions & 2 deletions pkg/test/collector_intermediate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func waitForCollectorReady(t *testing.T, address net.Addr) {
}
return true, nil
}
if err := wait.Poll(100 * time.Millisecond, 500 * time.Millisecond, checkConn); err != nil {
if err := wait.Poll(100*time.Millisecond, 500*time.Millisecond, checkConn); err != nil {
t.Errorf("Cannot establish connection to %s", address.String())
}
}
Expand All @@ -106,7 +106,7 @@ func waitForAggrationFinished(t *testing.T, ap *intermediate.AggregationProcess)
return false, fmt.Errorf("aggregation process does not process and store data correctly")
}
}
if err := wait.Poll(100 * time.Millisecond, 500 * time.Millisecond, checkConn); err != nil {
if err := wait.Poll(100*time.Millisecond, 500*time.Millisecond, checkConn); err != nil {
t.Error(err)
}
}
2 changes: 1 addition & 1 deletion pkg/test/exporter_collector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func testExporterToCollector(address net.Addr, isMultipleRecord bool, t *testing
go cp.Start()
go func() { // Start exporting process in go routine
waitForCollectorReady(t, address)
export, err := exporter.InitExportingProcess(address, 1, 0)
export, err := exporter.InitExportingProcess(address, 1, 0, 0)
if err != nil {
klog.Fatalf("Got error when connecting to %s", address.String())
}
Expand Down

0 comments on commit 150f98e

Please sign in to comment.