From 133ae6d0198dadd27aee636264316fc5aed073e6 Mon Sep 17 00:00:00 2001 From: alexrudd Date: Sun, 16 May 2021 01:36:47 +0100 Subject: [PATCH 1/2] Implement polling subscriptions for streams --- client.go | 105 +++++++++++++++++++++++++++++++++++++++++++++++- client_test.go | 107 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 210 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index c655d91..ee9c6f0 100644 --- a/client.go +++ b/client.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "strings" + "time" ) const ( @@ -22,13 +23,15 @@ var ErrUnexpectedStreamVersion = errors.New("unexpected stream version when writ // Client exposes the message-db interface. type Client struct { - db *sql.DB + db *sql.DB + pollInterval time.Duration } // NewClient returns a new message-db client for the provided database. func NewClient(db *sql.DB) *Client { return &Client{ - db: db, + db: db, + pollInterval: 100 * time.Millisecond, } } @@ -250,3 +253,101 @@ func (c *Client) GetStreamVersion(ctx context.Context, stream StreamIdentifier) return 0, fmt.Errorf("unexpected column value type: %T", value) } + +// MessageHandler handles messages as they appear after being written. +type MessageHandler func(*Message) + +// LivenessHandler handles whether the subscription is in a "live" state or +// whether it is catching up. +type LivenessHandler func(bool) + +// ErrorHandler handles errors that appear and stop the subscription. +type ErrorHandler func(error) + +// SubscribeToStream subscribes to a stream and asynchronously passes messages +// to the message handler in batches. Once a subscription has caught up it will +// poll the database periodically for new messages. To stop a subscription +// cancel the provided context. +// When a subscription catches up it will call the LivenessHandler with true. If +// the subscription falls behind again it will called the LivenessHandler with +// false. +// If there is an error while reading messages then the subscription will be +// stopped and the ErrorHandler will be called with the stopping error. +func (c *Client) SubscribeToStream( + ctx context.Context, + stream StreamIdentifier, + handleMessage MessageHandler, + handleLiveness LivenessHandler, + handleError ErrorHandler, + opts ...GetStreamOption, +) error { + cfg := newDefaultStreamConfig() + for _, opt := range opts { + opt(cfg) + } + + // validate inputs + if err := stream.validate(); err != nil { + return fmt.Errorf("validating stream identifier: %w", err) + } else if handleMessage == nil || handleLiveness == nil || handleError == nil { + return errors.New("all subscription handlers are required") + } else if err := cfg.validate(); err != nil { + return fmt.Errorf("validating options: %w", err) + } + + // ignore context cancelled errors + wrappedHandleError := func(e error) { + if errors.Is(e, context.Canceled) { + handleError(nil) + } else { + handleError(ctx.Err()) + } + } + + go func() { + poll := time.NewTicker(1) + live := false + + for { + // check for context cancelled + select { + case <-ctx.Done(): + wrappedHandleError(ctx.Err()) + return + case <-poll.C: + } + + msgs, err := c.GetStreamMessages(ctx, stream, func(c *streamConfig) { + c.version = cfg.version + c.batchSize = cfg.batchSize + c.condition = cfg.condition + }) + if err != nil { + wrappedHandleError(err) + return + } + + for _, msg := range msgs { + handleMessage(msg) + } + + if len(msgs) > 0 { + cfg.version = msgs[len(msgs)-1].Version + 1 + } + + // if we've read fewer messages than the batch size we must have + // caught up and can go live. Otherwise we've fallen behind. + if len(msgs) < int(cfg.batchSize) && !live { + live = true + poll.Reset(c.pollInterval) + handleLiveness(live) + } else if len(msgs) == int(cfg.batchSize) && live { + live = false + poll.Reset(0) + handleLiveness(live) + } + } + }() + + return nil +} diff --git a/client_test.go b/client_test.go index 9f781d1..2fe35c8 100644 --- a/client_test.go +++ b/client_test.go @@ -6,6 +6,7 @@ import ( "errors" "flag" "fmt" + "sync" "testing" "github.com/alexrudd/gomdb" @@ -500,3 +501,109 @@ func TestGetStreamVersion(t *testing.T) { } }) } + +// TestSubscribeToStream tests the SubscribeToStream API. +func TestSubscribeToStream(t *testing.T) { + // t.Parallel() + + client := NewClient(t) + + t.Run("subscribe to empty stream", func(t *testing.T) { + // t.Parallel() + + stream := NewTestStream("nonexistant") + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + goneLive := sync.WaitGroup{} + goneLive.Add(1) + + err := client.SubscribeToStream( + ctx, + stream, + func(m *gomdb.Message) { + t.Fatal("No messages should exist on stream") + }, + func(live bool) { + if !live { + t.Fatal("subscription should be live") + } + goneLive.Done() + }, + func(err error) { + if err != nil { + t.Fatalf("received subscription error: %s", err) + } + }, + ) + if err != nil { + t.Fatal(err) + } + + goneLive.Wait() + }) + + t.Run("subscribe to new messages", func(t *testing.T) { + // t.Parallel() + + stream := NewTestStream("nonexistant") + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + received := sync.WaitGroup{} + received.Add(3) + + err := client.SubscribeToStream( + ctx, + stream, + func(m *gomdb.Message) { + received.Done() + }, + func(live bool) {}, + func(err error) {}, + ) + if err != nil { + t.Fatal(err) + } + + PopulateStream(t, client, stream, 3) + + received.Wait() + }) + + t.Run("catch up to stream then go live", func(t *testing.T) { + // t.Parallel() + + stream := NewTestStream("nonexistant") + PopulateStream(t, client, stream, 10) + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + received := sync.WaitGroup{} + received.Add(10) + version := int64(0) + + err := client.SubscribeToStream( + ctx, + stream, + func(m *gomdb.Message) { + version = m.Version + received.Done() + }, + func(live bool) { + if live && version != 9 { + t.Fatalf("expected to go live at version 9, actual: %v", version) + } + }, + func(err error) {}, + gomdb.WithStreamBatchSize(5), + ) + if err != nil { + t.Fatal(err) + } + + received.Wait() + }) + +} From 1e8acafa5f62e9415fd78631b09f8996244e2c4c Mon Sep 17 00:00:00 2001 From: alexrudd Date: Sun, 16 May 2021 19:12:05 +0100 Subject: [PATCH 2/2] Complete subscription implementation --- README.md | 70 +++++++++++++-- client.go | 158 ++++++++++++++++++++++++++------- client_test.go | 227 +++++++++++++++++++++++++++++++++++++++++------- options.go | 89 +++++++++++++++++-- options_test.go | 160 ++++++++++++++++++++++++++++++++++ types_test.go | 15 ++++ 6 files changed, 642 insertions(+), 77 deletions(-) create mode 100644 options_test.go diff --git a/README.md b/README.md index ee76721..f42cce3 100644 --- a/README.md +++ b/README.md @@ -19,10 +19,10 @@ if _, err := db.Exec("SET search_path TO message_store,public;"); err != nil { log.Fatalf("setting search path: %s", err) } -// create client +// Create client client := gomdb.NewClient(db) -// read from stream +// Read from stream msgs, err := client.GetStreamMessages(context.Background(), stream) if err != nil { log.Fatalf("reading from stream: %s", err) @@ -31,14 +31,14 @@ if err != nil { log.Println(msgs) ``` -See the [example](./example) directory for a more complete example. +See the [example](./example) directory or [client_test.go](./client_test.go) for a more complete examples. ## Running tests The unit tests require an instance of Message DB running to test against. ```bash -# start Message DB +# Start Message DB docker build -t message-db . docker run -d --rm \ -p 5432:5432 \ @@ -46,6 +46,64 @@ docker run -d --rm \ message-db \ -c message_store.sql_condition=on -# run tests +# Run tests go test -condition-on -``` \ No newline at end of file +``` + +## Subscriptions + +Subscriptions are built on top of the `GetStreamMessages` and `GetCategoryMessages` methods and simply poll from the last read version or position. + +```go +subCtx, cancel := context.WithCancel(context.Background()) +defer cancel() // cancel will stop the subscription + +err := client.SubscribeToCategory(subCtx, "user", + func(m *gomdb.Message) { // Message handler + log.Printf("Received message: %v", m) + }, + func(live bool) { // Liveness handler + if live { + log.Print("subscription is handling live messages!") + } else { + log.Print("subscription has fallen behind") + } + }, + func(err error) { // subscription dropped handler + if err != nil { + log.Fatalf("subscription dropped with error: %s", err) + } + }, +) +if err != nil { + log.Fatal(err) +} +``` + +The client can be configured with different polling strategies to reduce reads to the database for subscriptions that rarely receive messages + +```go +// Client configured with exponential backoff +client := gomdb.NewClient( + db, + gomdb.WithSubPollingStrategy( + gomdb.ExpBackoffPolling( + 50*time.Millisecond, // minimum polling delay on no messages read + 5*time.Second, // maximum polling delay on no messages read + 2, // delay will double for every read that returns no messages + ), + ), +) + +// Client configured with constant polling interval +client = gomdb.NewClient( + db, + gomdb.WithSubPollingStrategy( + gomdb.ConstantPolling(100*time.Millisecond), // polling delay on no messages read + ), +) +``` + +## Contributing + +All contributions welcome, especially anyone with SQL experience who could tidy up how queries are run and how read errors are handled. \ No newline at end of file diff --git a/client.go b/client.go index ee9c6f0..2478a01 100644 --- a/client.go +++ b/client.go @@ -1,3 +1,4 @@ +// Package gomdb provides a Client for calling Message DB procedures. package gomdb import ( @@ -15,6 +16,9 @@ const ( NoStreamVersion = int64(-1) // AnyVersion allows writing of a message regardless of the stream version. AnyVersion = int64(-2) + // DefaultPollingInterval defines the default polling duration for + // subscriptions. + DefaultPollingInterval = 100 * time.Millisecond ) // ErrUnexpectedStreamVersion is returned when a stream is not at the expected @@ -23,16 +27,22 @@ var ErrUnexpectedStreamVersion = errors.New("unexpected stream version when writ // Client exposes the message-db interface. type Client struct { - db *sql.DB - pollInterval time.Duration + db *sql.DB + pollingStrategy PollingStrategy } // NewClient returns a new message-db client for the provided database. -func NewClient(db *sql.DB) *Client { - return &Client{ - db: db, - pollInterval: 100 * time.Millisecond, +func NewClient(db *sql.DB, opts ...ClientOption) *Client { + c := &Client{ + db: db, + pollingStrategy: ConstantPolling(DefaultPollingInterval), } + + for _, opt := range opts { + opt(c) + } + + return c } // WriteMessage attempted to write the proposed message to the specifed stream. @@ -77,18 +87,18 @@ func (c *Client) WriteMessage(ctx context.Context, stream StreamIdentifier, mess defer rows.Close() - // read revision from results. - var revision int64 + // read version from results. + var version int64 if !rows.Next() { return 0, errors.New("write succeeded but no rows were returned") } - if err = rows.Scan(&revision); err != nil { - return 0, fmt.Errorf("write succeeded but could not read returned revision: %w", err) + if err = rows.Scan(&version); err != nil { + return 0, fmt.Errorf("write succeeded but could not read returned version: %w", err) } - return revision, nil + return version, nil } // GetStreamMessages reads messages from an individual stream. By default the @@ -234,7 +244,7 @@ func (c *Client) GetStreamVersion(ctx context.Context, stream StreamIdentifier) defer rows.Close() - // read revision from results. + // read version from results. if !rows.Next() { return 0, errors.New("no rows were returned") @@ -242,7 +252,7 @@ func (c *Client) GetStreamVersion(ctx context.Context, stream StreamIdentifier) var value interface{} if err = rows.Scan(&value); err != nil { - return 0, fmt.Errorf("reading stream revision: %w", err) + return 0, fmt.Errorf("reading stream version: %w", err) } if value == nil { @@ -261,8 +271,8 @@ type MessageHandler func(*Message) // whether it is catching up. type LivenessHandler func(bool) -// ErrorHandler handles errors that appear and stop the subscription. -type ErrorHandler func(error) +// SubDroppedHandler handles errors that appear and stop the subscription. +type SubDroppedHandler func(error) // SubscribeToStream subscribes to a stream and asynchronously passes messages // to the message handler in batches. Once a subscription has caught up it will @@ -272,13 +282,15 @@ type ErrorHandler func(error) // the subscription falls behind again it will called the LivenessHandler with // false. // If there is an error while reading messages then the subscription will be -// stopped and the ErrorHandler will be called with the stopping error. +// stopped and the SubDroppedHandler will be called with the stopping error. If +// the subscription is cancelled then the SubDroppedHandler will be called with +// nil. func (c *Client) SubscribeToStream( ctx context.Context, stream StreamIdentifier, handleMessage MessageHandler, handleLiveness LivenessHandler, - handleError ErrorHandler, + handleDropped SubDroppedHandler, opts ...GetStreamOption, ) error { cfg := newDefaultStreamConfig() @@ -289,41 +301,38 @@ func (c *Client) SubscribeToStream( // validate inputs if err := stream.validate(); err != nil { return fmt.Errorf("validating stream identifier: %w", err) - } else if handleMessage == nil || handleLiveness == nil || handleError == nil { + } else if handleMessage == nil || handleLiveness == nil || handleDropped == nil { return errors.New("all subscription handlers are required") } else if err := cfg.validate(); err != nil { return fmt.Errorf("validating options: %w", err) } // ignore context cancelled errors - wrappedHandleError := func(e error) { + wrappedHandleDropped := func(e error) { if errors.Is(e, context.Canceled) { - handleError(nil) + handleDropped(nil) } else { - handleError(ctx.Err()) + handleDropped(ctx.Err()) } } go func() { - poll := time.NewTicker(1) + poll := time.NewTimer(0) live := false + defer poll.Stop() for { // check for context cancelled select { case <-ctx.Done(): - wrappedHandleError(ctx.Err()) + wrappedHandleDropped(ctx.Err()) return case <-poll.C: } - msgs, err := c.GetStreamMessages(ctx, stream, func(c *streamConfig) { - c.version = cfg.version - c.batchSize = cfg.batchSize - c.condition = cfg.condition - }) + msgs, err := c.GetStreamMessages(ctx, stream, func(c *streamConfig) { *c = *cfg }) if err != nil { - wrappedHandleError(err) + wrappedHandleDropped(err) return } @@ -339,13 +348,100 @@ func (c *Client) SubscribeToStream( // caught up and can go live. Otherwise we've fallen behind. if len(msgs) < int(cfg.batchSize) && !live { live = true - poll.Reset(c.pollInterval) handleLiveness(live) } else if len(msgs) == int(cfg.batchSize) && live { live = false - poll.Reset(0) handleLiveness(live) } + + poll.Reset(c.pollingStrategy(int64(len(msgs)), cfg.batchSize)) + } + }() + + return nil +} + +// SubscribeToCategory subscribes to a category and asynchronously passes messages +// to the message handler in batches. Once a subscription has caught up it will +// poll the database periodically for new messages. To stop a subscription +// cancel the provided context. +// When a subscription catches up it will call the LivenessHandler with true. If +// the subscription falls behind again it will called the LivenessHandler with +// false. +// If there is an error while reading messages then the subscription will be +// stopped and the SubDroppedHandler will be called with the stopping error. If +// the subscription is cancelled then the SubDroppedHandler will be called with +// nil. +func (c *Client) SubscribeToCategory( + ctx context.Context, + category string, + handleMessage MessageHandler, + handleLiveness LivenessHandler, + handleDropped SubDroppedHandler, + opts ...GetCategoryOption, +) error { + cfg := newDefaultCategoryConfig() + for _, opt := range opts { + opt(cfg) + } + + // validate inputs + if strings.Contains(category, StreamNameSeparator) { + return fmt.Errorf("category cannot contain stream name separator (%s)", StreamNameSeparator) + } else if handleMessage == nil || handleLiveness == nil || handleDropped == nil { + return errors.New("all subscription handlers are required") + } else if err := cfg.validate(); err != nil { + return fmt.Errorf("validating options: %w", err) + } + + // ignore context cancelled errors + wrappedHandleDropped := func(e error) { + if errors.Is(e, context.Canceled) { + handleDropped(nil) + } else { + handleDropped(ctx.Err()) + } + } + + go func() { + poll := time.NewTimer(0) + live := false + defer poll.Stop() + + for { + // check for context cancelled + select { + case <-ctx.Done(): + wrappedHandleDropped(ctx.Err()) + return + case <-poll.C: + } + + msgs, err := c.GetCategoryMessages(ctx, category, func(c *categoryConfig) { *c = *cfg }) + if err != nil { + wrappedHandleDropped(err) + return + } + + for _, msg := range msgs { + handleMessage(msg) + } + + if len(msgs) > 0 { + cfg.position = msgs[len(msgs)-1].GlobalPosition + 1 + } + + // if we've read fewer messages than the batch size we must have + // caught up and can go live. Otherwise we've fallen behind. + if len(msgs) < int(cfg.batchSize) && !live { + live = true + handleLiveness(live) + } else if len(msgs) == int(cfg.batchSize) && live { + live = false + handleLiveness(live) + } + + poll.Reset(c.pollingStrategy(int64(len(msgs)), cfg.batchSize)) } }() diff --git a/client_test.go b/client_test.go index 2fe35c8..31e89d5 100644 --- a/client_test.go +++ b/client_test.go @@ -64,15 +64,18 @@ func NewClient(t *testing.T) *gomdb.Client { // NewTestStream creates a new StreamIdentifier using the provided category // prefix. -func NewTestStream(catPrefix string) gomdb.StreamIdentifier { - randstr.Base62(5) - +func NewTestStream(category string) gomdb.StreamIdentifier { return gomdb.StreamIdentifier{ - Category: catPrefix + randstr.Base62(5), + Category: category, ID: randstr.Base62(10), } } +// NewTestCategory returns a unique category name +func NewTestCategory(prefix string) string { + return prefix + randstr.Base62(10) +} + // PopulateStream creates the specified number of messages and writes them // to the specified stream. func PopulateStream(t *testing.T, client *gomdb.Client, stream gomdb.StreamIdentifier, messages int) { @@ -97,11 +100,9 @@ func PopulateStream(t *testing.T, client *gomdb.Client, stream gomdb.StreamIdent // PopulateCategory creates multiple streams within a single categatory and // populates them with messages. The actual category is returned. -func PopulateCategory(t *testing.T, client *gomdb.Client, catPrefix string, streams, messages int) string { +func PopulateCategory(t *testing.T, client *gomdb.Client, category string, streams, messages int) string { t.Helper() - category := catPrefix + randstr.Base62(5) - for i := 0; i < streams; i++ { stream := gomdb.StreamIdentifier{ Category: category, @@ -123,7 +124,7 @@ func TestWriteMessage(t *testing.T) { t.Run("stream does not exist", func(t *testing.T) { t.Parallel() - stream := NewTestStream("new_stream") + stream := NewTestStream(NewTestCategory("new_stream")) msg := gomdb.ProposedMessage{ ID: uuid.NewV4().String(), Type: "TestMessage", @@ -143,7 +144,7 @@ func TestWriteMessage(t *testing.T) { t.Run("skip OCC check", func(t *testing.T) { t.Parallel() - stream := NewTestStream("any_stream") + stream := NewTestStream(NewTestCategory("any_stream")) msg := gomdb.ProposedMessage{ ID: uuid.NewV4().String(), Type: "TestMessage", @@ -168,7 +169,7 @@ func TestWriteMessage(t *testing.T) { t.Run("fail OCC check", func(t *testing.T) { t.Parallel() - stream := NewTestStream("any_stream") + stream := NewTestStream(NewTestCategory("any_stream")) msg := gomdb.ProposedMessage{ ID: uuid.NewV4().String(), Type: "TestMessage", @@ -196,7 +197,7 @@ func TestGetStreamMessages(t *testing.T) { t.Run("stream does not exist", func(t *testing.T) { t.Parallel() - stream := NewTestStream("nonexistant") + stream := NewTestStream(NewTestCategory("nonexistant")) msgs, err := client.GetStreamMessages(context.TODO(), stream) if err != nil { @@ -211,7 +212,7 @@ func TestGetStreamMessages(t *testing.T) { t.Run("get entire stream", func(t *testing.T) { t.Parallel() - stream := NewTestStream("entire") + stream := NewTestStream(NewTestCategory("entire")) PopulateStream(t, client, stream, 10) msgs, err := client.GetStreamMessages(context.TODO(), stream) @@ -227,7 +228,7 @@ func TestGetStreamMessages(t *testing.T) { t.Run("get first half of stream", func(t *testing.T) { t.Parallel() - stream := NewTestStream("half") + stream := NewTestStream(NewTestCategory("half")) PopulateStream(t, client, stream, 10) msgs, err := client.GetStreamMessages(context.TODO(), stream, gomdb.WithStreamBatchSize(5)) @@ -249,7 +250,7 @@ func TestGetStreamMessages(t *testing.T) { t.Run("get second half of stream", func(t *testing.T) { t.Parallel() - stream := NewTestStream("half") + stream := NewTestStream(NewTestCategory("half")) PopulateStream(t, client, stream, 10) msgs, err := client.GetStreamMessages(context.TODO(), stream, gomdb.WithStreamBatchSize(5), gomdb.FromVersion(5)) @@ -275,7 +276,7 @@ func TestGetStreamMessages(t *testing.T) { t.Skip() } - stream := NewTestStream("conditional") + stream := NewTestStream(NewTestCategory("conditional")) PopulateStream(t, client, stream, 10) msgs, err := client.GetStreamMessages(context.TODO(), stream, @@ -319,7 +320,7 @@ func TestGetCategoryMessages(t *testing.T) { t.Run("get all messages for category", func(t *testing.T) { t.Parallel() - category := PopulateCategory(t, client, "category", 5, 10) + category := PopulateCategory(t, client, NewTestCategory("category"), 5, 10) msgs, err := client.GetCategoryMessages(context.TODO(), category) if err != nil { @@ -334,7 +335,7 @@ func TestGetCategoryMessages(t *testing.T) { t.Run("get half messages for category", func(t *testing.T) { t.Parallel() - category := PopulateCategory(t, client, "half", 5, 10) + category := PopulateCategory(t, client, NewTestCategory("half"), 5, 10) // read all messages. msgs, _ := client.GetCategoryMessages(context.TODO(), category) @@ -356,7 +357,7 @@ func TestGetCategoryMessages(t *testing.T) { t.Run("read only first 15 messages in category", func(t *testing.T) { t.Parallel() - category := PopulateCategory(t, client, "batched", 5, 10) + category := PopulateCategory(t, client, NewTestCategory("batched"), 5, 10) msgs, err := client.GetCategoryMessages(context.TODO(), category, gomdb.WithCategoryBatchSize(15)) if err != nil { @@ -371,7 +372,7 @@ func TestGetCategoryMessages(t *testing.T) { t.Run("read as consumer group", func(t *testing.T) { t.Parallel() - category := PopulateCategory(t, client, "batched", 5, 10) + category := PopulateCategory(t, client, NewTestCategory("consumer"), 5, 10) msgs1, err := client.GetCategoryMessages(context.TODO(), category, gomdb.AsConsumerGroup(0, 2)) if err != nil { @@ -393,7 +394,42 @@ func TestGetCategoryMessages(t *testing.T) { t.Run("read with correlation", func(t *testing.T) { t.Parallel() - t.Skip() // TODO + category := NewTestCategory("correlation") + stream := NewTestStream(category) + + // write correlated event + _, _ = client.WriteMessage(context.TODO(), stream, gomdb.ProposedMessage{ + ID: uuid.NewV4().String(), + Type: "Correlated", + Data: "data", + Metadata: map[string]string{ + gomdb.CorrelationKey: "correlated", + }, + }, gomdb.AnyVersion) + + // write uncorrelated event + _, _ = client.WriteMessage(context.TODO(), stream, gomdb.ProposedMessage{ + ID: uuid.NewV4().String(), + Type: "Uncorrelated", + Data: "data", + Metadata: map[string]string{ + gomdb.CorrelationKey: "uncorrelated", + }, + }, gomdb.AnyVersion) + + msgs, err := client.GetCategoryMessages(context.TODO(), category, gomdb.WithCorrelation("correlated")) + if err != nil { + t.Fatal(err) + } + + if len(msgs) != 1 { + t.Fatalf("expected 1 correlated message, actual: %v", len(msgs)) + } + + if msgs[0].Type != "Correlated" { + t.Fatalf("expected message type of Correlated, actual: %s", msgs[0].Type) + } + }) t.Run("read with condition", func(t *testing.T) { @@ -403,7 +439,7 @@ func TestGetCategoryMessages(t *testing.T) { t.Skip() } - category := PopulateCategory(t, client, "batched", 5, 10) + category := PopulateCategory(t, client, NewTestCategory("condition"), 5, 10) msgs, err := client.GetCategoryMessages(context.TODO(), category, gomdb.WithCategoryCondition("MOD(messages.position, 2) = 0"), @@ -433,7 +469,7 @@ func TestGetLastStreamMessage(t *testing.T) { t.Run("stream does not exist", func(t *testing.T) { t.Parallel() - stream := NewTestStream("nonexistant") + stream := NewTestStream(NewTestCategory("nonexistant")) msg, err := client.GetLastStreamMessage(context.TODO(), stream) if err != nil { @@ -448,7 +484,7 @@ func TestGetLastStreamMessage(t *testing.T) { t.Run("get last message", func(t *testing.T) { t.Parallel() - stream := NewTestStream("stream") + stream := NewTestStream(NewTestCategory("stream")) PopulateStream(t, client, stream, 3) msg, err := client.GetLastStreamMessage(context.TODO(), stream) @@ -473,7 +509,7 @@ func TestGetStreamVersion(t *testing.T) { t.Run("stream does not exist", func(t *testing.T) { t.Parallel() - stream := NewTestStream("nonexistant") + stream := NewTestStream(NewTestCategory("nonexistant")) version, err := client.GetStreamVersion(context.TODO(), stream) if err != nil { @@ -488,7 +524,7 @@ func TestGetStreamVersion(t *testing.T) { t.Run("get stream version", func(t *testing.T) { t.Parallel() - stream := NewTestStream("stream") + stream := NewTestStream(NewTestCategory("stream")) PopulateStream(t, client, stream, 3) version, err := client.GetStreamVersion(context.TODO(), stream) @@ -504,14 +540,14 @@ func TestGetStreamVersion(t *testing.T) { // TestSubscribeToStream tests the SubscribeToStream API. func TestSubscribeToStream(t *testing.T) { - // t.Parallel() + t.Parallel() client := NewClient(t) t.Run("subscribe to empty stream", func(t *testing.T) { - // t.Parallel() + t.Parallel() - stream := NewTestStream("nonexistant") + stream := NewTestStream(NewTestCategory("empty")) ctx, cancel := context.WithCancel(context.TODO()) defer cancel() @@ -544,9 +580,9 @@ func TestSubscribeToStream(t *testing.T) { }) t.Run("subscribe to new messages", func(t *testing.T) { - // t.Parallel() + t.Parallel() - stream := NewTestStream("nonexistant") + stream := NewTestStream(NewTestCategory("incoming")) ctx, cancel := context.WithCancel(context.TODO()) defer cancel() @@ -572,9 +608,9 @@ func TestSubscribeToStream(t *testing.T) { }) t.Run("catch up to stream then go live", func(t *testing.T) { - // t.Parallel() + t.Parallel() - stream := NewTestStream("nonexistant") + stream := NewTestStream(NewTestCategory("catchup")) PopulateStream(t, client, stream, 10) ctx, cancel := context.WithCancel(context.TODO()) @@ -605,5 +641,132 @@ func TestSubscribeToStream(t *testing.T) { received.Wait() }) +} +// TestSubscribeToCategory tests the SubscribeToCategory API. +func TestSubscribeToCategory(t *testing.T) { + t.Parallel() + + client := NewClient(t) + + t.Run("subscribe to empty category", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + goneLive := sync.WaitGroup{} + goneLive.Add(1) + + err := client.SubscribeToCategory( + ctx, + NewTestCategory("empty"), + func(m *gomdb.Message) { + t.Fatal("No messages should exist on stream") + }, + func(live bool) { + if !live { + t.Fatal("subscription should be live") + } + goneLive.Done() + }, + func(err error) { + if err != nil { + t.Fatalf("received subscription error: %s", err) + } + }, + ) + if err != nil { + t.Fatal(err) + } + + goneLive.Wait() + }) + + t.Run("subscribe to new category messages", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + goneLive := sync.WaitGroup{} + goneLive.Add(1) + received := sync.WaitGroup{} + received.Add(30) + + category := NewTestCategory("empty") + + err := client.SubscribeToCategory( + ctx, + category, + func(m *gomdb.Message) { + received.Done() + }, + func(live bool) { + if !live { + t.Fatal("subscription should be live") + } + goneLive.Done() + }, + func(err error) { + if err != nil { + t.Fatalf("received subscription error: %s", err) + } + }, + ) + if err != nil { + t.Fatal(err) + } + + goneLive.Wait() + + PopulateCategory(t, client, category, 3, 10) + + received.Wait() + }) + + t.Run("catch up to category then go live", func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.TODO()) + defer cancel() + + goneLive := sync.WaitGroup{} + goneLive.Add(1) + received := sync.WaitGroup{} + received.Add(30) + + category := NewTestCategory("empty") + PopulateCategory(t, client, category, 3, 10) + + err := client.SubscribeToCategory( + ctx, + category, + func(m *gomdb.Message) { + received.Done() + }, + func(live bool) { + if !live { + t.Fatal("subscription should be live") + } + goneLive.Done() + }, + func(err error) { + if err != nil { + t.Fatalf("received subscription error: %s", err) + } + }, + ) + if err != nil { + t.Fatal(err) + } + + received.Wait() + goneLive.Wait() + + // receive 10 more messages live + received.Add(10) + PopulateCategory(t, client, category, 2, 5) + received.Wait() + }) } diff --git a/options.go b/options.go index 40dc71e..135ae8d 100644 --- a/options.go +++ b/options.go @@ -1,6 +1,79 @@ package gomdb -import "errors" +import ( + "errors" + "math" + "time" +) + +var ( + // ErrInvalidReadStreamVersion is returned when the stream version inside a + // read call is less than zero. + ErrInvalidReadStreamVersion = errors.New("stream version cannot be less than 0") + // ErrInvalidReadBatchSize is returned when the batch size inside a read + // call is less than one. + ErrInvalidReadBatchSize = errors.New("batch size must be greater than 0") + // ErrInvalidReadPosition is returned when the stream position inside a + // read call is less than zero. + ErrInvalidReadPosition = errors.New("stream position cannot be less than 0") + // ErrInvalidConsumerGroupMember is returned when the consumer group ID + // index is either less than zero or greater than or equal to the consumer + // group size. + ErrInvalidConsumerGroupMember = errors.New("consumer group member must be >= 0 < group size") + // ErrInvalidConsumerGroupSize is returned when the consumer group size is + // less that zero. + ErrInvalidConsumerGroupSize = errors.New("consumer group size must be 0 or greater (0 to disbale consumer groups)") +) + +// ClientOption is an option for modifiying how the Message DB client operates. +type ClientOption func(*Client) + +// WithSubPollingStrategy configures the client with the specified +// PollingStrategy. +func WithSubPollingStrategy(strat PollingStrategy) ClientOption { + return func(c *Client) { + c.pollingStrategy = strat + } +} + +// PollingStrategy returns the delay duration before the next polling attempt +// based on how many messages were returned from the previous poll vs how many +// were expected. +type PollingStrategy func(retrieved, expected int64) time.Duration + +// ExpBackoffPolling returns an exponential polling backoff strategy that starts +// at the min duration but is multipled for every read that did not return +// any messages up to the max duration. The backoff duration is reset to min +// everytime a message is read. +func ExpBackoffPolling(min, max time.Duration, multiplier float64) PollingStrategy { + noMessageCount := 0 + return func(retrieved, _ int64) time.Duration { + if retrieved > 0 { + noMessageCount = 0 + return time.Duration(0) + } + + backoff := time.Duration(math.Pow(multiplier, float64(noMessageCount))) * min + noMessageCount++ + + if backoff > max { + return max + } + + return backoff + } +} + +// ConstantPolling returns a constant interval polling strategy +func ConstantPolling(interval time.Duration) PollingStrategy { + return func(retrieved, _ int64) time.Duration { + if retrieved > 0 { + return time.Duration(0) + } + + return interval + } +} // GetStreamOption is an option for modifiying how to read from a stream. type GetStreamOption func(*streamConfig) @@ -35,9 +108,9 @@ type streamConfig struct { func (cfg *streamConfig) validate() error { if cfg.version < 0 { - return errors.New("stream version cannot be less than 0") + return ErrInvalidReadStreamVersion } else if cfg.batchSize < 1 { - return errors.New("batch size must be greater than 0") + return ErrInvalidReadBatchSize } return nil @@ -123,13 +196,13 @@ func newDefaultCategoryConfig() *categoryConfig { func (cfg *categoryConfig) validate() error { if cfg.position < 0 { - return errors.New("stream version cannot be less than 0") + return ErrInvalidReadPosition } else if cfg.batchSize < 1 { - return errors.New("batch size must be greater than 0") - } else if cfg.consumerGroupMember < 0 { - return errors.New("consumer group member must be 0 or greater") + return ErrInvalidReadBatchSize + } else if cfg.consumerGroupMember < 0 || (cfg.consumerGroupSize > 0 && cfg.consumerGroupMember >= cfg.consumerGroupSize) { + return ErrInvalidConsumerGroupMember } else if cfg.consumerGroupSize < 0 { - return errors.New("consumer group size must be 0 or greater (0 to disbale consumer groups)") + return ErrInvalidConsumerGroupSize } return nil diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..ee3b2b3 --- /dev/null +++ b/options_test.go @@ -0,0 +1,160 @@ +package gomdb + +import ( + "errors" + "testing" + "time" +) + +func Test_PollingStrategies(t *testing.T) { + t.Parallel() + + testcases := []struct { + name string + strategy PollingStrategy + retrieved []int64 + expected int64 + delays []time.Duration + }{ + { + name: "constant polling", + strategy: ConstantPolling(time.Second), + retrieved: []int64{0, 0, 1, 1}, + delays: []time.Duration{time.Second, time.Second, 0, 0}, + }, + { + name: "exponential polling", + strategy: ExpBackoffPolling(time.Second, 10*time.Second, 2), + retrieved: []int64{1, 1, 0, 0, 0, 0, 0, 0}, + delays: []time.Duration{0, 0, time.Second, 2 * time.Second, 4 * time.Second, 8 * time.Second, 10 * time.Second, 10 * time.Second}, + }, + } + + for _, tc := range testcases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + for i, r := range tc.retrieved { + d := tc.strategy(r, tc.expected) + if d != tc.delays[i] { + t.Fatalf("on retreived %v expected delay %s, actual %s", r, d, tc.delays[i]) + } + } + }) + } +} + +func Test_categoryConfig_validate(t *testing.T) { + testcases := []struct { + name string + config categoryConfig + expErr error + }{ + { + name: "invalid stream position", + config: categoryConfig{ + position: -1, + batchSize: 1, + }, + expErr: ErrInvalidReadPosition, + }, + { + name: "invalid batch size", + config: categoryConfig{ + position: 0, + batchSize: 0, + }, + expErr: ErrInvalidReadBatchSize, + }, + { + name: "negative consumer member index", + config: categoryConfig{ + position: 0, + batchSize: 1, + consumerGroupMember: -1, + }, + expErr: ErrInvalidConsumerGroupMember, + }, + { + name: "consumer member index out of range", + config: categoryConfig{ + position: 0, + batchSize: 1, + consumerGroupMember: 1, + consumerGroupSize: 1, + }, + expErr: ErrInvalidConsumerGroupMember, + }, + { + name: "negative consumer group size", + config: categoryConfig{ + position: 0, + batchSize: 1, + consumerGroupSize: -1, + }, + expErr: ErrInvalidConsumerGroupSize, + }, + { + name: "valid", + config: categoryConfig{ + position: 0, + batchSize: 1, + consumerGroupMember: 1, + consumerGroupSize: 2, + }, + }, + } + + for _, tc := range testcases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := tc.config.validate() + if !errors.Is(err, tc.expErr) { + t.Fatalf("expected %v, actual %v", tc.expErr, err) + } + }) + } +} + +func Test_streamConfig_validate(t *testing.T) { + testcases := []struct { + name string + config streamConfig + expErr error + }{ + { + name: "invalid stream version", + config: streamConfig{ + version: -1, + batchSize: 1, + }, + expErr: ErrInvalidReadStreamVersion, + }, + { + name: "invalid batch size", + config: streamConfig{ + version: 0, + batchSize: 0, + }, + expErr: ErrInvalidReadBatchSize, + }, + { + name: "valid", + config: streamConfig{ + version: 0, + batchSize: 1, + }, + }, + } + + for _, tc := range testcases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + err := tc.config.validate() + if !errors.Is(err, tc.expErr) { + t.Fatalf("expected %v, actual %v", tc.expErr, err) + } + }) + } +} diff --git a/types_test.go b/types_test.go index ea37bed..d0ef627 100644 --- a/types_test.go +++ b/types_test.go @@ -38,6 +38,14 @@ func Test_ProposedMessage_validate(t *testing.T) { }, expErr: ErrMissingData, }, + { + name: "valid", + message: ProposedMessage{ + ID: uuid.NewV4().String(), + Type: "SomeType", + Data: "data", + }, + }, } for _, tc := range testcases { @@ -87,6 +95,13 @@ func Test_StreamIdentifier_validate(t *testing.T) { }, expErr: ErrInvalidStreamID, }, + { + name: "valid", + sid: StreamIdentifier{ + Category: "category", + ID: "123abc", + }, + }, } for _, tc := range testcases {