Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix pipeline trigger model hanging #80

Merged
merged 7 commits into from
Oct 23, 2022
38 changes: 34 additions & 4 deletions pkg/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,17 +378,22 @@ func (s *service) TriggerPipeline(req *pipelinePB.TriggerPipelineRequest, dbPipe
wg.Add(1)

var modelInstOutputs []*pipelinePB.ModelInstanceOutput
errors := make(chan error)
go func() {
defer wg.Done()

for idx, modelInstance := range dbPipeline.Recipe.ModelInstances {

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
resp, err := s.modelServiceClient.TriggerModelInstance(ctx, &modelPB.TriggerModelInstanceRequest{
Name: modelInstance,
Inputs: inputs,
})
if err != nil {
errors <- err
logger.Error(fmt.Sprintf("[model-backend] Error %s at %dth model instance %s: %v", "TriggerModel", idx, modelInstance, err.Error()))
return
}

taskOutputs := cvtModelTaskOutputToPipelineTaskOutput(resp.TaskOutputs)
Expand All @@ -405,29 +410,54 @@ func (s *service) TriggerPipeline(req *pipelinePB.TriggerPipelineRequest, dbPipe
// Increment trigger image numbers
uid, err := resource.GetPermalinkUID(dbPipeline.Owner)
if err != nil {
errors <- err
logger.Error(err.Error())
return
}
if strings.HasPrefix(dbPipeline.Owner, "users/") {
s.redisClient.IncrBy(context.Background(), fmt.Sprintf("user:%s:trigger.image.num", uid), int64(len(inputs)))
} else if strings.HasPrefix(dbPipeline.Owner, "orgs/") {
s.redisClient.IncrBy(context.Background(), fmt.Sprintf("org:%s:trigger.image.num", uid), int64(len(inputs)))
}
}
wg.Done()
errors <- nil
}()

switch {
// If this is a SYNC trigger (i.e., HTTP, gRPC source and destination connectors), return right away
case dbPipeline.Mode == datamodel.PipelineMode(pipelinePB.Pipeline_MODE_SYNC):
wg.Wait()
go func() {
wg.Wait()
close(errors)
}()
for err := range errors {
if err != nil {
switch {
case strings.Contains(err.Error(), "code = DeadlineExceeded"):
return nil, status.Errorf(codes.DeadlineExceeded, "trigger model instance got timeout error")
default:
return nil, status.Errorf(codes.Internal, fmt.Sprintf("trigger model instance got error %v", err.Error()))
}
}
}
return &pipelinePB.TriggerPipelineResponse{
DataMappingIndices: dataMappingIndices,
ModelInstanceOutputs: modelInstOutputs,
}, nil
// If this is a async trigger, write to the destination connector
case dbPipeline.Mode == datamodel.PipelineMode(pipelinePB.Pipeline_MODE_ASYNC):
go func() {
wg.Wait()
go func() {
wg.Wait()
close(errors)
}()
for err := range errors {
if err != nil {
logger.Error(fmt.Sprintf("[model-backend] Error trigger model instance got error %v", err.Error()))
return
}
}

for idx, modelInstRecName := range dbPipeline.Recipe.ModelInstances {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
Expand Down