diff --git a/pkg/service/service.go b/pkg/service/service.go index 28647f5bd..ed678a38e 100644 --- a/pkg/service/service.go +++ b/pkg/service/service.go @@ -378,17 +378,22 @@ func (s *service) TriggerPipeline(req *pipelinePB.TriggerPipelineRequest, dbPipe wg.Add(1) var modelInstOutputs []*pipelinePB.ModelInstanceOutput + errors := make(chan error) go func() { + defer wg.Done() + for idx, modelInstance := range dbPipeline.Recipe.ModelInstances { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() resp, err := s.modelServiceClient.TriggerModelInstance(ctx, &modelPB.TriggerModelInstanceRequest{ Name: modelInstance, Inputs: inputs, }) if err != nil { + errors <- err logger.Error(fmt.Sprintf("[model-backend] Error %s at %dth model instance %s: %v", "TriggerModel", idx, modelInstance, err.Error())) + return } taskOutputs := cvtModelTaskOutputToPipelineTaskOutput(resp.TaskOutputs) @@ -405,7 +410,9 @@ func (s *service) TriggerPipeline(req *pipelinePB.TriggerPipelineRequest, dbPipe // Increment trigger image numbers uid, err := resource.GetPermalinkUID(dbPipeline.Owner) if err != nil { + errors <- err logger.Error(err.Error()) + return } if strings.HasPrefix(dbPipeline.Owner, "users/") { s.redisClient.IncrBy(context.Background(), fmt.Sprintf("user:%s:trigger.image.num", uid), int64(len(inputs))) @@ -413,13 +420,26 @@ func (s *service) TriggerPipeline(req *pipelinePB.TriggerPipelineRequest, dbPipe s.redisClient.IncrBy(context.Background(), fmt.Sprintf("org:%s:trigger.image.num", uid), int64(len(inputs))) } } - wg.Done() + errors <- nil }() switch { // If this is a SYNC trigger (i.e., HTTP, gRPC source and destination connectors), return right away case dbPipeline.Mode == datamodel.PipelineMode(pipelinePB.Pipeline_MODE_SYNC): - wg.Wait() + go func() { + wg.Wait() + close(errors) + }() + for err := range errors { + if err != nil { + switch { + case strings.Contains(err.Error(), "code = DeadlineExceeded"): + return nil, status.Errorf(codes.DeadlineExceeded, "trigger model instance got timeout error") + default: + return nil, status.Errorf(codes.Internal, fmt.Sprintf("trigger model instance got error %v", err.Error())) + } + } + } return &pipelinePB.TriggerPipelineResponse{ DataMappingIndices: dataMappingIndices, ModelInstanceOutputs: modelInstOutputs, @@ -427,7 +447,17 @@ func (s *service) TriggerPipeline(req *pipelinePB.TriggerPipelineRequest, dbPipe // If this is a async trigger, write to the destination connector case dbPipeline.Mode == datamodel.PipelineMode(pipelinePB.Pipeline_MODE_ASYNC): go func() { - wg.Wait() + go func() { + wg.Wait() + close(errors) + }() + for err := range errors { + if err != nil { + logger.Error(fmt.Sprintf("[model-backend] Error trigger model instance got error %v", err.Error())) + return + } + } + for idx, modelInstRecName := range dbPipeline.Recipe.ModelInstances { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel()