Skip to content

Commit

Permalink
feat: add confidence score for ocr output (#167)
Browse files Browse the repository at this point in the history
Because

- OCR output should have a confidence score like other models

This commit

- add confidence score for OCR output
- update output format following the latest protobuf
  • Loading branch information
Phelan164 authored Sep 13, 2022
1 parent c80ef02 commit e915452
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 33 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.10.3
github.com/iancoleman/strcase v0.2.0
github.com/instill-ai/protogen-go v0.3.1-alpha.0.20220912044511-0ab17e86726e
github.com/instill-ai/protogen-go v0.3.1-alpha.0.20220913153004-111a73642332
github.com/instill-ai/usage-client v0.1.2-alpha
github.com/instill-ai/x v0.2.0-alpha
github.com/knadh/koanf v1.4.1
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,8 @@ github.com/instill-ai/protogen-go v0.3.1-alpha.0.20220911092846-074c65eae91d h1:
github.com/instill-ai/protogen-go v0.3.1-alpha.0.20220911092846-074c65eae91d/go.mod h1:7/Jj3ATVozPwB0WmKRM612o/k5UJF8K9oRCNKYH8iy0=
github.com/instill-ai/protogen-go v0.3.1-alpha.0.20220912044511-0ab17e86726e h1:XCMsBKtRbME8AWm8VlXQS2tZb/IeZmVtsOBkTm9sxhQ=
github.com/instill-ai/protogen-go v0.3.1-alpha.0.20220912044511-0ab17e86726e/go.mod h1:7/Jj3ATVozPwB0WmKRM612o/k5UJF8K9oRCNKYH8iy0=
github.com/instill-ai/protogen-go v0.3.1-alpha.0.20220913153004-111a73642332 h1:b8nej+a5kg6uEiLNde2OZ7Sspkue0wFkWIuvenzkdKI=
github.com/instill-ai/protogen-go v0.3.1-alpha.0.20220913153004-111a73642332/go.mod h1:7/Jj3ATVozPwB0WmKRM612o/k5UJF8K9oRCNKYH8iy0=
github.com/instill-ai/usage-client v0.1.2-alpha h1:aGZKqZSZu4FB4ov1lyaJVIuoD6MQ+zDBUFqSYWVauSE=
github.com/instill-ai/usage-client v0.1.2-alpha/go.mod h1:Vi+RgL2YNT+hfztD33JzqFl/Y7/SsV+NpWGIjUgig3s=
github.com/instill-ai/x v0.2.0-alpha h1:8yszKP9DE8bvSRAtEpOwqhG2wwqU3olhTqhwoiLrHfc=
Expand Down
5 changes: 3 additions & 2 deletions internal/triton/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ type DetectionOutput struct {
}

type OcrOutput struct {
Boxes [][][]float32
Texts [][]string
Boxes [][][]float32
Texts [][]string
Scores [][]float32
}

type KeypointOutput struct {
Expand Down
90 changes: 84 additions & 6 deletions internal/triton/triton.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,66 @@ func postProcessDetection(modelInferResponse *inferenceserver.ModelInferResponse
}, nil
}

func postProcessOcr(modelInferResponse *inferenceserver.ModelInferResponse, outputNameBboxes string, outputNameLabels string) (interface{}, error) {
func postProcessOcrWithScore(modelInferResponse *inferenceserver.ModelInferResponse, outputNameBboxes string, outputNameLabels string, outputNameScores string) (interface{}, error) {
outputTensorBboxes, rawOutputContentBboxes, err := GetOutputFromInferResponse(outputNameBboxes, modelInferResponse)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to find inference output for boxes")
}
if rawOutputContentBboxes == nil {
return nil, fmt.Errorf("Unable to find output content for boxes")
}
outputTensorLabels, rawOutputContentLabels, err := GetOutputFromInferResponse(outputNameLabels, modelInferResponse)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to find inference output for labels")
}
if rawOutputContentLabels == nil {
return nil, fmt.Errorf("Unable to find output content for labels")
}
outputTensorScores, rawOutputContentScores, err := GetOutputFromInferResponse(outputNameScores, modelInferResponse)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to find inference output for scores")
}
if rawOutputContentScores == nil {
return nil, fmt.Errorf("Unable to find output content for scores")
}

outputDataBboxes := DeserializeFloat32Tensor(rawOutputContentBboxes)
batchedOutputDataBboxes, err := Reshape1DArrayFloat32To3D(outputDataBboxes, outputTensorBboxes.Shape)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to reshape inference output for boxes")
}

outputDataLabels := DeserializeBytesTensor(rawOutputContentLabels, outputTensorLabels.Shape[0]*outputTensorLabels.Shape[1])
batchedOutputDataLabels, err := Reshape1DArrayStringTo2D(outputDataLabels, outputTensorLabels.Shape)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to reshape inference output for labels")
}

outputDataScores := DeserializeFloat32Tensor(rawOutputContentScores)
batchedOutputDataScores, err := Reshape1DArrayFloat32To2D(outputDataScores, outputTensorScores.Shape)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to reshape inference output for labels")
}

if len(batchedOutputDataBboxes) != len(batchedOutputDataLabels) || len(batchedOutputDataLabels) != len(batchedOutputDataScores) {
log.Printf("Bboxes output has length %v but labels has length %v and scores has length %v", len(batchedOutputDataBboxes), len(batchedOutputDataLabels), len(batchedOutputDataScores))
return nil, fmt.Errorf("Inconsistent batch size for bboxes and labels")
}

return OcrOutput{
Boxes: batchedOutputDataBboxes,
Texts: batchedOutputDataLabels,
Scores: batchedOutputDataScores,
}, nil
}

func postProcessOcrWithoutScore(modelInferResponse *inferenceserver.ModelInferResponse, outputNameBboxes string, outputNameLabels string) (interface{}, error) {
outputTensorBboxes, rawOutputContentBboxes, err := GetOutputFromInferResponse(outputNameBboxes, modelInferResponse)
if err != nil {
log.Printf("%v", err.Error())
Expand Down Expand Up @@ -282,9 +341,19 @@ func postProcessOcr(modelInferResponse *inferenceserver.ModelInferResponse, outp
return nil, fmt.Errorf("Inconsistent batch size for bboxes and labels")
}

var batchedOutputDataScores [][]float32
for i := range batchedOutputDataLabels {
var batchedOutputDataScore []float32
for range batchedOutputDataLabels[i] {
batchedOutputDataScore = append(batchedOutputDataScore, -1)
}
batchedOutputDataScores = append(batchedOutputDataScores, batchedOutputDataScore)
}

return OcrOutput{
Boxes: batchedOutputDataBboxes,
Texts: batchedOutputDataLabels,
Boxes: batchedOutputDataBboxes,
Texts: batchedOutputDataLabels,
Scores: batchedOutputDataScores,
}, nil
}

Expand Down Expand Up @@ -482,10 +551,19 @@ func (ts *triton) PostProcess(inferResponse *inferenceserver.ModelInferResponse,
if len(modelMetadata.Outputs) < 2 {
return nil, fmt.Errorf("Wrong output format of OCR task")
}
outputs, err = postProcessOcr(inferResponse, modelMetadata.Outputs[0].Name, modelMetadata.Outputs[1].Name)
if err != nil {
return nil, fmt.Errorf("Unable to post-process detection output: %w", err)
switch len(modelMetadata.Outputs) {
case 2:
outputs, err = postProcessOcrWithoutScore(inferResponse, modelMetadata.Outputs[0].Name, modelMetadata.Outputs[1].Name)
if err != nil {
return nil, fmt.Errorf("Unable to post-process detection output: %w", err)
}
case 3:
outputs, err = postProcessOcrWithScore(inferResponse, modelMetadata.Outputs[0].Name, modelMetadata.Outputs[1].Name, modelMetadata.Outputs[2].Name)
if err != nil {
return nil, fmt.Errorf("Unable to post-process detection output: %w", err)
}
}

default:
outputs, err = postProcessUnspecifiedTask(inferResponse, modelMetadata.Outputs)
if err != nil {
Expand Down
42 changes: 18 additions & 24 deletions pkg/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, imgsBytes [][]byte, tas
for i := range batchedOutputDataBboxes {
var detOutput = modelPB.TaskOutput{
Output: &modelPB.TaskOutput_Detection{
Detection: &modelPB.BoundingBoxOutput{
BoundingBoxes: []*modelPB.BoundingBoxObject{},
Detection: &modelPB.DetectionOutput{
DetectionObjects: []*modelPB.DetectionObject{},
},
},
}
Expand All @@ -244,11 +244,9 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, imgsBytes [][]byte, tas
label := batchedOutputDataLabels[i][j]
// Non-meaningful bboxes were added with coords [-1, -1, -1, -1, -1] and label "0" for Triton to be able to batch Tensors
if label != "0" {
bbObj := &modelPB.BoundingBoxObject{
TaskField: &modelPB.BoundingBoxObject_Category{
Category: label,
},
Score: box[4],
bbObj := &modelPB.DetectionObject{
Category: label,
Score: box[4],
// Convert x1y1x2y2 to xywh where xy is top-left corner
BoundingBox: &modelPB.BoundingBox{
Left: box[0],
Expand All @@ -257,7 +255,7 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, imgsBytes [][]byte, tas
Height: box[3] - box[1],
},
}
detOutput.GetDetection().BoundingBoxes = append(detOutput.GetDetection().BoundingBoxes, bbObj)
detOutput.GetDetection().DetectionObjects = append(detOutput.GetDetection().DetectionObjects, bbObj)
}
}
detOutputs = append(detOutputs, &detOutput)
Expand All @@ -267,7 +265,7 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, imgsBytes [][]byte, tas
keypointResponse := postprocessResponse.(triton.KeypointOutput)
var keypointOutputs []*modelPB.TaskOutput
for i := range keypointResponse.Keypoints { // batch size
var keypointGroups []*modelPB.BoundingBoxObject
var keypointObjs []*modelPB.KeypointObject
for j := range keypointResponse.Keypoints[i] { // n keypoints in one image
if keypointResponse.Scores[i][j] == -1 { // dummy object for batching to make sure every images have same output shape
continue
Expand All @@ -284,12 +282,8 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, imgsBytes [][]byte, tas
V: points[k][2],
})
}
keypointGroups = append(keypointGroups, &modelPB.BoundingBoxObject{
TaskField: &modelPB.BoundingBoxObject_KeypointGroup{
KeypointGroup: &modelPB.KeypointGroup{
Keypoints: keypoints,
},
},
keypointObjs = append(keypointObjs, &modelPB.KeypointObject{
Keypoints: keypoints,
BoundingBox: &modelPB.BoundingBox{
Left: keypointResponse.Boxes[i][j][0],
Top: keypointResponse.Boxes[i][j][1],
Expand All @@ -301,8 +295,8 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, imgsBytes [][]byte, tas
}
keypointOutputs = append(keypointOutputs, &modelPB.TaskOutput{
Output: &modelPB.TaskOutput_Keypoint{
Keypoint: &modelPB.BoundingBoxOutput{
BoundingBoxes: keypointGroups,
Keypoint: &modelPB.KeypointOutput{
KeypointObjects: keypointObjs,
},
},
})
Expand All @@ -313,31 +307,31 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, imgsBytes [][]byte, tas
detResponses := postprocessResponse.(triton.OcrOutput)
batchedOutputDataBboxes := detResponses.Boxes
batchedOutputDataTexts := detResponses.Texts
batchedOutputDataScores := detResponses.Scores
var ocrOutputs []*modelPB.TaskOutput
for i := range batchedOutputDataBboxes {
var ocrOutput = modelPB.TaskOutput{
Output: &modelPB.TaskOutput_Ocr{
Ocr: &modelPB.BoundingBoxOutput{
BoundingBoxes: []*modelPB.BoundingBoxObject{},
Ocr: &modelPB.OcrOutput{
OcrObjects: []*modelPB.OcrObject{},
},
},
}
for j := range batchedOutputDataBboxes[i] {
box := batchedOutputDataBboxes[i][j]
text := batchedOutputDataTexts[i][j]
score := batchedOutputDataScores[i][j]
// Non-meaningful bboxes were added with coords [-1, -1, -1, -1, -1] and text "" for Triton to be able to batch Tensors
if text != "" && box[0] != -1 {
ocrOutput.GetOcr().BoundingBoxes = append(ocrOutput.GetOcr().BoundingBoxes, &modelPB.BoundingBoxObject{
ocrOutput.GetOcr().OcrObjects = append(ocrOutput.GetOcr().OcrObjects, &modelPB.OcrObject{
BoundingBox: &modelPB.BoundingBox{
Left: box[0],
Top: box[1],
Width: box[2],
Height: box[3],
},
Score: 1.0,
TaskField: &modelPB.BoundingBoxObject_Text{
Text: text,
},
Score: score,
Text: text,
})
}
}
Expand Down

0 comments on commit e915452

Please sign in to comment.