Skip to content

Commit

Permalink
Implement polling subscriptions (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrudd authored May 16, 2021
1 parent 899dfae commit a313aa8
Show file tree
Hide file tree
Showing 6 changed files with 823 additions and 50 deletions.
70 changes: 64 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ if _, err := db.Exec("SET search_path TO message_store,public;"); err != nil {
log.Fatalf("setting search path: %s", err)
}

// create client
// Create client
client := gomdb.NewClient(db)

// read from stream
// Read from stream
msgs, err := client.GetStreamMessages(context.Background(), stream)
if err != nil {
log.Fatalf("reading from stream: %s", err)
Expand All @@ -31,21 +31,79 @@ if err != nil {
log.Println(msgs)
```

See the [example](./example) directory for a more complete example.
See the [example](./example) directory or [client_test.go](./client_test.go) for a more complete examples.

## Running tests

The unit tests require an instance of Message DB running to test against.

```bash
# start Message DB
# Start Message DB
docker build -t message-db .
docker run -d --rm \
-p 5432:5432 \
-e POSTGRES_HOST_AUTH_METHOD=trust \
message-db \
-c message_store.sql_condition=on

# run tests
# Run tests
go test -condition-on
```
```

## Subscriptions

Subscriptions are built on top of the `GetStreamMessages` and `GetCategoryMessages` methods and simply poll from the last read version or position.

```go
subCtx, cancel := context.WithCancel(context.Background())
defer cancel() // cancel will stop the subscription

err := client.SubscribeToCategory(subCtx, "user",
func(m *gomdb.Message) { // Message handler
log.Printf("Received message: %v", m)
},
func(live bool) { // Liveness handler
if live {
log.Print("subscription is handling live messages!")
} else {
log.Print("subscription has fallen behind")
}
},
func(err error) { // subscription dropped handler
if err != nil {
log.Fatalf("subscription dropped with error: %s", err)
}
},
)
if err != nil {
log.Fatal(err)
}
```

The client can be configured with different polling strategies to reduce reads to the database for subscriptions that rarely receive messages

```go
// Client configured with exponential backoff
client := gomdb.NewClient(
db,
gomdb.WithSubPollingStrategy(
gomdb.ExpBackoffPolling(
50*time.Millisecond, // minimum polling delay on no messages read
5*time.Second, // maximum polling delay on no messages read
2, // delay will double for every read that returns no messages
),
),
)

// Client configured with constant polling interval
client = gomdb.NewClient(
db,
gomdb.WithSubPollingStrategy(
gomdb.ConstantPolling(100*time.Millisecond), // polling delay on no messages read
),
)
```

## Contributing

All contributions welcome, especially anyone with SQL experience who could tidy up how queries are run and how read errors are handled.
219 changes: 208 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Package gomdb provides a Client for calling Message DB procedures.
package gomdb

import (
Expand All @@ -7,13 +8,17 @@ import (
"errors"
"fmt"
"strings"
"time"
)

const (
// NoStreamVersion is expected version for a stream that doesn't exist.
NoStreamVersion = int64(-1)
// AnyVersion allows writing of a message regardless of the stream version.
AnyVersion = int64(-2)
// DefaultPollingInterval defines the default polling duration for
// subscriptions.
DefaultPollingInterval = 100 * time.Millisecond
)

// ErrUnexpectedStreamVersion is returned when a stream is not at the expected
Expand All @@ -22,14 +27,22 @@ var ErrUnexpectedStreamVersion = errors.New("unexpected stream version when writ

// Client exposes the message-db interface.
type Client struct {
db *sql.DB
db *sql.DB
pollingStrategy PollingStrategy
}

// NewClient returns a new message-db client for the provided database.
func NewClient(db *sql.DB) *Client {
return &Client{
db: db,
func NewClient(db *sql.DB, opts ...ClientOption) *Client {
c := &Client{
db: db,
pollingStrategy: ConstantPolling(DefaultPollingInterval),
}

for _, opt := range opts {
opt(c)
}

return c
}

// WriteMessage attempted to write the proposed message to the specifed stream.
Expand Down Expand Up @@ -74,18 +87,18 @@ func (c *Client) WriteMessage(ctx context.Context, stream StreamIdentifier, mess

defer rows.Close()

// read revision from results.
var revision int64
// read version from results.
var version int64

if !rows.Next() {
return 0, errors.New("write succeeded but no rows were returned")
}

if err = rows.Scan(&revision); err != nil {
return 0, fmt.Errorf("write succeeded but could not read returned revision: %w", err)
if err = rows.Scan(&version); err != nil {
return 0, fmt.Errorf("write succeeded but could not read returned version: %w", err)
}

return revision, nil
return version, nil
}

// GetStreamMessages reads messages from an individual stream. By default the
Expand Down Expand Up @@ -231,15 +244,15 @@ func (c *Client) GetStreamVersion(ctx context.Context, stream StreamIdentifier)

defer rows.Close()

// read revision from results.
// read version from results.

if !rows.Next() {
return 0, errors.New("no rows were returned")
}

var value interface{}
if err = rows.Scan(&value); err != nil {
return 0, fmt.Errorf("reading stream revision: %w", err)
return 0, fmt.Errorf("reading stream version: %w", err)
}

if value == nil {
Expand All @@ -250,3 +263,187 @@ func (c *Client) GetStreamVersion(ctx context.Context, stream StreamIdentifier)

return 0, fmt.Errorf("unexpected column value type: %T", value)
}

// MessageHandler handles messages as they appear after being written.
type MessageHandler func(*Message)

// LivenessHandler handles whether the subscription is in a "live" state or
// whether it is catching up.
type LivenessHandler func(bool)

// SubDroppedHandler handles errors that appear and stop the subscription.
type SubDroppedHandler func(error)

// SubscribeToStream subscribes to a stream and asynchronously passes messages
// to the message handler in batches. Once a subscription has caught up it will
// poll the database periodically for new messages. To stop a subscription
// cancel the provided context.
// When a subscription catches up it will call the LivenessHandler with true. If
// the subscription falls behind again it will called the LivenessHandler with
// false.
// If there is an error while reading messages then the subscription will be
// stopped and the SubDroppedHandler will be called with the stopping error. If
// the subscription is cancelled then the SubDroppedHandler will be called with
// nil.
func (c *Client) SubscribeToStream(
ctx context.Context,
stream StreamIdentifier,
handleMessage MessageHandler,
handleLiveness LivenessHandler,
handleDropped SubDroppedHandler,
opts ...GetStreamOption,
) error {
cfg := newDefaultStreamConfig()
for _, opt := range opts {
opt(cfg)
}

// validate inputs
if err := stream.validate(); err != nil {
return fmt.Errorf("validating stream identifier: %w", err)
} else if handleMessage == nil || handleLiveness == nil || handleDropped == nil {
return errors.New("all subscription handlers are required")
} else if err := cfg.validate(); err != nil {
return fmt.Errorf("validating options: %w", err)
}

// ignore context cancelled errors
wrappedHandleDropped := func(e error) {
if errors.Is(e, context.Canceled) {
handleDropped(nil)
} else {
handleDropped(ctx.Err())
}
}

go func() {
poll := time.NewTimer(0)
live := false
defer poll.Stop()

for {
// check for context cancelled
select {
case <-ctx.Done():
wrappedHandleDropped(ctx.Err())
return
case <-poll.C:
}

msgs, err := c.GetStreamMessages(ctx, stream, func(c *streamConfig) { *c = *cfg })
if err != nil {
wrappedHandleDropped(err)
return
}

for _, msg := range msgs {
handleMessage(msg)
}

if len(msgs) > 0 {
cfg.version = msgs[len(msgs)-1].Version + 1
}

// if we've read fewer messages than the batch size we must have
// caught up and can go live. Otherwise we've fallen behind.
if len(msgs) < int(cfg.batchSize) && !live {
live = true
handleLiveness(live)
} else if len(msgs) == int(cfg.batchSize) && live {
live = false
handleLiveness(live)
}

poll.Reset(c.pollingStrategy(int64(len(msgs)), cfg.batchSize))
}
}()

return nil
}

// SubscribeToCategory subscribes to a category and asynchronously passes messages
// to the message handler in batches. Once a subscription has caught up it will
// poll the database periodically for new messages. To stop a subscription
// cancel the provided context.
// When a subscription catches up it will call the LivenessHandler with true. If
// the subscription falls behind again it will called the LivenessHandler with
// false.
// If there is an error while reading messages then the subscription will be
// stopped and the SubDroppedHandler will be called with the stopping error. If
// the subscription is cancelled then the SubDroppedHandler will be called with
// nil.
func (c *Client) SubscribeToCategory(
ctx context.Context,
category string,
handleMessage MessageHandler,
handleLiveness LivenessHandler,
handleDropped SubDroppedHandler,
opts ...GetCategoryOption,
) error {
cfg := newDefaultCategoryConfig()
for _, opt := range opts {
opt(cfg)
}

// validate inputs
if strings.Contains(category, StreamNameSeparator) {
return fmt.Errorf("category cannot contain stream name separator (%s)", StreamNameSeparator)
} else if handleMessage == nil || handleLiveness == nil || handleDropped == nil {
return errors.New("all subscription handlers are required")
} else if err := cfg.validate(); err != nil {
return fmt.Errorf("validating options: %w", err)
}

// ignore context cancelled errors
wrappedHandleDropped := func(e error) {
if errors.Is(e, context.Canceled) {
handleDropped(nil)
} else {
handleDropped(ctx.Err())
}
}

go func() {
poll := time.NewTimer(0)
live := false
defer poll.Stop()

for {
// check for context cancelled
select {
case <-ctx.Done():
wrappedHandleDropped(ctx.Err())
return
case <-poll.C:
}

msgs, err := c.GetCategoryMessages(ctx, category, func(c *categoryConfig) { *c = *cfg })
if err != nil {
wrappedHandleDropped(err)
return
}

for _, msg := range msgs {
handleMessage(msg)
}

if len(msgs) > 0 {
cfg.position = msgs[len(msgs)-1].GlobalPosition + 1
}

// if we've read fewer messages than the batch size we must have
// caught up and can go live. Otherwise we've fallen behind.
if len(msgs) < int(cfg.batchSize) && !live {
live = true
handleLiveness(live)
} else if len(msgs) == int(cfg.batchSize) && live {
live = false
handleLiveness(live)
}

poll.Reset(c.pollingStrategy(int64(len(msgs)), cfg.batchSize))
}
}()

return nil
}
Loading

0 comments on commit a313aa8

Please sign in to comment.