From 32238796f4e17392cd49b9173f32d0d75df1597c Mon Sep 17 00:00:00 2001 From: streamer45 Date: Mon, 7 Oct 2024 11:04:04 -0600 Subject: [PATCH] Implement rtc stats monitor --- client/client.go | 5 + client/rtc.go | 87 +++++++++++++++- client/rtc_monitor.go | 196 +++++++++++++++++++++++++++++++++++++ client/rtc_monitor_test.go | 121 +++++++++++++++++++++++ 4 files changed, 407 insertions(+), 2 deletions(-) create mode 100644 client/rtc_monitor.go create mode 100644 client/rtc_monitor_test.go diff --git a/client/client.go b/client/client.go index 739854f..7ef0556 100644 --- a/client/client.go +++ b/client/client.go @@ -104,6 +104,7 @@ type Client struct { receivers map[string][]*webrtc.RTPReceiver voiceSender *webrtc.RTPSender screenTransceivers []*webrtc.RTPTransceiver + rtcMon *rtcMonitor state int32 @@ -225,6 +226,10 @@ func (c *Client) emit(eventType EventType, ctx any) { func (c *Client) close() { atomic.StoreInt32(&c.state, clientStateClosed) + if c.rtcMon != nil { + c.rtcMon.Stop() + } + if c.pc != nil { if err := c.pc.Close(); err != nil { c.log.Error("failed to close peer connection", slog.String("err", err.Error())) diff --git a/client/rtc.go b/client/rtc.go index 9ca11a2..d7d3781 100644 --- a/client/rtc.go +++ b/client/rtc.go @@ -12,10 +12,12 @@ import ( "io" "log/slog" "sync/atomic" + "time" "github.com/mattermost/rtcd/service/rtc/dc" "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/stats" "github.com/pion/rtcp" "github.com/pion/webrtc/v3" ) @@ -25,8 +27,10 @@ const ( signalMsgOffer = "offer" signalMsgAnswer = "answer" - iceChSize = 20 - receiveMTU = 1460 + iceChSize = 20 + receiveMTU = 1460 + rtcMonitorInterval = 4 * time.Second + pingInterval = time.Second ) var ( @@ -186,6 +190,17 @@ func (c *Client) initRTCSession() error { } i := interceptor.Registry{} + + statsInterceptorFactory, err := stats.NewInterceptor() + if err != nil { + return fmt.Errorf("failed to create stats interceptor: %w", err) + } + var statsGetter stats.Getter + statsInterceptorFactory.OnNewPeerConnection(func(_ string, g stats.Getter) { + statsGetter = g + }) + i.Add(statsInterceptorFactory) + if err := webrtc.RegisterDefaultInterceptors(&m, &i); err != nil { return fmt.Errorf("failed to register default interceptors: %w", err) } @@ -206,6 +221,13 @@ func (c *Client) initRTCSession() error { c.pc = pc c.mut.Unlock() + rtcMon := newRTCMonitor(c.log, pc, statsGetter, rtcMonitorInterval) + rtcMon.Start() + + c.mut.Lock() + c.rtcMon = rtcMon + c.mut.Unlock() + pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { if candidate == nil { c.log.Debug("local ICE gathering completed") @@ -374,6 +396,64 @@ func (c *Client) initRTCSession() error { } c.dc.Store(dataCh) + lastPingTS := new(int64) + lastRTT := new(int64) + go func() { + pingTicker := time.NewTicker(pingInterval) + for { + select { + case <-pingTicker.C: + msg, err := dc.EncodeMessage(dc.MessageTypePing, nil) + if err != nil { + c.log.Error("failed to encode ping msg", slog.String("err", err.Error())) + continue + } + + if err := dataCh.Send(msg); err != nil { + c.log.Error("failed to send ping msg", slog.String("err", err.Error())) + continue + } + + atomic.StoreInt64(lastPingTS, time.Now().UnixMilli()) + case stats := <-rtcMon.StatsCh(): + if stats.lossRate >= 0 { + msg, err := dc.EncodeMessage(dc.MessageTypeLossRate, stats.lossRate) + if err != nil { + c.log.Error("failed to encode loss rate msg", slog.String("err", err.Error())) + } else { + if err := dataCh.Send(msg); err != nil { + c.log.Error("failed to send loss rate msg", slog.String("err", err.Error())) + } + } + } + + if rtt := atomic.LoadInt64(lastRTT); rtt > 0 { + msg, err := dc.EncodeMessage(dc.MessageTypeRoundTripTime, rtt) + if err != nil { + c.log.Error("failed to encode rtt msg", slog.String("err", err.Error())) + } else { + if err := dataCh.Send(msg); err != nil { + c.log.Error("failed to send rtt msg", slog.String("err", err.Error())) + } + } + } + + if stats.jitter > 0 { + msg, err := dc.EncodeMessage(dc.MessageTypeJitter, stats.jitter) + if err != nil { + c.log.Error("failed to encode jitter msg", slog.String("err", err.Error())) + } else { + if err := dataCh.Send(msg); err != nil { + c.log.Error("failed to send jitter msg", slog.String("err", err.Error())) + } + } + } + case <-c.wsCloseCh: + return + } + } + }() + dataCh.OnMessage(func(msg webrtc.DataChannelMessage) { mt, payload, err := dc.DecodeMessage(msg.Data) if err != nil { @@ -383,6 +463,9 @@ func (c *Client) initRTCSession() error { switch mt { case dc.MessageTypePong: + if ts := atomic.LoadInt64(lastPingTS); ts > 0 { + atomic.StoreInt64(lastRTT, time.Now().UnixMilli()-ts) + } case dc.MessageTypeSDP: var sdp webrtc.SessionDescription if err := json.Unmarshal(payload.([]byte), &sdp); err != nil { diff --git a/client/rtc_monitor.go b/client/rtc_monitor.go new file mode 100644 index 0000000..38b8847 --- /dev/null +++ b/client/rtc_monitor.go @@ -0,0 +1,196 @@ +// Copyright (c) 2022-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package client + +import ( + "log/slog" + "time" + + "github.com/pion/interceptor/pkg/stats" + "github.com/pion/webrtc/v3" +) + +type rtcMonitor struct { + log *slog.Logger + pc *webrtc.PeerConnection + statsGetter stats.Getter + interval time.Duration + + lastSndStats map[webrtc.SSRC]*stats.Stats + lastRcvStats map[webrtc.SSRC]*stats.Stats + + statsCh chan rtcStats + stopCh chan struct{} + doneCh chan struct{} +} + +type rtcStats struct { + lossRate float64 + jitter float64 +} + +func newRTCMonitor(log *slog.Logger, pc *webrtc.PeerConnection, sg stats.Getter, intv time.Duration) *rtcMonitor { + return &rtcMonitor{ + log: log, + pc: pc, + statsGetter: sg, + interval: intv, + statsCh: make(chan rtcStats, 1), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + lastSndStats: make(map[webrtc.SSRC]*stats.Stats), + lastRcvStats: make(map[webrtc.SSRC]*stats.Stats), + } +} + +func (m *rtcMonitor) gatherStats() (map[webrtc.SSRC]*stats.Stats, map[webrtc.SSRC]*stats.Stats) { + sndStats := make(map[webrtc.SSRC]*stats.Stats) + for _, snd := range m.pc.GetSenders() { + if snd == nil { + continue + } + for i, enc := range snd.GetParameters().Encodings { + // For simplicity we only consider audio streams. + // This lets us more easily make assumptions on the clock rate. + if snd.GetParameters().Codecs[i].MimeType != webrtc.MimeTypeOpus { + continue + } + + stats := m.statsGetter.Get(uint32(enc.SSRC)) + if stats != nil { + sndStats[enc.SSRC] = stats + } + } + } + + rcvStats := make(map[webrtc.SSRC]*stats.Stats) + for _, rcv := range m.pc.GetReceivers() { + if rcv == nil { + continue + } + + track := rcv.Track() + + if track == nil { + continue + } + + // For simplicity we only consider audio streams. + // This lets us more easily make assumptions on the clock rate. + if track.Codec().MimeType != webrtc.MimeTypeOpus { + continue + } + + stats := m.statsGetter.Get(uint32(track.SSRC())) + if stats != nil { + rcvStats[track.SSRC()] = stats + } + } + + return sndStats, rcvStats +} + +func (m *rtcMonitor) getAvgSenderStats(stats map[webrtc.SSRC]*stats.Stats) (avgLossRate, avgJitter, statsCount float64) { + var totalJitter, totalLossRate float64 + + for ssrc, s := range stats { + if prevStats := m.lastSndStats[ssrc]; prevStats == nil || s.OutboundRTPStreamStats.PacketsSent == prevStats.OutboundRTPStreamStats.PacketsSent { + continue + } + + totalLossRate += s.RemoteInboundRTPStreamStats.FractionLost + totalJitter += s.RemoteInboundRTPStreamStats.Jitter + statsCount++ + } + + if statsCount > 0 { + avgJitter = (totalJitter / statsCount) * 1000 + avgLossRate = totalLossRate / statsCount + } + + return +} + +func (m *rtcMonitor) getAvgReceiverStats(stats map[webrtc.SSRC]*stats.Stats) (avgLossRate, avgJitter, statsCount float64) { + var totalJitter, totalLost, totalReceived float64 + + for ssrc, s := range stats { + prevStats := m.lastRcvStats[ssrc] + if prevStats == nil || s.InboundRTPStreamStats.PacketsReceived == prevStats.InboundRTPStreamStats.PacketsReceived { + continue + } + + receivedDiff := s.InboundRTPStreamStats.PacketsReceived - prevStats.InboundRTPStreamStats.PacketsReceived + potentiallyLost := int64(s.RemoteOutboundRTPStreamStats.PacketsSent) - int64(s.InboundRTPStreamStats.PacketsReceived) + prevPotentiallyLost := int64(prevStats.RemoteOutboundRTPStreamStats.PacketsSent) - int64(prevStats.InboundRTPStreamStats.PacketsReceived) + var lostDiff int64 + if prevPotentiallyLost >= 0 && potentiallyLost > prevPotentiallyLost { + lostDiff = potentiallyLost - prevPotentiallyLost + } + totalLost += float64(lostDiff) + totalReceived += float64(receivedDiff) + + // pion stores inbound jitter in RTP units rather than seconds. + // 960 is the expected frame size for opus packets. (20ms at 48000Hz) + totalJitter += s.InboundRTPStreamStats.Jitter / 960 + + statsCount++ + } + + if statsCount > 0 { + avgJitter = (totalJitter / statsCount) + avgLossRate = totalLost / totalReceived + } + + return +} + +func (m *rtcMonitor) processStats(sndStats, rcvStats map[webrtc.SSRC]*stats.Stats) { + defer func() { + // cache stats for the next iteration + m.lastSndStats = sndStats + m.lastRcvStats = rcvStats + }() + + sndLossRate, sndJitter, sndCnt := m.getAvgSenderStats(sndStats) + rcvLossRate, rcvJitter, rcvCnt := m.getAvgReceiverStats(rcvStats) + + // nothing to do if we didn't process any stats + if sndCnt == 0 && rcvCnt == 0 { + return + } + + select { + case m.statsCh <- rtcStats{lossRate: max(sndLossRate, rcvLossRate), jitter: max(sndJitter, rcvJitter)}: + default: + m.log.Error("failed to send stats: channel is full") + } +} + +func (m *rtcMonitor) Start() { + go func() { + defer close(m.doneCh) + ticker := time.NewTicker(m.interval) + + for { + select { + case <-ticker.C: + sndStats, rcvStats := m.gatherStats() + m.processStats(sndStats, rcvStats) + case <-m.stopCh: + return + } + } + }() +} + +func (m *rtcMonitor) Stop() { + close(m.stopCh) + <-m.doneCh + close(m.statsCh) +} + +func (m *rtcMonitor) StatsCh() <-chan rtcStats { + return m.statsCh +} diff --git a/client/rtc_monitor_test.go b/client/rtc_monitor_test.go new file mode 100644 index 0000000..bf5e650 --- /dev/null +++ b/client/rtc_monitor_test.go @@ -0,0 +1,121 @@ +// Copyright (c) 2022-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package client + +import ( + "log/slog" + "os" + "sync" + "testing" + "time" + + "github.com/pion/interceptor/pkg/stats" + "github.com/pion/webrtc/v3" + + "github.com/stretchr/testify/require" +) + +type statsGetter struct{} + +func (sg *statsGetter) Get(_ uint32) *stats.Stats { + return nil +} + +func TestRTCMonitor(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + })) + + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + require.NoError(t, err) + defer pc.Close() + + var sg statsGetter + rtcMon := newRTCMonitor(logger, pc, &sg, time.Second) + require.NotNil(t, rtcMon) + + rtcMon.Start() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + s := <-rtcMon.StatsCh() + require.Empty(t, s) + }() + + time.Sleep(2 * time.Second) + + rtcMon.Stop() + wg.Wait() +} + +func TestRTCMonitorProcessStats(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + })) + + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + require.NoError(t, err) + defer pc.Close() + + var sg statsGetter + rtcMon := newRTCMonitor(logger, pc, &sg, time.Second) + require.NotNil(t, rtcMon) + + rtcMon.lastSndStats = map[webrtc.SSRC]*stats.Stats{ + 45454545: { + RemoteInboundRTPStreamStats: stats.RemoteInboundRTPStreamStats{}, + OutboundRTPStreamStats: stats.OutboundRTPStreamStats{ + SentRTPStreamStats: stats.SentRTPStreamStats{ + PacketsSent: 45, + }, + }, + }, + } + rtcMon.lastRcvStats = map[webrtc.SSRC]*stats.Stats{ + 45454545: { + InboundRTPStreamStats: stats.InboundRTPStreamStats{ + ReceivedRTPStreamStats: stats.ReceivedRTPStreamStats{ + PacketsReceived: 45, + }, + }, + }, + } + rtcMon.processStats(map[webrtc.SSRC]*stats.Stats{ + 45454545: { + RemoteInboundRTPStreamStats: stats.RemoteInboundRTPStreamStats{ + FractionLost: 0.45, + ReceivedRTPStreamStats: stats.ReceivedRTPStreamStats{ + Jitter: 0.0045, + }, + }, + OutboundRTPStreamStats: stats.OutboundRTPStreamStats{ + SentRTPStreamStats: stats.SentRTPStreamStats{ + PacketsSent: 4545, + }, + }, + }, + }, map[webrtc.SSRC]*stats.Stats{ + 45454545: { + InboundRTPStreamStats: stats.InboundRTPStreamStats{ + ReceivedRTPStreamStats: stats.ReceivedRTPStreamStats{ + PacketsReceived: 4545, + }, + }, + }, + }) + + select { + case stats := <-rtcMon.StatsCh(): + require.Equal(t, rtcStats{ + lossRate: 0.45, + jitter: 4.5, + }, stats) + default: + require.Fail(t, "channel should have stats") + } +}