diff --git a/app/app.go b/app/app.go index da789c7401..4d7a02c4fa 100644 --- a/app/app.go +++ b/app/app.go @@ -17,6 +17,7 @@ import ( "github.com/target/goalert/alert/alertmetrics" "github.com/target/goalert/app/lifecycle" "github.com/target/goalert/auth" + "github.com/target/goalert/auth/authlink" "github.com/target/goalert/auth/basic" "github.com/target/goalert/auth/nonce" "github.com/target/goalert/calsub" @@ -116,9 +117,10 @@ type App struct { LimitStore *limit.Store HeartbeatStore *heartbeat.Store - OAuthKeyring keyring.Keyring - SessionKeyring keyring.Keyring - APIKeyring keyring.Keyring + OAuthKeyring keyring.Keyring + SessionKeyring keyring.Keyring + APIKeyring keyring.Keyring + AuthLinkKeyring keyring.Keyring NonceStore *nonce.Store LabelStore *label.Store @@ -126,6 +128,7 @@ type App struct { NCStore *notificationchannel.Store TimeZoneStore *timezone.Store NoticeStore *notice.Store + AuthLinkStore *authlink.Store } // NewApp constructs a new App and binds the listening socket. diff --git a/app/initengine.go b/app/initengine.go index 71ce1a2356..ed6403f3d6 100644 --- a/app/initengine.go +++ b/app/initengine.go @@ -37,6 +37,7 @@ func (app *App) initEngine(ctx context.Context) error { NCStore: app.NCStore, OnCallStore: app.OnCallStore, ScheduleStore: app.ScheduleStore, + AuthLinkStore: app.AuthLinkStore, ConfigSource: app.ConfigStore, diff --git a/app/initgraphql.go b/app/initgraphql.go index 922413bb68..a84d8e7be6 100644 --- a/app/initgraphql.go +++ b/app/initgraphql.go @@ -7,7 +7,6 @@ import ( ) func (app *App) initGraphQL(ctx context.Context) error { - app.graphql2 = &graphqlapp.App{ DB: app.db, AuthBasicStore: app.AuthBasicStore, @@ -35,11 +34,12 @@ func (app *App) initGraphQL(ctx context.Context) error { NotificationStore: app.NotificationStore, SlackStore: app.slackChan, HeartbeatStore: app.HeartbeatStore, - NoticeStore: *app.NoticeStore, + NoticeStore: app.NoticeStore, Twilio: app.twilioConfig, AuthHandler: app.AuthHandler, FormatDestFunc: app.notificationManager.FormatDestValue, - NotificationManager: *app.notificationManager, + NotificationManager: app.notificationManager, + AuthLinkStore: app.AuthLinkStore, } return nil diff --git a/app/initstores.go b/app/initstores.go index 86bf76d16b..f30d4f8692 100644 --- a/app/initstores.go +++ b/app/initstores.go @@ -7,6 +7,7 @@ import ( "github.com/target/goalert/alert" "github.com/target/goalert/alert/alertlog" "github.com/target/goalert/alert/alertmetrics" + "github.com/target/goalert/auth/authlink" "github.com/target/goalert/auth/basic" "github.com/target/goalert/auth/nonce" "github.com/target/goalert/calsub" @@ -78,6 +79,18 @@ func (app *App) initStores(ctx context.Context) error { return errors.Wrap(err, "init oauth state keyring") } + if app.AuthLinkKeyring == nil { + app.AuthLinkKeyring, err = keyring.NewDB(ctx, app.cfg.Logger, app.db, &keyring.Config{ + Name: "auth-link", + RotationDays: 1, + MaxOldKeys: 1, + Keys: app.cfg.EncryptionKeys, + }) + } + if err != nil { + return errors.Wrap(err, "init oauth state keyring") + } + if app.SessionKeyring == nil { app.SessionKeyring, err = keyring.NewDB(ctx, app.cfg.Logger, app.db, &keyring.Config{ Name: "browser-sessions", @@ -101,9 +114,19 @@ func (app *App) initStores(ctx context.Context) error { return errors.Wrap(err, "init API keyring") } + if app.AuthLinkStore == nil { + app.AuthLinkStore, err = authlink.NewStore(ctx, app.db, app.AuthLinkKeyring) + } + if err != nil { + return errors.Wrap(err, "init auth link store") + } + if app.AlertMetricsStore == nil { app.AlertMetricsStore, err = alertmetrics.NewStore(ctx, app.db) } + if err != nil { + return errors.Wrap(err, "init alert metrics store") + } if app.AlertLogStore == nil { app.AlertLogStore, err = alertlog.NewStore(ctx, app.db) diff --git a/app/shutdown.go b/app/shutdown.go index bff03f0237..16ac18b235 100644 --- a/app/shutdown.go +++ b/app/shutdown.go @@ -69,6 +69,7 @@ func (app *App) _Shutdown(ctx context.Context) error { shut(app.SessionKeyring, "session keyring") shut(app.OAuthKeyring, "oauth keyring") shut(app.APIKeyring, "API keyring") + shut(app.AuthLinkKeyring, "auth link keyring") shut(app.NonceStore, "nonce store") shut(app.ConfigStore, "config store") shut(app.requestLock, "context locker") diff --git a/auth/authlink/store.go b/auth/authlink/store.go new file mode 100644 index 0000000000..941f1517ab --- /dev/null +++ b/auth/authlink/store.go @@ -0,0 +1,186 @@ +package authlink + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "net/url" + "time" + + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "github.com/target/goalert/config" + "github.com/target/goalert/keyring" + "github.com/target/goalert/permission" + "github.com/target/goalert/util" + "github.com/target/goalert/validation" + "github.com/target/goalert/validation/validate" +) + +type Store struct { + db *sql.DB + + k keyring.Keyring + + newLink *sql.Stmt + rmLink *sql.Stmt + addSubject *sql.Stmt + findLink *sql.Stmt +} + +type Metadata struct { + UserDetails string + AlertID int `json:",omitempty"` + AlertAction string `json:",omitempty"` +} + +func (m Metadata) Validate() error { + return validate.Many( + validate.ASCII("UserDetails", m.UserDetails, 1, 255), + validate.OneOf("AlertAction", m.AlertAction, "", "ResultAcknowledge", "ResultResolve"), + ) +} + +func NewStore(ctx context.Context, db *sql.DB, k keyring.Keyring) (*Store, error) { + p := &util.Prepare{ + DB: db, + Ctx: ctx, + } + + return &Store{ + db: db, + k: k, + newLink: p.P(`insert into auth_link_requests (id, provider_id, subject_id, expires_at, metadata) values ($1, $2, $3, $4, $5)`), + rmLink: p.P(`delete from auth_link_requests where id = $1 and expires_at > now() returning provider_id, subject_id`), + addSubject: p.P(`insert into auth_subjects (provider_id, subject_id, user_id) values ($1, $2, $3)`), + findLink: p.P(`select metadata from auth_link_requests where id = $1 and expires_at > now()`), + }, p.Err +} + +func (s *Store) FindLinkMetadata(ctx context.Context, token string) (*Metadata, error) { + err := permission.LimitCheckAny(ctx, permission.User) + if err != nil { + return nil, err + } + + tokID, err := s.tokenID(ctx, token) + if err != nil { + // don't return anything, treat it as not found + return nil, nil + } + + var meta Metadata + var data json.RawMessage + err = s.findLink.QueryRowContext(ctx, tokID).Scan(&data) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, err + } + + err = json.Unmarshal(data, &meta) + if err != nil { + return nil, err + } + + return &meta, nil +} + +func (s *Store) tokenID(ctx context.Context, token string) (string, error) { + var c jwt.RegisteredClaims + _, err := s.k.VerifyJWT(token, &c) + if err != nil { + return "", validation.WrapError(err) + } + + if !c.VerifyIssuer("goalert", true) { + return "", validation.NewGenericError("invalid issuer") + } + if !c.VerifyAudience("auth-link", true) { + return "", validation.NewGenericError("invalid audience") + } + err = validate.UUID("ID", c.ID) + if err != nil { + return "", err + } + + return c.ID, nil +} + +func (s *Store) LinkAccount(ctx context.Context, token string) error { + err := permission.LimitCheckAny(ctx, permission.User) + if err != nil { + return err + } + + tokID, err := s.tokenID(ctx, token) + if err != nil { + return err + } + + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + var providerID, subjectID string + err = tx.StmtContext(ctx, s.rmLink).QueryRowContext(ctx, tokID).Scan(&providerID, &subjectID) + if errors.Is(err, sql.ErrNoRows) { + return validation.NewGenericError("invalid link token") + } + if err != nil { + return err + } + + _, err = tx.StmtContext(ctx, s.addSubject).ExecContext(ctx, providerID, subjectID, permission.UserID(ctx)) + if err != nil { + return err + } + + return tx.Commit() +} + +func (s *Store) AuthLinkURL(ctx context.Context, providerID, subjectID string, meta Metadata) (string, error) { + err := permission.LimitCheckAny(ctx, permission.System) + if err != nil { + return "", err + } + err = validate.Many( + validate.SubjectID("ProviderID", providerID), + validate.SubjectID("SubjectID", subjectID), + meta.Validate(), + ) + if err != nil { + return "", err + } + + id := uuid.New() + now := time.Now() + expires := now.Add(5 * time.Minute) + + var c jwt.RegisteredClaims + c.ID = id.String() + c.Audience = jwt.ClaimStrings{"auth-link"} + c.Issuer = "goalert" + c.NotBefore = jwt.NewNumericDate(now.Add(-2 * time.Minute)) + c.ExpiresAt = jwt.NewNumericDate(expires) + c.IssuedAt = jwt.NewNumericDate(now) + + token, err := s.k.SignJWT(c) + if err != nil { + return "", err + } + + _, err = s.newLink.ExecContext(ctx, id, providerID, subjectID, expires, meta) + if err != nil { + return "", err + } + + cfg := config.FromContext(ctx) + p := make(url.Values) + p.Set("authLinkToken", token) + return cfg.CallbackURL("/profile", p), nil +} diff --git a/devtools/mockslack/actions.go b/devtools/mockslack/actions.go index 717c3ed31c..1475822069 100644 --- a/devtools/mockslack/actions.go +++ b/devtools/mockslack/actions.go @@ -34,6 +34,10 @@ type actionBody struct { Name string TeamID string `json:"team_id"` } + Team struct { + ID string + Domain string + } ResponseURL string `json:"response_url"` Actions []actionItem } @@ -42,6 +46,18 @@ func (s *Server) ServeActionResponse(w http.ResponseWriter, r *http.Request) { var req struct { Text string Type string `json:"response_type"` + + Blocks []struct { + Type string + Text struct{ Text string } + Elements []struct { + Type string + Text struct{ Text string } + Value string + ActionID string `json:"action_id"` + URL string + } + } } if err := json.NewDecoder(r.Body).Decode(&req); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) @@ -60,11 +76,40 @@ func (s *Server) ServeActionResponse(w http.ResponseWriter, r *http.Request) { return } - msg, err := s.API().ChatPostMessage(r.Context(), ChatPostMessageOptions{ + opts := ChatPostMessageOptions{ ChannelID: a.ChannelID, - Text: req.Text, User: r.URL.Query().Get("user"), - }) + } + + if len(req.Blocks) > 0 { + // new API + for _, block := range req.Blocks { + switch block.Type { + case "section": + opts.Text = block.Text.Text + case "actions": + for _, action := range block.Elements { + if action.Type != "button" { + continue + } + + opts.Actions = append(opts.Actions, Action{ + ChannelID: a.ChannelID, + TeamID: a.TeamID, + AppID: a.AppID, + ActionID: action.ActionID, + Text: action.Text.Text, + Value: action.Value, + URL: action.URL, + }) + } + } + } + } else { + opts.Text = req.Text + } + + msg, err := s.API().ChatPostMessage(r.Context(), opts) if respondErr(w, err) { return } @@ -106,6 +151,8 @@ func (s *Server) PerformActionAs(userID string, a Action) error { p.User.Username = usr.Name p.User.Name = usr.Name p.User.TeamID = a.TeamID + p.Team.ID = a.TeamID + p.Team.Domain = "example.com" p.Channel.ID = a.ChannelID p.AppID = a.AppID diff --git a/devtools/mockslack/chatpostmessage.go b/devtools/mockslack/chatpostmessage.go index 971924dca6..7507a794a7 100644 --- a/devtools/mockslack/chatpostmessage.go +++ b/devtools/mockslack/chatpostmessage.go @@ -124,6 +124,7 @@ type Action struct { ActionID string Text string Value string + URL string } // parseAttachments parses the attachments from the payload value. @@ -174,6 +175,7 @@ func parseAttachments(appID, teamID, chanID, value string) (*attachments, error) Text textBlock ActionID string `json:"action_id"` Value string + URL string } err = json.Unmarshal(b.Elements, &acts) if err != nil { @@ -189,6 +191,7 @@ func parseAttachments(appID, teamID, chanID, value string) (*attachments, error) ActionID: a.ActionID, Text: a.Text.Text, Value: a.Value, + URL: a.URL, }) } default: diff --git a/engine/config.go b/engine/config.go index 6035d9b640..49e3bc1560 100644 --- a/engine/config.go +++ b/engine/config.go @@ -3,6 +3,7 @@ package engine import ( "github.com/target/goalert/alert" "github.com/target/goalert/alert/alertlog" + "github.com/target/goalert/auth/authlink" "github.com/target/goalert/config" "github.com/target/goalert/keyring" "github.com/target/goalert/notification" @@ -24,6 +25,7 @@ type Config struct { NCStore *notificationchannel.Store OnCallStore *oncall.Store ScheduleStore *schedule.Store + AuthLinkStore *authlink.Store ConfigSource config.Source diff --git a/engine/engine.go b/engine/engine.go index 27d300f4a3..989760ada0 100644 --- a/engine/engine.go +++ b/engine/engine.go @@ -10,6 +10,7 @@ import ( "github.com/pkg/errors" "github.com/target/goalert/alert" "github.com/target/goalert/app/lifecycle" + "github.com/target/goalert/auth/authlink" "github.com/target/goalert/engine/cleanupmanager" "github.com/target/goalert/engine/escalationmanager" "github.com/target/goalert/engine/heartbeatmanager" @@ -153,6 +154,13 @@ func NewEngine(ctx context.Context, db *sql.DB, c *Config) (*Engine, error) { return p, nil } +func (p *Engine) AuthLinkURL(ctx context.Context, providerID, subjectID string, meta authlink.Metadata) (url string, err error) { + permission.SudoContext(ctx, func(ctx context.Context) { + url, err = p.cfg.AuthLinkStore.AuthLinkURL(ctx, providerID, subjectID, meta) + }) + return url, err +} + func (p *Engine) processModule(ctx context.Context, m updater) { defer recoverPanic(ctx, m.Name()) ctx, cancel := context.WithTimeout(ctx, 30*time.Second) @@ -293,7 +301,9 @@ func (p *Engine) ReceiveSubject(ctx context.Context, providerID, subjectID, call return fmt.Errorf("failed to find user: %w", err) } if usr == nil { - return notification.ErrUnknownSubject + return ¬ification.UnknownSubjectError{ + AlertID: cb.AlertID, + } } ctx = permission.UserSourceContext(ctx, usr.ID, usr.Role, &permission.SourceInfo{ diff --git a/graphql2/generated.go b/graphql2/generated.go index 70c18e89f5..0f1657274b 100644 --- a/graphql2/generated.go +++ b/graphql2/generated.go @@ -254,6 +254,12 @@ type ComplexityRoot struct { PageInfo func(childComplexity int) int } + LinkAccountInfo struct { + AlertID func(childComplexity int) int + AlertNewStatus func(childComplexity int) int + UserDetails func(childComplexity int) int + } + Mutation struct { AddAuthSubject func(childComplexity int, input user.AuthSubject) int ClearTemporarySchedules func(childComplexity int, input ClearTemporarySchedulesInput) int @@ -276,6 +282,7 @@ type ComplexityRoot struct { DeleteAuthSubject func(childComplexity int, input user.AuthSubject) int EndAllAuthSessionsByCurrentUser func(childComplexity int) int EscalateAlerts func(childComplexity int, input []int) int + LinkAccount func(childComplexity int, token string) int SendContactMethodVerification func(childComplexity int, input SendContactMethodVerificationInput) int SetConfig func(childComplexity int, input []ConfigValueInput) int SetFavorite func(childComplexity int, input SetFavoriteInput) int @@ -358,6 +365,7 @@ type ComplexityRoot struct { LabelKeys func(childComplexity int, input *LabelKeySearchOptions) int LabelValues func(childComplexity int, input *LabelValueSearchOptions) int Labels func(childComplexity int, input *LabelSearchOptions) int + LinkAccountInfo func(childComplexity int, token string) int PhoneNumberInfo func(childComplexity int, number string) int Rotation func(childComplexity int, id string) int Rotations func(childComplexity int, input *RotationSearchOptions) int @@ -614,6 +622,7 @@ type IntegrationKeyResolver interface { Href(ctx context.Context, obj *integrationkey.IntegrationKey) (string, error) } type MutationResolver interface { + LinkAccount(ctx context.Context, token string) (bool, error) SetTemporarySchedule(ctx context.Context, input SetTemporaryScheduleInput) (bool, error) ClearTemporarySchedules(ctx context.Context, input ClearTemporarySchedulesInput) (bool, error) SetScheduleOnCallNotificationRules(ctx context.Context, input SetScheduleOnCallNotificationRulesInput) (bool, error) @@ -698,6 +707,7 @@ type QueryResolver interface { SlackChannels(ctx context.Context, input *SlackChannelSearchOptions) (*SlackChannelConnection, error) SlackChannel(ctx context.Context, id string) (*slack.Channel, error) GenerateSlackAppManifest(ctx context.Context) (string, error) + LinkAccountInfo(ctx context.Context, token string) (*LinkAccountInfo, error) } type RotationResolver interface { IsFavorite(ctx context.Context, obj *rotation.Rotation) (bool, error) @@ -1460,6 +1470,27 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.LabelConnection.PageInfo(childComplexity), true + case "LinkAccountInfo.alertID": + if e.complexity.LinkAccountInfo.AlertID == nil { + break + } + + return e.complexity.LinkAccountInfo.AlertID(childComplexity), true + + case "LinkAccountInfo.alertNewStatus": + if e.complexity.LinkAccountInfo.AlertNewStatus == nil { + break + } + + return e.complexity.LinkAccountInfo.AlertNewStatus(childComplexity), true + + case "LinkAccountInfo.userDetails": + if e.complexity.LinkAccountInfo.UserDetails == nil { + break + } + + return e.complexity.LinkAccountInfo.UserDetails(childComplexity), true + case "Mutation.addAuthSubject": if e.complexity.Mutation.AddAuthSubject == nil { break @@ -1707,6 +1738,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Mutation.EscalateAlerts(childComplexity, args["input"].([]int)), true + case "Mutation.linkAccount": + if e.complexity.Mutation.LinkAccount == nil { + break + } + + args, err := ec.field_Mutation_linkAccount_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Mutation.LinkAccount(childComplexity, args["token"].(string)), true + case "Mutation.sendContactMethodVerification": if e.complexity.Mutation.SendContactMethodVerification == nil { break @@ -2314,6 +2357,18 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.Query.Labels(childComplexity, args["input"].(*LabelSearchOptions)), true + case "Query.linkAccountInfo": + if e.complexity.Query.LinkAccountInfo == nil { + break + } + + args, err := ec.field_Query_linkAccountInfo_args(context.TODO(), rawArgs) + if err != nil { + return 0, false + } + + return e.complexity.Query.LinkAccountInfo(childComplexity, args["token"].(string)), true + case "Query.phoneNumberInfo": if e.complexity.Query.PhoneNumberInfo == nil { break @@ -3836,6 +3891,21 @@ func (ec *executionContext) field_Mutation_escalateAlerts_args(ctx context.Conte return args, nil } +func (ec *executionContext) field_Mutation_linkAccount_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 string + if tmp, ok := rawArgs["token"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("token")) + arg0, err = ec.unmarshalNID2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["token"] = arg0 + return args, nil +} + func (ec *executionContext) field_Mutation_sendContactMethodVerification_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -4409,6 +4479,21 @@ func (ec *executionContext) field_Query_labels_args(ctx context.Context, rawArgs return args, nil } +func (ec *executionContext) field_Query_linkAccountInfo_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { + var err error + args := map[string]interface{}{} + var arg0 string + if tmp, ok := rawArgs["token"]; ok { + ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("token")) + arg0, err = ec.unmarshalNID2string(ctx, tmp) + if err != nil { + return nil, err + } + } + args["token"] = arg0 + return args, nil +} + func (ec *executionContext) field_Query_phoneNumberInfo_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} @@ -9140,6 +9225,187 @@ func (ec *executionContext) fieldContext_LabelConnection_pageInfo(ctx context.Co return fc, nil } +func (ec *executionContext) _LinkAccountInfo_userDetails(ctx context.Context, field graphql.CollectedField, obj *LinkAccountInfo) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_LinkAccountInfo_userDetails(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.UserDetails, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(string) + fc.Result = res + return ec.marshalNString2string(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_LinkAccountInfo_userDetails(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "LinkAccountInfo", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type String does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _LinkAccountInfo_alertID(ctx context.Context, field graphql.CollectedField, obj *LinkAccountInfo) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_LinkAccountInfo_alertID(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.AlertID, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*int) + fc.Result = res + return ec.marshalOInt2ᚖint(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_LinkAccountInfo_alertID(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "LinkAccountInfo", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Int does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _LinkAccountInfo_alertNewStatus(ctx context.Context, field graphql.CollectedField, obj *LinkAccountInfo) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_LinkAccountInfo_alertNewStatus(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return obj.AlertNewStatus, nil + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*AlertStatus) + fc.Result = res + return ec.marshalOAlertStatus2ᚖgithub.comᚋtargetᚋgoalertᚋgraphql2ᚐAlertStatus(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_LinkAccountInfo_alertNewStatus(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "LinkAccountInfo", + Field: field, + IsMethod: false, + IsResolver: false, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type AlertStatus does not have child fields") + }, + } + return fc, nil +} + +func (ec *executionContext) _Mutation_linkAccount(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Mutation_linkAccount(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Mutation().LinkAccount(rctx, fc.Args["token"].(string)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return graphql.Null + } + res := resTmp.(bool) + fc.Result = res + return ec.marshalNBoolean2bool(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Mutation_linkAccount(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Mutation", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + return nil, errors.New("field of type Boolean does not have child fields") + }, + } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Mutation_linkAccount_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return + } + return fc, nil +} + func (ec *executionContext) _Mutation_setTemporarySchedule(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Mutation_setTemporarySchedule(ctx, field) if err != nil { @@ -14931,6 +15197,66 @@ func (ec *executionContext) fieldContext_Query_generateSlackAppManifest(ctx cont return fc, nil } +func (ec *executionContext) _Query_linkAccountInfo(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { + fc, err := ec.fieldContext_Query_linkAccountInfo(ctx, field) + if err != nil { + return graphql.Null + } + ctx = graphql.WithFieldContext(ctx, fc) + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = graphql.Null + } + }() + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Query().LinkAccountInfo(rctx, fc.Args["token"].(string)) + }) + if err != nil { + ec.Error(ctx, err) + return graphql.Null + } + if resTmp == nil { + return graphql.Null + } + res := resTmp.(*LinkAccountInfo) + fc.Result = res + return ec.marshalOLinkAccountInfo2ᚖgithub.comᚋtargetᚋgoalertᚋgraphql2ᚐLinkAccountInfo(ctx, field.Selections, res) +} + +func (ec *executionContext) fieldContext_Query_linkAccountInfo(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { + fc = &graphql.FieldContext{ + Object: "Query", + Field: field, + IsMethod: true, + IsResolver: true, + Child: func(ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) { + switch field.Name { + case "userDetails": + return ec.fieldContext_LinkAccountInfo_userDetails(ctx, field) + case "alertID": + return ec.fieldContext_LinkAccountInfo_alertID(ctx, field) + case "alertNewStatus": + return ec.fieldContext_LinkAccountInfo_alertNewStatus(ctx, field) + } + return nil, fmt.Errorf("no field named %q was found under type LinkAccountInfo", field.Name) + }, + } + defer func() { + if r := recover(); r != nil { + err = ec.Recover(ctx, r) + ec.Error(ctx, err) + } + }() + ctx = graphql.WithFieldContext(ctx, fc) + if fc.Args, err = ec.field_Query_linkAccountInfo_args(ctx, field.ArgumentMap(ec.Variables)); err != nil { + ec.Error(ctx, err) + return + } + return fc, nil +} + func (ec *executionContext) _Query___type(ctx context.Context, field graphql.CollectedField) (ret graphql.Marshaler) { fc, err := ec.fieldContext_Query___type(ctx, field) if err != nil { @@ -27195,6 +27521,42 @@ func (ec *executionContext) _LabelConnection(ctx context.Context, sel ast.Select return out } +var linkAccountInfoImplementors = []string{"LinkAccountInfo"} + +func (ec *executionContext) _LinkAccountInfo(ctx context.Context, sel ast.SelectionSet, obj *LinkAccountInfo) graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, linkAccountInfoImplementors) + out := graphql.NewFieldSet(fields) + var invalids uint32 + for i, field := range fields { + switch field.Name { + case "__typename": + out.Values[i] = graphql.MarshalString("LinkAccountInfo") + case "userDetails": + + out.Values[i] = ec._LinkAccountInfo_userDetails(ctx, field, obj) + + if out.Values[i] == graphql.Null { + invalids++ + } + case "alertID": + + out.Values[i] = ec._LinkAccountInfo_alertID(ctx, field, obj) + + case "alertNewStatus": + + out.Values[i] = ec._LinkAccountInfo_alertNewStatus(ctx, field, obj) + + default: + panic("unknown field " + strconv.Quote(field.Name)) + } + } + out.Dispatch() + if invalids > 0 { + return graphql.Null + } + return out +} + var mutationImplementors = []string{"Mutation"} func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet) graphql.Marshaler { @@ -27214,6 +27576,15 @@ func (ec *executionContext) _Mutation(ctx context.Context, sel ast.SelectionSet) switch field.Name { case "__typename": out.Values[i] = graphql.MarshalString("Mutation") + case "linkAccount": + + out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) { + return ec._Mutation_linkAccount(ctx, field) + }) + + if out.Values[i] == graphql.Null { + invalids++ + } case "setTemporarySchedule": out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) { @@ -28601,6 +28972,26 @@ func (ec *executionContext) _Query(ctx context.Context, sel ast.SelectionSet) gr return ec.OperationContext.RootResolverMiddleware(ctx, innerFunc) } + out.Concurrently(i, func() graphql.Marshaler { + return rrm(innerCtx) + }) + case "linkAccountInfo": + field := field + + innerFunc := func(ctx context.Context) (res graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + } + }() + res = ec._Query_linkAccountInfo(ctx, field) + return res + } + + rrm := func(ctx context.Context) graphql.Marshaler { + return ec.OperationContext.RootResolverMiddleware(ctx, innerFunc) + } + out.Concurrently(i, func() graphql.Marshaler { return rrm(innerCtx) }) @@ -33720,6 +34111,22 @@ func (ec *executionContext) marshalOAlertStatus2ᚕgithub.comᚋtargetᚋgoale return ret } +func (ec *executionContext) unmarshalOAlertStatus2ᚖgithub.comᚋtargetᚋgoalertᚋgraphql2ᚐAlertStatus(ctx context.Context, v interface{}) (*AlertStatus, error) { + if v == nil { + return nil, nil + } + var res = new(AlertStatus) + err := res.UnmarshalGQL(v) + return res, graphql.ErrorOnPath(ctx, err) +} + +func (ec *executionContext) marshalOAlertStatus2ᚖgithub.comᚋtargetᚋgoalertᚋgraphql2ᚐAlertStatus(ctx context.Context, sel ast.SelectionSet, v *AlertStatus) graphql.Marshaler { + if v == nil { + return graphql.Null + } + return v +} + func (ec *executionContext) unmarshalOBoolean2bool(ctx context.Context, v interface{}) (bool, error) { res, err := graphql.UnmarshalBoolean(v) return res, graphql.ErrorOnPath(ctx, err) @@ -34147,6 +34554,13 @@ func (ec *executionContext) unmarshalOLabelValueSearchOptions2ᚖgithub.comᚋ return &res, graphql.ErrorOnPath(ctx, err) } +func (ec *executionContext) marshalOLinkAccountInfo2ᚖgithub.comᚋtargetᚋgoalertᚋgraphql2ᚐLinkAccountInfo(ctx context.Context, sel ast.SelectionSet, v *LinkAccountInfo) graphql.Marshaler { + if v == nil { + return graphql.Null + } + return ec._LinkAccountInfo(ctx, sel, v) +} + func (ec *executionContext) marshalONotificationState2ᚖgithub.comᚋtargetᚋgoalertᚋgraphql2ᚐNotificationState(ctx context.Context, sel ast.SelectionSet, v *NotificationState) graphql.Marshaler { if v == nil { return graphql.Null diff --git a/graphql2/graphqlapp/app.go b/graphql2/graphqlapp/app.go index d623b3b885..f7731e92cd 100644 --- a/graphql2/graphqlapp/app.go +++ b/graphql2/graphqlapp/app.go @@ -17,6 +17,7 @@ import ( "github.com/target/goalert/alert/alertlog" "github.com/target/goalert/alert/alertmetrics" "github.com/target/goalert/auth" + "github.com/target/goalert/auth/authlink" "github.com/target/goalert/auth/basic" "github.com/target/goalert/calsub" "github.com/target/goalert/config" @@ -74,9 +75,11 @@ type App struct { LimitStore *limit.Store SlackStore *slack.ChannelSender HeartbeatStore *heartbeat.Store - NoticeStore notice.Store + NoticeStore *notice.Store - NotificationManager notification.Manager + AuthLinkStore *authlink.Store + + NotificationManager *notification.Manager AuthHandler *auth.Handler diff --git a/graphql2/graphqlapp/mutation.go b/graphql2/graphqlapp/mutation.go index a86582da31..10e0b1fdc1 100644 --- a/graphql2/graphqlapp/mutation.go +++ b/graphql2/graphqlapp/mutation.go @@ -35,6 +35,11 @@ func (a *Mutation) SetFavorite(ctx context.Context, input graphql2.SetFavoriteIn return true, nil } +func (a *Mutation) LinkAccount(ctx context.Context, token string) (bool, error) { + err := a.AuthLinkStore.LinkAccount(ctx, token) + return err == nil, err +} + func (a *Mutation) SetScheduleOnCallNotificationRules(ctx context.Context, input graphql2.SetScheduleOnCallNotificationRulesInput) (bool, error) { schedID, err := parseUUID("ScheduleID", input.ScheduleID) if err != nil { @@ -104,6 +109,7 @@ func (a *Mutation) SetTemporarySchedule(ctx context.Context, input graphql2.SetT return err == nil, err } + func (a *Mutation) ClearTemporarySchedules(ctx context.Context, input graphql2.ClearTemporarySchedulesInput) (bool, error) { schedID, err := parseUUID("ScheduleID", input.ScheduleID) if err != nil { diff --git a/graphql2/graphqlapp/query.go b/graphql2/graphqlapp/query.go index 1e574ea1e7..792ae42e08 100644 --- a/graphql2/graphqlapp/query.go +++ b/graphql2/graphqlapp/query.go @@ -28,6 +28,34 @@ type ( func (a *App) Query() graphql2.QueryResolver { return (*Query)(a) } +func (a *Query) LinkAccountInfo(ctx context.Context, token string) (*graphql2.LinkAccountInfo, error) { + m, err := a.AuthLinkStore.FindLinkMetadata(ctx, token) + if err != nil { + return nil, err + } + if m == nil { + return nil, nil + } + + info := &graphql2.LinkAccountInfo{ + UserDetails: m.UserDetails, + } + if m.AlertID > 0 { + info.AlertID = &m.AlertID + } + var s graphql2.AlertStatus + switch m.AlertAction { + case notification.ResultAcknowledge.String(): + s = graphql2.AlertStatusStatusAcknowledged + info.AlertNewStatus = &s + case notification.ResultResolve.String(): + s = graphql2.AlertStatusStatusClosed + info.AlertNewStatus = &s + } + + return info, nil +} + func (a *App) formatNC(ctx context.Context, id string) (string, error) { if id == "" { return "", nil diff --git a/graphql2/models_gen.go b/graphql2/models_gen.go index 77da0b2772..833bb14246 100644 --- a/graphql2/models_gen.go +++ b/graphql2/models_gen.go @@ -300,6 +300,12 @@ type LabelValueSearchOptions struct { Omit []string `json:"omit"` } +type LinkAccountInfo struct { + UserDetails string `json:"userDetails"` + AlertID *int `json:"alertID"` + AlertNewStatus *AlertStatus `json:"alertNewStatus"` +} + type NotificationState struct { Details string `json:"details"` Status *NotificationStatus `json:"status"` diff --git a/graphql2/schema.graphql b/graphql2/schema.graphql index ca8f74d17f..131d28a7ce 100644 --- a/graphql2/schema.graphql +++ b/graphql2/schema.graphql @@ -109,6 +109,14 @@ type Query { slackChannel(id: ID!): SlackChannel generateSlackAppManifest: String! + + linkAccountInfo(token: ID!): LinkAccountInfo +} + +type LinkAccountInfo { + userDetails: String! + alertID: Int + alertNewStatus: AlertStatus } input AlertMetricsOptions { @@ -341,6 +349,8 @@ input SetScheduleShiftInput { } type Mutation { + linkAccount(token: ID!): Boolean! + setTemporarySchedule(input: SetTemporaryScheduleInput!): Boolean! clearTemporarySchedules(input: ClearTemporarySchedulesInput!): Boolean! diff --git a/keyring/store.go b/keyring/store.go index 1431da6e3e..daa4084cd6 100644 --- a/keyring/store.go +++ b/keyring/store.go @@ -544,7 +544,7 @@ func (db *DB) VerifyJWT(s string, c jwt.Claims) (bool, error) { return false, err } - return currentKey, nil + return currentKey, c.Valid() } // Verify will validate the signature and metadata, and optionally length, of a message. diff --git a/migrate/migrations/20220802104000-auth-link-requests.sql b/migrate/migrations/20220802104000-auth-link-requests.sql new file mode 100644 index 0000000000..c18e9fa51f --- /dev/null +++ b/migrate/migrations/20220802104000-auth-link-requests.sql @@ -0,0 +1,12 @@ +-- +migrate Up +CREATE TABLE auth_link_requests ( + id UUID PRIMARY KEY, + provider_id TEXT NOT NULL, + subject_id TEXT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + metadata JSONB NOT NULL DEFAULT '{}'::JSONB +); + +-- +migrate Down +DROP TABLE auth_link_requests; diff --git a/notification/namedreceiver.go b/notification/namedreceiver.go index 9ed790ec6b..e63854660e 100644 --- a/notification/namedreceiver.go +++ b/notification/namedreceiver.go @@ -1,6 +1,10 @@ package notification -import "context" +import ( + "context" + + "github.com/target/goalert/auth/authlink" +) type namedReceiver struct { r ResultReceiver @@ -23,6 +27,11 @@ func (nr *namedReceiver) SetMessageStatus(ctx context.Context, externalID string return nr.r.SetSendResult(ctx, res) } +// AuthLinkURL calls the underlying AuthLinkURL method. +func (nr *namedReceiver) AuthLinkURL(ctx context.Context, providerID, subjectID string, meta authlink.Metadata) (string, error) { + return nr.r.AuthLinkURL(ctx, providerID, subjectID, meta) +} + // Start implements the Receiver interface by calling the underlying Receiver.Start method. func (nr *namedReceiver) Start(ctx context.Context, d Dest) error { metricRecvTotal.WithLabelValues(d.Type.String(), "START") diff --git a/notification/receiver.go b/notification/receiver.go index 72b1ba46a1..57520cf270 100644 --- a/notification/receiver.go +++ b/notification/receiver.go @@ -2,7 +2,8 @@ package notification import ( "context" - "errors" + + "github.com/target/goalert/auth/authlink" ) // A Receiver processes incoming messages and responses. @@ -16,6 +17,9 @@ type Receiver interface { // ReceiveSubject records a response to a previously sent message from a provider/subject (e.g. Slack user). ReceiveSubject(ctx context.Context, providerID, subjectID, callbackID string, result Result) error + // AuthLinkURL will generate a URL to link a provider and subject to a GoAlert user. + AuthLinkURL(ctx context.Context, providerID, subjectID string, meta authlink.Metadata) (string, error) + // Start indicates a user has opted-in for notifications to this contact method. Start(context.Context, Dest) error @@ -26,5 +30,11 @@ type Receiver interface { IsKnownDest(ctx context.Context, value string) (bool, error) } -// ErrUnknownSubject is returned from ReceiveSubject when the subject is unknown. -var ErrUnknownSubject = errors.New("unknown subject for that provider") +// UnknownSubjectError is returned from ReceiveSubject when the subject is unknown. +type UnknownSubjectError struct { + AlertID int +} + +func (e UnknownSubjectError) Error() string { + return "unknown subject for that provider" +} diff --git a/notification/resultreceiver.go b/notification/resultreceiver.go index bc376411bb..637b8bd828 100644 --- a/notification/resultreceiver.go +++ b/notification/resultreceiver.go @@ -2,6 +2,8 @@ package notification import ( "context" + + "github.com/target/goalert/auth/authlink" ) // A ResultReceiver processes notification responses. @@ -10,6 +12,7 @@ type ResultReceiver interface { Receive(ctx context.Context, callbackID string, result Result) error ReceiveSubject(ctx context.Context, providerID, subjectID, callbackID string, result Result) error + AuthLinkURL(ctx context.Context, providerID, subjectID string, meta authlink.Metadata) (string, error) Start(context.Context, Dest) error Stop(context.Context, Dest) error diff --git a/notification/slack/channel.go b/notification/slack/channel.go index 9cfe33edfb..55583c6945 100644 --- a/notification/slack/channel.go +++ b/notification/slack/channel.go @@ -246,6 +246,7 @@ const ( alertResponseBlockID = "block_alert_response" alertCloseActionID = "action_alert_close" alertAckActionID = "action_alert_ack" + linkActActionID = "action_link_account" ) // alertMsgOption will return the slack.MsgOption for an alert-type message (e.g., notification or status update). diff --git a/notification/slack/servemessageaction.go b/notification/slack/servemessageaction.go index bb4d1a408a..284c9eab5a 100644 --- a/notification/slack/servemessageaction.go +++ b/notification/slack/servemessageaction.go @@ -14,6 +14,7 @@ import ( "github.com/slack-go/slack" "github.com/target/goalert/alert" + "github.com/target/goalert/auth/authlink" "github.com/target/goalert/config" "github.com/target/goalert/notification" "github.com/target/goalert/util/errutil" @@ -78,12 +79,17 @@ func (s *ChannelSender) ServeMessageAction(w http.ResponseWriter, req *http.Requ var payload struct { Type string ResponseURL string `json:"response_url"` - Channel struct { + Team struct { + ID string + Domain string + } + Channel struct { ID string } User struct { - ID string `json:"id"` - TeamID string `json:"team_id"` + ID string `json:"id"` + Username string `json:"username"` + Name string } Actions []struct { ActionID string `json:"action_id"` @@ -113,23 +119,74 @@ func (s *ChannelSender) ServeMessageAction(w http.ResponseWriter, req *http.Requ res = notification.ResultAcknowledge case alertCloseActionID: res = notification.ResultResolve + case linkActActionID: + s.withClient(ctx, func(c *slack.Client) error { + // remove ephemeral 'Link Account' button + _, err = c.PostEphemeralContext(ctx, payload.Channel.ID, payload.User.ID, + slack.MsgOptionText("", false), slack.MsgOptionReplaceOriginal(payload.ResponseURL), + slack.MsgOptionDeleteOriginal(payload.ResponseURL)) + if err != nil { + return err + } + return nil + }) + return default: errutil.HTTPError(ctx, w, validation.NewFieldErrorf("action_id", "unknown action ID '%s'", act.ActionID)) return } - err = s.recv.ReceiveSubject(ctx, "slack:"+payload.User.TeamID, payload.User.ID, act.Value, res) - if errors.Is(err, notification.ErrUnknownSubject) { - log.Log(ctx, fmt.Errorf("unknown provider/subject ID for Slack 'slack:%s/%s'", payload.User.TeamID, payload.User.ID)) + var e *notification.UnknownSubjectError + err = s.recv.ReceiveSubject(ctx, "slack:"+payload.Team.ID, payload.User.ID, act.Value, res) + + if errors.As(err, &e) { + var linkURL string + switch { + case payload.User.Name == "", payload.User.Username == "", payload.Team.ID == "", payload.Team.Domain == "": + // missing data, don't allow linking + log.Log(ctx, errors.New("slack payload missing requried data")) + default: + linkURL, err = s.recv.AuthLinkURL(ctx, "slack:"+payload.Team.ID, payload.User.ID, authlink.Metadata{ + UserDetails: fmt.Sprintf("Slack user %s (@%s) from %s.slack.com", payload.User.Name, payload.User.Username, payload.Team.Domain), + AlertID: e.AlertID, + AlertAction: res.String(), + }) + if err != nil { + log.Log(ctx, err) + } + } + err = s.withClient(ctx, func(c *slack.Client) error { - _, err := c.PostEphemeralContext(ctx, payload.Channel.ID, payload.User.ID, + var msg string + if linkURL == "" { + msg = "Your Slack account isn't currently linked to GoAlert, please try again later." + } else { + msg = "Please link your Slack account with GoAlert." + } + blocks := []slack.Block{ + slack.NewSectionBlock( + slack.NewTextBlockObject("plain_text", msg, false, false), + nil, nil, + ), + } + + if linkURL != "" { + btn := slack.NewButtonBlockElement(linkActActionID, linkURL, + slack.NewTextBlockObject("plain_text", "Link Account", false, false)) + btn.URL = linkURL + blocks = append(blocks, slack.NewActionBlock(alertResponseBlockID, btn)) + } + + _, err = c.PostEphemeralContext(ctx, payload.Channel.ID, payload.User.ID, slack.MsgOptionResponseURL(payload.ResponseURL, "ephemeral"), - - // TODO: add user-link/OAUTH flow - slack.MsgOptionText("Your Slack account isn't currently linked to GoAlert, the admin will need to set this up for it to work.", false), + slack.MsgOptionBlocks(blocks...), ) - return err + if err != nil { + return err + } + return nil }) + return } if alert.IsAlreadyAcknowledged(err) || alert.IsAlreadyClosed(err) { // ignore errors from duplicate requests diff --git a/test/smoke/harness/slack.go b/test/smoke/harness/slack.go index 918af4a6ec..44a8852516 100644 --- a/test/smoke/harness/slack.go +++ b/test/smoke/harness/slack.go @@ -50,6 +50,7 @@ type SlackMessageState interface { type SlackAction interface { Click() + URL() string } type SlackMessage interface { @@ -112,7 +113,8 @@ func (msg *slackMessage) Action(text string) SlackAction { a = &action break } - require.NotNil(msg.h.t, a, "could not find action with that text") + require.NotNilf(msg.h.t, a, `expected action "%s"`, text) + msg.h.t.Logf("found action: %s\n%#v", text, *a) return &slackAction{ slackMessage: msg, @@ -120,6 +122,11 @@ func (msg *slackMessage) Action(text string) SlackAction { } } +func (a *slackAction) URL() string { + a.h.t.Helper() + return a.Action.URL +} + func (a *slackAction) Click() { a.h.t.Helper() diff --git a/test/smoke/slackinteraction_test.go b/test/smoke/slackinteraction_test.go index 53d776851b..d9f691afe1 100644 --- a/test/smoke/slackinteraction_test.go +++ b/test/smoke/slackinteraction_test.go @@ -1,6 +1,8 @@ package smoke import ( + "fmt" + "net/url" "testing" "github.com/target/goalert/test/smoke/harness" @@ -30,7 +32,7 @@ func TestSlackInteraction(t *testing.T) { values ({{uuid "sid"}}, {{uuid "eid"}}, 'service'); ` - h := harness.NewHarness(t, sql, "slack-user-link") + h := harness.NewHarness(t, sql, "auth-link-requests") defer h.Close() h.SetConfigValue("Slack.InteractiveMessages", "true") @@ -44,9 +46,26 @@ func TestSlackInteraction(t *testing.T) { h.IgnoreErrorsWith("unknown provider/subject") msg.Action("Acknowledge").Click() // expect ephemeral - ch.ExpectEphemeralMessage("GoAlert", "admin") - h.LinkSlackUser() + urlStr := ch.ExpectEphemeralMessage("link", "Slack", "account").Action("Link Account").URL() + t.Logf("url: %s", urlStr) + + u, err := url.Parse(urlStr) + if err != nil { + t.Fatal("bad link url returned:", err) + } + + tokenStr := u.Query().Get("authLinkToken") + resp := h.GraphQLQuery2(fmt.Sprintf(` + mutation { + linkAccount(token: "%s") + } + `, tokenStr)) + + if len(resp.Errors) > 0 { + t.Fatalf("expected no errors but got %v", resp.Errors) + } + msg.Action("Acknowledge").Click() updated := msg.ExpectUpdate() diff --git a/web/src/app/main/App.tsx b/web/src/app/main/App.tsx index 24b4f8d99a..0687ab77b4 100644 --- a/web/src/app/main/App.tsx +++ b/web/src/app/main/App.tsx @@ -23,6 +23,7 @@ import { Theme } from '@mui/material/styles' import AppRoutes from './AppRoutes' import { useURLKey } from '../actions' import NavBar from './NavBar' +import AuthLink from './components/AuthLink' const useStyles = makeStyles((theme: Theme) => ({ root: { @@ -115,6 +116,7 @@ export default function App(): JSX.Element {
+ { + if (!ready) return + if (!token) return + if (fetching) return + if (error) return + if (info) return + + setToken('') + }, [!!info, !!error, fetching, ready, token]) + + if (!token || !ready || fetching) { + return null + } + + if (error) { + return ( + setSnack(false)} + open={snack && !!error} + > + + Unable to fetch account link details. Try again later. + + + ) + } + + if (!info) { + return ( + setSnack(false)} + open={snack} + > + + Invalid or expired account link URL. Try again. + + + ) + } + + let alertAction = '' + if (info.alertID && info.alertNewStatus) { + switch (info.alertNewStatus) { + case 'StatusAcknowledged': + alertAction = `alert #${data.linkAccountInfo.alertID} will be acknowledged.` + break + case 'StatusClosed': + alertAction = `alert #${data.linkAccountInfo.alertID} will be closed.` + break + default: + alertAction = `Alert #${data.linkAccountInfo.alertID} will be updated to ${info.alertNewStatus}.` + break + } + } + + return ( + setToken('')} + onSubmit={() => + linkAccount({ token }).then((result) => { + if (result.error) return + if (info.alertID) navigate(`/alerts/${info.alertID}`) + if (info.alertNewStatus) { + updateAlertStatus({ + input: { + alertIDs: [info.alertID], + newStatus: info.alertNewStatus, + }, + }) + } + + setToken('') + }) + } + form={ + + + Clicking confirm will link the current GoAlert user{' '} + {userName} with: + + {data.linkAccountInfo.userDetails}. +
+
+ {alertAction && ( + After linking, {alertAction} + )} +
+ } + /> + ) +} diff --git a/web/src/schema.d.ts b/web/src/schema.d.ts index 691171ab58..d5a63fd7ce 100644 --- a/web/src/schema.d.ts +++ b/web/src/schema.d.ts @@ -34,6 +34,13 @@ export interface Query { slackChannels: SlackChannelConnection slackChannel?: null | SlackChannel generateSlackAppManifest: string + linkAccountInfo?: null | LinkAccountInfo +} + +export interface LinkAccountInfo { + userDetails: string + alertID?: null | number + alertNewStatus?: null | AlertStatus } export interface AlertMetricsOptions { @@ -256,6 +263,7 @@ export interface SetScheduleShiftInput { } export interface Mutation { + linkAccount: boolean setTemporarySchedule: boolean clearTemporarySchedules: boolean setScheduleOnCallNotificationRules: boolean