diff --git a/assets/huggingface-vit-template/huggingface-infer/config.pbtxt b/assets/huggingface-vit-template/huggingface-infer/config.pbtxt index b79ed1d0..b2db35e1 100644 --- a/assets/huggingface-vit-template/huggingface-infer/config.pbtxt +++ b/assets/huggingface-vit-template/huggingface-infer/config.pbtxt @@ -1,6 +1,6 @@ name: "huggingface-infer" platform: "onnxruntime_onnx" -max_batch_size: 32 +max_batch_size: 16 input [ { name: "pixel_values" diff --git a/assets/huggingface-vit-template/huggingface/config.pbtxt b/assets/huggingface-vit-template/huggingface/config.pbtxt index db835782..6b5ac853 100644 --- a/assets/huggingface-vit-template/huggingface/config.pbtxt +++ b/assets/huggingface-vit-template/huggingface/config.pbtxt @@ -1,6 +1,6 @@ name: "huggingface" platform: "ensemble" -max_batch_size: 32 +max_batch_size: 16 input [ { name: "input" diff --git a/assets/huggingface-vit-template/pre/config.pbtxt b/assets/huggingface-vit-template/pre/config.pbtxt index 02531b56..04b113ef 100644 --- a/assets/huggingface-vit-template/pre/config.pbtxt +++ b/assets/huggingface-vit-template/pre/config.pbtxt @@ -1,6 +1,6 @@ name: "pre" backend: "python" -max_batch_size: 32 +max_batch_size: 16 input [ { name: "input" diff --git a/cmd/main/misc.go b/cmd/main/misc.go index 779508a2..1d566d07 100644 --- a/cmd/main/misc.go +++ b/cmd/main/misc.go @@ -120,6 +120,8 @@ func errorHandler(ctx context.Context, mux *runtime.ServeMux, marshaler runtime. switch v.Violations[0].Type { case "UPDATE", "DELETE", "STATE", "RENAME": httpStatus = http.StatusUnprocessableEntity + case "MAX BATCH SIZE LIMITATION": + httpStatus = http.StatusBadRequest } default: httpStatus = runtime.HTTPStatusFromCode(s.Code()) diff --git a/config/config.go b/config/config.go index ec64e048..8464fe54 100644 --- a/config/config.go +++ b/config/config.go @@ -80,15 +80,24 @@ type PipelineBackendConfig struct { } } +type MaxBatchSizeConfig struct { + Unspecified int `koanf:"unspecified"` + Classification int `koanf:"classification"` + Detection int `koanf:"detection"` + Keypoint int `koanf:"keypoint"` + Ocr int `koanf:"ocr"` +} + // AppConfig defines type AppConfig struct { - Server ServerConfig `koanf:"server"` - Database DatabaseConfig `koanf:"database"` - TritonServer TritonServerConfig `koanf:"tritonserver"` - MgmtBackend MgmtBackendConfig `koanf:"mgmtbackend"` - Cache CacheConfig `koanf:"cache"` - UsageBackend UsageBackendConfig `koanf:"usagebackend"` - PipelineBackend PipelineBackendConfig `koanf:"pipelinebackend"` + Server ServerConfig `koanf:"server"` + Database DatabaseConfig `koanf:"database"` + TritonServer TritonServerConfig `koanf:"tritonserver"` + MgmtBackend MgmtBackendConfig `koanf:"mgmtbackend"` + Cache CacheConfig `koanf:"cache"` + UsageBackend UsageBackendConfig `koanf:"usagebackend"` + PipelineBackend PipelineBackendConfig `koanf:"pipelinebackend"` + MaxBatchSizeLimitation MaxBatchSizeConfig `koanf:"maxbatchsizelimitation"` } // Config - Global variable to export diff --git a/config/config.yaml b/config/config.yaml index 7ca2e102..2a86491b 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -45,3 +45,9 @@ pipelinebackend: https: cert: # /ssl/tls.crt key: # /ssl/tls.key +maxbatchsizelimitation: + unspecified: 2 + classification: 16 + detection: 8 + keypoint: 8 + ocr: 2 diff --git a/integration-test/data/dummy-cls-model-bz17.zip b/integration-test/data/dummy-cls-model-bz17.zip new file mode 100644 index 00000000..1760da14 Binary files /dev/null and b/integration-test/data/dummy-cls-model-bz17.zip differ diff --git a/integration-test/data/dummy-cls-model.zip b/integration-test/data/dummy-cls-model.zip index 2ab3c4d3..bd43485d 100644 Binary files a/integration-test/data/dummy-cls-model.zip and b/integration-test/data/dummy-cls-model.zip differ diff --git a/integration-test/data/dummy-cls-no-readme.zip b/integration-test/data/dummy-cls-no-readme.zip index e76f738a..a248d1f2 100644 Binary files a/integration-test/data/dummy-cls-no-readme.zip and b/integration-test/data/dummy-cls-no-readme.zip differ diff --git a/integration-test/data/dummy-det-model-bz9.zip b/integration-test/data/dummy-det-model-bz9.zip new file mode 100644 index 00000000..ffd4d405 Binary files /dev/null and b/integration-test/data/dummy-det-model-bz9.zip differ diff --git a/integration-test/data/dummy-det-model.zip b/integration-test/data/dummy-det-model.zip index fef12b12..eac55f9c 100644 Binary files a/integration-test/data/dummy-det-model.zip and b/integration-test/data/dummy-det-model.zip differ diff --git a/integration-test/data/dummy-keypoint-model-bz9.zip b/integration-test/data/dummy-keypoint-model-bz9.zip new file mode 100644 index 00000000..3187aa41 Binary files /dev/null and b/integration-test/data/dummy-keypoint-model-bz9.zip differ diff --git a/integration-test/data/dummy-unspecified-model-bz3.zip b/integration-test/data/dummy-unspecified-model-bz3.zip new file mode 100644 index 00000000..9e840456 Binary files /dev/null and b/integration-test/data/dummy-unspecified-model-bz3.zip differ diff --git a/integration-test/data/dummy-unspecified-model.zip b/integration-test/data/dummy-unspecified-model.zip index e3078647..12db4535 100644 Binary files a/integration-test/data/dummy-unspecified-model.zip and b/integration-test/data/dummy-unspecified-model.zip differ diff --git a/integration-test/data/empty-response-model.zip b/integration-test/data/empty-response-model.zip index 4730caf8..8c5702a5 100644 Binary files a/integration-test/data/empty-response-model.zip and b/integration-test/data/empty-response-model.zip differ diff --git a/integration-test/rest.js b/integration-test/rest.js index a9e7ca91..dbd45296 100644 --- a/integration-test/rest.js +++ b/integration-test/rest.js @@ -44,7 +44,7 @@ export default function (data) { }); } - // Create Model API + // // Create Model API createModel.CreateModelFromLocal() createModel.CreateModelFromGitHub() diff --git a/integration-test/rest_create_model.js b/integration-test/rest_create_model.js index 11281cc7..07819997 100644 --- a/integration-test/rest_create_model.js +++ b/integration-test/rest_create_model.js @@ -13,6 +13,10 @@ const cls_model = open(`${__ENV.TEST_FOLDER_ABS_PATH}/integration-test/data/dumm const det_model = open(`${__ENV.TEST_FOLDER_ABS_PATH}/integration-test/data/dummy-det-model.zip`, "b"); const keypoint_model = open(`${__ENV.TEST_FOLDER_ABS_PATH}/integration-test/data/dummy-keypoint-model.zip`, "b"); const unspecified_model = open(`${__ENV.TEST_FOLDER_ABS_PATH}/integration-test/data/dummy-unspecified-model.zip`, "b"); +const cls_model_bz17 = open(`${__ENV.TEST_FOLDER_ABS_PATH}/integration-test/data/dummy-cls-model-bz17.zip`, "b"); +const det_model_bz9 = open(`${__ENV.TEST_FOLDER_ABS_PATH}/integration-test/data/dummy-det-model-bz9.zip`, "b"); +const keypoint_model_bz9 = open(`${__ENV.TEST_FOLDER_ABS_PATH}/integration-test/data/dummy-keypoint-model-bz9.zip`, "b"); +const unspecified_model_bz3 = open(`${__ENV.TEST_FOLDER_ABS_PATH}/integration-test/data/dummy-unspecified-model-bz3.zip`, "b"); export function CreateModelFromLocal() { @@ -188,6 +192,90 @@ export function CreateModelFromLocal() { r.status === 204 }); }); + + group("Model Backend API: Upload a model which exceed max batch size limitation", function () { + let fd_cls = new FormData(); + let model_id_cls = randomString(10) + let model_description = randomString(20) + fd_cls.append("id", model_id_cls); + fd_cls.append("description", model_description); + fd_cls.append("model_definition", "model-definitions/local"); + fd_cls.append("content", http.file(cls_model_bz17, "dummy-cls-model-bz17.zip")); + check(http.request("POST", `${apiHost}/v1alpha/models:multipart`, fd_cls.body(), { + headers: genHeader(`multipart/form-data; boundary=${fd_cls.boundary}`), + }), { + "POST /v1alpha/models:multipart task cls response status": (r) => + r.status === 400, + }); + + let fd_det = new FormData(); + let model_id_det = randomString(10) + model_description = randomString(20) + fd_det.append("id", model_id_det); + fd_det.append("description", model_description); + fd_det.append("model_definition", "model-definitions/local"); + fd_det.append("content", http.file(det_model_bz9, "dummy-det-model-bz9.zip")); + check(http.request("POST", `${apiHost}/v1alpha/models:multipart`, fd_det.body(), { + headers: genHeader(`multipart/form-data; boundary=${fd_det.boundary}`), + }), { + "POST /v1alpha/models:multipart task det response status": (r) => + r.status === 400, + }); + + let fd_keypoint = new FormData(); + let model_id_keypoint = randomString(10) + model_description = randomString(20) + fd_keypoint.append("id", model_id_keypoint); + fd_keypoint.append("description", model_description); + fd_keypoint.append("model_definition", "model-definitions/local"); + fd_keypoint.append("content", http.file(keypoint_model_bz9, "dummy-keypoint-model-bz9.zip")); + check(http.request("POST", `${apiHost}/v1alpha/models:multipart`, fd_keypoint.body(), { + headers: genHeader(`multipart/form-data; boundary=${fd_keypoint.boundary}`), + }), { + "POST /v1alpha/models:multipart task keypoint response status": (r) => + r.status === 400, + }); + + let fd_unspecified = new FormData(); + let model_id_unspecified = randomString(10) + model_description = randomString(20) + fd_unspecified.append("id", model_id_unspecified); + fd_unspecified.append("description", model_description); + fd_unspecified.append("model_definition", "model-definitions/local"); + fd_unspecified.append("content", http.file(unspecified_model_bz3, "dummy-unspecified-model-bz3.zip")); + check(http.request("POST", `${apiHost}/v1alpha/models:multipart`, fd_unspecified.body(), { + headers: genHeader(`multipart/form-data; boundary=${fd_unspecified.boundary}`), + }), { + "POST /v1alpha/models:multipart task unspecified response status": (r) => + r.status === 400, + }); + + // clean up + check(http.request("DELETE", `${apiHost}/v1alpha/models/${model_id_cls}`, null, { + headers: genHeader(`application/json`), + }), { + "DELETE clean up response status": (r) => + r.status === 404 + }); + check(http.request("DELETE", `${apiHost}/v1alpha/models/${model_id_det}`, null, { + headers: genHeader(`application/json`), + }), { + "DELETE clean up response status": (r) => + r.status === 404 + }); + check(http.request("DELETE", `${apiHost}/v1alpha/models/${model_id_keypoint}`, null, { + headers: genHeader(`application/json`), + }), { + "DELETE clean up response status": (r) => + r.status === 404 + }); + check(http.request("DELETE", `${apiHost}/v1alpha/models/${model_id_unspecified}`, null, { + headers: genHeader(`application/json`), + }), { + "DELETE clean up response status": (r) => + r.status === 404 + }); + }); } } diff --git a/internal/triton/triton.go b/internal/triton/triton.go index 92f7b65e..642a5e05 100644 --- a/internal/triton/triton.go +++ b/internal/triton/triton.go @@ -497,8 +497,8 @@ func (ts *triton) PostProcess(inferResponse *inferenceserver.ModelInferResponse, } func (ts *triton) LoadModelRequest(modelName string) (*inferenceserver.RepositoryModelLoadResponse, error) { - // Create context for our request with 60 second timeout - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + // Create context for our request with 600 second timeout. The time for warmup model inference + ctx, cancel := context.WithTimeout(context.Background(), 600*time.Second) defer cancel() // Create status request for a given model diff --git a/internal/util/util.go b/internal/util/util.go index 6d083905..a05ba922 100644 --- a/internal/util/util.go +++ b/internal/util/util.go @@ -313,6 +313,35 @@ func ConvertAllJSONEnumValueToProtoStyle(enumRegistry map[string]map[string]int3 } } +func GetMaxBatchSize(configFilePath string) (int, error) { + if _, err := os.Stat(configFilePath); errors.Is(err, os.ErrNotExist) { + return -1, err + } + file, err := os.Open(configFilePath) + if err != nil { + return -1, err + } + defer file.Close() + + scanner := bufio.NewScanner(file) + r, err := regexp.Compile(`max_batch_size:`) + if err != nil { + return -1, err + } + + for scanner.Scan() { + if r.MatchString(scanner.Text()) { + maxBatchSize := scanner.Text() + maxBatchSize = strings.Trim(maxBatchSize, "max_batch_size:") + maxBatchSize = strings.Trim(maxBatchSize, " ") + intMaxBatchSize, err := strconv.Atoi(maxBatchSize) + return intMaxBatchSize, err + } + } + + return -1, fmt.Errorf("not found") +} + func DoSupportBatch(configFilePath string) (bool, error) { if _, err := os.Stat(configFilePath); errors.Is(err, os.ErrNotExist) { return false, err diff --git a/pkg/handler/handler.go b/pkg/handler/handler.go index 9d1e7cee..de030403 100644 --- a/pkg/handler/handler.go +++ b/pkg/handler/handler.go @@ -21,6 +21,7 @@ import ( "github.com/gofrs/uuid" "github.com/pkg/errors" "github.com/santhosh-tekuri/jsonschema/v5" + "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -94,11 +95,11 @@ func isEnsembleConfig(configPath string) bool { return strings.Contains(fileString, "platform: \"ensemble\"") } -func unzip(filePath string, dstDir string, owner string, uploadedModel *datamodel.Model) (string, error) { +func unzip(filePath string, dstDir string, owner string, uploadedModel *datamodel.Model) (string, string, error) { archive, err := zip.OpenReader(filePath) if err != nil { fmt.Println("Error when open zip file ", err) - return "", err + return "", "", err } defer archive.Close() var readmeFilePath string @@ -116,7 +117,7 @@ func unzip(filePath string, dstDir string, owner string, uploadedModel *datamode if !strings.HasPrefix(filePath, filepath.Clean(dstDir)+string(os.PathSeparator)) { fmt.Println("invalid file path") - return "", fmt.Errorf("invalid file path") + return "", "", fmt.Errorf("invalid file path") } if f.FileInfo().IsDir() { dirName := f.Name @@ -147,11 +148,11 @@ func unzip(filePath string, dstDir string, owner string, uploadedModel *datamode } filePath := filepath.Join(dstDir, dirName) if err := util.ValidateFilePath(filePath); err != nil { - return "", err + return "", "", err } err = os.MkdirAll(filePath, os.ModePerm) if err != nil { - return "", err + return "", "", err } continue } @@ -170,18 +171,18 @@ func unzip(filePath string, dstDir string, owner string, uploadedModel *datamode readmeFilePath = filePath } if err := util.ValidateFilePath(filePath); err != nil { - return "", err + return "", "", err } dstFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()) if err != nil { - return "", err + return "", "", err } fileInArchive, err := f.Open() if err != nil { - return "", err + return "", "", err } if _, err := io.Copy(dstFile, fileInArchive); err != nil { - return "", err + return "", "", err } dstFile.Close() @@ -194,7 +195,7 @@ func unzip(filePath string, dstDir string, owner string, uploadedModel *datamode } err = util.UpdateConfigModelName(filePath, oldModelName, newModelName) if err != nil { - return "", err + return "", "", err } } } @@ -203,7 +204,7 @@ func unzip(filePath string, dstDir string, owner string, uploadedModel *datamode for oldModelName, newModelName := range newModelNameMap { err = util.UpdateConfigModelName(ensembleFilePath, oldModelName, newModelName) if err != nil { - return "", err + return "", "", err } } for i := 0; i < len(createdTModels); i++ { @@ -214,11 +215,11 @@ func unzip(filePath string, dstDir string, owner string, uploadedModel *datamode } } uploadedModel.Instances[0].TritonModels = createdTModels - return readmeFilePath, nil + return readmeFilePath, ensembleFilePath, nil } // modelDir and dstDir are absolute path -func updateModelPath(modelDir string, dstDir string, owner string, modelID string, modelInstance *datamodel.ModelInstance) (string, error) { +func updateModelPath(modelDir string, dstDir string, owner string, modelID string, modelInstance *datamodel.ModelInstance) (string, string, error) { var createdTModels []datamodel.TritonModel var ensembleFilePath string var newModelNameMap = make(map[string]string) @@ -234,12 +235,12 @@ func updateModelPath(modelDir string, dstDir string, owner string, modelID strin return nil }) if err != nil { - return "", err + return "", "", err } modelRootDir := strings.Join([]string{dstDir, owner}, "/") err = os.MkdirAll(modelRootDir, os.ModePerm) if err != nil { - return "", err + return "", "", err } for _, f := range files { if f.path == modelDir { @@ -259,7 +260,7 @@ func updateModelPath(modelDir string, dstDir string, owner string, modelID strin err = os.MkdirAll(filePath, os.ModePerm) if err != nil { - return "", err + return "", "", err } newModelNameMap[oldModelName] = subStrs[0] if v, err := strconv.Atoi(subStrs[len(subStrs)-1]); err == nil { @@ -277,14 +278,14 @@ func updateModelPath(modelDir string, dstDir string, owner string, modelID strin dstFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.fInfo.Mode()) if err != nil { - return "", err + return "", "", err } srcFile, err := os.Open(f.path) if err != nil { - return "", err + return "", "", err } if _, err := io.Copy(dstFile, srcFile); err != nil { - return "", err + return "", "", err } dstFile.Close() srcFile.Close() @@ -296,7 +297,7 @@ func updateModelPath(modelDir string, dstDir string, owner string, modelID strin } err = util.UpdateConfigModelName(filePath, oldModelName, subStrs[0]) if err != nil { - return "", err + return "", "", err } } } @@ -305,7 +306,7 @@ func updateModelPath(modelDir string, dstDir string, owner string, modelID strin for oldModelName, newModelName := range newModelNameMap { err = util.UpdateConfigModelName(ensembleFilePath, oldModelName, newModelName) if err != nil { - return "", err + return "", "", err } } for i := 0; i < len(createdTModels); i++ { @@ -316,7 +317,7 @@ func updateModelPath(modelDir string, dstDir string, owner string, modelID strin } } modelInstance.TritonModels = createdTModels - return readmeFilePath, nil + return readmeFilePath, ensembleFilePath, nil } func saveFile(stream modelPB.ModelService_CreateModelBinaryFileUploadServer) (outFile string, modelInfo *datamodel.Model, modelDefinitionID string, err error) { @@ -498,6 +499,8 @@ func (h *handler) Readiness(ctx context.Context, pb *modelPB.ReadinessRequest) ( // HandleCreateModelByMultiPartFormData is a custom handler func HandleCreateModelByMultiPartFormData(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { + logger, _ := logger.GetZapLogger() + contentType := r.Header.Get("Content-Type") if strings.Contains(contentType, "multipart/form-data") { owner, err := resource.GetOwnerFromHeader(r) @@ -636,7 +639,7 @@ func HandleCreateModelByMultiPartFormData(w http.ResponseWriter, r *http.Request return } - readmeFilePath, err := unzip(tmpFile, config.Config.TritonServer.ModelStore, owner, &uploadedModel) + readmeFilePath, ensembleFilePath, err := unzip(tmpFile, config.Config.TritonServer.ModelStore, owner, &uploadedModel) _ = os.Remove(tmpFile) // remove uploaded temporary zip file if err != nil { util.RemoveModelRepository(config.Config.TritonServer.ModelStore, owner, uploadedModel.ID, uploadedModel.Instances[0].ID) @@ -665,6 +668,58 @@ func HandleCreateModelByMultiPartFormData(w http.ResponseWriter, r *http.Request uploadedModel.Instances[0].Task = datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_UNSPECIFIED) } + maxBatchSize, err := util.GetMaxBatchSize(ensembleFilePath) + if err != nil { + st, e := sterr.CreateErrorResourceInfo( + codes.FailedPrecondition, + "[handler] create a model error", + "Local model", + "Missing ensemble model", + "", + "err.Error()", + ) + if e != nil { + logger.Error(e.Error()) + } + obj, _ := json.Marshal(st.Details()) + makeJSONResponse(w, 400, st.Message(), string(obj)) + util.RemoveModelRepository(config.Config.TritonServer.ModelStore, owner, uploadedModel.ID, uploadedModel.Instances[0].ID) + return + } + + allowedMaxBatchSize := 0 + switch uploadedModel.Instances[0].Task { + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_UNSPECIFIED): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Unspecified + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_CLASSIFICATION): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Classification + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_DETECTION): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Detection + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_KEYPOINT): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Keypoint + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_OCR): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr + } + + if maxBatchSize > allowedMaxBatchSize { + st, e := sterr.CreateErrorPreconditionFailure( + "[handler] create a model", + []*errdetails.PreconditionFailure_Violation{ + { + Type: "MAX BATCH SIZE LIMITATION", + Subject: "Create a model error", + Description: fmt.Sprintf("The max_batch_size in config.pbtxt exceeded the limitation %v, please try with a smaller max_batch_size", allowedMaxBatchSize), + }, + }) + if e != nil { + logger.Error(e.Error()) + } + util.RemoveModelRepository(config.Config.TritonServer.ModelStore, owner, uploadedModel.ID, uploadedModel.Instances[0].ID) + obj, _ := json.Marshal(st.Details()) + makeJSONResponse(w, 400, st.Message(), string(obj)) + return + } + dbModel, err := modelService.CreateModel(owner, &uploadedModel) if err != nil { util.RemoveModelRepository(config.Config.TritonServer.ModelStore, owner, uploadedModel.ID, uploadedModel.Instances[0].ID) @@ -693,6 +748,8 @@ func HandleCreateModelByMultiPartFormData(w http.ResponseWriter, r *http.Request // AddModel - upload a model to the model server func (h *handler) CreateModelBinaryFileUpload(stream modelPB.ModelService_CreateModelBinaryFileUploadServer) (err error) { + logger, _ := logger.GetZapLogger() + owner, err := resource.GetOwner(stream.Context()) if err != nil { return err @@ -720,7 +777,7 @@ func (h *handler) CreateModelBinaryFileUpload(stream modelPB.ModelService_Create uploadedModel.Owner = owner // extract zip file from tmp to models directory - readmeFilePath, err := unzip(tmpFile, config.Config.TritonServer.ModelStore, owner, uploadedModel) + readmeFilePath, ensembleFilePath, err := unzip(tmpFile, config.Config.TritonServer.ModelStore, owner, uploadedModel) _ = os.Remove(tmpFile) // remove uploaded temporary zip file if err != nil { util.RemoveModelRepository(config.Config.TritonServer.ModelStore, owner, uploadedModel.ID, uploadedModel.Instances[0].ID) @@ -746,6 +803,54 @@ func (h *handler) CreateModelBinaryFileUpload(stream modelPB.ModelService_Create uploadedModel.Instances[0].Task = datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_UNSPECIFIED) } + maxBatchSize, err := util.GetMaxBatchSize(ensembleFilePath) + if err != nil { + st, e := sterr.CreateErrorResourceInfo( + codes.FailedPrecondition, + "[handler] create a model error", + "Local model", + "Missing ensemble model", + "", + err.Error(), + ) + if e != nil { + logger.Error(e.Error()) + } + util.RemoveModelRepository(config.Config.TritonServer.ModelStore, owner, uploadedModel.ID, uploadedModel.Instances[0].ID) + return st.Err() + } + + allowedMaxBatchSize := 0 + switch uploadedModel.Instances[0].Task { + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_UNSPECIFIED): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Unspecified + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_CLASSIFICATION): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Classification + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_DETECTION): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Detection + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_KEYPOINT): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Keypoint + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_OCR): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr + } + + if maxBatchSize > allowedMaxBatchSize { + st, e := sterr.CreateErrorPreconditionFailure( + "[handler] create a model", + []*errdetails.PreconditionFailure_Violation{ + { + Type: "MAX BATCH SIZE LIMITATION", + Subject: "Create a model error", + Description: fmt.Sprintf("The max_batch_size in config.pbtxt exceeded the limitation %v, please try with a smaller max_batch_size", allowedMaxBatchSize), + }, + }) + if e != nil { + logger.Error(e.Error()) + } + util.RemoveModelRepository(config.Config.TritonServer.ModelStore, owner, uploadedModel.ID, uploadedModel.Instances[0].ID) + return st.Err() + } + dbModel, err := h.service.CreateModel(owner, uploadedModel) if err != nil { util.RemoveModelRepository(config.Config.TritonServer.ModelStore, owner, uploadedModel.ID, uploadedModel.Instances[0].ID) @@ -831,7 +936,7 @@ func createGitHubModel(h *handler, ctx context.Context, req *modelPB.CreateModel Configuration: bInstanceConfig, } - readmeFilePath, err := updateModelPath(modelSrcDir, config.Config.TritonServer.ModelStore, owner, githubModel.ID, &instance) + readmeFilePath, ensembleFilePath, err := updateModelPath(modelSrcDir, config.Config.TritonServer.ModelStore, owner, githubModel.ID, &instance) _ = os.RemoveAll(modelSrcDir) // remove uploaded temporary files if err != nil { st, err := sterr.CreateErrorResourceInfo( @@ -889,6 +994,52 @@ func createGitHubModel(h *handler, ctx context.Context, req *modelPB.CreateModel } else { instance.Task = datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_UNSPECIFIED) } + + maxBatchSize, err := util.GetMaxBatchSize(ensembleFilePath) + if err != nil { + st, e := sterr.CreateErrorResourceInfo( + codes.FailedPrecondition, + "[handler] create a model error", + "GitHub model", + "Missing ensemble model", + "", + err.Error(), + ) + if e != nil { + logger.Error(e.Error()) + } + return &modelPB.CreateModelResponse{}, st.Err() + } + + allowedMaxBatchSize := 0 + switch instance.Task { + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_UNSPECIFIED): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Unspecified + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_CLASSIFICATION): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Classification + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_DETECTION): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Detection + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_KEYPOINT): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Keypoint + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_OCR): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr + } + if maxBatchSize > allowedMaxBatchSize { + st, e := sterr.CreateErrorPreconditionFailure( + "[handler] create a model", + []*errdetails.PreconditionFailure_Violation{ + { + Type: "MAX BATCH SIZE LIMITATION", + Subject: "Create a model error", + Description: fmt.Sprintf("The max_batch_size in config.pbtxt exceeded the limitation %v, please try with a smaller max_batch_size", allowedMaxBatchSize), + }, + }) + if e != nil { + logger.Error(e.Error()) + } + return &modelPB.CreateModelResponse{}, st.Err() + } + githubModel.Instances = append(githubModel.Instances, instance) } dbModel, err := h.service.CreateModel(owner, &githubModel) @@ -1000,7 +1151,7 @@ func createArtiVCModel(h *handler, ctx context.Context, req *modelPB.CreateModel Configuration: bInstanceConfig, } - readmeFilePath, err := updateModelPath(modelSrcDir, config.Config.TritonServer.ModelStore, owner, artivcModel.ID, &instance) + readmeFilePath, ensembleFilePath, err := updateModelPath(modelSrcDir, config.Config.TritonServer.ModelStore, owner, artivcModel.ID, &instance) _ = os.RemoveAll(modelSrcDir) // remove uploaded temporary files if err != nil { st, err := sterr.CreateErrorResourceInfo( @@ -1058,6 +1209,53 @@ func createArtiVCModel(h *handler, ctx context.Context, req *modelPB.CreateModel } else { instance.Task = datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_UNSPECIFIED) } + + maxBatchSize, err := util.GetMaxBatchSize(ensembleFilePath) + if err != nil { + st, e := sterr.CreateErrorResourceInfo( + codes.FailedPrecondition, + "[handler] create a model error", + "ArtiVC model", + "Missing ensemble model", + "", + err.Error(), + ) + if e != nil { + logger.Error(e.Error()) + } + return &modelPB.CreateModelResponse{}, st.Err() + } + + allowedMaxBatchSize := 0 + switch instance.Task { + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_UNSPECIFIED): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Unspecified + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_CLASSIFICATION): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Classification + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_DETECTION): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Detection + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_KEYPOINT): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Keypoint + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_OCR): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr + } + if maxBatchSize > allowedMaxBatchSize { + st, e := sterr.CreateErrorPreconditionFailure( + "[handler] create a model", + []*errdetails.PreconditionFailure_Violation{ + { + Type: "MAX BATCH SIZE LIMITATION", + Subject: "Create a model error", + Description: fmt.Sprintf("The max_batch_size in config.pbtxt exceeded the limitation %v, please try with a smaller max_batch_size", allowedMaxBatchSize), + }, + }) + + if e != nil { + logger.Error(e.Error()) + } + return &modelPB.CreateModelResponse{}, st.Err() + } + artivcModel.Instances = append(artivcModel.Instances, instance) } dbModel, err := h.service.CreateModel(owner, &artivcModel) @@ -1170,7 +1368,7 @@ func createHuggingFaceModel(h *handler, ctx context.Context, req *modelPB.Create Configuration: bInstanceConfig, } - readmeFilePath, err := updateModelPath(modelDir, config.Config.TritonServer.ModelStore, owner, huggingfaceModel.ID, &instance) + readmeFilePath, ensembleFilePath, err := updateModelPath(modelDir, config.Config.TritonServer.ModelStore, owner, huggingfaceModel.ID, &instance) _ = os.RemoveAll(modelDir) // remove uploaded temporary files if err != nil { st, err := sterr.CreateErrorResourceInfo( @@ -1240,6 +1438,52 @@ func createHuggingFaceModel(h *handler, ctx context.Context, req *modelPB.Create } else { instance.Task = datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_UNSPECIFIED) } + + maxBatchSize, err := util.GetMaxBatchSize(ensembleFilePath) + if err != nil { + st, e := sterr.CreateErrorResourceInfo( + codes.FailedPrecondition, + "[handler] create a model error", + "HuggingFace model", + "Missing ensemble model", + "", + err.Error(), + ) + if e != nil { + logger.Error(e.Error()) + } + return &modelPB.CreateModelResponse{}, st.Err() + } + + allowedMaxBatchSize := 0 + switch instance.Task { + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_UNSPECIFIED): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Unspecified + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_CLASSIFICATION): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Classification + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_DETECTION): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Detection + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_KEYPOINT): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Keypoint + case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_OCR): + allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr + } + if maxBatchSize > allowedMaxBatchSize { + st, e := sterr.CreateErrorPreconditionFailure( + "[handler] create a model", + []*errdetails.PreconditionFailure_Violation{ + { + Type: "MAX BATCH SIZE LIMITATION", + Subject: "Create a model error", + Description: fmt.Sprintf("The max_batch_size in config.pbtxt exceeded the limitation %v, please try with a smaller max_batch_size", allowedMaxBatchSize), + }, + }) + if e != nil { + logger.Error(e.Error()) + } + return &modelPB.CreateModelResponse{}, st.Err() + } + huggingfaceModel.Instances = append(huggingfaceModel.Instances, instance) dbModel, err := h.service.CreateModel(owner, &huggingfaceModel) @@ -1736,7 +1980,7 @@ func (h *handler) DeployModelInstance(ctx context.Context, req *modelPB.DeployMo } err = h.service.DeployModelInstance(dbModelInstance.UID) if err != nil { - st, err := sterr.CreateErrorResourceInfo( + st, e := sterr.CreateErrorResourceInfo( codes.Internal, "[handler] deploy model error", "triton-inference-server", @@ -1744,8 +1988,19 @@ func (h *handler) DeployModelInstance(ctx context.Context, req *modelPB.DeployMo "", err.Error(), ) - if err != nil { - logger.Error(err.Error()) + if strings.Contains(err.Error(), "Failed to allocate memory") { + st, e = sterr.CreateErrorResourceInfo( + codes.ResourceExhausted, + "[handler] deploy model error", + "triton-inference-server", + "Out of memory for deploying the model to triton server, maybe try with smaller batch size", + "", + err.Error(), + ) + } + + if e != nil { + logger.Error(e.Error()) } return &modelPB.DeployModelInstanceResponse{}, st.Err() @@ -1805,6 +2060,8 @@ func (h *handler) UndeployModelInstance(ctx context.Context, req *modelPB.Undepl } func (h *handler) TestModelInstanceBinaryFileUpload(stream modelPB.ModelService_TestModelInstanceBinaryFileUploadServer) error { + logger, _ := logger.GetZapLogger() + if !h.triton.IsTritonServerReady() { return status.Error(codes.Unavailable, "Triton Server not ready yet") } @@ -1847,7 +2104,29 @@ func (h *handler) TestModelInstanceBinaryFileUpload(stream modelPB.ModelService_ task := modelPB.ModelInstance_Task(modelInstanceInDB.Task) response, err := h.service.ModelInferTestMode(owner, modelInstanceInDB.UID, imageBytes, task) if err != nil { - return err + st, e := sterr.CreateErrorResourceInfo( + codes.FailedPrecondition, + "[handler] inference model error", + "Triton inference server", + "", + "", + err.Error(), + ) + if strings.Contains(err.Error(), "Failed to allocate memory") { + st, e = sterr.CreateErrorResourceInfo( + codes.ResourceExhausted, + "[handler] inference model error", + "Triton inference server OOM", + "Out of memory for running the model, maybe try with smaller batch size", + "", + err.Error(), + ) + } + + if e != nil { + logger.Error(e.Error()) + } + return st.Err() } err = stream.SendAndClose(&modelPB.TestModelInstanceBinaryFileUploadResponse{ @@ -1858,6 +2137,8 @@ func (h *handler) TestModelInstanceBinaryFileUpload(stream modelPB.ModelService_ } func (h *handler) TriggerModelInstanceBinaryFileUpload(stream modelPB.ModelService_TriggerModelInstanceBinaryFileUploadServer) error { + logger, _ := logger.GetZapLogger() + if !h.triton.IsTritonServerReady() { return status.Error(codes.Unavailable, "Triton Server not ready yet") } @@ -1900,7 +2181,29 @@ func (h *handler) TriggerModelInstanceBinaryFileUpload(stream modelPB.ModelServi task := modelPB.ModelInstance_Task(modelInstanceInDB.Task) response, err := h.service.ModelInfer(modelInstanceInDB.UID, imgsBytes, task) if err != nil { - return err + st, e := sterr.CreateErrorResourceInfo( + codes.FailedPrecondition, + "[handler] inference model error", + "Triton inference server", + "", + "", + err.Error(), + ) + if strings.Contains(err.Error(), "Failed to allocate memory") { + st, e = sterr.CreateErrorResourceInfo( + codes.ResourceExhausted, + "[handler] inference model error", + "Triton inference server OOM", + "Out of memory for running the model, maybe try with smaller batch size", + "", + err.Error(), + ) + } + + if e != nil { + logger.Error(e.Error()) + } + return st.Err() } err = stream.SendAndClose(&modelPB.TriggerModelInstanceBinaryFileUploadResponse{ @@ -1911,6 +2214,8 @@ func (h *handler) TriggerModelInstanceBinaryFileUpload(stream modelPB.ModelServi } func (h *handler) TriggerModelInstance(ctx context.Context, req *modelPB.TriggerModelInstanceRequest) (*modelPB.TriggerModelInstanceResponse, error) { + logger, _ := logger.GetZapLogger() + owner, err := resource.GetOwner(ctx) if err != nil { return &modelPB.TriggerModelInstanceResponse{}, err @@ -1955,7 +2260,29 @@ func (h *handler) TriggerModelInstance(ctx context.Context, req *modelPB.Trigger task := modelPB.ModelInstance_Task(modelInstanceInDB.Task) response, err := h.service.ModelInfer(modelInstanceInDB.UID, imgsBytes, task) if err != nil { - return &modelPB.TriggerModelInstanceResponse{}, status.Error(codes.InvalidArgument, err.Error()) + st, e := sterr.CreateErrorResourceInfo( + codes.FailedPrecondition, + "[handler] inference model error", + "Triton inference server", + "", + "", + err.Error(), + ) + if strings.Contains(err.Error(), "Failed to allocate memory") { + st, e = sterr.CreateErrorResourceInfo( + codes.ResourceExhausted, + "[handler] inference model error", + "Triton inference server OOM", + "Out of memory for running the model, maybe try with smaller batch size", + "", + err.Error(), + ) + } + + if e != nil { + logger.Error(e.Error()) + } + return &modelPB.TriggerModelInstanceResponse{}, st.Err() } return &modelPB.TriggerModelInstanceResponse{ @@ -1965,6 +2292,8 @@ func (h *handler) TriggerModelInstance(ctx context.Context, req *modelPB.Trigger } func (h *handler) TestModelInstance(ctx context.Context, req *modelPB.TestModelInstanceRequest) (*modelPB.TestModelInstanceResponse, error) { + logger, _ := logger.GetZapLogger() + owner, err := resource.GetOwner(ctx) if err != nil { return &modelPB.TestModelInstanceResponse{}, err @@ -2012,7 +2341,29 @@ func (h *handler) TestModelInstance(ctx context.Context, req *modelPB.TestModelI task := modelPB.ModelInstance_Task(modelInstanceInDB.Task) response, err := h.service.ModelInferTestMode(owner, modelInstanceInDB.UID, imgsBytes, task) if err != nil { - return &modelPB.TestModelInstanceResponse{}, status.Error(codes.Internal, err.Error()) + st, e := sterr.CreateErrorResourceInfo( + codes.FailedPrecondition, + "[handler] inference model error", + "Triton inference server", + "", + "", + err.Error(), + ) + if strings.Contains(err.Error(), "Failed to allocate memory") { + st, e = sterr.CreateErrorResourceInfo( + codes.ResourceExhausted, + "[handler] inference model error", + "Triton inference server OOM", + "Out of memory for running the model, maybe try with smaller batch size", + "", + err.Error(), + ) + } + + if e != nil { + logger.Error(e.Error()) + } + return &modelPB.TestModelInstanceResponse{}, st.Err() } return &modelPB.TestModelInstanceResponse{ @@ -2022,6 +2373,8 @@ func (h *handler) TestModelInstance(ctx context.Context, req *modelPB.TestModelI } func inferModelInstanceByUpload(w http.ResponseWriter, r *http.Request, pathParams map[string]string, mode string) { + logger, _ := logger.GetZapLogger() + contentType := r.Header.Get("Content-Type") if strings.Contains(contentType, "multipart/form-data") { owner, err := resource.GetOwnerFromHeader(r) @@ -2102,7 +2455,30 @@ func inferModelInstanceByUpload(w http.ResponseWriter, r *http.Request, pathPara response, err = modelService.ModelInfer(modelInstanceInDB.UID, imgsBytes, task) } if err != nil { - makeJSONResponse(w, 500, "Error Predict Model", err.Error()) + st, e := sterr.CreateErrorResourceInfo( + codes.FailedPrecondition, + "[handler] inference model error", + "Triton inference server", + "", + "", + err.Error(), + ) + if strings.Contains(err.Error(), "Failed to allocate memory") { + st, e = sterr.CreateErrorResourceInfo( + codes.ResourceExhausted, + "[handler] inference model error", + "Triton inference server OOM", + "Out of memory for running the model, maybe try with smaller batch size", + "", + err.Error(), + ) + } + + if e != nil { + logger.Error(e.Error()) + } + obj, _ := json.Marshal(st.Details()) + makeJSONResponse(w, 500, st.Message(), string(obj)) return }