Skip to content

Commit

Permalink
Test
Browse files Browse the repository at this point in the history
  • Loading branch information
Philipp Heckel committed Dec 25, 2021
1 parent 41514cd commit d676227
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 33 deletions.
20 changes: 9 additions & 11 deletions server/message.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package server

import (
"time"

"heckel.io/ntfy/util"
"time"
)

// List of possible events
Expand All @@ -19,15 +18,14 @@ const (

// message represents a message published to a topic
type message struct {
ID string `json:"id"` // Random message ID
Time int64 `json:"time"` // Unix time in seconds
Event string `json:"event"` // One of the above
Topic string `json:"topic"`
Priority int `json:"priority,omitempty"`
Tags []string `json:"tags,omitempty"`
Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"`
UnifiedPush bool `json:"unifiedpush,omitempty"` //this could be 'up'
ID string `json:"id"` // Random message ID
Time int64 `json:"time"` // Unix time in seconds
Event string `json:"event"` // One of the above
Topic string `json:"topic"`
Priority int `json:"priority,omitempty"`
Tags []string `json:"tags,omitempty"`
Title string `json:"title,omitempty"`
Message string `json:"message,omitempty"`
}

// messageEncoder is a function that knows how to encode a message
Expand Down
46 changes: 24 additions & 22 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ import (
"context"
"embed"
"encoding/json"
firebase "firebase.google.com/go"
"firebase.google.com/go/messaging"
"fmt"
"google.golang.org/api/option"
"heckel.io/ntfy/util"
"html/template"
"io"
"log"
Expand All @@ -16,11 +20,6 @@ import (
"strings"
"sync"
"time"

firebase "firebase.google.com/go"
"firebase.google.com/go/messaging"
"google.golang.org/api/option"
"heckel.io/ntfy/util"
)

// TODO add "max messages in a topic" limit
Expand Down Expand Up @@ -288,7 +287,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
} else if r.Method == http.MethodOptions {
return s.handleOptions(w, r)
} else if r.Method == http.MethodGet && topicRegex.MatchString(r.URL.Path) {
return s.handleHome(w, r)
return s.handleTopic(w, r)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handlePublish)
} else if r.Method == http.MethodGet && sendRegex.MatchString(r.URL.Path) {
Expand All @@ -310,6 +309,17 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error {
})
}

func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request) error {
unifiedpush := readParam(r, "x-unifiedpush", "unifiedpush", "up") == "1" // see PUT/POST too!
if unifiedpush {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
_, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n")
return err
}
return s.handleHome(w, r)
}

func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request) error {
return nil
}
Expand Down Expand Up @@ -340,25 +350,15 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
return err
}
m := newDefaultMessage(t.ID, strings.TrimSpace(string(b)))
cache, firebase, email, err := s.parseParams(r, m)
cache, firebase, email, err := s.parsePublishParams(r, m)
if err != nil {
return err
}

if r.Method == http.MethodGet && unifiedpush {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
_, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`)
return err
}

if email != "" {
if err := v.EmailAllowed(); err != nil {
return errHTTPTooManyRequestsLimitEmails
}
}

m.UnifiedPush = unifiedpush
if s.mailer == nil && email != "" {
return errHTTPBadRequestEmailDisabled
}
Expand All @@ -371,37 +371,35 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
return err
}
}
if s.firebase != nil && firebase && !delayed && !unifiedpush {
if s.firebase != nil && firebase && !delayed {
go func() {
if err := s.firebase(m); err != nil {
log.Printf("Unable to publish to Firebase: %v", err.Error())
}
}()
}
if s.mailer != nil && email != "" && !delayed && !unifiedpush {
if s.mailer != nil && email != "" && !delayed {
go func() {
if err := s.mailer.Send(v.ip, email, m); err != nil {
log.Printf("Unable to send email: %v", err.Error())
}
}()
}

if cache {
if err := s.cache.AddMessage(m); err != nil {
return err
}
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests

if err := json.NewEncoder(w).Encode(m); err != nil {
return err
}
s.inc(&s.messages)
return nil
}

func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase bool, email string, err error) {
func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email string, err error) {
cache = readParam(r, "x-cache", "cache") != "no"
firebase = readParam(r, "x-firebase", "firebase") != "no"
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
Expand Down Expand Up @@ -439,6 +437,10 @@ func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase
}
m.Time = delay.Unix()
}
unifiedpush := readParam(r, "x-unifiedpush", "unifiedpush", "up") == "1" // see GET too!
if unifiedpush {
firebase = false
}
return cache, firebase, email, nil
}

Expand Down
7 changes: 7 additions & 0 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,13 @@ func TestServer_PublishEmailNoMailer_Fail(t *testing.T) {
require.Equal(t, 400, response.Code)
}

func TestServer_UnifiedPushDiscovery(t *testing.T) {
s := newTestServer(t, newTestConfig(t))
response := request(t, s, "GET", "/mytopic?up=1", "", nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"unifiedpush":{"version":1}}`+"\n", response.Body.String())
}

func newTestConfig(t *testing.T) *Config {
conf := NewConfig()
conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")
Expand Down

0 comments on commit d676227

Please sign in to comment.