diff --git a/oomagent/server.go b/oomagent/server.go index e19b64aba..54583ce2e 100644 --- a/oomagent/server.go +++ b/oomagent/server.go @@ -199,40 +199,44 @@ func (s *server) ChannelJoin(stream codegen.OomAgent_ChannelJoinServer) error { var globalErr error // This channel indicates when the the ChannelJoin oomstore operation is finished, whether succeeded or failed. - done := make(chan struct{}) + done := make(chan struct{}, 1) + // This channel receives requests from the client. - entityRows := make(chan types.EntityRow) + entityRows := make(chan types.EntityRow, 1) - // This goroutine runs the join operation, and send whatever joined as the response go func() { - joinResult, err := s.oomstore.ChannelJoin(context.Background(), types.ChannelJoinOpt{ - JoinFeatureNames: firstReq.JoinFeatures, - EntityRows: entityRows, - ExistedFeatureNames: firstReq.ExistedFeatures, - }) - if err != nil { - globalErr = err - } else { - header := joinResult.Header - for row := range joinResult.Data { - joinedRow, err := convertJoinedRow(row) - if err != nil { - globalErr = err - break - } - resp := &codegen.ChannelJoinResponse{ - Header: header, - JoinedRow: joinedRow, - } - if err = stream.Send(resp); err != nil { - globalErr = err - break + defer func() { + close(entityRows) + }() + + for { + req, err := stream.Recv() + if err == io.EOF { + return + } + if err != nil { + globalErr = err + return + } + if globalErr != nil { + return + } + if req.GetEntityRow() == nil { + globalErr = errdefs.Errorf("cannot process nil entity row") + return + } + + select { + case <-done: + return + default: + entityRows <- types.EntityRow{ + EntityKey: req.EntityRow.EntityKey, + UnixMilli: req.EntityRow.UnixMilli, + Values: req.EntityRow.Values, } - // Only need to send header upon the first response - header = nil } } - done <- struct{}{} }() // DO NOT move it before the goroutine starts, @@ -243,33 +247,37 @@ func (s *server) ChannelJoin(stream codegen.OomAgent_ChannelJoinServer) error { Values: firstReq.EntityRow.Values, } - for { - req, err := stream.Recv() - if err == io.EOF { - break - } - if err != nil { - globalErr = err - break - } - if globalErr != nil { - break - } - if req.GetEntityRow() == nil { - globalErr = errdefs.Errorf("cannot process nil entity row") - break - } - entityRows <- types.EntityRow{ - EntityKey: req.EntityRow.EntityKey, - UnixMilli: req.EntityRow.UnixMilli, - Values: req.EntityRow.Values, + // This goroutine runs the join operation, and send whatever joined as the response + joinResult, err := s.oomstore.ChannelJoin(context.Background(), types.ChannelJoinOpt{ + JoinFeatureNames: firstReq.JoinFeatures, + EntityRows: entityRows, + ExistedFeatureNames: firstReq.ExistedFeatures, + }) + if err != nil { + globalErr = err + } else { + header := joinResult.Header + for row := range joinResult.Data { + joinedRow, err := convertJoinedRow(row) + if err != nil { + globalErr = err + break + } + resp := &codegen.ChannelJoinResponse{ + Header: header, + JoinedRow: joinedRow, + } + if err = stream.Send(resp); err != nil { + globalErr = err + break + } + // Only need to send header upon the first response + header = nil } } - close(entityRows) - // wait until oomstore ChannelJoin is done, whether succeeded or failed - <-done - + // send a notification to the data receiving goroutine that ChannelJoin has done + done <- struct{}{} return globalErr }