From b93c4d4abf2fb8a85144732ce487294ca01f871e Mon Sep 17 00:00:00 2001 From: Adam Kiss Date: Sun, 29 Nov 2020 12:31:43 +0100 Subject: [PATCH] review fixes, rename receiver -> generator, sender -> responder, linter fixes --- go.mod | 2 +- go.sum | 4 +- {test => internal/test}/stream.go | 14 +++ {test => internal/test}/stream_test.go | 0 pkg/nack/errors.go | 7 ++ .../nack/generator_interceptor.go | 88 ++++++++----------- .../nack/generator_interceptor_test.go | 36 ++++---- pkg/nack/generator_option.go | 40 +++++++++ {nack => pkg/nack}/receive_log.go | 77 ++++++++-------- {nack => pkg/nack}/receive_log_test.go | 10 +-- .../nack/responder_interceptor.go | 51 ++++++----- .../nack/responder_interceptor_test.go | 35 +++++--- pkg/nack/responder_option.go | 21 +++++ {nack => pkg/nack}/send_buffer.go | 42 ++++----- {nack => pkg/nack}/send_buffer_test.go | 8 +- 15 files changed, 255 insertions(+), 180 deletions(-) rename {test => internal/test}/stream.go (75%) rename {test => internal/test}/stream_test.go (100%) create mode 100644 pkg/nack/errors.go rename nack/receiver_interceptor.go => pkg/nack/generator_interceptor.go (53%) rename nack/receiver_interceptor_test.go => pkg/nack/generator_interceptor_test.go (65%) create mode 100644 pkg/nack/generator_option.go rename {nack => pkg/nack}/receive_log.go (58%) rename {nack => pkg/nack}/receive_log_test.go (95%) rename nack/sender_interceptor.go => pkg/nack/responder_interceptor.go (63%) rename nack/sender_interceptor_test.go => pkg/nack/responder_interceptor_test.go (61%) create mode 100644 pkg/nack/responder_option.go rename {nack => pkg/nack}/send_buffer.go (51%) rename {nack => pkg/nack}/send_buffer_test.go (89%) diff --git a/go.mod b/go.mod index 289bb9ba..4091812e 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.15 require ( github.com/pion/logging v0.2.2 - github.com/pion/rtcp v1.2.4 + github.com/pion/rtcp v1.2.6 github.com/pion/rtp v1.6.1 github.com/stretchr/testify v1.6.1 ) diff --git a/go.sum b/go.sum index e9da74ab..60ea4354 100644 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= -github.com/pion/rtcp v1.2.4 h1:NT3H5LkUGgaEapvp0HGik+a+CpflRF7KTD7H+o7OWIM= -github.com/pion/rtcp v1.2.4/go.mod h1:52rMNPWFsjr39z9B9MhnkqhPLoeHTv1aN63o/42bWE0= +github.com/pion/rtcp v1.2.6 h1:1zvwBbyd0TeEuuWftrd/4d++m+/kZSeiguxU61LFWpo= +github.com/pion/rtcp v1.2.6/go.mod h1:52rMNPWFsjr39z9B9MhnkqhPLoeHTv1aN63o/42bWE0= github.com/pion/rtp v1.6.1 h1:2Y2elcVBrahYnHKN2X7rMHX/r1R4TEBMP1LaVu/wNhk= github.com/pion/rtp v1.6.1/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/test/stream.go b/internal/test/stream.go similarity index 75% rename from test/stream.go rename to internal/test/stream.go index 3e233ef3..f5a073cb 100644 --- a/test/stream.go +++ b/internal/test/stream.go @@ -1,3 +1,4 @@ +// Package test provides helpers for testing interceptors package test import ( @@ -8,6 +9,7 @@ import ( "github.com/pion/rtp" ) +// Stream is a helper struct for testing interceptors. type Stream struct { interceptor interceptor.Interceptor @@ -26,16 +28,19 @@ type Stream struct { rtpInModified chan RTPWithError } +// RTPWithError is used to send an rtp packet or an error on a channel type RTPWithError struct { Packet *rtp.Packet Err error } +// RTCPWithError is used to send a batch of rtcp packets or an error on a channel type RTCPWithError struct { Packets []rtcp.Packet Err error } +// NewStream creates a new Stream func NewStream(info *interceptor.StreamInfo, i interceptor.Interceptor) *Stream { s := &Stream{ interceptor: i, @@ -107,40 +112,49 @@ func NewStream(info *interceptor.StreamInfo, i interceptor.Interceptor) *Stream return s } +// WriteRTCP writes a batch of rtcp packet to the stream, using the interceptor func (s *Stream) WriteRTCP(pkts []rtcp.Packet) error { _, err := s.rtcpWriter.Write(pkts, interceptor.Attributes{}) return err } +// WriteRTP writes an rtp packet to the stream, using the interceptor func (s *Stream) WriteRTP(p *rtp.Packet) error { _, err := s.rtpWriter.Write(p, interceptor.Attributes{}) return err } +// ReceiveRTCP schedules a new rtcp batch, so it can be read be the stream func (s *Stream) ReceiveRTCP(pkts []rtcp.Packet) { s.rtcpIn <- pkts } +// ReceiveRTP schedules a rtp packet, so it can be read be the stream func (s *Stream) ReceiveRTP(packet *rtp.Packet) { s.rtpIn <- packet } +// WrittenRTCP returns a channel containing the rtcp batches written, modified by the interceptor func (s *Stream) WrittenRTCP() chan []rtcp.Packet { return s.rtcpOutModified } +// WrittenRTP returns a channel containing rtp packets written, modified by the interceptor func (s *Stream) WrittenRTP() chan *rtp.Packet { return s.rtpOutModified } +// ReadRTCP returns a channel containing the rtcp batched read, modified by the interceptor func (s *Stream) ReadRTCP() chan RTCPWithError { return s.rtcpInModified } +// ReadRTP returns a channel containing the rtp packets read, modified by the interceptor func (s *Stream) ReadRTP() chan RTPWithError { return s.rtpInModified } +// Close closes the stream and the underlying interceptor func (s *Stream) Close() error { close(s.rtcpIn) close(s.rtpIn) diff --git a/test/stream_test.go b/internal/test/stream_test.go similarity index 100% rename from test/stream_test.go rename to internal/test/stream_test.go diff --git a/pkg/nack/errors.go b/pkg/nack/errors.go new file mode 100644 index 00000000..588d3149 --- /dev/null +++ b/pkg/nack/errors.go @@ -0,0 +1,7 @@ +// Package nack provides interceptors to implement sending and receiving negative acknowledgements +package nack + +import "errors" + +// ErrInvalidSize is returned by newReceiveLog/newSendBuffer, when an incorrect buffer size is supplied. +var ErrInvalidSize = errors.New("invalid buffer size") diff --git a/nack/receiver_interceptor.go b/pkg/nack/generator_interceptor.go similarity index 53% rename from nack/receiver_interceptor.go rename to pkg/nack/generator_interceptor.go index bd95e314..abe528f5 100644 --- a/nack/receiver_interceptor.go +++ b/pkg/nack/generator_interceptor.go @@ -11,8 +11,8 @@ import ( "github.com/pion/rtp" ) -// ReceiverInterceptor interceptor generates nack messages. -type ReceiverInterceptor struct { +// GeneratorInterceptor interceptor generates nack feedback messages. +type GeneratorInterceptor struct { interceptor.NoOp size uint16 skipLastN uint16 @@ -24,38 +24,42 @@ type ReceiverInterceptor struct { log logging.LeveledLogger } -// NewReceiverInterceptor returns a new ReceiverInterceptor interceptor -func NewReceiverInterceptor(size uint16, skipLastN uint16, interval time.Duration, log logging.LeveledLogger) (*ReceiverInterceptor, error) { - _, err := NewReceiveLog(size) - if err != nil { - return nil, err - } - - return &ReceiverInterceptor{ +// NewGeneratorInterceptor returns a new GeneratorInterceptor interceptor +func NewGeneratorInterceptor(opts ...GeneratorOption) (*GeneratorInterceptor, error) { + r := &GeneratorInterceptor{ NoOp: interceptor.NoOp{}, - size: size, - skipLastN: skipLastN, - interval: interval, + size: 8192, + skipLastN: 0, + interval: time.Millisecond * 100, receiveLogs: &sync.Map{}, close: make(chan struct{}), - log: log, - }, nil + log: logging.NewDefaultLoggerFactory().NewLogger("nack_generator"), + } + + for _, opt := range opts { + opt(r) + } + + if _, err := newReceiveLog(r.size); err != nil { + return nil, err + } + + return r, nil } // BindRTCPWriter lets you modify any outgoing RTCP packets. It is called once per PeerConnection. The returned method // will be called once per packet batch. -func (n *ReceiverInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter { +func (n *GeneratorInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter { n.m.Lock() + defer n.m.Unlock() select { case <-n.close: // already closed - n.m.Unlock() return writer default: } n.wg.Add(1) - n.m.Unlock() go n.loop(writer) @@ -64,7 +68,7 @@ func (n *ReceiverInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) inte // BindRemoteStream lets you modify any incoming RTP packets. It is called once for per RemoteStream. The returned method // will be called once per rtp packet. -func (n *ReceiverInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { +func (n *GeneratorInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { hasNack := false for _, fb := range info.RTCPFeedback { if fb.Type == "nack" && fb.Parameter == "" { @@ -76,8 +80,8 @@ func (n *ReceiverInterceptor) BindRemoteStream(info *interceptor.StreamInfo, rea return reader } - // error is already checked in NewReceiverInterceptor - receiveLog, _ := NewReceiveLog(n.size) + // error is already checked in NewGeneratorInterceptor + receiveLog, _ := newReceiveLog(n.size) n.receiveLogs.Store(info.SSRC, receiveLog) return interceptor.RTPReaderFunc(func() (*rtp.Packet, interceptor.Attributes, error) { @@ -86,18 +90,19 @@ func (n *ReceiverInterceptor) BindRemoteStream(info *interceptor.StreamInfo, rea return nil, nil, err } - receiveLog.Add(p.SequenceNumber) + receiveLog.add(p.SequenceNumber) return p, attr, nil }) } // UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. -func (n *ReceiverInterceptor) UnbindLocalStream(info *interceptor.StreamInfo) { +func (n *GeneratorInterceptor) UnbindLocalStream(info *interceptor.StreamInfo) { n.receiveLogs.Delete(info.SSRC) } -func (n *ReceiverInterceptor) Close() error { +// Close closes the interceptor +func (n *GeneratorInterceptor) Close() error { defer n.wg.Wait() n.m.Lock() defer n.m.Unlock() @@ -114,10 +119,10 @@ func (n *ReceiverInterceptor) Close() error { return nil } -func (n *ReceiverInterceptor) loop(rtcpWriter interceptor.RTCPWriter) { +func (n *GeneratorInterceptor) loop(rtcpWriter interceptor.RTCPWriter) { defer n.wg.Done() - senderSSRC := rand.Uint32() + senderSSRC := rand.Uint32() // #nosec ticker := time.NewTicker(n.interval) for { @@ -125,9 +130,9 @@ func (n *ReceiverInterceptor) loop(rtcpWriter interceptor.RTCPWriter) { case <-ticker.C: n.receiveLogs.Range(func(key, value interface{}) bool { ssrc := key.(uint32) - receiveLog := value.(*ReceiveLog) + receiveLog := value.(*receiveLog) - missing := receiveLog.MissingSeqNumbers(n.skipLastN) + missing := receiveLog.missingSeqNumbers(n.skipLastN) if len(missing) == 0 { return true } @@ -135,11 +140,10 @@ func (n *ReceiverInterceptor) loop(rtcpWriter interceptor.RTCPWriter) { nack := &rtcp.TransportLayerNack{ SenderSSRC: senderSSRC, MediaSSRC: ssrc, - Nacks: nackPairs(missing), + Nacks: rtcp.NackPairsFromSequenceNumbers(missing), } - _, err := rtcpWriter.Write([]rtcp.Packet{nack}, interceptor.Attributes{}) - if err != nil { + if _, err := rtcpWriter.Write([]rtcp.Packet{nack}, interceptor.Attributes{}); err != nil { n.log.Warnf("failed sending nack: %+v", err) } @@ -151,25 +155,3 @@ func (n *ReceiverInterceptor) loop(rtcpWriter interceptor.RTCPWriter) { } } } - -func nackPairs(seqNums []uint16) []rtcp.NackPair { - // TODO: I think this shoud be moved to rtcp package - pairs := make([]rtcp.NackPair, 0) - startSeq := seqNums[0] - nackPair := &rtcp.NackPair{PacketID: startSeq} - for i := 1; i < len(seqNums); i++ { - m := seqNums[i] - - if m-nackPair.PacketID > 16 { - pairs = append(pairs, *nackPair) - nackPair = &rtcp.NackPair{PacketID: m} - continue - } - - nackPair.LostPackets |= 1 << (m - nackPair.PacketID - 1) - } - - pairs = append(pairs, *nackPair) - - return pairs -} diff --git a/nack/receiver_interceptor_test.go b/pkg/nack/generator_interceptor_test.go similarity index 65% rename from nack/receiver_interceptor_test.go rename to pkg/nack/generator_interceptor_test.go index a892de7b..e92ee737 100644 --- a/nack/receiver_interceptor_test.go +++ b/pkg/nack/generator_interceptor_test.go @@ -5,16 +5,21 @@ import ( "time" "github.com/pion/interceptor" - "github.com/pion/interceptor/test" + "github.com/pion/interceptor/internal/test" "github.com/pion/logging" "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/stretchr/testify/assert" ) -func TestReceiverInterceptor(t *testing.T) { +func TestGeneratorInterceptor(t *testing.T) { const interval = time.Millisecond * 10 - i, err := NewReceiverInterceptor(64, 2, interval, logging.NewDefaultLoggerFactory().NewLogger("test")) + i, err := NewGeneratorInterceptor( + GeneratorSize(64), + GeneratorSkipLastN(2), + GeneratorInterval(interval), + GeneratorLog(logging.NewDefaultLoggerFactory().NewLogger("test")), + ) if err != nil { t.Fatal(err) } @@ -24,10 +29,7 @@ func TestReceiverInterceptor(t *testing.T) { RTCPFeedback: []interceptor.RTCPFeedback{{Type: "nack"}}, }, i) defer func() { - err := stream.Close() - if err != nil { - t.Errorf("error closing stream: %v", err) - } + assert.NoError(t, stream.Close()) }() for _, seqNum := range []uint16{10, 11, 12, 14, 16, 18} { @@ -35,9 +37,7 @@ func TestReceiverInterceptor(t *testing.T) { select { case r := <-stream.ReadRTP(): - if r.Err != nil { - t.Fatal(r.Err) - } + assert.NoError(t, r.Err) assert.Equal(t, seqNum, r.Packet.SequenceNumber) case <-time.After(10 * time.Millisecond): t.Fatal("receiver rtp packet not found") @@ -54,17 +54,19 @@ func TestReceiverInterceptor(t *testing.T) { select { case pkts := <-stream.WrittenRTCP(): - if len(pkts) != 1 { - t.Fatalf("single packet rtcp batch expected, found: %v", pkts) - } + assert.Equal(t, len(pkts), 1, "single packet RTCP Compound Packet expected") + p, ok := pkts[0].(*rtcp.TransportLayerNack) - if !ok { - t.Fatalf("TransportLayerNack rtcp packet expected, found: %T", pkts[0]) - } + assert.True(t, ok, "TransportLayerNack rtcp packet expected, found: %T", pkts[0]) assert.Equal(t, uint16(13), p.Nacks[0].PacketID) - assert.Equal(t, rtcp.PacketBitmap(0b10), p.Nacks[0].LostPackets) // we want packets: 13, 15 (not packet 17, because skipLastN is set to 2) + assert.Equal(t, rtcp.PacketBitmap(0b10), p.Nacks[0].LostPackets) // we want packets: 13, 15 (not packet 17, because skipLastN is setReceived to 2) case <-time.After(10 * time.Millisecond): t.Fatal("written rtcp packet not found") } } + +func TestGeneratorInterceptor_InvalidSize(t *testing.T) { + _, err := NewGeneratorInterceptor(GeneratorSize(5)) + assert.Error(t, err, ErrInvalidSize) +} diff --git a/pkg/nack/generator_option.go b/pkg/nack/generator_option.go new file mode 100644 index 00000000..86c627ba --- /dev/null +++ b/pkg/nack/generator_option.go @@ -0,0 +1,40 @@ +package nack + +import ( + "time" + + "github.com/pion/logging" +) + +// GeneratorOption can be used to configure GeneratorInterceptor +type GeneratorOption func(r *GeneratorInterceptor) + +// GeneratorSize sets the size of the interceptor. +// Size must be one of: 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768 +func GeneratorSize(size uint16) GeneratorOption { + return func(r *GeneratorInterceptor) { + r.size = size + } +} + +// GeneratorSkipLastN sets the number of packets (n-1 packets before the last received packets) to ignore when generating +// nack requests. +func GeneratorSkipLastN(skipLastN uint16) GeneratorOption { + return func(r *GeneratorInterceptor) { + r.skipLastN = skipLastN + } +} + +// GeneratorLog sets a logger for the interceptor +func GeneratorLog(log logging.LeveledLogger) GeneratorOption { + return func(r *GeneratorInterceptor) { + r.log = log + } +} + +// GeneratorInterval sets the nack send interval for the interceptor +func GeneratorInterval(interval time.Duration) GeneratorOption { + return func(r *GeneratorInterceptor) { + r.interval = interval + } +} diff --git a/nack/receive_log.go b/pkg/nack/receive_log.go similarity index 58% rename from nack/receive_log.go rename to pkg/nack/receive_log.go index 1656944f..8107f59a 100644 --- a/nack/receive_log.go +++ b/pkg/nack/receive_log.go @@ -1,27 +1,11 @@ package nack import ( - "errors" - "strconv" + "fmt" "sync" ) -var ( - allowedReceiveLogSizes map[uint16]bool - invalidReceiveLogSizeError string -) - -func init() { - allowedReceiveLogSizes = make(map[uint16]bool, 15) - invalidReceiveLogSizeError = "invalid ReceiveLog size, must be one of: " - for i := 6; i < 16; i++ { - allowedReceiveLogSizes[1< end (with counting for rollovers) for i := s.end + 1; i != seq; i++ { // clear packets between end and seq (these may contain packets from a "size" ago) - s.del(i) + s.delReceived(i) } s.end = seq @@ -70,18 +65,16 @@ func (s *ReceiveLog) Add(seq uint16) { s.lastConsecutive = seq - s.size s.fixLastConsecutive() // there might be valid packets at the beginning of the buffer now } - } else { + case s.lastConsecutive+1 == seq: // negative diff, seq < end (with counting for rollovers) - if s.lastConsecutive+1 == seq { - s.lastConsecutive = seq - s.fixLastConsecutive() // there might be other valid packets after seq - } + s.lastConsecutive = seq + s.fixLastConsecutive() // there might be other valid packets after seq } - s.set(seq) + s.setReceived(seq) } -func (s *ReceiveLog) Get(seq uint16) bool { +func (s *receiveLog) get(seq uint16) bool { s.m.RLock() defer s.m.RUnlock() @@ -94,10 +87,10 @@ func (s *ReceiveLog) Get(seq uint16) bool { return false } - return s.get(seq) + return s.getReceived(seq) } -func (s *ReceiveLog) MissingSeqNumbers(skipLastN uint16) []uint16 { +func (s *receiveLog) missingSeqNumbers(skipLastN uint16) []uint16 { s.m.RLock() defer s.m.RUnlock() @@ -109,7 +102,7 @@ func (s *ReceiveLog) MissingSeqNumbers(skipLastN uint16) []uint16 { missingPacketSeqNums := make([]uint16, 0) for i := s.lastConsecutive + 1; i != until+1; i++ { - if !s.get(i) { + if !s.getReceived(i) { missingPacketSeqNums = append(missingPacketSeqNums, i) } } @@ -117,24 +110,24 @@ func (s *ReceiveLog) MissingSeqNumbers(skipLastN uint16) []uint16 { return missingPacketSeqNums } -func (s *ReceiveLog) set(seq uint16) { +func (s *receiveLog) setReceived(seq uint16) { pos := seq % s.size s.packets[pos/64] |= 1 << (pos % 64) } -func (s *ReceiveLog) del(seq uint16) { +func (s *receiveLog) delReceived(seq uint16) { pos := seq % s.size s.packets[pos/64] &^= 1 << (pos % 64) } -func (s *ReceiveLog) get(seq uint16) bool { +func (s *receiveLog) getReceived(seq uint16) bool { pos := seq % s.size return (s.packets[pos/64] & (1 << (pos % 64))) != 0 } -func (s *ReceiveLog) fixLastConsecutive() { +func (s *receiveLog) fixLastConsecutive() { i := s.lastConsecutive + 1 - for ; i != s.end+1 && s.get(i); i++ { + for ; i != s.end+1 && s.getReceived(i); i++ { // find all consecutive packets } s.lastConsecutive = i - 1 diff --git a/nack/receive_log_test.go b/pkg/nack/receive_log_test.go similarity index 95% rename from nack/receive_log_test.go rename to pkg/nack/receive_log_test.go index a3acfd6e..a631a03e 100644 --- a/nack/receive_log_test.go +++ b/pkg/nack/receive_log_test.go @@ -9,7 +9,7 @@ func TestReceivedBuffer(t *testing.T) { for _, start := range []uint16{0, 1, 127, 128, 129, 511, 512, 513, 32767, 32768, 32769, 65407, 65408, 65409, 65534, 65535} { start := start - rl, err := NewReceiveLog(128) + rl, err := newReceiveLog(128) if err != nil { t.Fatalf("%+v", err) } @@ -32,7 +32,7 @@ func TestReceivedBuffer(t *testing.T) { add := func(nums ...uint16) { for _, n := range nums { seq := start + n - rl.Add(seq) + rl.add(seq) } } @@ -40,7 +40,7 @@ func TestReceivedBuffer(t *testing.T) { t.Helper() for _, n := range nums { seq := start + n - if !rl.Get(seq) { + if !rl.get(seq) { t.Errorf("not found: %d", seq) } } @@ -49,14 +49,14 @@ func TestReceivedBuffer(t *testing.T) { t.Helper() for _, n := range nums { seq := start + n - if rl.Get(seq) { + if rl.get(seq) { t.Errorf("packet found for %d", seq) } } } assertMissing := func(skipLastN uint16, nums []uint16) { t.Helper() - missing := rl.MissingSeqNumbers(skipLastN) + missing := rl.missingSeqNumbers(skipLastN) if missing == nil { missing = []uint16{} } diff --git a/nack/sender_interceptor.go b/pkg/nack/responder_interceptor.go similarity index 63% rename from nack/sender_interceptor.go rename to pkg/nack/responder_interceptor.go index 0e2389cf..b9fdc539 100644 --- a/nack/sender_interceptor.go +++ b/pkg/nack/responder_interceptor.go @@ -9,36 +9,43 @@ import ( "github.com/pion/rtp" ) -type SenderInterceptor struct { +// ResponderInterceptor responds to nack feedback messages +type ResponderInterceptor struct { interceptor.NoOp size uint16 streams *sync.Map log logging.LeveledLogger } -type senderNackStream struct { - sendBuffer *SendBuffer +type localStream struct { + sendBuffer *sendBuffer rtpWriter interceptor.RTPWriter } -// NewSenderInterceptor returns a new ReceiverInterceptor interceptor -func NewSenderInterceptor(size uint16, log logging.LeveledLogger) (*SenderInterceptor, error) { - _, err := NewSendBuffer(size) +// NewResponderInterceptor returns a new GeneratorInterceptor interceptor +func NewResponderInterceptor(opts ...ResponderOption) (*ResponderInterceptor, error) { + r := &ResponderInterceptor{ + NoOp: interceptor.NoOp{}, + size: 8192, + streams: &sync.Map{}, + log: logging.NewDefaultLoggerFactory().NewLogger("nack_responder"), + } + + for _, opt := range opts { + opt(r) + } + + _, err := newSendBuffer(r.size) if err != nil { return nil, err } - return &SenderInterceptor{ - NoOp: interceptor.NoOp{}, - size: size, - streams: &sync.Map{}, - log: log, - }, nil + return r, nil } // BindRTCPReader lets you modify any incoming RTCP packets. It is called once per sender/receiver, however this might // change in the future. The returned method will be called once per packet batch. -func (n *SenderInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { +func (n *ResponderInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { return interceptor.RTCPReaderFunc(func() ([]rtcp.Packet, interceptor.Attributes, error) { pkts, attr, err := reader.Read() if err != nil { @@ -60,7 +67,7 @@ func (n *SenderInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interc // BindLocalStream lets you modify any outgoing RTP packets. It is called once for per LocalStream. The returned method // will be called once per rtp packet. -func (n *SenderInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { +func (n *ResponderInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { hasNack := false for _, fb := range info.RTCPFeedback { if fb.Type == "nack" && fb.Parameter == "" { @@ -72,33 +79,33 @@ func (n *SenderInterceptor) BindLocalStream(info *interceptor.StreamInfo, writer return writer } - // error is already checked in NewReceiverInterceptor - sendBuffer, _ := NewSendBuffer(n.size) - n.streams.Store(info.SSRC, &senderNackStream{sendBuffer: sendBuffer, rtpWriter: writer}) + // error is already checked in NewGeneratorInterceptor + sendBuffer, _ := newSendBuffer(n.size) + n.streams.Store(info.SSRC, &localStream{sendBuffer: sendBuffer, rtpWriter: writer}) return interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) { - sendBuffer.Add(p) + sendBuffer.add(p) return writer.Write(p, attributes) }) } // UnbindLocalStream is called when the Stream is removed. It can be used to clean up any data related to that track. -func (n *SenderInterceptor) UnbindLocalStream(info *interceptor.StreamInfo) { +func (n *ResponderInterceptor) UnbindLocalStream(info *interceptor.StreamInfo) { n.streams.Delete(info.SSRC) } -func (n *SenderInterceptor) resendPackets(nack *rtcp.TransportLayerNack) { +func (n *ResponderInterceptor) resendPackets(nack *rtcp.TransportLayerNack) { v, ok := n.streams.Load(nack.MediaSSRC) if !ok { return } - stream := v.(*senderNackStream) + stream := v.(*localStream) seqNums := nackParsToSequenceNumbers(nack.Nacks) for _, seq := range seqNums { - p := stream.sendBuffer.Get(seq) + p := stream.sendBuffer.get(seq) if p == nil { continue } diff --git a/nack/sender_interceptor_test.go b/pkg/nack/responder_interceptor_test.go similarity index 61% rename from nack/sender_interceptor_test.go rename to pkg/nack/responder_interceptor_test.go index 30ab846a..8efe46d0 100644 --- a/nack/sender_interceptor_test.go +++ b/pkg/nack/responder_interceptor_test.go @@ -1,19 +1,23 @@ package nack import ( + "errors" "testing" "time" "github.com/pion/interceptor" - "github.com/pion/interceptor/test" + "github.com/pion/interceptor/internal/test" "github.com/pion/logging" "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/stretchr/testify/assert" ) -func TestSenderInterceptor(t *testing.T) { - i, err := NewSenderInterceptor(8, logging.NewDefaultLoggerFactory().NewLogger("test")) +func TestResponderInterceptor(t *testing.T) { + i, err := NewResponderInterceptor( + ResponderSize(8), + ResponderLog(logging.NewDefaultLoggerFactory().NewLogger("test")), + ) if err != nil { t.Fatal(err) } @@ -23,9 +27,9 @@ func TestSenderInterceptor(t *testing.T) { RTCPFeedback: []interceptor.RTCPFeedback{{Type: "nack"}}, }, i) defer func() { - err := stream.Close() - if err != nil { - t.Errorf("error closing stream: %v", err) + closeErr := stream.Close() + if closeErr != nil { + t.Errorf("error closing stream: %v", closeErr) } }() @@ -43,10 +47,14 @@ func TestSenderInterceptor(t *testing.T) { } } - stream.ReceiveRTCP([]rtcp.Packet{&rtcp.TransportLayerNack{ - MediaSSRC: 1, - SenderSSRC: 2, - Nacks: []rtcp.NackPair{{PacketID: 11, LostPackets: 0b1011}}}, // sequence numbers: 11, 12, 13, 15 + stream.ReceiveRTCP([]rtcp.Packet{ + &rtcp.TransportLayerNack{ + MediaSSRC: 1, + SenderSSRC: 2, + Nacks: []rtcp.NackPair{ + {PacketID: 11, LostPackets: 0b1011}, // sequence numbers: 11, 12, 13, 15 + }, + }, }) // seq number 13 was never sent, so it can't be resent @@ -65,3 +73,10 @@ func TestSenderInterceptor(t *testing.T) { case <-time.After(10 * time.Millisecond): } } + +func TestResponderInterceptor_InvalidSize(t *testing.T) { + _, err := NewResponderInterceptor(ResponderSize(5)) + if err == nil || !errors.Is(err, ErrInvalidSize) { + t.Fatalf("expected invalid size error, got: %v", err) + } +} diff --git a/pkg/nack/responder_option.go b/pkg/nack/responder_option.go new file mode 100644 index 00000000..a4a1d292 --- /dev/null +++ b/pkg/nack/responder_option.go @@ -0,0 +1,21 @@ +package nack + +import "github.com/pion/logging" + +// ResponderOption can be used to configure ResponderInterceptor +type ResponderOption func(s *ResponderInterceptor) + +// ResponderSize sets the size of the interceptor. +// Size must be one of: 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768 +func ResponderSize(size uint16) ResponderOption { + return func(r *ResponderInterceptor) { + r.size = size + } +} + +// ResponderLog sets a logger for the interceptor +func ResponderLog(log logging.LeveledLogger) ResponderOption { + return func(r *ResponderInterceptor) { + r.log = log + } +} diff --git a/nack/send_buffer.go b/pkg/nack/send_buffer.go similarity index 51% rename from nack/send_buffer.go rename to pkg/nack/send_buffer.go index 8a4db653..cf3f020e 100644 --- a/nack/send_buffer.go +++ b/pkg/nack/send_buffer.go @@ -1,8 +1,7 @@ package nack import ( - "errors" - "strconv" + "fmt" "github.com/pion/rtp" ) @@ -11,40 +10,35 @@ const ( uint16SizeHalf = 1 << 15 ) -var ( - allowedSendBufferSizes map[uint16]bool - invalidSendBufferSizeError string -) - -func init() { - allowedSendBufferSizes = make(map[uint16]bool, 15) - invalidSendBufferSizeError = "invalid sendBuffer size, must be one of: " - for i := 0; i < 16; i++ { - allowedSendBufferSizes[1<= uint16SizeHalf { return nil diff --git a/nack/send_buffer_test.go b/pkg/nack/send_buffer_test.go similarity index 89% rename from nack/send_buffer_test.go rename to pkg/nack/send_buffer_test.go index be04fcf4..e4be7c96 100644 --- a/nack/send_buffer_test.go +++ b/pkg/nack/send_buffer_test.go @@ -10,7 +10,7 @@ func TestSendBuffer(t *testing.T) { for _, start := range []uint16{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 511, 512, 513, 32767, 32768, 32769, 65527, 65528, 65529, 65530, 65531, 65532, 65533, 65534, 65535} { start := start - sb, err := NewSendBuffer(8) + sb, err := newSendBuffer(8) if err != nil { t.Fatalf("%+v", err) } @@ -18,7 +18,7 @@ func TestSendBuffer(t *testing.T) { add := func(nums ...uint16) { for _, n := range nums { seq := start + n - sb.Add(&rtp.Packet{Header: rtp.Header{SequenceNumber: seq}}) + sb.add(&rtp.Packet{Header: rtp.Header{SequenceNumber: seq}}) } } @@ -26,7 +26,7 @@ func TestSendBuffer(t *testing.T) { t.Helper() for _, n := range nums { seq := start + n - packet := sb.Get(seq) + packet := sb.get(seq) if packet == nil { t.Errorf("packet not found: %d", seq) continue @@ -40,7 +40,7 @@ func TestSendBuffer(t *testing.T) { t.Helper() for _, n := range nums { seq := start + n - packet := sb.Get(seq) + packet := sb.get(seq) if packet != nil { t.Errorf("packet found for %d: %d", seq, packet.SequenceNumber) }