diff --git a/examples/broadcast/main.go b/examples/broadcast/main.go index 02026a08526..b35e35a1245 100644 --- a/examples/broadcast/main.go +++ b/examples/broadcast/main.go @@ -68,7 +68,7 @@ func main() { // nolint:gocognit rtpBuf := make([]byte, 1400) for { - i, readErr := remoteTrack.Read(rtpBuf) + i, _, readErr := remoteTrack.Read(rtpBuf) if readErr != nil { panic(readErr) } diff --git a/examples/reflect/main.go b/examples/reflect/main.go index cb1037a7460..6ba283091b9 100644 --- a/examples/reflect/main.go +++ b/examples/reflect/main.go @@ -82,7 +82,7 @@ func main() { fmt.Printf("Track has started, of type %d: %s \n", track.PayloadType(), track.Codec().MimeType) for { // Read RTP packets being sent to Pion - rtp, readErr := track.ReadRTP() + rtp, _, readErr := track.ReadRTP() if readErr != nil { panic(readErr) } diff --git a/examples/rtp-forwarder/main.go b/examples/rtp-forwarder/main.go index 4b0166ce872..e16ef6e9d25 100644 --- a/examples/rtp-forwarder/main.go +++ b/examples/rtp-forwarder/main.go @@ -116,7 +116,7 @@ func main() { b := make([]byte, 1500) for { // Read - n, readErr := track.Read(b) + n, _, readErr := track.Read(b) if readErr != nil { panic(readErr) } diff --git a/examples/save-to-disk/main.go b/examples/save-to-disk/main.go index bf05bf24634..593929815e3 100644 --- a/examples/save-to-disk/main.go +++ b/examples/save-to-disk/main.go @@ -23,7 +23,7 @@ func saveToDisk(i media.Writer, track *webrtc.TrackRemote) { }() for { - rtpPacket, err := track.ReadRTP() + rtpPacket, _, err := track.ReadRTP() if err != nil { panic(err) } diff --git a/examples/simulcast/main.go b/examples/simulcast/main.go index 819bb8f0ff2..d361767ce42 100644 --- a/examples/simulcast/main.go +++ b/examples/simulcast/main.go @@ -92,7 +92,7 @@ func main() { }() for { // Read RTP packets being sent to Pion - packet, readErr := track.ReadRTP() + packet, _, readErr := track.ReadRTP() if readErr != nil { panic(readErr) } diff --git a/examples/swap-tracks/main.go b/examples/swap-tracks/main.go index 1c60ef8f22d..c405c17178f 100644 --- a/examples/swap-tracks/main.go +++ b/examples/swap-tracks/main.go @@ -85,7 +85,7 @@ func main() { // nolint:gocognit var isCurrTrack bool for { // Read RTP packets being sent to Pion - rtp, readErr := track.ReadRTP() + rtp, _, readErr := track.ReadRTP() if readErr != nil { panic(readErr) } diff --git a/go.mod b/go.mod index a4e6f05887a..664f3d015d2 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/pion/datachannel v1.4.21 github.com/pion/dtls/v2 v2.0.4 github.com/pion/ice/v2 v2.0.14 - github.com/pion/interceptor v0.0.5 + github.com/pion/interceptor v0.0.6 github.com/pion/logging v0.2.2 github.com/pion/randutil v0.1.0 github.com/pion/rtcp v1.2.6 diff --git a/go.sum b/go.sum index b6ccbc8ac90..87f76a0f792 100644 --- a/go.sum +++ b/go.sum @@ -40,8 +40,8 @@ github.com/pion/dtls/v2 v2.0.4 h1:WuUcqi6oYMu/noNTz92QrF1DaFj4eXbhQ6dzaaAwOiI= github.com/pion/dtls/v2 v2.0.4/go.mod h1:qAkFscX0ZHoI1E07RfYPoRw3manThveu+mlTDdOxoGI= github.com/pion/ice/v2 v2.0.14 h1:FxXxauyykf89SWAtkQCfnHkno6G8+bhRkNguSh9zU+4= github.com/pion/ice/v2 v2.0.14/go.mod h1:wqaUbOq5ObDNU5ox1hRsEst0rWfsKuH1zXjQFEWiZwM= -github.com/pion/interceptor v0.0.5 h1:BOwlubM1lntji3eNaVrhW1Qk3u1UoemrhM4mbv24XGM= -github.com/pion/interceptor v0.0.5/go.mod h1:lPVrf5xfosI989ZcmgPS4WwwRhd+XAyTFaYI2wHf7nU= +github.com/pion/interceptor v0.0.6 h1:530EdZi757pZEx510kvO25FkEuKm2mrb0p9NA+Xfj8E= +github.com/pion/interceptor v0.0.6/go.mod h1:QHkPVN5uyuw54wHqqL1KS9fxf3M3RzOlVKg/YrtK1so= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/mdns v0.0.4 h1:O4vvVqr4DGX63vzmO6Fw9vpy3lfztVWHGCQfyw0ZLSY= diff --git a/interceptor.go b/interceptor.go index fd886707ca4..4e9dbf1809f 100644 --- a/interceptor.go +++ b/interceptor.go @@ -3,14 +3,16 @@ package webrtc import ( + "sync/atomic" + "github.com/pion/interceptor" + "github.com/pion/rtp" ) // RegisterDefaultInterceptors will register some useful interceptors. If you want to customize which interceptors are loaded, // you should copy the code from this method and remove unwanted interceptors. func RegisterDefaultInterceptors(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Registry) error { - err := ConfigureNack(mediaEngine, interceptorRegistry) - if err != nil { + if err := ConfigureNack(mediaEngine, interceptorRegistry); err != nil { return err } @@ -24,3 +26,47 @@ func ConfigureNack(mediaEngine *MediaEngine, interceptorRegistry *interceptor.Re interceptorRegistry.Add(&interceptor.NACK{}) return nil } + +type interceptorToTrackLocalWriter struct{ interceptor atomic.Value } // interceptor.RTPWriter } + +func (i *interceptorToTrackLocalWriter) WriteRTP(header *rtp.Header, payload []byte) (int, error) { + if writer, ok := i.interceptor.Load().(interceptor.RTPWriter); ok && writer != nil { + return writer.Write(header, payload, interceptor.Attributes{}) + } + + return 0, nil +} + +func (i *interceptorToTrackLocalWriter) Write(b []byte) (int, error) { + packet := &rtp.Packet{} + if err := packet.Unmarshal(b); err != nil { + return 0, err + } + + return i.WriteRTP(&packet.Header, packet.Payload) +} + +func createStreamInfo(id string, ssrc SSRC, payloadType PayloadType, codec RTPCodecCapability, webrtcHeaderExtensions []RTPHeaderExtensionParameter) interceptor.StreamInfo { + headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(webrtcHeaderExtensions)) + for _, h := range webrtcHeaderExtensions { + headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI}) + } + + feedbacks := make([]interceptor.RTCPFeedback, 0, len(codec.RTCPFeedback)) + for _, f := range codec.RTCPFeedback { + feedbacks = append(feedbacks, interceptor.RTCPFeedback{Type: f.Type, Parameter: f.Parameter}) + } + + return interceptor.StreamInfo{ + ID: id, + Attributes: interceptor.Attributes{}, + SSRC: uint32(ssrc), + PayloadType: uint8(payloadType), + RTPHeaderExtensions: headerExtensions, + MimeType: codec.MimeType, + ClockRate: codec.ClockRate, + Channels: codec.Channels, + SDPFmtpLine: codec.SDPFmtpLine, + RTCPFeedback: feedbacks, + } +} diff --git a/interceptor_test.go b/interceptor_test.go index 4ae6ce3ba4c..1ba0fe7d6d4 100644 --- a/interceptor_test.go +++ b/interceptor_test.go @@ -2,14 +2,13 @@ package webrtc +// import ( - "sync" - "sync/atomic" + "context" "testing" "time" "github.com/pion/interceptor" - "github.com/pion/rtcp" "github.com/pion/rtp" "github.com/pion/transport/test" "github.com/pion/webrtc/v3/pkg/media" @@ -17,68 +16,37 @@ import ( ) type testInterceptor struct { - t *testing.T - extensionID uint8 - rtcpWriter atomic.Value - lastRTCP atomic.Value interceptor.NoOp + + t *testing.T } func (t *testInterceptor) BindLocalStream(_ *interceptor.StreamInfo, writer interceptor.RTPWriter) interceptor.RTPWriter { - return interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) { + return interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { // set extension on outgoing packet - p.Header.Extension = true - p.Header.ExtensionProfile = 0xBEDE - assert.NoError(t.t, p.Header.SetExtension(t.extensionID, []byte("write"))) + header.Extension = true + header.ExtensionProfile = 0xBEDE + assert.NoError(t.t, header.SetExtension(2, []byte("foo"))) - return writer.Write(p, attributes) + return writer.Write(header, payload, attributes) }) } -func (t *testInterceptor) BindRemoteStream(info *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { - return interceptor.RTPReaderFunc(func() (*rtp.Packet, interceptor.Attributes, error) { - p, attributes, err := reader.Read() - if err != nil { - return nil, nil, err +func (t *testInterceptor) BindRemoteStream(_ *interceptor.StreamInfo, reader interceptor.RTPReader) interceptor.RTPReader { + return interceptor.RTPReaderFunc(func(b []byte, a interceptor.Attributes) (int, interceptor.Attributes, error) { + if a == nil { + a = interceptor.Attributes{} } - // set extension on incoming packet - p.Header.Extension = true - p.Header.ExtensionProfile = 0xBEDE - assert.NoError(t.t, p.Header.SetExtension(t.extensionID, []byte("read"))) - - // write back a pli - rtcpWriter := t.rtcpWriter.Load().(interceptor.RTCPWriter) - pli := &rtcp.PictureLossIndication{SenderSSRC: info.SSRC, MediaSSRC: info.SSRC} - _, err = rtcpWriter.Write([]rtcp.Packet{pli}, make(interceptor.Attributes)) - assert.NoError(t.t, err) - - return p, attributes, nil - }) -} -func (t *testInterceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.RTCPReader { - return interceptor.RTCPReaderFunc(func() ([]rtcp.Packet, interceptor.Attributes, error) { - pkts, attributes, err := reader.Read() - if err != nil { - return nil, nil, err - } - - t.lastRTCP.Store(pkts[0]) - - return pkts, attributes, nil + a.Set("attribute", "value") + return reader.Read(b, a) }) } -func (t *testInterceptor) lastReadRTCP() rtcp.Packet { - p, _ := t.lastRTCP.Load().(rtcp.Packet) - return p -} - -func (t *testInterceptor) BindRTCPWriter(writer interceptor.RTCPWriter) interceptor.RTCPWriter { - t.rtcpWriter.Store(writer) - return writer -} - +// E2E test of the features of Interceptors +// * Assert an extension can be set on an outbound packet +// * Assert an extension can be read on an outbound packet +// * Assert that attributes set by an interceptor are returned to the Reader func TestPeerConnection_Interceptor(t *testing.T) { to := test.TimeOut(time.Second * 20) defer to.Stop() @@ -86,12 +54,12 @@ func TestPeerConnection_Interceptor(t *testing.T) { report := test.CheckRoutines(t) defer report() - createPC := func(i interceptor.Interceptor) *PeerConnection { + createPC := func() *PeerConnection { m := &MediaEngine{} assert.NoError(t, m.RegisterDefaultCodecs()) ir := &interceptor.Registry{} - ir.Add(i) + ir.Add(&testInterceptor{t: t}) pc, err := NewAPI(WithMediaEngine(m), WithInterceptorRegistry(ir)).NewPeerConnection(Configuration{}) assert.NoError(t, err) @@ -99,75 +67,41 @@ func TestPeerConnection_Interceptor(t *testing.T) { return pc } - sendInterceptor := &testInterceptor{t: t, extensionID: 1} - senderPC := createPC(sendInterceptor) - receiverPC := createPC(&testInterceptor{t: t, extensionID: 2}) + offerer := createPC() + answerer := createPC() track, err := NewTrackLocalStaticSample(RTPCodecCapability{MimeType: "video/vp8"}, "video", "pion") assert.NoError(t, err) - sender, err := senderPC.AddTrack(track) + _, err = offerer.AddTrack(track) assert.NoError(t, err) - pending := new(int32) - wg := &sync.WaitGroup{} - - wg.Add(1) - *pending++ - receiverPC.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) { - p, readErr := track.ReadRTP() + seenRTP, seenRTPCancel := context.WithCancel(context.Background()) + answerer.OnTrack(func(track *TrackRemote, receiver *RTPReceiver) { + p, attributes, readErr := track.ReadRTP() assert.NoError(t, readErr) assert.Equal(t, p.Extension, true) - assert.Equal(t, "write", string(p.GetExtension(1))) - assert.Equal(t, "read", string(p.GetExtension(2))) - atomic.AddInt32(pending, -1) - wg.Done() + assert.Equal(t, "foo", string(p.GetExtension(2))) + assert.Equal(t, "value", attributes.Get("attribute")) - for { - if _, readErr = track.ReadRTP(); readErr != nil { - return - } - } + seenRTPCancel() }) - wg.Add(1) - *pending++ - go func() { - _, readErr := sender.ReadRTCP() - assert.NoError(t, readErr) - atomic.AddInt32(pending, -1) - wg.Done() + assert.NoError(t, signalPair(offerer, answerer)) + func() { + ticker := time.NewTicker(time.Millisecond * 20) for { - if _, readErr = sender.ReadRTCP(); readErr != nil { + select { + case <-seenRTP.Done(): return + case <-ticker.C: + assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second})) } } }() - assert.NoError(t, signalPair(senderPC, receiverPC)) - - wg.Add(1) - go func() { - defer wg.Done() - for { - time.Sleep(time.Millisecond * 100) - - assert.NoError(t, track.WriteSample(media.Sample{Data: []byte{0x00}, Duration: time.Second})) - - if atomic.LoadInt32(pending) == 0 { - return - } - } - }() - - wg.Wait() - assert.NoError(t, senderPC.Close()) - assert.NoError(t, receiverPC.Close()) - - pli, _ := sendInterceptor.lastReadRTCP().(*rtcp.PictureLossIndication) - if pli == nil || pli.SenderSSRC == 0 { - t.Errorf("pli not found by send interceptor") - } + assert.NoError(t, offerer.Close()) + assert.NoError(t, answerer.Close()) } diff --git a/interceptor_track_local.go b/interceptor_track_local.go deleted file mode 100644 index c68fdeef5d1..00000000000 --- a/interceptor_track_local.go +++ /dev/null @@ -1,27 +0,0 @@ -// +build !js - -package webrtc - -import ( - "sync/atomic" - - "github.com/pion/interceptor" - "github.com/pion/rtp" -) - -type interceptorTrackLocalWriter struct { - TrackLocalWriter - rtpWriter atomic.Value -} - -func (i *interceptorTrackLocalWriter) setRTPWriter(writer interceptor.RTPWriter) { - i.rtpWriter.Store(writer) -} - -func (i *interceptorTrackLocalWriter) WriteRTP(header *rtp.Header, payload []byte) (int, error) { - if writer, ok := i.rtpWriter.Load().(interceptor.RTPWriter); ok && writer != nil { - return writer.Write(&rtp.Packet{Header: *header, Payload: payload}, make(interceptor.Attributes)) - } - - return 0, nil -} diff --git a/peerconnection.go b/peerconnection.go index d504869f5ba..88d7e7e339d 100644 --- a/peerconnection.go +++ b/peerconnection.go @@ -1152,7 +1152,6 @@ func (pc *PeerConnection) startReceiver(incoming trackDetails, receiver *RTPRece receiver.Track().kind = receiver.kind receiver.Track().codec = params.Codecs[0] receiver.Track().params = params - receiver.Track().bindInterceptor() receiver.Track().mu.Unlock() pc.onTrack(receiver.Track(), receiver) diff --git a/peerconnection_media_test.go b/peerconnection_media_test.go index 488f6261ff1..3c66a91b259 100644 --- a/peerconnection_media_test.go +++ b/peerconnection_media_test.go @@ -105,7 +105,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) { }() go func() { - _, routineErr := receiver.Read(make([]byte, 1400)) + _, _, routineErr := receiver.Read(make([]byte, 1400)) if routineErr != nil { awaitRTCPReceiverRecv <- routineErr } else { @@ -115,7 +115,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) { haveClosedAwaitRTPRecv := false for { - p, routineErr := track.ReadRTP() + p, _, routineErr := track.ReadRTP() if routineErr != nil { close(awaitRTPRecvClosed) return @@ -168,7 +168,7 @@ func TestPeerConnection_Media_Sample(t *testing.T) { }() go func() { - if _, routineErr := sender.Read(make([]byte, 1400)); routineErr == nil { + if _, _, routineErr := sender.Read(make([]byte, 1400)); routineErr == nil { close(awaitRTCPSenderRecv) } }() @@ -688,11 +688,11 @@ func TestRtpSenderReceiver_ReadClose_Error(t *testing.T) { sender, receiver := tr.Sender(), tr.Receiver() assert.NoError(t, sender.Stop()) - _, err = sender.Read(make([]byte, 0, 1400)) + _, _, err = sender.Read(make([]byte, 0, 1400)) assert.Error(t, err, io.ErrClosedPipe) assert.NoError(t, receiver.Stop()) - _, err = receiver.Read(make([]byte, 0, 1400)) + _, _, err = receiver.Read(make([]byte, 0, 1400)) assert.Error(t, err, io.ErrClosedPipe) assert.NoError(t, pc.Close()) diff --git a/peerconnection_renegotiation_test.go b/peerconnection_renegotiation_test.go index 70ef5b95df3..cbb95f726b8 100644 --- a/peerconnection_renegotiation_test.go +++ b/peerconnection_renegotiation_test.go @@ -360,7 +360,7 @@ func TestPeerConnection_Renegotiation_CodecChange(t *testing.T) { pcAnswer.OnTrack(func(track *TrackRemote, r *RTPReceiver) { tracksCh <- track for { - if _, readErr := track.ReadRTP(); readErr == io.EOF { + if _, _, readErr := track.ReadRTP(); readErr == io.EOF { tracksClosed <- struct{}{} return } @@ -450,7 +450,7 @@ func TestPeerConnection_Renegotiation_RemoveTrack(t *testing.T) { onTrackFiredFunc() for { - if _, err := track.ReadRTP(); err == io.EOF { + if _, _, err := track.ReadRTP(); err == io.EOF { trackClosedFunc() return } diff --git a/rtpreceiver.go b/rtpreceiver.go index 4e2fa26fcfc..f387dbca519 100644 --- a/rtpreceiver.go +++ b/rtpreceiver.go @@ -10,14 +10,19 @@ import ( "github.com/pion/interceptor" "github.com/pion/rtcp" "github.com/pion/srtp/v2" + "github.com/pion/webrtc/v3/internal/util" ) // trackStreams maintains a mapping of RTP/RTCP streams to a specific track // a RTPReceiver may contain multiple streams if we are dealing with Multicast type trackStreams struct { - track *TrackRemote + track *TrackRemote + rtpReadStream *srtp.ReadStreamSRTP - rtcpReadStream *srtp.ReadStreamSRTCP + rtpInterceptor interceptor.RTPReader + + rtcpReadStream *srtp.ReadStreamSRTCP + rtcpInterceptor interceptor.RTCPReader } // RTPReceiver allows an application to inspect the receipt of a TrackRemote @@ -32,8 +37,6 @@ type RTPReceiver struct { // A reference to the associated api object api *API - - interceptorRTCPReader interceptor.RTCPReader } // NewRTPReceiver constructs a new RTPReceiver @@ -50,7 +53,6 @@ func (api *API) NewRTPReceiver(kind RTPCodecType, transport *DTLSTransport) (*RT received: make(chan interface{}), tracks: []trackStreams{}, } - r.interceptorRTCPReader = api.interceptor.BindRTCPReader(interceptor.RTCPReaderFunc(r.readRTCP)) return r, nil } @@ -115,8 +117,7 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error { } var err error - t.rtpReadStream, t.rtcpReadStream, err = r.streamsForSSRC(parameters.Encodings[0].SSRC) - if err != nil { + if t.rtpReadStream, t.rtpInterceptor, t.rtcpReadStream, t.rtcpInterceptor, err = r.streamsForSSRC(parameters.Encodings[0].SSRC, interceptor.StreamInfo{}); err != nil { return err } @@ -138,41 +139,35 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error { } // Read reads incoming RTCP for this RTPReceiver -func (r *RTPReceiver) Read(b []byte) (n int, err error) { +func (r *RTPReceiver) Read(b []byte) (n int, a interceptor.Attributes, err error) { select { case <-r.received: - return r.tracks[0].rtcpReadStream.Read(b) + return r.tracks[0].rtcpInterceptor.Read(b, a) case <-r.closed: - return 0, io.ErrClosedPipe + return 0, nil, io.ErrClosedPipe } } // ReadSimulcast reads incoming RTCP for this RTPReceiver for given rid -func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, err error) { +func (r *RTPReceiver) ReadSimulcast(b []byte, rid string) (n int, a interceptor.Attributes, err error) { select { case <-r.received: for _, t := range r.tracks { if t.track != nil && t.track.rid == rid { - return t.rtcpReadStream.Read(b) + return t.rtcpInterceptor.Read(b, a) } } - return 0, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid) + return 0, nil, fmt.Errorf("%w: %s", errRTPReceiverForRIDTrackStreamNotFound, rid) case <-r.closed: - return 0, io.ErrClosedPipe + return 0, nil, io.ErrClosedPipe } } // ReadRTCP is a convenience method that wraps Read and unmarshal for you. // It also runs any configured interceptors. -func (r *RTPReceiver) ReadRTCP() ([]rtcp.Packet, error) { - pkts, _, err := r.interceptorRTCPReader.Read() - return pkts, err -} - -// ReadRTCP is a convenience method that wraps Read and unmarshal for you -func (r *RTPReceiver) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) { +func (r *RTPReceiver) ReadRTCP() ([]rtcp.Packet, interceptor.Attributes, error) { b := make([]byte, receiveMTU) - i, err := r.Read(b) + i, attributes, err := r.Read(b) if err != nil { return nil, nil, err } @@ -182,18 +177,19 @@ func (r *RTPReceiver) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) return nil, nil, err } - return pkts, make(interceptor.Attributes), nil + return pkts, attributes, nil } // ReadSimulcastRTCP is a convenience method that wraps ReadSimulcast and unmarshal for you -func (r *RTPReceiver) ReadSimulcastRTCP(rid string) ([]rtcp.Packet, error) { +func (r *RTPReceiver) ReadSimulcastRTCP(rid string) ([]rtcp.Packet, interceptor.Attributes, error) { b := make([]byte, receiveMTU) - i, err := r.ReadSimulcast(b, rid) + i, attributes, err := r.ReadSimulcast(b, rid) if err != nil { - return nil, err + return nil, nil, err } - return rtcp.Unmarshal(b[:i]) + pkts, err := rtcp.Unmarshal(b[:i]) + return pkts, attributes, err } func (r *RTPReceiver) haveReceived() bool { @@ -209,32 +205,34 @@ func (r *RTPReceiver) haveReceived() bool { func (r *RTPReceiver) Stop() error { r.mu.Lock() defer r.mu.Unlock() + var err error select { case <-r.closed: - return nil + return err default: } select { case <-r.received: for i := range r.tracks { + errs := []error{} + if r.tracks[i].rtcpReadStream != nil { - if err := r.tracks[i].rtcpReadStream.Close(); err != nil { - return err - } + errs = append(errs, r.tracks[i].rtcpReadStream.Close()) } + if r.tracks[i].rtpReadStream != nil { - if err := r.tracks[i].rtpReadStream.Close(); err != nil { - return err - } + errs = append(errs, r.tracks[i].rtpReadStream.Close()) } + + err = util.FlattenErrs(errs) } default: } close(r.closed) - return nil + return err } func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams { @@ -247,13 +245,13 @@ func (r *RTPReceiver) streamsForTrack(t *TrackRemote) *trackStreams { } // readRTP should only be called by a track, this only exists so we can keep state in one place -func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, err error) { +func (r *RTPReceiver) readRTP(b []byte, reader *TrackRemote) (n int, a interceptor.Attributes, err error) { <-r.received if t := r.streamsForTrack(reader); t != nil { - return t.rtpReadStream.Read(b) + return t.rtpInterceptor.Read(b, a) } - return 0, fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC()) + return 0, nil, fmt.Errorf("%w: %d", errRTPReceiverWithSSRCTrackStreamNotFound, reader.SSRC()) } // receiveForRid is the sibling of Receive expect for RIDs instead of SSRCs @@ -269,12 +267,11 @@ func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, ssrc SSRC) r.tracks[i].track.codec = params.Codecs[0] r.tracks[i].track.params = params r.tracks[i].track.ssrc = ssrc - r.tracks[i].track.bindInterceptor() + streamInfo := createStreamInfo("", ssrc, params.Codecs[0].PayloadType, params.Codecs[0].RTPCodecCapability, params.HeaderExtensions) r.tracks[i].track.mu.Unlock() var err error - r.tracks[i].rtpReadStream, r.tracks[i].rtcpReadStream, err = r.streamsForSSRC(ssrc) - if err != nil { + if r.tracks[0].rtpReadStream, r.tracks[0].rtpInterceptor, r.tracks[0].rtcpReadStream, r.tracks[0].rtcpInterceptor, err = r.streamsForSSRC(ssrc, streamInfo); err != nil { return nil, err } @@ -285,26 +282,36 @@ func (r *RTPReceiver) receiveForRid(rid string, params RTPParameters, ssrc SSRC) return nil, fmt.Errorf("%w: %d", errRTPReceiverForSSRCTrackStreamNotFound, ssrc) } -func (r *RTPReceiver) streamsForSSRC(ssrc SSRC) (*srtp.ReadStreamSRTP, *srtp.ReadStreamSRTCP, error) { +func (r *RTPReceiver) streamsForSSRC(ssrc SSRC, streamInfo interceptor.StreamInfo) (*srtp.ReadStreamSRTP, interceptor.RTPReader, *srtp.ReadStreamSRTCP, interceptor.RTCPReader, error) { srtpSession, err := r.transport.getSRTPSession() if err != nil { - return nil, nil, err + return nil, nil, nil, nil, err } rtpReadStream, err := srtpSession.OpenReadStream(uint32(ssrc)) if err != nil { - return nil, nil, err + return nil, nil, nil, nil, err } + rtpInterceptor := r.api.interceptor.BindRemoteStream(&streamInfo, interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { + n, err = rtpReadStream.Read(in) + return n, a, err + })) + srtcpSession, err := r.transport.getSRTCPSession() if err != nil { - return nil, nil, err + return nil, nil, nil, nil, err } rtcpReadStream, err := srtcpSession.OpenReadStream(uint32(ssrc)) if err != nil { - return nil, nil, err + return nil, nil, nil, nil, err } - return rtpReadStream, rtcpReadStream, nil + rtcpInterceptor := r.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { + n, err = rtcpReadStream.Read(in) + return n, a, err + })) + + return rtpReadStream, rtpInterceptor, rtcpReadStream, rtcpInterceptor, nil } diff --git a/rtpsender.go b/rtpsender.go index e54770c0e39..c29cbeedee9 100644 --- a/rtpsender.go +++ b/rtpsender.go @@ -16,8 +16,10 @@ import ( type RTPSender struct { track TrackLocal - srtpStream *srtpWriterFuture - context TrackLocalContext + srtpStream *srtpWriterFuture + rtcpInterceptor interceptor.RTCPReader + + context TrackLocalContext transport *DTLSTransport @@ -36,8 +38,6 @@ type RTPSender struct { mu sync.RWMutex sendCalled, stopCalled chan struct{} - - interceptorRTCPReader interceptor.RTCPReader } // NewRTPSender constructs a new RTPSender @@ -64,9 +64,13 @@ func (api *API) NewRTPSender(track TrackLocal, transport *DTLSTransport) (*RTPSe srtpStream: &srtpWriterFuture{}, } - r.interceptorRTCPReader = api.interceptor.BindRTCPReader(interceptor.RTCPReaderFunc(r.readRTCP)) r.srtpStream.rtpSender = r + r.rtcpInterceptor = r.api.interceptor.BindRTCPReader(interceptor.RTPReaderFunc(func(in []byte, a interceptor.Attributes) (n int, attributes interceptor.Attributes, err error) { + n, err = r.srtpStream.Read(in) + return n, a, err + })) + return r, nil } @@ -156,8 +160,7 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error { return errRTPSenderSendAlreadyCalled } - writeStream := &interceptorTrackLocalWriter{TrackLocalWriter: r.srtpStream} - + writeStream := &interceptorToTrackLocalWriter{} r.context = TrackLocalContext{ id: r.id, params: r.api.mediaEngine.getRTPParametersByKind(r.track.Kind(), []RTPTransceiverDirection{RTPTransceiverDirectionSendonly}), @@ -171,33 +174,11 @@ func (r *RTPSender) Send(parameters RTPSendParameters) error { } r.context.params.Codecs = []RTPCodecParameters{codec} - headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(r.context.params.HeaderExtensions)) - for _, h := range r.context.params.HeaderExtensions { - headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI}) - } - feedbacks := make([]interceptor.RTCPFeedback, 0, len(codec.RTCPFeedback)) - for _, f := range codec.RTCPFeedback { - feedbacks = append(feedbacks, interceptor.RTCPFeedback{Type: f.Type, Parameter: f.Parameter}) - } - info := &interceptor.StreamInfo{ - ID: r.context.id, - Attributes: interceptor.Attributes{}, - SSRC: uint32(r.context.ssrc), - PayloadType: uint8(codec.PayloadType), - RTPHeaderExtensions: headerExtensions, - MimeType: codec.MimeType, - ClockRate: codec.ClockRate, - Channels: codec.Channels, - SDPFmtpLine: codec.SDPFmtpLine, - RTCPFeedback: feedbacks, - } - writeStream.setRTPWriter( - r.api.interceptor.BindLocalStream( - info, - interceptor.RTPWriterFunc(func(p *rtp.Packet, attributes interceptor.Attributes) (int, error) { - return r.srtpStream.WriteRTP(&p.Header, p.Payload) - }), - )) + streamInfo := createStreamInfo(r.id, parameters.Encodings[0].SSRC, codec.PayloadType, codec.RTPCodecCapability, parameters.HeaderExtensions) + rtpInterceptor := r.api.interceptor.BindLocalStream(&streamInfo, interceptor.RTPWriterFunc(func(header *rtp.Header, payload []byte, attributes interceptor.Attributes) (int, error) { + return r.srtpStream.WriteRTP(header, payload) + })) + writeStream.interceptor.Store(rtpInterceptor) close(r.sendCalled) return nil @@ -227,25 +208,19 @@ func (r *RTPSender) Stop() error { } // Read reads incoming RTCP for this RTPReceiver -func (r *RTPSender) Read(b []byte) (n int, err error) { +func (r *RTPSender) Read(b []byte) (n int, a interceptor.Attributes, err error) { select { case <-r.sendCalled: - return r.srtpStream.Read(b) + return r.rtcpInterceptor.Read(b, a) case <-r.stopCalled: - return 0, io.ErrClosedPipe + return 0, nil, io.ErrClosedPipe } } // ReadRTCP is a convenience method that wraps Read and unmarshals for you. -// It also runs any configured interceptors. -func (r *RTPSender) ReadRTCP() ([]rtcp.Packet, error) { - pkts, _, err := r.interceptorRTCPReader.Read() - return pkts, err -} - -func (r *RTPSender) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) { +func (r *RTPSender) ReadRTCP() ([]rtcp.Packet, interceptor.Attributes, error) { b := make([]byte, receiveMTU) - i, err := r.Read(b) + i, attributes, err := r.Read(b) if err != nil { return nil, nil, err } @@ -255,7 +230,7 @@ func (r *RTPSender) readRTCP() ([]rtcp.Packet, interceptor.Attributes, error) { return nil, nil, err } - return pkts, make(interceptor.Attributes), nil + return pkts, attributes, nil } // hasSent tells if data has been ever sent for this instance diff --git a/rtpsender_test.go b/rtpsender_test.go index 0f182b5b316..f8ac51247a6 100644 --- a/rtpsender_test.go +++ b/rtpsender_test.go @@ -50,7 +50,7 @@ func Test_RTPSender_ReplaceTrack(t *testing.T) { assert.Equal(t, uint64(1), atomic.AddUint64(&onTrackCount, 1)) for { - pkt, err := track.ReadRTP() + pkt, _, err := track.ReadRTP() if err != nil { assert.True(t, errors.Is(io.EOF, err)) return diff --git a/track_remote.go b/track_remote.go index e195c637839..c150a3a796b 100644 --- a/track_remote.go +++ b/track_remote.go @@ -23,46 +23,18 @@ type TrackRemote struct { params RTPParameters rid string - receiver *RTPReceiver - peeked []byte - - interceptorRTPReader interceptor.RTPReader + receiver *RTPReceiver + peeked []byte + peekedAttributes interceptor.Attributes } func newTrackRemote(kind RTPCodecType, ssrc SSRC, rid string, receiver *RTPReceiver) *TrackRemote { - t := &TrackRemote{ + return &TrackRemote{ kind: kind, ssrc: ssrc, rid: rid, receiver: receiver, } - t.interceptorRTPReader = interceptor.RTPReaderFunc(t.readRTP) - - return t -} - -func (t *TrackRemote) bindInterceptor() { - headerExtensions := make([]interceptor.RTPHeaderExtension, 0, len(t.params.HeaderExtensions)) - for _, h := range t.params.HeaderExtensions { - headerExtensions = append(headerExtensions, interceptor.RTPHeaderExtension{ID: h.ID, URI: h.URI}) - } - feedbacks := make([]interceptor.RTCPFeedback, 0, len(t.codec.RTCPFeedback)) - for _, f := range t.codec.RTCPFeedback { - feedbacks = append(feedbacks, interceptor.RTCPFeedback{Type: f.Type, Parameter: f.Parameter}) - } - info := &interceptor.StreamInfo{ - ID: t.id, - Attributes: interceptor.Attributes{}, - SSRC: uint32(t.ssrc), - PayloadType: uint8(t.payloadType), - RTPHeaderExtensions: headerExtensions, - MimeType: t.codec.MimeType, - ClockRate: t.codec.ClockRate, - Channels: t.codec.Channels, - SDPFmtpLine: t.codec.SDPFmtpLine, - RTCPFeedback: feedbacks, - } - t.interceptorRTPReader = t.receiver.api.interceptor.BindRemoteStream(info, interceptor.RTPReaderFunc(t.readRTP)) } // ID is the unique identifier for this Track. This should be unique for the @@ -125,7 +97,7 @@ func (t *TrackRemote) Codec() RTPCodecParameters { } // Read reads data from the track. -func (t *TrackRemote) Read(b []byte) (n int, err error) { +func (t *TrackRemote) Read(b []byte) (n int, attributes interceptor.Attributes, err error) { t.mu.RLock() r := t.receiver peeked := t.peeked != nil @@ -134,7 +106,10 @@ func (t *TrackRemote) Read(b []byte) (n int, err error) { if peeked { t.mu.Lock() data := t.peeked + attributes = t.peekedAttributes + t.peeked = nil + t.peekedAttributes = nil t.mu.Unlock() // someone else may have stolen our packet when we // released the lock. Deal with it. @@ -147,34 +122,10 @@ func (t *TrackRemote) Read(b []byte) (n int, err error) { return r.readRTP(b, t) } -// peek is like Read, but it doesn't discard the packet read -func (t *TrackRemote) peek(b []byte) (n int, err error) { - n, err = t.Read(b) - if err != nil { - return - } - - t.mu.Lock() - // this might overwrite data if somebody peeked between the Read - // and us getting the lock. Oh well, we'll just drop a packet in - // that case. - data := make([]byte, n) - n = copy(data, b[:n]) - t.peeked = data - t.mu.Unlock() - return -} - // ReadRTP is a convenience method that wraps Read and unmarshals for you. -// It also runs any configured interceptors. -func (t *TrackRemote) ReadRTP() (*rtp.Packet, error) { - p, _, err := t.interceptorRTPReader.Read() - return p, err -} - -func (t *TrackRemote) readRTP() (*rtp.Packet, interceptor.Attributes, error) { +func (t *TrackRemote) ReadRTP() (*rtp.Packet, interceptor.Attributes, error) { b := make([]byte, receiveMTU) - i, err := t.Read(b) + i, attributes, err := t.Read(b) if err != nil { return nil, nil, err } @@ -183,14 +134,14 @@ func (t *TrackRemote) readRTP() (*rtp.Packet, interceptor.Attributes, error) { if err := r.Unmarshal(b[:i]); err != nil { return nil, nil, err } - return r, interceptor.Attributes{}, nil + return r, attributes, nil } // determinePayloadType blocks and reads a single packet to determine the PayloadType for this Track // this is useful because we can't announce it to the user until we know the payloadType func (t *TrackRemote) determinePayloadType() error { b := make([]byte, receiveMTU) - n, err := t.peek(b) + n, _, err := t.peek(b) if err != nil { return err } @@ -205,3 +156,22 @@ func (t *TrackRemote) determinePayloadType() error { return nil } + +// peek is like Read, but it doesn't discard the packet read +func (t *TrackRemote) peek(b []byte) (n int, a interceptor.Attributes, err error) { + n, a, err = t.Read(b) + if err != nil { + return + } + + t.mu.Lock() + // this might overwrite data if somebody peeked between the Read + // and us getting the lock. Oh well, we'll just drop a packet in + // that case. + data := make([]byte, n) + n = copy(data, b[:n]) + t.peeked = data + t.peekedAttributes = a + t.mu.Unlock() + return +}