diff --git a/access/handler.go b/access/handler.go index 25316e7f3dd..b974e7034fc 100644 --- a/access/handler.go +++ b/access/handler.go @@ -1066,7 +1066,7 @@ func (h *Handler) SubscribeBlocksFromStartBlockID(request *access.SubscribeBlock } sub := h.api.SubscribeBlocksFromStartBlockID(stream.Context(), startBlockID, blockStatus) - return subscription.HandleSubscription(sub, h.handleBlocksResponse(stream.Send, request.GetFullBlockResponse(), blockStatus)) + return subscription.HandleRPCSubscription(sub, h.handleBlocksResponse(stream.Send, request.GetFullBlockResponse(), blockStatus)) } // SubscribeBlocksFromStartHeight handles subscription requests for blocks started from block height. @@ -1093,7 +1093,7 @@ func (h *Handler) SubscribeBlocksFromStartHeight(request *access.SubscribeBlocks } sub := h.api.SubscribeBlocksFromStartHeight(stream.Context(), request.GetStartBlockHeight(), blockStatus) - return subscription.HandleSubscription(sub, h.handleBlocksResponse(stream.Send, request.GetFullBlockResponse(), blockStatus)) + return subscription.HandleRPCSubscription(sub, h.handleBlocksResponse(stream.Send, request.GetFullBlockResponse(), blockStatus)) } // SubscribeBlocksFromLatest handles subscription requests for blocks started from latest sealed block. @@ -1120,7 +1120,7 @@ func (h *Handler) SubscribeBlocksFromLatest(request *access.SubscribeBlocksFromL } sub := h.api.SubscribeBlocksFromLatest(stream.Context(), blockStatus) - return subscription.HandleSubscription(sub, h.handleBlocksResponse(stream.Send, request.GetFullBlockResponse(), blockStatus)) + return subscription.HandleRPCSubscription(sub, h.handleBlocksResponse(stream.Send, request.GetFullBlockResponse(), blockStatus)) } // handleBlocksResponse handles the subscription to block updates and sends @@ -1179,7 +1179,7 @@ func (h *Handler) SubscribeBlockHeadersFromStartBlockID(request *access.Subscrib } sub := h.api.SubscribeBlockHeadersFromStartBlockID(stream.Context(), startBlockID, blockStatus) - return subscription.HandleSubscription(sub, h.handleBlockHeadersResponse(stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleBlockHeadersResponse(stream.Send)) } // SubscribeBlockHeadersFromStartHeight handles subscription requests for block headers started from block height. @@ -1206,7 +1206,7 @@ func (h *Handler) SubscribeBlockHeadersFromStartHeight(request *access.Subscribe } sub := h.api.SubscribeBlockHeadersFromStartHeight(stream.Context(), request.GetStartBlockHeight(), blockStatus) - return subscription.HandleSubscription(sub, h.handleBlockHeadersResponse(stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleBlockHeadersResponse(stream.Send)) } // SubscribeBlockHeadersFromLatest handles subscription requests for block headers started from latest sealed block. @@ -1233,7 +1233,7 @@ func (h *Handler) SubscribeBlockHeadersFromLatest(request *access.SubscribeBlock } sub := h.api.SubscribeBlockHeadersFromLatest(stream.Context(), blockStatus) - return subscription.HandleSubscription(sub, h.handleBlockHeadersResponse(stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleBlockHeadersResponse(stream.Send)) } // handleBlockHeadersResponse handles the subscription to block updates and sends @@ -1293,7 +1293,7 @@ func (h *Handler) SubscribeBlockDigestsFromStartBlockID(request *access.Subscrib } sub := h.api.SubscribeBlockDigestsFromStartBlockID(stream.Context(), startBlockID, blockStatus) - return subscription.HandleSubscription(sub, h.handleBlockDigestsResponse(stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleBlockDigestsResponse(stream.Send)) } // SubscribeBlockDigestsFromStartHeight handles subscription requests for lightweight blocks started from block height. @@ -1320,7 +1320,7 @@ func (h *Handler) SubscribeBlockDigestsFromStartHeight(request *access.Subscribe } sub := h.api.SubscribeBlockDigestsFromStartHeight(stream.Context(), request.GetStartBlockHeight(), blockStatus) - return subscription.HandleSubscription(sub, h.handleBlockDigestsResponse(stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleBlockDigestsResponse(stream.Send)) } // SubscribeBlockDigestsFromLatest handles subscription requests for lightweight block started from latest sealed block. @@ -1347,7 +1347,7 @@ func (h *Handler) SubscribeBlockDigestsFromLatest(request *access.SubscribeBlock } sub := h.api.SubscribeBlockDigestsFromLatest(stream.Context(), blockStatus) - return subscription.HandleSubscription(sub, h.handleBlockDigestsResponse(stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleBlockDigestsResponse(stream.Send)) } // handleBlockDigestsResponse handles the subscription to block updates and sends @@ -1433,7 +1433,7 @@ func (h *Handler) SendAndSubscribeTransactionStatuses( sub := h.api.SubscribeTransactionStatuses(ctx, &tx, request.GetEventEncodingVersion()) messageIndex := counters.NewMonotonousCounter(0) - return subscription.HandleSubscription(sub, func(txResults []*TransactionResult) error { + return subscription.HandleRPCSubscription(sub, func(txResults []*TransactionResult) error { for i := range txResults { index := messageIndex.Value() if ok := messageIndex.Set(index + 1); !ok { diff --git a/engine/access/state_stream/backend/handler.go b/engine/access/state_stream/backend/handler.go index b2066440bb8..3acf1bad6ca 100644 --- a/engine/access/state_stream/backend/handler.go +++ b/engine/access/state_stream/backend/handler.go @@ -102,7 +102,7 @@ func (h *Handler) SubscribeExecutionData(request *executiondata.SubscribeExecuti sub := h.api.SubscribeExecutionData(stream.Context(), startBlockID, request.GetStartBlockHeight()) - return subscription.HandleSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) } // SubscribeExecutionDataFromStartBlockID handles subscription requests for @@ -129,7 +129,7 @@ func (h *Handler) SubscribeExecutionDataFromStartBlockID(request *executiondata. sub := h.api.SubscribeExecutionDataFromStartBlockID(stream.Context(), startBlockID) - return subscription.HandleSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) } // SubscribeExecutionDataFromStartBlockHeight handles subscription requests for @@ -150,7 +150,7 @@ func (h *Handler) SubscribeExecutionDataFromStartBlockHeight(request *executiond sub := h.api.SubscribeExecutionDataFromStartBlockHeight(stream.Context(), request.GetStartBlockHeight()) - return subscription.HandleSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) } // SubscribeExecutionDataFromLatest handles subscription requests for @@ -171,7 +171,7 @@ func (h *Handler) SubscribeExecutionDataFromLatest(request *executiondata.Subscr sub := h.api.SubscribeExecutionDataFromLatest(stream.Context()) - return subscription.HandleSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, handleSubscribeExecutionData(stream.Send, request.GetEventEncodingVersion())) } // SubscribeEvents is deprecated and will be removed in a future version. @@ -213,7 +213,7 @@ func (h *Handler) SubscribeEvents(request *executiondata.SubscribeEventsRequest, sub := h.api.SubscribeEvents(stream.Context(), startBlockID, request.GetStartBlockHeight(), filter) - return subscription.HandleSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) } // SubscribeEventsFromStartBlockID handles subscription requests for events starting at the specified block ID. @@ -248,7 +248,7 @@ func (h *Handler) SubscribeEventsFromStartBlockID(request *executiondata.Subscri sub := h.api.SubscribeEventsFromStartBlockID(stream.Context(), startBlockID, filter) - return subscription.HandleSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) } // SubscribeEventsFromStartHeight handles subscription requests for events starting at the specified block height. @@ -278,7 +278,7 @@ func (h *Handler) SubscribeEventsFromStartHeight(request *executiondata.Subscrib sub := h.api.SubscribeEventsFromStartHeight(stream.Context(), request.GetStartBlockHeight(), filter) - return subscription.HandleSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) } // SubscribeEventsFromLatest handles subscription requests for events started from latest sealed block.. @@ -308,7 +308,7 @@ func (h *Handler) SubscribeEventsFromLatest(request *executiondata.SubscribeEven sub := h.api.SubscribeEventsFromLatest(stream.Context(), filter) - return subscription.HandleSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) + return subscription.HandleRPCSubscription(sub, h.handleEventsResponse(stream.Send, request.HeartbeatInterval, request.GetEventEncodingVersion())) } // handleSubscribeExecutionData handles the subscription to execution data and sends it to the client via the provided stream. @@ -546,7 +546,7 @@ func (h *Handler) SubscribeAccountStatusesFromStartBlockID( sub := h.api.SubscribeAccountStatusesFromStartBlockID(stream.Context(), startBlockID, filter) - return subscription.HandleSubscription(sub, h.handleAccountStatusesResponse(request.HeartbeatInterval, request.GetEventEncodingVersion(), stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleAccountStatusesResponse(request.HeartbeatInterval, request.GetEventEncodingVersion(), stream.Send)) } // SubscribeAccountStatusesFromStartHeight streams account statuses for all blocks starting at the requested @@ -573,7 +573,7 @@ func (h *Handler) SubscribeAccountStatusesFromStartHeight( sub := h.api.SubscribeAccountStatusesFromStartHeight(stream.Context(), request.GetStartBlockHeight(), filter) - return subscription.HandleSubscription(sub, h.handleAccountStatusesResponse(request.HeartbeatInterval, request.GetEventEncodingVersion(), stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleAccountStatusesResponse(request.HeartbeatInterval, request.GetEventEncodingVersion(), stream.Send)) } // SubscribeAccountStatusesFromLatestBlock streams account statuses for all blocks starting @@ -600,5 +600,5 @@ func (h *Handler) SubscribeAccountStatusesFromLatestBlock( sub := h.api.SubscribeAccountStatusesFromLatestBlock(stream.Context(), filter) - return subscription.HandleSubscription(sub, h.handleAccountStatusesResponse(request.HeartbeatInterval, request.GetEventEncodingVersion(), stream.Send)) + return subscription.HandleRPCSubscription(sub, h.handleAccountStatusesResponse(request.HeartbeatInterval, request.GetEventEncodingVersion(), stream.Send)) } diff --git a/engine/access/subscription/util.go b/engine/access/subscription/util.go index 593f3d78499..604cbf6e597 100644 --- a/engine/access/subscription/util.go +++ b/engine/access/subscription/util.go @@ -1,8 +1,10 @@ package subscription import ( + "context" + "fmt" + "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" "github.com/onflow/flow-go/engine/common/rpc" ) @@ -11,29 +13,50 @@ import ( // handles the received responses, and sends the processed information to the client via the provided stream using handleResponse. // // Parameters: +// - ctx: Context for the operation. // - sub: The subscription. // - handleResponse: The function responsible for handling the response of the subscribed type. // -// Expected errors during normal operation: -// - codes.Internal: If the subscription encounters an error or gets an unexpected response. -func HandleSubscription[T any](sub Subscription, handleResponse func(resp T) error) error { +// No errors are expected during normal operations. +func HandleSubscription[T any](ctx context.Context, sub Subscription, handleResponse func(resp T) error) error { for { - v, ok := <-sub.Channel() - if !ok { - if sub.Err() != nil { - return rpc.ConvertError(sub.Err(), "stream encountered an error", codes.Internal) + select { + case v, ok := <-sub.Channel(): + if !ok { + if sub.Err() != nil { + return fmt.Errorf("stream encountered an error: %w", sub.Err()) + } + return nil } - return nil - } - resp, ok := v.(T) - if !ok { - return status.Errorf(codes.Internal, "unexpected response type: %T", v) - } + resp, ok := v.(T) + if !ok { + return fmt.Errorf("unexpected response type: %T", v) + } - err := handleResponse(resp) - if err != nil { - return err + err := handleResponse(resp) + if err != nil { + return err + } + case <-ctx.Done(): + return nil } } } + +// HandleRPCSubscription is a generic handler for subscriptions to a specific type for rpc calls. +// +// Parameters: +// - sub: The subscription. +// - handleResponse: The function responsible for handling the response of the subscribed type. +// +// Expected errors during normal operation: +// - codes.Internal: If the subscription encounters an error or gets an unexpected response. +func HandleRPCSubscription[T any](sub Subscription, handleResponse func(resp T) error) error { + err := HandleSubscription(nil, sub, handleResponse) + if err != nil { + return rpc.ConvertError(err, "handle subscription error", codes.Internal) + } + + return nil +}