From a42f827422a2e72e94b76c9c581a690db7299f8d Mon Sep 17 00:00:00 2001 From: streamer45 Date: Thu, 4 Jul 2024 17:21:25 +0200 Subject: [PATCH] Suppress push notifications to call thread if receiver is in call --- server/db/calls_sessions_store.go | 31 ++++ server/db/calls_sessions_store_test.go | 58 ++++++++ server/push_notifications.go | 18 +++ server/push_notifications_test.go | 198 +++++++++++++++++++++++++ 4 files changed, 305 insertions(+) create mode 100644 server/push_notifications_test.go diff --git a/server/db/calls_sessions_store.go b/server/db/calls_sessions_store.go index b97ca06bd..c87282bb0 100644 --- a/server/db/calls_sessions_store.go +++ b/server/db/calls_sessions_store.go @@ -215,3 +215,34 @@ func (s *Store) GetCallSessionsCount(callID string, opts GetCallSessionOpts) (in return count, nil } + +func (s *Store) IsUserInCall(userID, callID string, opts GetCallSessionOpts) (bool, error) { + s.metrics.IncStoreOp("IsUserInCall") + defer func(start time.Time) { + s.metrics.ObserveStoreMethodsTime("IsUserInCall", time.Since(start).Seconds()) + }(time.Now()) + + qb := getQueryBuilder(s.driverName).Select("1"). + From("calls_sessions"). + Where( + sq.And{ + sq.Eq{"CallID": callID}, + sq.Eq{"UserID": userID}, + }) + + q, args, err := qb.ToSql() + if err != nil { + return false, fmt.Errorf("failed to prepare query: %w", err) + } + + var ok bool + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*s.settings.QueryTimeout)*time.Second) + defer cancel() + if err := s.dbXFromGetOpts(opts).GetContext(ctx, &ok, q, args...); err == sql.ErrNoRows { + return false, nil + } else if err != nil { + return false, fmt.Errorf("failed to get user in call: %w", err) + } + + return ok, nil +} diff --git a/server/db/calls_sessions_store_test.go b/server/db/calls_sessions_store_test.go index 3a136c9f6..12424e44d 100644 --- a/server/db/calls_sessions_store_test.go +++ b/server/db/calls_sessions_store_test.go @@ -24,6 +24,7 @@ func TestCallsSessionsStore(t *testing.T) { "TestGetCallSessions": testGetCallSessions, "TestDeleteCallsSessions": testDeleteCallsSessions, "TestGetCallSessionsCount": testGetCallSessionsCount, + "TestIsUserInCall": testIsUserInCall, }) } @@ -253,3 +254,60 @@ func testGetCallSessionsCount(t *testing.T, store *Store) { require.Equal(t, 10, cnt) }) } + +func testIsUserInCall(t *testing.T, store *Store) { + t.Run("no sessions", func(t *testing.T) { + ok, err := store.IsUserInCall(model.NewId(), model.NewId(), GetCallSessionOpts{}) + require.NoError(t, err) + require.False(t, ok) + }) + + t.Run("multiple sessions, user not in call", func(t *testing.T) { + sessions := map[string]*public.CallSession{} + callID := model.NewId() + for i := 0; i < 10; i++ { + session := &public.CallSession{ + ID: model.NewId(), + CallID: callID, + UserID: model.NewId(), + JoinAt: time.Now().UnixMilli(), + } + + err := store.CreateCallSession(session) + require.NoError(t, err) + + sessions[session.ID] = session + } + + ok, err := store.IsUserInCall(model.NewId(), callID, GetCallSessionOpts{}) + require.NoError(t, err) + require.False(t, ok) + }) + + t.Run("multiple sessions, user in call", func(t *testing.T) { + sessions := map[string]*public.CallSession{} + callID := model.NewId() + userID := model.NewId() + for i := 0; i < 10; i++ { + if i > 0 { + userID = model.NewId() + } + + session := &public.CallSession{ + ID: model.NewId(), + CallID: callID, + UserID: userID, + JoinAt: time.Now().UnixMilli(), + } + + err := store.CreateCallSession(session) + require.NoError(t, err) + + sessions[session.ID] = session + } + + ok, err := store.IsUserInCall(userID, callID, GetCallSessionOpts{}) + require.NoError(t, err) + require.True(t, ok) + }) +} diff --git a/server/push_notifications.go b/server/push_notifications.go index 583bae23f..ccd250cbd 100644 --- a/server/push_notifications.go +++ b/server/push_notifications.go @@ -1,13 +1,31 @@ package main import ( + "errors" "fmt" + "github.com/mattermost/mattermost-plugin-calls/server/db" + "github.com/mattermost/mattermost/server/public/model" "github.com/mattermost/mattermost/server/public/shared/i18n" ) func (p *Plugin) NotificationWillBePushed(notification *model.PushNotification, userID string) (*model.PushNotification, string) { + // If the user is in a call we suppress notifications for replies to the call thread. + call, err := p.store.GetActiveCallByChannelID(notification.ChannelId, db.GetCallOpts{}) + if err == nil && call.ThreadID == notification.RootId { + isUserInCall, err := p.store.IsUserInCall(userID, call.ID, db.GetCallSessionOpts{}) + if err != nil { + p.LogError("store.IsUserInCall failed", "err", err.Error()) + } else if isUserInCall { + msg := "calls: suppressing notification on call thread for connected user" + p.LogDebug(msg, "userID", userID, "channelID", notification.ChannelId, "threadID", call.ThreadID, "callID", call.ID) + return nil, msg + } + } else if err != nil && !errors.Is(err, db.ErrNotFound) { + p.LogError("store.GetActiveCallByChannelID failed", "err", err.Error()) + } + // We will use our own notifications if: // 1. This is a call start post // 2. We have enabled ringing diff --git a/server/push_notifications_test.go b/server/push_notifications_test.go new file mode 100644 index 000000000..6fb6a47b3 --- /dev/null +++ b/server/push_notifications_test.go @@ -0,0 +1,198 @@ +package main + +import ( + "testing" + "time" + + "github.com/mattermost/mattermost-plugin-calls/server/public" + + pluginMocks "github.com/mattermost/mattermost-plugin-calls/server/mocks/github.com/mattermost/mattermost/server/public/plugin" + + "github.com/mattermost/mattermost/server/public/model" + "github.com/mattermost/mattermost/server/public/plugin" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func TestNotificationWillBePushed(t *testing.T) { + mockAPI := &pluginMocks.MockAPI{} + + store, tearDown := NewTestStore(t) + t.Cleanup(tearDown) + + p := Plugin{ + MattermostPlugin: plugin.MattermostPlugin{ + API: mockAPI, + }, + store: store, + } + + t.Run("not a call post", func(t *testing.T) { + res, msg := p.NotificationWillBePushed(&model.PushNotification{}, "userID") + require.Nil(t, res) + require.Empty(t, msg) + }) + + t.Run("user not in call", func(t *testing.T) { + channelID := model.NewId() + threadID := model.NewId() + + err := p.store.CreateCall(&public.Call{ + ID: model.NewId(), + CreateAt: time.Now().UnixMilli(), + ChannelID: channelID, + StartAt: time.Now().UnixMilli(), + PostID: model.NewId(), + ThreadID: threadID, + OwnerID: model.NewId(), + }) + require.NoError(t, err) + + res, msg := p.NotificationWillBePushed(&model.PushNotification{ + ChannelId: channelID, + RootId: threadID, + }, "userID") + require.Nil(t, res) + require.Empty(t, msg) + }) + + t.Run("user in call", func(t *testing.T) { + defer mockAPI.AssertExpectations(t) + + channelID := model.NewId() + threadID := model.NewId() + userID := model.NewId() + callID := model.NewId() + + err := p.store.CreateCall(&public.Call{ + ID: callID, + CreateAt: time.Now().UnixMilli(), + ChannelID: channelID, + StartAt: time.Now().UnixMilli(), + PostID: model.NewId(), + ThreadID: threadID, + OwnerID: model.NewId(), + }) + require.NoError(t, err) + + err = p.store.CreateCallSession(&public.CallSession{ + ID: model.NewId(), + CallID: callID, + UserID: userID, + JoinAt: time.Now().UnixMilli(), + }) + require.NoError(t, err) + + mockAPI.On("LogDebug", "calls: suppressing notification on call thread for connected user", + "origin", mock.AnythingOfType("string"), + "userID", userID, "channelID", channelID, "threadID", threadID, "callID", callID).Once() + + res, msg := p.NotificationWillBePushed(&model.PushNotification{ + ChannelId: channelID, + RootId: threadID, + }, userID) + require.Nil(t, res) + require.Equal(t, "calls: suppressing notification on call thread for connected user", msg) + }) + + t.Run("DM/GM ringing", func(t *testing.T) { + defer mockAPI.AssertExpectations(t) + var cfg configuration + cfg.SetDefaults() + + mockAPI.On("GetLicense").Return(&model.License{}, nil).Times(2) + *cfg.EnableRinging = true + err := p.setConfiguration(cfg.Clone()) + require.NoError(t, err) + require.True(t, *p.getConfiguration().EnableRinging) + + t.Run("not a call post", func(t *testing.T) { + res, msg := p.NotificationWillBePushed(&model.PushNotification{}, "userID") + require.Nil(t, res) + require.Empty(t, msg) + }) + + t.Run("call post but ringing disabled", func(t *testing.T) { + *cfg.EnableRinging = false + err := p.setConfiguration(cfg.Clone()) + require.NoError(t, err) + require.False(t, *p.getConfiguration().EnableRinging) + + res, msg := p.NotificationWillBePushed(&model.PushNotification{ + PostType: callStartPostType, + }, "userID") + require.Nil(t, res) + require.Empty(t, msg) + }) + + t.Run("custom notification for DMs/GMs", func(t *testing.T) { + *cfg.EnableRinging = true + err := p.setConfiguration(cfg.Clone()) + require.NoError(t, err) + require.True(t, *p.getConfiguration().EnableRinging) + + res, msg := p.NotificationWillBePushed(&model.PushNotification{ + PostType: callStartPostType, + ChannelType: model.ChannelTypeDirect, + }, "userID") + require.Nil(t, res) + require.Equal(t, "calls plugin will handle this notification", msg) + + res, msg = p.NotificationWillBePushed(&model.PushNotification{ + PostType: callStartPostType, + ChannelType: model.ChannelTypeGroup, + }, "userID") + require.Nil(t, res) + require.Equal(t, "calls plugin will handle this notification", msg) + }) + + t.Run("regular channel", func(t *testing.T) { + mockAPI.On("GetUser", "receiverID").Return(&model.User{ + FirstName: "Firstname", + LastName: "Lastname", + }, nil).Twice() + + var serverConfig model.Config + serverConfig.SetDefaults() + mockAPI.On("GetConfig").Return(&serverConfig).Once() + + mockAPI.On("GetPreferencesForUser", "receiverID").Return([]model.Preference{}, nil).Once() + + mockAPI.On("GetUser", "senderID").Return(&model.User{ + FirstName: "Sender Firstname", + LastName: "Sender Lastname", + }, nil).Once() + + res, msg := p.NotificationWillBePushed(&model.PushNotification{ + PostType: callStartPostType, + ChannelType: model.ChannelTypeOpen, + SenderId: "senderID", + }, "receiverID") + require.Equal(t, &model.PushNotification{ + PostType: callStartPostType, + ChannelType: model.ChannelTypeOpen, + SenderId: "senderID", + Message: "\u200bapp.push_notification.inviting_message", + }, res) + require.Empty(t, msg) + + t.Run("id loaded", func(t *testing.T) { + res, msg := p.NotificationWillBePushed(&model.PushNotification{ + PostType: callStartPostType, + ChannelType: model.ChannelTypeOpen, + SenderId: "senderID", + IsIdLoaded: true, + }, "receiverID") + require.Equal(t, &model.PushNotification{ + PostType: callStartPostType, + ChannelType: model.ChannelTypeOpen, + SenderId: "senderID", + IsIdLoaded: true, + Message: "app.push_notification.generic_message", + }, res) + require.Empty(t, msg) + }) + }) + }) +}