From 0e95fc467cb57b77c120f5f46e13e7fce03c16f5 Mon Sep 17 00:00:00 2001 From: Srikar Tati Date: Wed, 2 Dec 2020 12:10:04 -0800 Subject: [PATCH] Add proper message size check for UDP Added a path MTU config parameter for InitExportingProcess function. For UDP transport, based on path MTU parameter, we make a decision on the message size. This is optional parameter for TCP transport; TCP supports max socket buffer size (65535). --- pkg/entities/message.go | 2 ++ pkg/exporter/process.go | 37 +++++++++++++++++++++-------- pkg/exporter/process_test.go | 36 ++++++++++++++++++++-------- pkg/test/exporter_collector_test.go | 2 +- 4 files changed, 56 insertions(+), 21 deletions(-) diff --git a/pkg/entities/message.go b/pkg/entities/message.go index 9b4fe150..d58a8fae 100644 --- a/pkg/entities/message.go +++ b/pkg/entities/message.go @@ -21,6 +21,8 @@ import ( const ( MaxTcpSocketMsgSize uint16 = 65535 + DefaultUDPMsgSize uint16 = 512 + MaxUDPMsgSize uint16 = 1500 ) // Message represents IPFIX message. diff --git a/pkg/exporter/process.go b/pkg/exporter/process.go index 88fdd7b9..b291a6b7 100644 --- a/pkg/exporter/process.go +++ b/pkg/exporter/process.go @@ -43,16 +43,22 @@ type ExportingProcess struct { obsDomainID uint32 seqNumber uint32 templateID uint16 + pathMTU uint16 templatesMap map[uint16]templateValue templateRefCh chan struct{} mutex sync.Mutex } -// InitExportingProcess takes in collector address(net.Addr format), obsID(observation ID) and tempRefTimeout -// (template refresh timeout). tempRefTimeout is applicable only for collectors listening over UDP; unit is seconds. For TCP, you can -// pass any value. For UDP, if 0 is passed, consider 1800s as default. -// TODO: Get obsID, tempRefTimeout as args which can be of dynamic size supporting both TCP and UDP. -func InitExportingProcess(collectorAddr net.Addr, obsID uint32, tempRefTimeout uint32) (*ExportingProcess, error) { +// InitExportingProcess takes in collector address(net.Addr format), obsID(observation ID) +// and tempRefTimeout(template refresh timeout). tempRefTimeout is applicable only +// for collectors listening over UDP; unit is seconds. For TCP, you can pass any +// value. For UDP, if 0 is passed, consider 1800s as default. +// +// PathMTU is recommended for UDP transport. If not given a valid value, i.e., either +// 0 or a value more than 1500, we consider a default value of 512B as per RFC7011. +// PathMTU is optional for TCP as we use max socket buffer size of 65535. It can +// be provided as 0. +func InitExportingProcess(collectorAddr net.Addr, obsID uint32, tempRefTimeout uint32, pathMTU uint16) (*ExportingProcess, error) { conn, err := net.Dial(collectorAddr.Network(), collectorAddr.String()) if err != nil { klog.Errorf("Cannot the create the connection to configured ExportingProcess %s: %v", collectorAddr.String(), err) @@ -64,12 +70,16 @@ func InitExportingProcess(collectorAddr net.Addr, obsID uint32, tempRefTimeout u obsDomainID: obsID, seqNumber: 0, templateID: startTemplateID, + pathMTU: pathMTU, templatesMap: make(map[uint16]templateValue), templateRefCh: make(chan struct{}), } - // Template refresh logic is only for UDP transport. + // Template refresh logic and pathMTU check is only required for UDP transport. if collectorAddr.Network() == "udp" { + if expProc.pathMTU == 0 || expProc.pathMTU > entities.MaxUDPMsgSize { + expProc.pathMTU = entities.DefaultUDPMsgSize + } if tempRefTimeout == 0 { // Default value tempRefTimeout = entities.TemplateRefreshTimeOut @@ -109,6 +119,7 @@ func (ep *ExportingProcess) SendSet(set entities.Set) (int, error) { } } } + // Update the length in set header before sending the message. set.UpdateLenInHeader() bytesSent, err := ep.createAndSendMsg(set) @@ -149,11 +160,17 @@ func (ep *ExportingProcess) createAndSendMsg(set entities.Set) (int, error) { return 0, fmt.Errorf("error when creating header: %v", err) } - // Check if message is exceeding the limit with new set + // Check if message is exceeding the limit after adding the set. Include message + // header length too. msgLen := uint16(msg.GetMsgBufferLen()) + set.GetBuffLen() - // TODO: Change the limit for UDP transport. This is only valid for TCP transport. - if msgLen > entities.MaxTcpSocketMsgSize { - return 0, fmt.Errorf("set size exceeds max socket size") + if ep.connToCollector.LocalAddr().Network() == "tcp" { + if msgLen > entities.MaxTcpSocketMsgSize { + return 0, fmt.Errorf("TCP transport: message size exceeds max socket buffer size") + } + } else { + if msgLen > ep.pathMTU { + return 0, fmt.Errorf("UDP transport: message size exceeds max pathMTU (set as %v)", ep.pathMTU) + } } // Set the fields in the message header. diff --git a/pkg/exporter/process_test.go b/pkg/exporter/process_test.go index b134791b..ac2ec41e 100644 --- a/pkg/exporter/process_test.go +++ b/pkg/exporter/process_test.go @@ -59,7 +59,7 @@ func TestExportingProcess_SendingTemplateRecordToLocalTCPServer(t *testing.T) { }() // Create exporter using local server info - exporter, err := InitExportingProcess(listener.Addr(), 1, 0) + exporter, err := InitExportingProcess(listener.Addr(), 1, 0, 0) if err != nil { t.Fatalf("Got error when connecting to local server %s: %v", listener.Addr().String(), err) } @@ -129,7 +129,7 @@ func TestExportingProcess_SendingTemplateRecordToLocalUDPServer(t *testing.T) { }() // Create exporter using local server info - exporter, err := InitExportingProcess(conn.LocalAddr(), 1, 1) + exporter, err := InitExportingProcess(conn.LocalAddr(), 1, 1, 0) if err != nil { t.Fatalf("Got error when connecting to local server %s: %v", conn.LocalAddr().String(), err) } @@ -203,7 +203,7 @@ func TestExportingProcess_SendingDataRecordToLocalTCPServer(t *testing.T) { }() // Create exporter using local server info - exporter, err := InitExportingProcess(listener.Addr(), 1, 0) + exporter, err := InitExportingProcess(listener.Addr(), 1, 0, 0) if err != nil { t.Fatalf("Got error when connecting to local server %s: %v", listener.Addr().String(), err) } @@ -247,13 +247,21 @@ func TestExportingProcess_SendingDataRecordToLocalTCPServer(t *testing.T) { dataRecBytes := dataRecBuff.Bytes() bytesSent, err := exporter.SendSet(dataSet) - if err != nil { - t.Fatalf("Got error when sending record: %v", err) - } + assert.NoError(t, err) // 28 is the size of the IPFIX message including all headers (20 bytes) assert.Equal(t, 28, bytesSent) assert.Equal(t, dataRecBytes, <-buffCh) assert.Equal(t, uint32(1), exporter.seqNumber) + + // Create data set with multiple data records to test invalid message length + // logic for TCP transport. + dataSet = entities.NewSet(entities.Data, templateID, false) + for i := 0; i < 10000; i++ { + dataSet.AddRecord(elements, templateID) + } + bytesSent, err = exporter.SendSet(dataSet) + assert.Error(t, err) + exporter.CloseConnToCollector() } @@ -284,7 +292,7 @@ func TestExportingProcess_SendingDataRecordToLocalUDPServer(t *testing.T) { }() // Create exporter using local server info - exporter, err := InitExportingProcess(conn.LocalAddr(), 1, 0) + exporter, err := InitExportingProcess(conn.LocalAddr(), 1, 0, 0) if err != nil { t.Fatalf("Got error when connecting to local server %s: %v", conn.LocalAddr().String(), err) } @@ -328,12 +336,20 @@ func TestExportingProcess_SendingDataRecordToLocalUDPServer(t *testing.T) { dataRecBytes := dataRecBuff.Bytes() bytesSent, err := exporter.SendSet(dataSet) - if err != nil { - t.Fatalf("Got error when sending record: %v", err) - } + assert.NoError(t, err) // 28 is the size of the IPFIX message including all headers (20 bytes) assert.Equal(t, 28, bytesSent) assert.Equal(t, dataRecBytes, <-buffCh) assert.Equal(t, uint32(1), exporter.seqNumber) + + // Create data set with multiple data records to test invalid message length + // logic for + dataSet = entities.NewSet(entities.Data, templateID, false) + for i := 0; i < 100; i++ { + dataSet.AddRecord(elements, templateID) + } + bytesSent, err = exporter.SendSet(dataSet) + assert.Error(t, err) + exporter.CloseConnToCollector() } diff --git a/pkg/test/exporter_collector_test.go b/pkg/test/exporter_collector_test.go index e9fe80d0..7bf8da16 100644 --- a/pkg/test/exporter_collector_test.go +++ b/pkg/test/exporter_collector_test.go @@ -71,7 +71,7 @@ func testExporterToCollector(address net.Addr, isMultipleRecord bool, t *testing go func() { // Start exporting process in go routine time.Sleep(2 * time.Second) // wait for collector to be ready - export, err := exporter.InitExportingProcess(address, 1, 0) + export, err := exporter.InitExportingProcess(address, 1, 0, 0) if err != nil { klog.Fatalf("Got error when connecting to %s", address.String()) }