Skip to content

Commit

Permalink
fix: pipeline trigger model hangout (#80)
Browse files Browse the repository at this point in the history
Because

- Trigger model timeout 5 seconds is not enough for making model inference
- Goroutines use sync.WaitGroup which does not decrease the counter due to exception errors

This commit

- Increase timeout to 30 seconds
- defer wg.Done() and end function when have any error
- close #79 
- close #77
  • Loading branch information
Phelan164 authored Oct 23, 2022
1 parent 8307ac7 commit a692d2b
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions pkg/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,17 +378,22 @@ func (s *service) TriggerPipeline(req *pipelinePB.TriggerPipelineRequest, dbPipe
wg.Add(1)

var modelInstOutputs []*pipelinePB.ModelInstanceOutput
errors := make(chan error)
go func() {
defer wg.Done()

for idx, modelInstance := range dbPipeline.Recipe.ModelInstances {

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
resp, err := s.modelServiceClient.TriggerModelInstance(ctx, &modelPB.TriggerModelInstanceRequest{
Name: modelInstance,
Inputs: inputs,
})
if err != nil {
errors <- err
logger.Error(fmt.Sprintf("[model-backend] Error %s at %dth model instance %s: %v", "TriggerModel", idx, modelInstance, err.Error()))
return
}

taskOutputs := cvtModelTaskOutputToPipelineTaskOutput(resp.TaskOutputs)
Expand All @@ -405,29 +410,54 @@ func (s *service) TriggerPipeline(req *pipelinePB.TriggerPipelineRequest, dbPipe
// Increment trigger image numbers
uid, err := resource.GetPermalinkUID(dbPipeline.Owner)
if err != nil {
errors <- err
logger.Error(err.Error())
return
}
if strings.HasPrefix(dbPipeline.Owner, "users/") {
s.redisClient.IncrBy(context.Background(), fmt.Sprintf("user:%s:trigger.image.num", uid), int64(len(inputs)))
} else if strings.HasPrefix(dbPipeline.Owner, "orgs/") {
s.redisClient.IncrBy(context.Background(), fmt.Sprintf("org:%s:trigger.image.num", uid), int64(len(inputs)))
}
}
wg.Done()
errors <- nil
}()

switch {
// If this is a SYNC trigger (i.e., HTTP, gRPC source and destination connectors), return right away
case dbPipeline.Mode == datamodel.PipelineMode(pipelinePB.Pipeline_MODE_SYNC):
wg.Wait()
go func() {
wg.Wait()
close(errors)
}()
for err := range errors {
if err != nil {
switch {
case strings.Contains(err.Error(), "code = DeadlineExceeded"):
return nil, status.Errorf(codes.DeadlineExceeded, "trigger model instance got timeout error")
default:
return nil, status.Errorf(codes.Internal, fmt.Sprintf("trigger model instance got error %v", err.Error()))
}
}
}
return &pipelinePB.TriggerPipelineResponse{
DataMappingIndices: dataMappingIndices,
ModelInstanceOutputs: modelInstOutputs,
}, nil
// If this is a async trigger, write to the destination connector
case dbPipeline.Mode == datamodel.PipelineMode(pipelinePB.Pipeline_MODE_ASYNC):
go func() {
wg.Wait()
go func() {
wg.Wait()
close(errors)
}()
for err := range errors {
if err != nil {
logger.Error(fmt.Sprintf("[model-backend] Error trigger model instance got error %v", err.Error()))
return
}
}

for idx, modelInstRecName := range dbPipeline.Recipe.ModelInstances {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down

0 comments on commit a692d2b

Please sign in to comment.