Skip to content

Commit

Permalink
[MM-60561] RTC client metrics (#159)
Browse files Browse the repository at this point in the history
* Client metrics

* Implement rtc stats monitor

* Remove unnecessary parenthesis
  • Loading branch information
streamer45 authored Oct 9, 2024
1 parent a22df51 commit 5e0ced3
Show file tree
Hide file tree
Showing 10 changed files with 504 additions and 13 deletions.
5 changes: 5 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ type Client struct {
receivers map[string][]*webrtc.RTPReceiver
voiceSender *webrtc.RTPSender
screenTransceivers []*webrtc.RTPTransceiver
rtcMon *rtcMonitor

state int32

Expand Down Expand Up @@ -225,6 +226,10 @@ func (c *Client) emit(eventType EventType, ctx any) {
func (c *Client) close() {
atomic.StoreInt32(&c.state, clientStateClosed)

if c.rtcMon != nil {
c.rtcMon.Stop()
}

if c.pc != nil {
if err := c.pc.Close(); err != nil {
c.log.Error("failed to close peer connection", slog.String("err", err.Error()))
Expand Down
2 changes: 2 additions & 0 deletions client/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type Config struct {
// EnableDCSignaling controls whether the client should use data channels
// for signaling of media tracks.
EnableDCSignaling bool
// EnableRTCMonitor controls whether the RTC monitor component should be enabled.
EnableRTCMonitor bool

wsURL string
}
Expand Down
14 changes: 8 additions & 6 deletions client/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,17 +357,19 @@ func setupTestHelper(tb testing.TB, channelName string) *TestHelper {
}

th.adminClient, err = New(Config{
SiteURL: th.apiURL,
AuthToken: th.adminAPIClient.AuthToken,
ChannelID: channelID,
SiteURL: th.apiURL,
AuthToken: th.adminAPIClient.AuthToken,
ChannelID: channelID,
EnableRTCMonitor: true,
}, WithLogger(logger))
require.NoError(tb, err)
require.NotNil(tb, th.adminClient)

th.userClient, err = New(Config{
SiteURL: th.apiURL,
AuthToken: th.userAPIClient.AuthToken,
ChannelID: channelID,
SiteURL: th.apiURL,
AuthToken: th.userAPIClient.AuthToken,
ChannelID: channelID,
EnableRTCMonitor: true,
}, WithLogger(logger))
require.NoError(tb, err)
require.NotNil(tb, th.userClient)
Expand Down
93 changes: 91 additions & 2 deletions client/rtc.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ import (
"io"
"log/slog"
"sync/atomic"
"time"

"github.com/mattermost/rtcd/service/rtc/dc"

"github.com/pion/interceptor"
"github.com/pion/interceptor/pkg/stats"
"github.com/pion/rtcp"
"github.com/pion/webrtc/v3"
)
Expand All @@ -25,8 +27,10 @@ const (
signalMsgOffer = "offer"
signalMsgAnswer = "answer"

iceChSize = 20
receiveMTU = 1460
iceChSize = 20
receiveMTU = 1460
rtcMonitorInterval = 4 * time.Second
pingInterval = time.Second
)

var (
Expand Down Expand Up @@ -186,6 +190,17 @@ func (c *Client) initRTCSession() error {
}

i := interceptor.Registry{}

statsInterceptorFactory, err := stats.NewInterceptor()
if err != nil {
return fmt.Errorf("failed to create stats interceptor: %w", err)
}
var statsGetter stats.Getter
statsInterceptorFactory.OnNewPeerConnection(func(_ string, g stats.Getter) {
statsGetter = g
})
i.Add(statsInterceptorFactory)

if err := webrtc.RegisterDefaultInterceptors(&m, &i); err != nil {
return fmt.Errorf("failed to register default interceptors: %w", err)
}
Expand All @@ -206,6 +221,14 @@ func (c *Client) initRTCSession() error {
c.pc = pc
c.mut.Unlock()

rtcMon := newRTCMonitor(c.log, pc, statsGetter, rtcMonitorInterval)
if c.cfg.EnableRTCMonitor {
c.mut.Lock()
c.rtcMon = rtcMon
c.mut.Unlock()
rtcMon.Start()
}

pc.OnICECandidate(func(candidate *webrtc.ICECandidate) {
if candidate == nil {
c.log.Debug("local ICE gathering completed")
Expand Down Expand Up @@ -374,6 +397,69 @@ func (c *Client) initRTCSession() error {
}
c.dc.Store(dataCh)

lastPingTS := new(int64)
lastRTT := new(int64)
go func() {
pingTicker := time.NewTicker(pingInterval)
for {
select {
case <-pingTicker.C:
msg, err := dc.EncodeMessage(dc.MessageTypePing, nil)
if err != nil {
c.log.Error("failed to encode ping msg", slog.String("err", err.Error()))
continue
}

if err := dataCh.Send(msg); err != nil {
c.log.Error("failed to send ping msg", slog.String("err", err.Error()))
continue
}

atomic.StoreInt64(lastPingTS, time.Now().UnixMilli())
case stats := <-rtcMon.StatsCh():
c.log.Debug("rtc stats",
slog.Float64("lossRate", stats.lossRate),
slog.Int64("rtt", atomic.LoadInt64(lastRTT)),
slog.Float64("jitter", stats.jitter))

if stats.lossRate >= 0 {
msg, err := dc.EncodeMessage(dc.MessageTypeLossRate, stats.lossRate)
if err != nil {
c.log.Error("failed to encode loss rate msg", slog.String("err", err.Error()))
} else {
if err := dataCh.Send(msg); err != nil {
c.log.Error("failed to send loss rate msg", slog.String("err", err.Error()))
}
}
}

if rtt := atomic.LoadInt64(lastRTT); rtt > 0 {
msg, err := dc.EncodeMessage(dc.MessageTypeRoundTripTime, float64(rtt/1000))
if err != nil {
c.log.Error("failed to encode rtt msg", slog.String("err", err.Error()))
} else {
if err := dataCh.Send(msg); err != nil {
c.log.Error("failed to send rtt msg", slog.String("err", err.Error()))
}
}
}

if stats.jitter > 0 {
msg, err := dc.EncodeMessage(dc.MessageTypeJitter, stats.jitter)
if err != nil {
c.log.Error("failed to encode jitter msg", slog.String("err", err.Error()))
} else {
if err := dataCh.Send(msg); err != nil {
c.log.Error("failed to send jitter msg", slog.String("err", err.Error()))
}
}
}
case <-c.wsCloseCh:
return
}
}
}()

dataCh.OnMessage(func(msg webrtc.DataChannelMessage) {
mt, payload, err := dc.DecodeMessage(msg.Data)
if err != nil {
Expand All @@ -383,6 +469,9 @@ func (c *Client) initRTCSession() error {

switch mt {
case dc.MessageTypePong:
if ts := atomic.LoadInt64(lastPingTS); ts > 0 {
atomic.StoreInt64(lastRTT, time.Now().UnixMilli()-ts)
}
case dc.MessageTypeSDP:
var sdp webrtc.SessionDescription
if err := json.Unmarshal(payload.([]byte), &sdp); err != nil {
Expand Down
195 changes: 195 additions & 0 deletions client/rtc_monitor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
// Copyright (c) 2022-present Mattermost, Inc. All Rights Reserved.
// See LICENSE.txt for license information.

package client

import (
"log/slog"
"time"

"github.com/pion/interceptor/pkg/stats"
"github.com/pion/webrtc/v3"
)

type rtcMonitor struct {
log *slog.Logger
pc *webrtc.PeerConnection
statsGetter stats.Getter
interval time.Duration

lastSndStats map[webrtc.SSRC]*stats.Stats
lastRcvStats map[webrtc.SSRC]*stats.Stats

statsCh chan rtcStats
stopCh chan struct{}
doneCh chan struct{}
}

type rtcStats struct {
lossRate float64
jitter float64
}

func newRTCMonitor(log *slog.Logger, pc *webrtc.PeerConnection, sg stats.Getter, intv time.Duration) *rtcMonitor {
return &rtcMonitor{
log: log,
pc: pc,
statsGetter: sg,
interval: intv,
statsCh: make(chan rtcStats, 1),
stopCh: make(chan struct{}),
doneCh: make(chan struct{}),
lastSndStats: make(map[webrtc.SSRC]*stats.Stats),
lastRcvStats: make(map[webrtc.SSRC]*stats.Stats),
}
}

func (m *rtcMonitor) gatherStats() (map[webrtc.SSRC]*stats.Stats, map[webrtc.SSRC]*stats.Stats) {
sndStats := make(map[webrtc.SSRC]*stats.Stats)
for _, snd := range m.pc.GetSenders() {
if snd == nil {
continue
}
for i, enc := range snd.GetParameters().Encodings {
// For simplicity we only consider audio streams.
// This lets us more easily make assumptions on the clock rate.
if snd.GetParameters().Codecs[i].MimeType != webrtc.MimeTypeOpus {
continue
}

stats := m.statsGetter.Get(uint32(enc.SSRC))
if stats != nil {
sndStats[enc.SSRC] = stats
}
}
}

rcvStats := make(map[webrtc.SSRC]*stats.Stats)
for _, rcv := range m.pc.GetReceivers() {
if rcv == nil {
continue
}

track := rcv.Track()

if track == nil {
continue
}

// For simplicity we only consider audio streams.
// This lets us more easily make assumptions on the clock rate.
if track.Codec().MimeType != webrtc.MimeTypeOpus {
continue
}

stats := m.statsGetter.Get(uint32(track.SSRC()))
if stats != nil {
rcvStats[track.SSRC()] = stats
}
}

return sndStats, rcvStats
}

func (m *rtcMonitor) getAvgSenderStats(stats map[webrtc.SSRC]*stats.Stats) (avgLossRate, avgJitter, statsCount float64) {
var totalJitter, totalLossRate float64

for ssrc, s := range stats {
if prevStats := m.lastSndStats[ssrc]; prevStats == nil || s.OutboundRTPStreamStats.PacketsSent == prevStats.OutboundRTPStreamStats.PacketsSent {
continue
}

totalLossRate += s.RemoteInboundRTPStreamStats.FractionLost
totalJitter += s.RemoteInboundRTPStreamStats.Jitter
statsCount++
}

if statsCount > 0 {
avgJitter = totalJitter / statsCount
avgLossRate = totalLossRate / statsCount
}

return
}

func (m *rtcMonitor) getAvgReceiverStats(stats map[webrtc.SSRC]*stats.Stats) (avgLossRate, avgJitter, statsCount float64) {
var totalJitter, totalLost, totalReceived float64

for ssrc, s := range stats {
prevStats := m.lastRcvStats[ssrc]
if prevStats == nil || s.InboundRTPStreamStats.PacketsReceived == prevStats.InboundRTPStreamStats.PacketsReceived {
continue
}

receivedDiff := s.InboundRTPStreamStats.PacketsReceived - prevStats.InboundRTPStreamStats.PacketsReceived
potentiallyLost := int64(s.RemoteOutboundRTPStreamStats.PacketsSent) - int64(s.InboundRTPStreamStats.PacketsReceived)
prevPotentiallyLost := int64(prevStats.RemoteOutboundRTPStreamStats.PacketsSent) - int64(prevStats.InboundRTPStreamStats.PacketsReceived)
var lostDiff int64
if prevPotentiallyLost >= 0 && potentiallyLost > prevPotentiallyLost {
lostDiff = potentiallyLost - prevPotentiallyLost
}
totalLost += float64(lostDiff)
totalReceived += float64(receivedDiff)
totalJitter += s.InboundRTPStreamStats.Jitter

statsCount++
}

if statsCount > 0 {
avgJitter = totalJitter / statsCount
avgLossRate = totalLost / totalReceived
}

return
}

func (m *rtcMonitor) processStats(sndStats, rcvStats map[webrtc.SSRC]*stats.Stats) {
defer func() {
// cache stats for the next iteration
m.lastSndStats = sndStats
m.lastRcvStats = rcvStats
}()

sndLossRate, sndJitter, sndCnt := m.getAvgSenderStats(sndStats)
rcvLossRate, rcvJitter, rcvCnt := m.getAvgReceiverStats(rcvStats)

// nothing to do if we didn't process any stats
if sndCnt == 0 && rcvCnt == 0 {
return
}

select {
case m.statsCh <- rtcStats{lossRate: max(sndLossRate, rcvLossRate), jitter: max(sndJitter, rcvJitter)}:
default:
m.log.Error("failed to send stats: channel is full")
}
}

func (m *rtcMonitor) Start() {
m.log.Debug("starting rtc monitor")
go func() {
defer close(m.doneCh)
ticker := time.NewTicker(m.interval)

for {
select {
case <-ticker.C:
sndStats, rcvStats := m.gatherStats()
m.processStats(sndStats, rcvStats)
case <-m.stopCh:
return
}
}
}()
}

func (m *rtcMonitor) Stop() {
m.log.Debug("stopping rtc monitor")
close(m.stopCh)
<-m.doneCh
close(m.statsCh)
}

func (m *rtcMonitor) StatsCh() <-chan rtcStats {
return m.statsCh
}
Loading

0 comments on commit 5e0ced3

Please sign in to comment.