From 150f98e6694f160b069cdd51fd6f2217f3e65114 Mon Sep 17 00:00:00 2001 From: srikartati Date: Thu, 3 Dec 2020 23:10:28 +0530 Subject: [PATCH] Add proper message size check for UDP (#92) Added a path MTU config parameter for InitExportingProcess function. For UDP transport, based on path MTU parameter, we make a decision on the message size. This is optional parameter for TCP transport; TCP supports max socket buffer size (65535). --- Makefile | 7 +++-- pkg/entities/message.go | 4 ++- pkg/entities/set.go | 11 ++++--- pkg/entities/set_test.go | 16 +++++----- pkg/entities/testing/mock_set.go | 4 +-- pkg/exporter/process.go | 41 +++++++++++++++++-------- pkg/exporter/process_test.go | 37 ++++++++++++++++------ pkg/test/collector_intermediate_test.go | 4 +-- pkg/test/exporter_collector_test.go | 2 +- 9 files changed, 84 insertions(+), 42 deletions(-) diff --git a/Makefile b/Makefile index 46c45582..35c9b875 100644 --- a/Makefile +++ b/Makefile @@ -5,8 +5,11 @@ BINDIR ?= $(CURDIR)/bin codegen: GO111MODULE=on $(GO) get github.com/golang/mock/mockgen@v1.4.3 PATH=$$PATH:$(GOPATH)/bin $(GO) generate ./... - # Make sure the IPFIX registries are up-to-date - GO111MODULE=on $(GO) run pkg/registry/build_registry/build_registry.go + + # Make sure the IPFIX registries are up-to-date. + # Hitting 304 error when getting IANA registry csv file multiple times, so + # skipping this check temporarily. + #GO111MODULE=on $(GO) run pkg/registry/build_registry/build_registry.go .coverage: mkdir -p ./.coverage diff --git a/pkg/entities/message.go b/pkg/entities/message.go index 9b4fe150..c77c249d 100644 --- a/pkg/entities/message.go +++ b/pkg/entities/message.go @@ -20,7 +20,9 @@ import ( ) const ( - MaxTcpSocketMsgSize uint16 = 65535 + MaxTcpSocketMsgSize int = 65535 + DefaultUDPMsgSize int = 512 + MaxUDPMsgSize int = 1500 ) // Message represents IPFIX message. diff --git a/pkg/entities/set.go b/pkg/entities/set.go index 14e4649b..0575dc9b 100644 --- a/pkg/entities/set.go +++ b/pkg/entities/set.go @@ -41,7 +41,7 @@ const ( ) type Set interface { - GetBuffLen() uint16 + GetBuffLen() int GetBuffer() *bytes.Buffer GetSetType() ContentType UpdateLenInHeader() @@ -72,8 +72,8 @@ func NewSet(setType ContentType, templateID uint16, isDecoding bool) Set { return set } -func (s *set) GetBuffLen() uint16 { - return uint16(s.buffer.Len()) +func (s *set) GetBuffLen() int { + return s.buffer.Len() } func (s *set) GetBuffer() *bytes.Buffer { @@ -107,10 +107,13 @@ func (s *set) AddRecord(elements []*InfoElementWithValue, templateID uint16) err // write record to set when encoding if !s.isDecoding { recordBytes := record.GetBuffer().Bytes() - _, err := s.buffer.Write(recordBytes) + bytesWritten, err := s.buffer.Write(recordBytes) if err != nil { return fmt.Errorf("error in writing the buffer to set: %v", err) } + if bytesWritten != len(recordBytes) { + return fmt.Errorf("bytes written length is not expected") + } } return nil } diff --git a/pkg/entities/set_test.go b/pkg/entities/set_test.go index 8a2d616b..996f1421 100644 --- a/pkg/entities/set_test.go +++ b/pkg/entities/set_test.go @@ -71,10 +71,10 @@ func TestGetBuffer(t *testing.T) { } func TestGetBuffLen(t *testing.T) { - assert.Equal(t, uint16(0), NewSet(Template, uint16(256), true).GetBuffLen()) - assert.Equal(t, uint16(4), NewSet(Template, uint16(257), false).GetBuffLen()) - assert.Equal(t, uint16(0), NewSet(Data, uint16(258), true).GetBuffLen()) - assert.Equal(t, uint16(4), NewSet(Data, uint16(259), false).GetBuffLen()) + assert.Equal(t, 0, NewSet(Template, uint16(256), true).GetBuffLen()) + assert.Equal(t, 4, NewSet(Template, uint16(257), false).GetBuffLen()) + assert.Equal(t, 0, NewSet(Data, uint16(258), true).GetBuffLen()) + assert.Equal(t, 4, NewSet(Data, uint16(259), false).GetBuffLen()) } func TestGetRecords(t *testing.T) { @@ -105,12 +105,12 @@ func TestSet_UpdateLenInHeader(t *testing.T) { setForDecoding := NewSet(Template, uint16(256), true) setForEncoding := NewSet(Template, uint16(257), false) setForEncoding.AddRecord(elements, 256) - assert.Equal(t, uint16(0), setForDecoding.GetBuffLen()) - assert.Equal(t, uint16(16), setForEncoding.GetBuffLen()) + assert.Equal(t, 0, setForDecoding.GetBuffLen()) + assert.Equal(t, 16, setForEncoding.GetBuffLen()) setForDecoding.UpdateLenInHeader() setForEncoding.UpdateLenInHeader() // Nothing should be written in setForDecoding - assert.Equal(t, uint16(0), setForDecoding.GetBuffLen()) + assert.Equal(t, 0, setForDecoding.GetBuffLen()) // Check the bytes in the header for set length - assert.Equal(t, setForEncoding.GetBuffLen(), binary.BigEndian.Uint16(setForEncoding.GetBuffer().Bytes()[2:4])) + assert.Equal(t, uint16(setForEncoding.GetBuffLen()), binary.BigEndian.Uint16(setForEncoding.GetBuffer().Bytes()[2:4])) } diff --git a/pkg/entities/testing/mock_set.go b/pkg/entities/testing/mock_set.go index e5262ee4..5e30d905 100644 --- a/pkg/entities/testing/mock_set.go +++ b/pkg/entities/testing/mock_set.go @@ -63,10 +63,10 @@ func (mr *MockSetMockRecorder) AddRecord(arg0, arg1 interface{}) *gomock.Call { } // GetBuffLen mocks base method -func (m *MockSet) GetBuffLen() uint16 { +func (m *MockSet) GetBuffLen() int { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "GetBuffLen") - ret0, _ := ret[0].(uint16) + ret0, _ := ret[0].(int) return ret0 } diff --git a/pkg/exporter/process.go b/pkg/exporter/process.go index 88fdd7b9..33bd1d05 100644 --- a/pkg/exporter/process.go +++ b/pkg/exporter/process.go @@ -43,16 +43,22 @@ type ExportingProcess struct { obsDomainID uint32 seqNumber uint32 templateID uint16 + pathMTU int templatesMap map[uint16]templateValue templateRefCh chan struct{} mutex sync.Mutex } -// InitExportingProcess takes in collector address(net.Addr format), obsID(observation ID) and tempRefTimeout -// (template refresh timeout). tempRefTimeout is applicable only for collectors listening over UDP; unit is seconds. For TCP, you can -// pass any value. For UDP, if 0 is passed, consider 1800s as default. -// TODO: Get obsID, tempRefTimeout as args which can be of dynamic size supporting both TCP and UDP. -func InitExportingProcess(collectorAddr net.Addr, obsID uint32, tempRefTimeout uint32) (*ExportingProcess, error) { +// InitExportingProcess takes in collector address(net.Addr format), obsID(observation ID) +// and tempRefTimeout(template refresh timeout). tempRefTimeout is applicable only +// for collectors listening over UDP; unit is seconds. For TCP, you can pass any +// value. For UDP, if 0 is passed, consider 1800s as default. +// +// PathMTU is recommended for UDP transport. If not given a valid value, i.e., either +// 0 or a value more than 1500, we consider a default value of 512B as per RFC7011. +// PathMTU is optional for TCP as we use max socket buffer size of 65535. It can +// be provided as 0. +func InitExportingProcess(collectorAddr net.Addr, obsID uint32, tempRefTimeout uint32, pathMTU int) (*ExportingProcess, error) { conn, err := net.Dial(collectorAddr.Network(), collectorAddr.String()) if err != nil { klog.Errorf("Cannot the create the connection to configured ExportingProcess %s: %v", collectorAddr.String(), err) @@ -64,12 +70,16 @@ func InitExportingProcess(collectorAddr net.Addr, obsID uint32, tempRefTimeout u obsDomainID: obsID, seqNumber: 0, templateID: startTemplateID, + pathMTU: pathMTU, templatesMap: make(map[uint16]templateValue), templateRefCh: make(chan struct{}), } - // Template refresh logic is only for UDP transport. + // Template refresh logic and pathMTU check is only required for UDP transport. if collectorAddr.Network() == "udp" { + if expProc.pathMTU == 0 || expProc.pathMTU > entities.MaxUDPMsgSize { + expProc.pathMTU = entities.DefaultUDPMsgSize + } if tempRefTimeout == 0 { // Default value tempRefTimeout = entities.TemplateRefreshTimeOut @@ -109,6 +119,7 @@ func (ep *ExportingProcess) SendSet(set entities.Set) (int, error) { } } } + // Update the length in set header before sending the message. set.UpdateLenInHeader() bytesSent, err := ep.createAndSendMsg(set) @@ -149,11 +160,17 @@ func (ep *ExportingProcess) createAndSendMsg(set entities.Set) (int, error) { return 0, fmt.Errorf("error when creating header: %v", err) } - // Check if message is exceeding the limit with new set - msgLen := uint16(msg.GetMsgBufferLen()) + set.GetBuffLen() - // TODO: Change the limit for UDP transport. This is only valid for TCP transport. - if msgLen > entities.MaxTcpSocketMsgSize { - return 0, fmt.Errorf("set size exceeds max socket size") + // Check if message is exceeding the limit after adding the set. Include message + // header length too. + msgLen := msg.GetMsgBufferLen() + set.GetBuffLen() + if ep.connToCollector.LocalAddr().Network() == "tcp" { + if msgLen > entities.MaxTcpSocketMsgSize { + return 0, fmt.Errorf("TCP transport: message size exceeds max socket buffer size") + } + } else { + if msgLen > ep.pathMTU { + return 0, fmt.Errorf("UDP transport: message size exceeds max pathMTU (set as %v)", ep.pathMTU) + } } // Set the fields in the message header. @@ -161,7 +178,7 @@ func (ep *ExportingProcess) createAndSendMsg(set entities.Set) (int, error) { // https://www.iana.org/assignments/ipfix/ipfix.xhtml#ipfix-version-numbers msg.SetVersion(10) msg.SetObsDomainID(ep.obsDomainID) - msg.SetMessageLen(msgLen) + msg.SetMessageLen(uint16(msgLen)) msg.SetExportTime(uint32(time.Now().Unix())) if set.GetSetType() == entities.Data { ep.seqNumber = ep.seqNumber + set.GetNumberOfRecords() diff --git a/pkg/exporter/process_test.go b/pkg/exporter/process_test.go index 78a93c0f..68dde9f9 100644 --- a/pkg/exporter/process_test.go +++ b/pkg/exporter/process_test.go @@ -59,7 +59,7 @@ func TestExportingProcess_SendingTemplateRecordToLocalTCPServer(t *testing.T) { }() // Create exporter using local server info - exporter, err := InitExportingProcess(listener.Addr(), 1, 0) + exporter, err := InitExportingProcess(listener.Addr(), 1, 0, 0) if err != nil { t.Fatalf("Got error when connecting to local server %s: %v", listener.Addr().String(), err) } @@ -129,7 +129,7 @@ func TestExportingProcess_SendingTemplateRecordToLocalUDPServer(t *testing.T) { }() // Create exporter using local server info - exporter, err := InitExportingProcess(conn.LocalAddr(), 1, 1) + exporter, err := InitExportingProcess(conn.LocalAddr(), 1, 1, 0) if err != nil { t.Fatalf("Got error when connecting to local server %s: %v", conn.LocalAddr().String(), err) } @@ -203,7 +203,7 @@ func TestExportingProcess_SendingDataRecordToLocalTCPServer(t *testing.T) { }() // Create exporter using local server info - exporter, err := InitExportingProcess(listener.Addr(), 1, 0) + exporter, err := InitExportingProcess(listener.Addr(), 1, 0, 0) if err != nil { t.Fatalf("Got error when connecting to local server %s: %v", listener.Addr().String(), err) } @@ -247,13 +247,22 @@ func TestExportingProcess_SendingDataRecordToLocalTCPServer(t *testing.T) { dataRecBytes := dataRecBuff.Bytes() bytesSent, err := exporter.SendSet(dataSet) - if err != nil { - t.Fatalf("Got error when sending record: %v", err) - } + assert.NoError(t, err) // 28 is the size of the IPFIX message including all headers (20 bytes) assert.Equal(t, 28, bytesSent) assert.Equal(t, dataRecBytes, <-buffCh) assert.Equal(t, uint32(1), exporter.seqNumber) + + // Create data set with multiple data records to test invalid message length + // logic for TCP transport. + dataSet = entities.NewSet(entities.Data, templateID, false) + for i := 0; i < 10000; i++ { + err := dataSet.AddRecord(elements, templateID) + assert.NoError(t, err) + } + bytesSent, err = exporter.SendSet(dataSet) + assert.Error(t, err) + exporter.CloseConnToCollector() } @@ -284,7 +293,7 @@ func TestExportingProcess_SendingDataRecordToLocalUDPServer(t *testing.T) { }() // Create exporter using local server info - exporter, err := InitExportingProcess(conn.LocalAddr(), 1, 0) + exporter, err := InitExportingProcess(conn.LocalAddr(), 1, 0, 0) if err != nil { t.Fatalf("Got error when connecting to local server %s: %v", conn.LocalAddr().String(), err) } @@ -328,12 +337,20 @@ func TestExportingProcess_SendingDataRecordToLocalUDPServer(t *testing.T) { dataRecBytes := dataRecBuff.Bytes() bytesSent, err := exporter.SendSet(dataSet) - if err != nil { - t.Fatalf("Got error when sending record: %v", err) - } + assert.NoError(t, err) // 28 is the size of the IPFIX message including all headers (20 bytes) assert.Equal(t, 28, bytesSent) assert.Equal(t, dataRecBytes, <-buffCh) assert.Equal(t, uint32(1), exporter.seqNumber) + + // Create data set with multiple data records to test invalid message length + // logic for UDP transport. + dataSet = entities.NewSet(entities.Data, templateID, false) + for i := 0; i < 100; i++ { + dataSet.AddRecord(elements, templateID) + } + bytesSent, err = exporter.SendSet(dataSet) + assert.Error(t, err) + exporter.CloseConnToCollector() } diff --git a/pkg/test/collector_intermediate_test.go b/pkg/test/collector_intermediate_test.go index f2118146..812f9d03 100644 --- a/pkg/test/collector_intermediate_test.go +++ b/pkg/test/collector_intermediate_test.go @@ -92,7 +92,7 @@ func waitForCollectorReady(t *testing.T, address net.Addr) { } return true, nil } - if err := wait.Poll(100 * time.Millisecond, 500 * time.Millisecond, checkConn); err != nil { + if err := wait.Poll(100*time.Millisecond, 500*time.Millisecond, checkConn); err != nil { t.Errorf("Cannot establish connection to %s", address.String()) } } @@ -106,7 +106,7 @@ func waitForAggrationFinished(t *testing.T, ap *intermediate.AggregationProcess) return false, fmt.Errorf("aggregation process does not process and store data correctly") } } - if err := wait.Poll(100 * time.Millisecond, 500 * time.Millisecond, checkConn); err != nil { + if err := wait.Poll(100*time.Millisecond, 500*time.Millisecond, checkConn); err != nil { t.Error(err) } } diff --git a/pkg/test/exporter_collector_test.go b/pkg/test/exporter_collector_test.go index d1fe317d..f5cd5e3e 100644 --- a/pkg/test/exporter_collector_test.go +++ b/pkg/test/exporter_collector_test.go @@ -74,7 +74,7 @@ func testExporterToCollector(address net.Addr, isMultipleRecord bool, t *testing go cp.Start() go func() { // Start exporting process in go routine waitForCollectorReady(t, address) - export, err := exporter.InitExportingProcess(address, 1, 0) + export, err := exporter.InitExportingProcess(address, 1, 0, 0) if err != nil { klog.Fatalf("Got error when connecting to %s", address.String()) }