Skip to content

Commit

Permalink
fix(oomagent/join): fix deadlock
Browse files Browse the repository at this point in the history
  • Loading branch information
lianxmfor committed Jan 26, 2022
1 parent 6775b4e commit abc2050
Showing 1 changed file with 60 additions and 52 deletions.
112 changes: 60 additions & 52 deletions oomagent/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,40 +199,44 @@ func (s *server) ChannelJoin(stream codegen.OomAgent_ChannelJoinServer) error {
var globalErr error

// This channel indicates when the the ChannelJoin oomstore operation is finished, whether succeeded or failed.
done := make(chan struct{})
done := make(chan struct{}, 1)

// This channel receives requests from the client.
entityRows := make(chan types.EntityRow)
entityRows := make(chan types.EntityRow, 1)

// This goroutine runs the join operation, and send whatever joined as the response
go func() {
joinResult, err := s.oomstore.ChannelJoin(context.Background(), types.ChannelJoinOpt{
JoinFeatureNames: firstReq.JoinFeatures,
EntityRows: entityRows,
ExistedFeatureNames: firstReq.ExistedFeatures,
})
if err != nil {
globalErr = err
} else {
header := joinResult.Header
for row := range joinResult.Data {
joinedRow, err := convertJoinedRow(row)
if err != nil {
globalErr = err
break
}
resp := &codegen.ChannelJoinResponse{
Header: header,
JoinedRow: joinedRow,
}
if err = stream.Send(resp); err != nil {
globalErr = err
break
defer func() {
close(entityRows)
}()

for {
req, err := stream.Recv()
if err == io.EOF {
return
}
if err != nil {
globalErr = err
return
}
if globalErr != nil {
return
}
if req.GetEntityRow() == nil {
globalErr = errdefs.Errorf("cannot process nil entity row")
return
}

select {
case <-done:
return
default:
entityRows <- types.EntityRow{
EntityKey: req.EntityRow.EntityKey,
UnixMilli: req.EntityRow.UnixMilli,
Values: req.EntityRow.Values,
}
// Only need to send header upon the first response
header = nil
}
}
done <- struct{}{}
}()

// DO NOT move it before the goroutine starts,
Expand All @@ -243,33 +247,37 @@ func (s *server) ChannelJoin(stream codegen.OomAgent_ChannelJoinServer) error {
Values: firstReq.EntityRow.Values,
}

for {
req, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
globalErr = err
break
}
if globalErr != nil {
break
}
if req.GetEntityRow() == nil {
globalErr = errdefs.Errorf("cannot process nil entity row")
break
}
entityRows <- types.EntityRow{
EntityKey: req.EntityRow.EntityKey,
UnixMilli: req.EntityRow.UnixMilli,
Values: req.EntityRow.Values,
// This goroutine runs the join operation, and send whatever joined as the response
joinResult, err := s.oomstore.ChannelJoin(context.Background(), types.ChannelJoinOpt{
JoinFeatureNames: firstReq.JoinFeatures,
EntityRows: entityRows,
ExistedFeatureNames: firstReq.ExistedFeatures,
})
if err != nil {
globalErr = err
} else {
header := joinResult.Header
for row := range joinResult.Data {
joinedRow, err := convertJoinedRow(row)
if err != nil {
globalErr = err
break
}
resp := &codegen.ChannelJoinResponse{
Header: header,
JoinedRow: joinedRow,
}
if err = stream.Send(resp); err != nil {
globalErr = err
break
}
// Only need to send header upon the first response
header = nil
}
}

close(entityRows)
// wait until oomstore ChannelJoin is done, whether succeeded or failed
<-done

// send a notification to the data receiving goroutine that ChannelJoin has done
done <- struct{}{}
return globalErr
}

Expand Down

0 comments on commit abc2050

Please sign in to comment.