diff --git a/client/client.go b/client/client.go index 739854f..7ef0556 100644 --- a/client/client.go +++ b/client/client.go @@ -104,6 +104,7 @@ type Client struct { receivers map[string][]*webrtc.RTPReceiver voiceSender *webrtc.RTPSender screenTransceivers []*webrtc.RTPTransceiver + rtcMon *rtcMonitor state int32 @@ -225,6 +226,10 @@ func (c *Client) emit(eventType EventType, ctx any) { func (c *Client) close() { atomic.StoreInt32(&c.state, clientStateClosed) + if c.rtcMon != nil { + c.rtcMon.Stop() + } + if c.pc != nil { if err := c.pc.Close(); err != nil { c.log.Error("failed to close peer connection", slog.String("err", err.Error())) diff --git a/client/config.go b/client/config.go index 323cdf7..1edc37c 100644 --- a/client/config.go +++ b/client/config.go @@ -28,6 +28,8 @@ type Config struct { // EnableDCSignaling controls whether the client should use data channels // for signaling of media tracks. EnableDCSignaling bool + // EnableRTCMonitor controls whether the RTC monitor component should be enabled. + EnableRTCMonitor bool wsURL string } diff --git a/client/helper_test.go b/client/helper_test.go index 8b39e5f..d3318d4 100644 --- a/client/helper_test.go +++ b/client/helper_test.go @@ -357,17 +357,19 @@ func setupTestHelper(tb testing.TB, channelName string) *TestHelper { } th.adminClient, err = New(Config{ - SiteURL: th.apiURL, - AuthToken: th.adminAPIClient.AuthToken, - ChannelID: channelID, + SiteURL: th.apiURL, + AuthToken: th.adminAPIClient.AuthToken, + ChannelID: channelID, + EnableRTCMonitor: true, }, WithLogger(logger)) require.NoError(tb, err) require.NotNil(tb, th.adminClient) th.userClient, err = New(Config{ - SiteURL: th.apiURL, - AuthToken: th.userAPIClient.AuthToken, - ChannelID: channelID, + SiteURL: th.apiURL, + AuthToken: th.userAPIClient.AuthToken, + ChannelID: channelID, + EnableRTCMonitor: true, }, WithLogger(logger)) require.NoError(tb, err) require.NotNil(tb, th.userClient) diff --git a/client/rtc.go b/client/rtc.go index 9ca11a2..65ef05d 100644 --- a/client/rtc.go +++ b/client/rtc.go @@ -12,10 +12,12 @@ import ( "io" "log/slog" "sync/atomic" + "time" "github.com/mattermost/rtcd/service/rtc/dc" "github.com/pion/interceptor" + "github.com/pion/interceptor/pkg/stats" "github.com/pion/rtcp" "github.com/pion/webrtc/v3" ) @@ -25,8 +27,10 @@ const ( signalMsgOffer = "offer" signalMsgAnswer = "answer" - iceChSize = 20 - receiveMTU = 1460 + iceChSize = 20 + receiveMTU = 1460 + rtcMonitorInterval = 4 * time.Second + pingInterval = time.Second ) var ( @@ -186,6 +190,17 @@ func (c *Client) initRTCSession() error { } i := interceptor.Registry{} + + statsInterceptorFactory, err := stats.NewInterceptor() + if err != nil { + return fmt.Errorf("failed to create stats interceptor: %w", err) + } + var statsGetter stats.Getter + statsInterceptorFactory.OnNewPeerConnection(func(_ string, g stats.Getter) { + statsGetter = g + }) + i.Add(statsInterceptorFactory) + if err := webrtc.RegisterDefaultInterceptors(&m, &i); err != nil { return fmt.Errorf("failed to register default interceptors: %w", err) } @@ -206,6 +221,14 @@ func (c *Client) initRTCSession() error { c.pc = pc c.mut.Unlock() + rtcMon := newRTCMonitor(c.log, pc, statsGetter, rtcMonitorInterval) + if c.cfg.EnableRTCMonitor { + c.mut.Lock() + c.rtcMon = rtcMon + c.mut.Unlock() + rtcMon.Start() + } + pc.OnICECandidate(func(candidate *webrtc.ICECandidate) { if candidate == nil { c.log.Debug("local ICE gathering completed") @@ -374,6 +397,69 @@ func (c *Client) initRTCSession() error { } c.dc.Store(dataCh) + lastPingTS := new(int64) + lastRTT := new(int64) + go func() { + pingTicker := time.NewTicker(pingInterval) + for { + select { + case <-pingTicker.C: + msg, err := dc.EncodeMessage(dc.MessageTypePing, nil) + if err != nil { + c.log.Error("failed to encode ping msg", slog.String("err", err.Error())) + continue + } + + if err := dataCh.Send(msg); err != nil { + c.log.Error("failed to send ping msg", slog.String("err", err.Error())) + continue + } + + atomic.StoreInt64(lastPingTS, time.Now().UnixMilli()) + case stats := <-rtcMon.StatsCh(): + c.log.Debug("rtc stats", + slog.Float64("lossRate", stats.lossRate), + slog.Int64("rtt", atomic.LoadInt64(lastRTT)), + slog.Float64("jitter", stats.jitter)) + + if stats.lossRate >= 0 { + msg, err := dc.EncodeMessage(dc.MessageTypeLossRate, stats.lossRate) + if err != nil { + c.log.Error("failed to encode loss rate msg", slog.String("err", err.Error())) + } else { + if err := dataCh.Send(msg); err != nil { + c.log.Error("failed to send loss rate msg", slog.String("err", err.Error())) + } + } + } + + if rtt := atomic.LoadInt64(lastRTT); rtt > 0 { + msg, err := dc.EncodeMessage(dc.MessageTypeRoundTripTime, float64(rtt/1000)) + if err != nil { + c.log.Error("failed to encode rtt msg", slog.String("err", err.Error())) + } else { + if err := dataCh.Send(msg); err != nil { + c.log.Error("failed to send rtt msg", slog.String("err", err.Error())) + } + } + } + + if stats.jitter > 0 { + msg, err := dc.EncodeMessage(dc.MessageTypeJitter, stats.jitter) + if err != nil { + c.log.Error("failed to encode jitter msg", slog.String("err", err.Error())) + } else { + if err := dataCh.Send(msg); err != nil { + c.log.Error("failed to send jitter msg", slog.String("err", err.Error())) + } + } + } + case <-c.wsCloseCh: + return + } + } + }() + dataCh.OnMessage(func(msg webrtc.DataChannelMessage) { mt, payload, err := dc.DecodeMessage(msg.Data) if err != nil { @@ -383,6 +469,9 @@ func (c *Client) initRTCSession() error { switch mt { case dc.MessageTypePong: + if ts := atomic.LoadInt64(lastPingTS); ts > 0 { + atomic.StoreInt64(lastRTT, time.Now().UnixMilli()-ts) + } case dc.MessageTypeSDP: var sdp webrtc.SessionDescription if err := json.Unmarshal(payload.([]byte), &sdp); err != nil { diff --git a/client/rtc_monitor.go b/client/rtc_monitor.go new file mode 100644 index 0000000..2acc51f --- /dev/null +++ b/client/rtc_monitor.go @@ -0,0 +1,195 @@ +// Copyright (c) 2022-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package client + +import ( + "log/slog" + "time" + + "github.com/pion/interceptor/pkg/stats" + "github.com/pion/webrtc/v3" +) + +type rtcMonitor struct { + log *slog.Logger + pc *webrtc.PeerConnection + statsGetter stats.Getter + interval time.Duration + + lastSndStats map[webrtc.SSRC]*stats.Stats + lastRcvStats map[webrtc.SSRC]*stats.Stats + + statsCh chan rtcStats + stopCh chan struct{} + doneCh chan struct{} +} + +type rtcStats struct { + lossRate float64 + jitter float64 +} + +func newRTCMonitor(log *slog.Logger, pc *webrtc.PeerConnection, sg stats.Getter, intv time.Duration) *rtcMonitor { + return &rtcMonitor{ + log: log, + pc: pc, + statsGetter: sg, + interval: intv, + statsCh: make(chan rtcStats, 1), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + lastSndStats: make(map[webrtc.SSRC]*stats.Stats), + lastRcvStats: make(map[webrtc.SSRC]*stats.Stats), + } +} + +func (m *rtcMonitor) gatherStats() (map[webrtc.SSRC]*stats.Stats, map[webrtc.SSRC]*stats.Stats) { + sndStats := make(map[webrtc.SSRC]*stats.Stats) + for _, snd := range m.pc.GetSenders() { + if snd == nil { + continue + } + for i, enc := range snd.GetParameters().Encodings { + // For simplicity we only consider audio streams. + // This lets us more easily make assumptions on the clock rate. + if snd.GetParameters().Codecs[i].MimeType != webrtc.MimeTypeOpus { + continue + } + + stats := m.statsGetter.Get(uint32(enc.SSRC)) + if stats != nil { + sndStats[enc.SSRC] = stats + } + } + } + + rcvStats := make(map[webrtc.SSRC]*stats.Stats) + for _, rcv := range m.pc.GetReceivers() { + if rcv == nil { + continue + } + + track := rcv.Track() + + if track == nil { + continue + } + + // For simplicity we only consider audio streams. + // This lets us more easily make assumptions on the clock rate. + if track.Codec().MimeType != webrtc.MimeTypeOpus { + continue + } + + stats := m.statsGetter.Get(uint32(track.SSRC())) + if stats != nil { + rcvStats[track.SSRC()] = stats + } + } + + return sndStats, rcvStats +} + +func (m *rtcMonitor) getAvgSenderStats(stats map[webrtc.SSRC]*stats.Stats) (avgLossRate, avgJitter, statsCount float64) { + var totalJitter, totalLossRate float64 + + for ssrc, s := range stats { + if prevStats := m.lastSndStats[ssrc]; prevStats == nil || s.OutboundRTPStreamStats.PacketsSent == prevStats.OutboundRTPStreamStats.PacketsSent { + continue + } + + totalLossRate += s.RemoteInboundRTPStreamStats.FractionLost + totalJitter += s.RemoteInboundRTPStreamStats.Jitter + statsCount++ + } + + if statsCount > 0 { + avgJitter = totalJitter / statsCount + avgLossRate = totalLossRate / statsCount + } + + return +} + +func (m *rtcMonitor) getAvgReceiverStats(stats map[webrtc.SSRC]*stats.Stats) (avgLossRate, avgJitter, statsCount float64) { + var totalJitter, totalLost, totalReceived float64 + + for ssrc, s := range stats { + prevStats := m.lastRcvStats[ssrc] + if prevStats == nil || s.InboundRTPStreamStats.PacketsReceived == prevStats.InboundRTPStreamStats.PacketsReceived { + continue + } + + receivedDiff := s.InboundRTPStreamStats.PacketsReceived - prevStats.InboundRTPStreamStats.PacketsReceived + potentiallyLost := int64(s.RemoteOutboundRTPStreamStats.PacketsSent) - int64(s.InboundRTPStreamStats.PacketsReceived) + prevPotentiallyLost := int64(prevStats.RemoteOutboundRTPStreamStats.PacketsSent) - int64(prevStats.InboundRTPStreamStats.PacketsReceived) + var lostDiff int64 + if prevPotentiallyLost >= 0 && potentiallyLost > prevPotentiallyLost { + lostDiff = potentiallyLost - prevPotentiallyLost + } + totalLost += float64(lostDiff) + totalReceived += float64(receivedDiff) + totalJitter += s.InboundRTPStreamStats.Jitter + + statsCount++ + } + + if statsCount > 0 { + avgJitter = totalJitter / statsCount + avgLossRate = totalLost / totalReceived + } + + return +} + +func (m *rtcMonitor) processStats(sndStats, rcvStats map[webrtc.SSRC]*stats.Stats) { + defer func() { + // cache stats for the next iteration + m.lastSndStats = sndStats + m.lastRcvStats = rcvStats + }() + + sndLossRate, sndJitter, sndCnt := m.getAvgSenderStats(sndStats) + rcvLossRate, rcvJitter, rcvCnt := m.getAvgReceiverStats(rcvStats) + + // nothing to do if we didn't process any stats + if sndCnt == 0 && rcvCnt == 0 { + return + } + + select { + case m.statsCh <- rtcStats{lossRate: max(sndLossRate, rcvLossRate), jitter: max(sndJitter, rcvJitter)}: + default: + m.log.Error("failed to send stats: channel is full") + } +} + +func (m *rtcMonitor) Start() { + m.log.Debug("starting rtc monitor") + go func() { + defer close(m.doneCh) + ticker := time.NewTicker(m.interval) + + for { + select { + case <-ticker.C: + sndStats, rcvStats := m.gatherStats() + m.processStats(sndStats, rcvStats) + case <-m.stopCh: + return + } + } + }() +} + +func (m *rtcMonitor) Stop() { + m.log.Debug("stopping rtc monitor") + close(m.stopCh) + <-m.doneCh + close(m.statsCh) +} + +func (m *rtcMonitor) StatsCh() <-chan rtcStats { + return m.statsCh +} diff --git a/client/rtc_monitor_test.go b/client/rtc_monitor_test.go new file mode 100644 index 0000000..bf0737d --- /dev/null +++ b/client/rtc_monitor_test.go @@ -0,0 +1,121 @@ +// Copyright (c) 2022-present Mattermost, Inc. All Rights Reserved. +// See LICENSE.txt for license information. + +package client + +import ( + "log/slog" + "os" + "sync" + "testing" + "time" + + "github.com/pion/interceptor/pkg/stats" + "github.com/pion/webrtc/v3" + + "github.com/stretchr/testify/require" +) + +type statsGetter struct{} + +func (sg *statsGetter) Get(_ uint32) *stats.Stats { + return nil +} + +func TestRTCMonitor(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + })) + + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + require.NoError(t, err) + defer pc.Close() + + var sg statsGetter + rtcMon := newRTCMonitor(logger, pc, &sg, time.Second) + require.NotNil(t, rtcMon) + + rtcMon.Start() + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + s := <-rtcMon.StatsCh() + require.Empty(t, s) + }() + + time.Sleep(2 * time.Second) + + rtcMon.Stop() + wg.Wait() +} + +func TestRTCMonitorProcessStats(t *testing.T) { + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelDebug, + })) + + pc, err := webrtc.NewPeerConnection(webrtc.Configuration{}) + require.NoError(t, err) + defer pc.Close() + + var sg statsGetter + rtcMon := newRTCMonitor(logger, pc, &sg, time.Second) + require.NotNil(t, rtcMon) + + rtcMon.lastSndStats = map[webrtc.SSRC]*stats.Stats{ + 45454545: { + RemoteInboundRTPStreamStats: stats.RemoteInboundRTPStreamStats{}, + OutboundRTPStreamStats: stats.OutboundRTPStreamStats{ + SentRTPStreamStats: stats.SentRTPStreamStats{ + PacketsSent: 45, + }, + }, + }, + } + rtcMon.lastRcvStats = map[webrtc.SSRC]*stats.Stats{ + 45454545: { + InboundRTPStreamStats: stats.InboundRTPStreamStats{ + ReceivedRTPStreamStats: stats.ReceivedRTPStreamStats{ + PacketsReceived: 45, + }, + }, + }, + } + rtcMon.processStats(map[webrtc.SSRC]*stats.Stats{ + 45454545: { + RemoteInboundRTPStreamStats: stats.RemoteInboundRTPStreamStats{ + FractionLost: 0.45, + ReceivedRTPStreamStats: stats.ReceivedRTPStreamStats{ + Jitter: 0.4545, + }, + }, + OutboundRTPStreamStats: stats.OutboundRTPStreamStats{ + SentRTPStreamStats: stats.SentRTPStreamStats{ + PacketsSent: 4545, + }, + }, + }, + }, map[webrtc.SSRC]*stats.Stats{ + 45454545: { + InboundRTPStreamStats: stats.InboundRTPStreamStats{ + ReceivedRTPStreamStats: stats.ReceivedRTPStreamStats{ + PacketsReceived: 4545, + }, + }, + }, + }) + + select { + case stats := <-rtcMon.StatsCh(): + require.Equal(t, rtcStats{ + lossRate: 0.45, + jitter: 0.4545, + }, stats) + default: + require.Fail(t, "channel should have stats") + } +} diff --git a/service/perf/metrics.go b/service/perf/metrics.go index 57d887f..b614671 100644 --- a/service/perf/metrics.go +++ b/service/perf/metrics.go @@ -12,8 +12,9 @@ import ( ) const ( - metricsSubSystemRTC = "rtc" - metricsSubSystemWS = "ws" + metricsSubSystemRTC = "rtc" + metricsSubSystemRTCClient = "rtc_client" + metricsSubSystemWS = "ws" ) type Metrics struct { @@ -25,6 +26,10 @@ type Metrics struct { RTCConnStateCounters *prometheus.CounterVec RTCErrors *prometheus.CounterVec + RTCClientLoss *prometheus.HistogramVec + RTCClientRTT *prometheus.HistogramVec + RTCClientJitter *prometheus.HistogramVec + WSConnections *prometheus.GaugeVec WSMessageCounters *prometheus.CounterVec } @@ -119,6 +124,41 @@ func NewMetrics(namespace string, registry *prometheus.Registry) *Metrics { ) m.registry.MustRegister(m.WSMessageCounters) + // Client metrics + + m.RTCClientLoss = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: metricsSubSystemRTCClient, + Name: "loss_rate", + Help: "Client loss rate", + }, + []string{"groupID"}, + ) + m.registry.MustRegister(m.RTCClientLoss) + + m.RTCClientRTT = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: metricsSubSystemRTCClient, + Name: "rtt", + Help: "Client round trip time", + }, + []string{"groupID"}, + ) + m.registry.MustRegister(m.RTCClientRTT) + + m.RTCClientJitter = prometheus.NewHistogramVec( + prometheus.HistogramOpts{ + Namespace: namespace, + Subsystem: metricsSubSystemRTCClient, + Name: "jitter", + Help: "Client latency jitter", + }, + []string{"groupID"}, + ) + m.registry.MustRegister(m.RTCClientJitter) + return &m } @@ -165,3 +205,15 @@ func (m *Metrics) Handler() http.Handler { func (m *Metrics) ObserveRTPTracksWrite(groupID, trackType string, dur float64) { m.RTPTrackWrites.With(prometheus.Labels{"groupID": groupID, "type": trackType}).Observe(dur) } + +func (m *Metrics) ObserveRTCClientLossRate(groupID string, val float64) { + m.RTCClientLoss.With(prometheus.Labels{"groupID": groupID}).Observe(val) +} + +func (m *Metrics) ObserveRTCClientRTT(groupID string, val float64) { + m.RTCClientRTT.With(prometheus.Labels{"groupID": groupID}).Observe(val) +} + +func (m *Metrics) ObserveRTCClientJitter(groupID string, val float64) { + m.RTCClientJitter.With(prometheus.Labels{"groupID": groupID}).Observe(val) +} diff --git a/service/rtc/dc/msg.go b/service/rtc/dc/msg.go index 558da44..3952f0b 100644 --- a/service/rtc/dc/msg.go +++ b/service/rtc/dc/msg.go @@ -19,9 +19,12 @@ import ( type MessageType uint8 const ( - MessageTypePing MessageType = iota + 1 // no payload - MessageTypePong // no payload - MessageTypeSDP // MessageSDP + MessageTypePing MessageType = iota + 1 // no payload + MessageTypePong // no payload + MessageTypeSDP // MessageSDP + MessageTypeLossRate // float64 + MessageTypeRoundTripTime // float64 + MessageTypeJitter // float64 ) // Supported payloads @@ -104,6 +107,17 @@ func DecodeMessage(msg []byte) (MessageType, any, error) { return 0, nil, fmt.Errorf("failed to unpack sdp data: %w", err) } return MessageTypeSDP, unpacked, nil + case MessageTypeLossRate: + fallthrough + case MessageTypeRoundTripTime: + fallthrough + case MessageTypeJitter: + var payload float64 + err := dec.Decode(&payload) + if err != nil { + return 0, nil, fmt.Errorf("failed to decode message type %d: %w", t, err) + } + return MessageType(t), payload, nil } return 0, nil, fmt.Errorf("unexpected dc message type: %d", t) diff --git a/service/rtc/metrics.go b/service/rtc/metrics.go index e17eae8..604df72 100644 --- a/service/rtc/metrics.go +++ b/service/rtc/metrics.go @@ -11,4 +11,9 @@ type Metrics interface { IncRTPTracks(groupID string, direction, trackType string) DecRTPTracks(groupID string, direction, trackType string) ObserveRTPTracksWrite(groupID, trackType string, dur float64) + + // Client metrics + ObserveRTCClientLossRate(groupID string, val float64) + ObserveRTCClientRTT(groupID string, val float64) + ObserveRTCClientJitter(groupID string, val float64) } diff --git a/service/rtc/server.go b/service/rtc/server.go index 38543b5..3b7d011 100644 --- a/service/rtc/server.go +++ b/service/rtc/server.go @@ -403,6 +403,12 @@ func (s *Server) handleDCMessage(data []byte, us *session, dataCh *webrtc.DataCh if err := s.handleIncomingSDP(us, us.dcSDPCh, payload.([]byte)); err != nil { return fmt.Errorf("failed to handle incoming sdp message: %w", err) } + case dc.MessageTypeLossRate: + s.metrics.ObserveRTCClientLossRate(us.cfg.GroupID, payload.(float64)) + case dc.MessageTypeRoundTripTime: + s.metrics.ObserveRTCClientRTT(us.cfg.GroupID, payload.(float64)) + case dc.MessageTypeJitter: + s.metrics.ObserveRTCClientJitter(us.cfg.GroupID, payload.(float64)) } return nil