Skip to content

Commit

Permalink
feat: support semantic segmentation (#203)
Browse files Browse the repository at this point in the history
Because

- support semantic segmentation task for VDP

This commit

- add the semantic segmentation task
  • Loading branch information
Phelan164 authored Dec 16, 2022
1 parent f68eb62 commit f22262c
Show file tree
Hide file tree
Showing 11 changed files with 109 additions and 3 deletions.
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ type MaxBatchSizeConfig struct {
Keypoint int `koanf:"keypoint"`
Ocr int `koanf:"ocr"`
InstanceSegmentation int `koanf:"instancesegmentation"`
SemanticSegmentation int `koanf:"semanticsegmentation"`
}

// AppConfig defines
Expand Down
1 change: 1 addition & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@ maxbatchsizelimitation:
keypoint: 8
ocr: 2
instancesegmentation: 8
semanticsegmentation: 8
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.11.3
github.com/iancoleman/strcase v0.2.0
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20221126215020-740adcc891b9
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20221213025837-a8a36c1a4b0e
github.com/instill-ai/usage-client v0.2.1-alpha
github.com/instill-ai/x v0.2.0-alpha
github.com/knadh/koanf v1.4.3
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,8 @@ github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20221126215020-740adcc891b9 h1:vwYnipjl7M+xiUxvs2sIATPtpY5LH4RLMSl/9wHaBoY=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20221126215020-740adcc891b9/go.mod h1:7/Jj3ATVozPwB0WmKRM612o/k5UJF8K9oRCNKYH8iy0=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20221213025837-a8a36c1a4b0e h1:i/5mDUDsBjxREe2EldTCIUF5/h4bdBLhXaCydezRvg8=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20221213025837-a8a36c1a4b0e/go.mod h1:7/Jj3ATVozPwB0WmKRM612o/k5UJF8K9oRCNKYH8iy0=
github.com/instill-ai/usage-client v0.2.1-alpha h1:XXMCTDT2BWOgGwerOpxghzt6hW9J7/yUR1tkNRuGjjM=
github.com/instill-ai/usage-client v0.2.1-alpha/go.mod h1:ThySPYe08Jy7OpfdtCZDckm19ET39K+KXGJ4lr+rOss=
github.com/instill-ai/x v0.2.0-alpha h1:8yszKP9DE8bvSRAtEpOwqhG2wwqU3olhTqhwoiLrHfc=
Expand Down
3 changes: 2 additions & 1 deletion internal/db/migration/000001_init.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ CREATE TYPE valid_task AS ENUM (
'TASK_DETECTION',
'TASK_KEYPOINT',
'TASK_OCR',
'TASK_INSTANCE_SEGMENTATION'
'TASK_INSTANCE_SEGMENTATION',
'TASK_SEMANTIC_SEGMENTATION'
);

CREATE TYPE valid_release_stage AS ENUM (
Expand Down
5 changes: 5 additions & 0 deletions internal/triton/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,8 @@ type InstanceSegmentationOutput struct {
Scores [][]float32
Labels [][]string
}

type SemanticSegmentationOutput struct {
Rles [][]string
Categories [][]string
}
54 changes: 54 additions & 0 deletions internal/triton/triton.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,51 @@ func postProcessInstanceSegmentation(modelInferResponse *inferenceserver.ModelIn
}, nil
}

func postProcessSemanticSegmentation(modelInferResponse *inferenceserver.ModelInferResponse, outputNameRles string, outputNameCategories string) (interface{}, error) {
outputTensorRles, rawOutputContentRles, err := GetOutputFromInferResponse(outputNameRles, modelInferResponse)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to find inference output for RLEs")
}
if rawOutputContentRles == nil {
return nil, fmt.Errorf("Unable to find output content for RLEs")
}

outputTensorCategories, rawOutputContentCategories, err := GetOutputFromInferResponse(outputNameCategories, modelInferResponse)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to find inference output for labels")
}
if rawOutputContentCategories == nil {
return nil, fmt.Errorf("Unable to find output content for labels")
}

outputDataLabels := DeserializeBytesTensor(rawOutputContentCategories, outputTensorCategories.Shape[0]*outputTensorCategories.Shape[1])
batchedOutputDataCategories, err := Reshape1DArrayStringTo2D(outputDataLabels, outputTensorCategories.Shape)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to reshape inference output for labels")
}

outputDataRles := DeserializeBytesTensor(rawOutputContentRles, outputTensorRles.Shape[0]*outputTensorRles.Shape[1])
batchedOutputDataRles, err := Reshape1DArrayStringTo2D(outputDataRles, outputTensorRles.Shape)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to reshape inference output for RLEs")
}

if len(batchedOutputDataCategories) != len(batchedOutputDataRles) {
log.Printf("Rles output has length %v but categories has length %v",
len(batchedOutputDataCategories), len(batchedOutputDataRles))
return nil, fmt.Errorf("Inconsistent batch size for rles and categories")
}

return SemanticSegmentationOutput{
Rles: batchedOutputDataRles,
Categories: batchedOutputDataCategories,
}, nil
}

func (ts *triton) PostProcess(inferResponse *inferenceserver.ModelInferResponse, modelMetadata *inferenceserver.ModelMetadataResponse, task modelPB.ModelInstance_Task) (interface{}, error) {
var (
outputs interface{}
Expand Down Expand Up @@ -652,6 +697,15 @@ func (ts *triton) PostProcess(inferResponse *inferenceserver.ModelInferResponse,
return nil, fmt.Errorf("Unable to post-process instance segmentation output: %w", err)
}

case modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION:
if len(modelMetadata.Outputs) < 2 {
return nil, fmt.Errorf("Wrong output format of semantic segmentation task")
}
outputs, err = postProcessSemanticSegmentation(inferResponse, modelMetadata.Outputs[0].Name, modelMetadata.Outputs[1].Name)
if err != nil {
return nil, fmt.Errorf("Unable to post-process semantic segmentation output: %w", err)
}

default:
outputs, err = postProcessUnspecifiedTask(inferResponse, modelMetadata.Outputs)
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions internal/util/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ var Tasks = map[string]modelPB.ModelInstance_Task{
"TASK_OCR": modelPB.ModelInstance_TASK_OCR,
"TASK_INSTANCESEGMENTATION": modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION,
"TASK_INSTANCE_SEGMENTATION": modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION,
"TASK_SEMANTIC_SEGMENTATION": modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION,
"TASK_SEMANTICSEGMENTATION": modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION,
}

var Tags = map[string]modelPB.ModelInstance_Task{
Expand All @@ -24,6 +26,8 @@ var Tags = map[string]modelPB.ModelInstance_Task{
"OCR": modelPB.ModelInstance_TASK_OCR,
"INSTANCESEGMENTATION": modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION,
"INSTANCE_SEGMENTATION": modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION,
"SEMANTIC_SEGMENTATION": modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION,
"SEMANTICSEGMENTATION": modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION,
}

var Visibility = map[string]modelPB.Model_Visibility{
Expand Down
10 changes: 10 additions & 0 deletions pkg/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,8 @@ func HandleCreateModelByMultiPartFormData(w http.ResponseWriter, r *http.Request
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.InstanceSegmentation
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.SemanticSegmentation
}

if maxBatchSize > allowedMaxBatchSize {
Expand Down Expand Up @@ -863,6 +865,8 @@ func (h *handler) CreateModelBinaryFileUpload(stream modelPB.ModelService_Create
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.InstanceSegmentation
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.SemanticSegmentation
}

if maxBatchSize > allowedMaxBatchSize {
Expand Down Expand Up @@ -1072,6 +1076,8 @@ func createGitHubModel(h *handler, ctx context.Context, req *modelPB.CreateModel
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.InstanceSegmentation
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.SemanticSegmentation
}
if maxBatchSize > allowedMaxBatchSize {
st, e := sterr.CreateErrorPreconditionFailure(
Expand Down Expand Up @@ -1305,6 +1311,8 @@ func createArtiVCModel(h *handler, ctx context.Context, req *modelPB.CreateModel
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.InstanceSegmentation
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.SemanticSegmentation
}
if maxBatchSize > allowedMaxBatchSize {
st, e := sterr.CreateErrorPreconditionFailure(
Expand Down Expand Up @@ -1545,6 +1553,8 @@ func createHuggingFaceModel(h *handler, ctx context.Context, req *modelPB.Create
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.Ocr
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_INSTANCE_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.InstanceSegmentation
case datamodel.ModelInstanceTask(modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION):
allowedMaxBatchSize = config.Config.MaxBatchSizeLimitation.SemanticSegmentation
}
if maxBatchSize > allowedMaxBatchSize {
st, e := sterr.CreateErrorPreconditionFailure(
Expand Down
28 changes: 28 additions & 0 deletions pkg/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,34 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, imgsBytes [][]byte, tas
}
return instanceSegmentationOutputs, nil

case modelPB.ModelInstance_TASK_SEMANTIC_SEGMENTATION:
semanticSegmentationResponses := postprocessResponse.(triton.SemanticSegmentationOutput)
batchedOutputDataRles := semanticSegmentationResponses.Rles
batchedOutputDataCategories := semanticSegmentationResponses.Categories
var semanticSegmentationOutputs []*modelPB.TaskOutput
for i := range batchedOutputDataCategories { // loop over images
var semanticSegmentationOutput = modelPB.TaskOutput{
Output: &modelPB.TaskOutput_SemanticSegmentation{
SemanticSegmentation: &modelPB.SemanticSegmentationOutput{
Stuffs: []*modelPB.SemanticSegmentationStuff{},
},
},
}
for j := range batchedOutputDataCategories[i] { // single image
rle := batchedOutputDataRles[i][j]
category := batchedOutputDataCategories[i][j]
// Non-meaningful bboxes were added with coords [-1, -1, -1, -1, -1] and text "" for Triton to be able to batch Tensors
if category != "" && rle != "" {
semanticSegmentationOutput.GetSemanticSegmentation().Stuffs = append(semanticSegmentationOutput.GetSemanticSegmentation().Stuffs, &modelPB.SemanticSegmentationStuff{
Rle: rle,
Category: category,
})
}
}
semanticSegmentationOutputs = append(semanticSegmentationOutputs, &semanticSegmentationOutput)
}
return semanticSegmentationOutputs, nil

default:
outputs := postprocessResponse.([]triton.BatchUnspecifiedTaskOutputs)
var rawOutputs []*modelPB.TaskOutput
Expand Down
2 changes: 1 addition & 1 deletion pkg/usage/usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (u *usage) RetrieveUsageData() interface{} {
InstanceOfflineStateNum: instanceOfflineStateNum,
ModelDefinitionIds: modelDefinitionIds,
Tasks: tasks,
TestImageNum: testImageNum,
TestNum: testImageNum,
})
}

Expand Down

0 comments on commit f22262c

Please sign in to comment.