Skip to content

Commit

Permalink
feat: support ocr task (#150)
Browse files Browse the repository at this point in the history
Because

- support OCR task

This commit

- add OCR task
  • Loading branch information
Phelan164 authored Aug 15, 2022
1 parent bf0b5b4 commit 7766c6f
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 4 deletions.
6 changes: 5 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.10.3
github.com/iancoleman/strcase v0.2.0
github.com/instill-ai/protogen-go v0.2.1-alpha.0.20220815023532-edacc1131d77
github.com/instill-ai/protogen-go v0.2.1-alpha.0.20220815030441-efcd607dc85a
github.com/instill-ai/x v0.1.0-alpha.0.20220706215306-bceeac65f523
github.com/knadh/koanf v1.4.1
github.com/mitchellh/mapstructure v1.5.0
github.com/pkg/errors v0.9.1
github.com/rs/cors v1.8.2
github.com/santhosh-tekuri/jsonschema/v5 v5.0.0
github.com/stretchr/testify v1.7.2
github.com/urfave/cli/v2 v2.11.2
go.uber.org/zap v1.21.0
golang.org/x/net v0.0.0-20220615171555-694bf12d69de
google.golang.org/genproto v0.0.0-20220616135557-88e70c0c3a90
Expand All @@ -33,6 +34,7 @@ require (
require (
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/containerd/containerd v1.6.6 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/fsnotify/fsnotify v1.5.4 // indirect
Expand All @@ -58,6 +60,8 @@ require (
github.com/opencontainers/image-spec v1.0.3-0.20211202183452-c5a74bcca799 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rogpeppe/go-internal v1.8.1 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.8.0 // indirect
golang.org/x/crypto v0.0.0-20220525230936-793ad666bf5e // indirect
Expand Down
12 changes: 10 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,8 @@ github.com/coreos/pkg v0.0.0-20160727233714-3ac0863d7acf/go.mod h1:E3G3o1h8I7cfc
github.com/coreos/pkg v0.0.0-20180928190104-399ea9e2e55f/go.mod h1:E3G3o1h8I7cfcXa63jLwjI0eiQQMgzzUDFVpN/nH/eA=
github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
Expand Down Expand Up @@ -700,8 +702,8 @@ github.com/imdario/mergo v0.3.10/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH
github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/imdario/mergo v0.3.12/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA=
github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8=
github.com/instill-ai/protogen-go v0.2.1-alpha.0.20220815023532-edacc1131d77 h1:o++oPWFP0u3hRU0gt0WBj6SxTY+VtSsJYKqgIPOpg0s=
github.com/instill-ai/protogen-go v0.2.1-alpha.0.20220815023532-edacc1131d77/go.mod h1:d9ebEdwMX2Las4OScym45qbQM+xcBQITqvq/8anTVas=
github.com/instill-ai/protogen-go v0.2.1-alpha.0.20220815030441-efcd607dc85a h1:HjdZTEpdz7pY7skluy934TJfF1zsfmnDuTSuzznMviM=
github.com/instill-ai/protogen-go v0.2.1-alpha.0.20220815030441-efcd607dc85a/go.mod h1:d9ebEdwMX2Las4OScym45qbQM+xcBQITqvq/8anTVas=
github.com/instill-ai/x v0.1.0-alpha.0.20220706215306-bceeac65f523 h1:HsZW2VWEnPhxitcyJEGbuQ9vi2LVsSEA8ezPIkp4VQs=
github.com/instill-ai/x v0.1.0-alpha.0.20220706215306-bceeac65f523/go.mod h1:/UEx/zFyMo7so2ctBY0pzjmIoJB9Qz5Y4gvwU2FoU74=
github.com/intel/goresctrl v0.2.0/go.mod h1:+CZdzouYFn5EsxgqAQTEzMfwKwuc0fVdMrT9FCCAVRQ=
Expand Down Expand Up @@ -1086,6 +1088,8 @@ github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w=
github.com/ryanuber/columnize v0.0.0-20160712163229-9b3edd62028f/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
github.com/ryanuber/columnize v2.1.0+incompatible/go.mod h1:sm1tb6uqfes/u+d4ooFouqFdy9/2g9QGwK3SQygK0Ts=
Expand Down Expand Up @@ -1170,6 +1174,8 @@ github.com/urfave/cli v0.0.0-20171014202726-7bc6a0acffa5/go.mod h1:70zkFmudgCuE/
github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA=
github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/urfave/cli v1.22.2/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0=
github.com/urfave/cli/v2 v2.11.2 h1:FVfNg4m3vbjbBpLYxW//WjxUoHvJ9TlppXcqY9Q9ZfA=
github.com/urfave/cli/v2 v2.11.2/go.mod h1:f8iq5LtQ/bLxafbdBSLPPNsgaW0l/2fYYEHhAyPlwvo=
github.com/vishvananda/netlink v0.0.0-20181108222139-023a6dafdcdf/go.mod h1:+SR5DhBJrl6ZM7CoCKvpw5BKroDKQ+PJqOg65H/2ktk=
github.com/vishvananda/netlink v1.1.0/go.mod h1:cTgwzPIzzgDAYoQrMm0EdrjRUBkTqKYppBueQtXaqoE=
github.com/vishvananda/netlink v1.1.1-0.20201029203352-d40f9887b852/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho=
Expand All @@ -1189,6 +1195,8 @@ github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415/go.mod h1:
github.com/xeipuuv/gojsonschema v0.0.0-20180618132009-1d523034197f/go.mod h1:5yf86TLmAcydyeJq5YvxkGPE2fm/u4myDekKRoLuqhs=
github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU=
github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8=
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
Expand Down
3 changes: 2 additions & 1 deletion internal/db/migration/000001_init.up.sql
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ CREATE TYPE valid_task AS ENUM (
'TASK_UNSPECIFIED',
'TASK_CLASSIFICATION',
'TASK_DETECTION',
'TASK_KEYPOINT'
'TASK_KEYPOINT',
'TASK_OCR'
);

CREATE TABLE IF NOT EXISTS "model_definition" (
Expand Down
5 changes: 5 additions & 0 deletions internal/triton/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ type DetectionOutput struct {
Labels [][]string
}

type OcrOutput struct {
Boxes [][][]float32
Texts [][]string
}

type KeypointOutput struct {
Keypoints [][][]float32
Scores []float32
Expand Down
61 changes: 61 additions & 0 deletions internal/triton/triton.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,49 @@ func postProcessDetection(modelInferResponse *inferenceserver.ModelInferResponse
}, nil
}

func postProcessOcr(modelInferResponse *inferenceserver.ModelInferResponse, outputNameBboxes string, outputNameLabels string) (interface{}, error) {
outputTensorBboxes, rawOutputContentBboxes, err := GetOutputFromInferResponse(outputNameBboxes, modelInferResponse)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to find inference output for boxes")
}
if rawOutputContentBboxes == nil {
return nil, fmt.Errorf("Unable to find output content for boxes")
}
outputTensorLabels, rawOutputContentLabels, err := GetOutputFromInferResponse(outputNameLabels, modelInferResponse)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to find inference output for labels")
}
if rawOutputContentLabels == nil {
return nil, fmt.Errorf("Unable to find output content for labels")
}

outputDataBboxes := DeserializeFloat32Tensor(rawOutputContentBboxes)
batchedOutputDataBboxes, err := Reshape1DArrayFloat32To3D(outputDataBboxes, outputTensorBboxes.Shape)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to reshape inference output for boxes")
}

outputDataLabels := DeserializeBytesTensor(rawOutputContentLabels, outputTensorBboxes.Shape[0]*outputTensorBboxes.Shape[1])
batchedOutputDataLabels, err := Reshape1DArrayStringTo2D(outputDataLabels, outputTensorLabels.Shape)
if err != nil {
log.Printf("%v", err.Error())
return nil, fmt.Errorf("Unable to reshape inference output for labels")
}

if len(batchedOutputDataBboxes) != len(batchedOutputDataLabels) {
log.Printf("Bboxes output has length %v but labels has length %v", len(batchedOutputDataBboxes), len(batchedOutputDataLabels))
return nil, fmt.Errorf("Inconsistent batch size for bboxes and labels")
}

return OcrOutput{
Boxes: batchedOutputDataBboxes,
Texts: batchedOutputDataLabels,
}, nil
}

func postProcessClassification(modelInferResponse *inferenceserver.ModelInferResponse, outputName string) (interface{}, error) {
outputTensor, rawOutputContent, err := GetOutputFromInferResponse(outputName, modelInferResponse)
if err != nil {
Expand Down Expand Up @@ -306,6 +349,19 @@ func postProcessUnspecifiedTask(modelInferResponse *inferenceserver.ModelInferRe
serializedOutputs = append(serializedOutputs, reshapedOutput)
}
}
case "INT32":
deserializedRawOutput := DeserializeInt32Tensor(rawOutputContent)
if len(outputTensor.Shape) == 1 {
serializedOutputs = append(serializedOutputs, deserializedRawOutput)
} else if len(outputTensor.Shape) == 2 {
reshapedOutputs, err := Reshape1DArrayInt32To2D(deserializedRawOutput, outputTensor.Shape)
if err != nil {
return nil, err
}
for _, reshapedOutput := range reshapedOutputs {
serializedOutputs = append(serializedOutputs, reshapedOutput)
}
}
case "STRING":
deserializedRawOutput := DeserializeBytesTensor(rawOutputContent, outputTensor.Shape[0]*outputTensor.Shape[1])
reshapedOutputs, err := Reshape1DArrayStringTo2D(deserializedRawOutput, outputTensor.Shape)
Expand Down Expand Up @@ -394,6 +450,11 @@ func (ts *triton) PostProcess(inferResponse *inferenceserver.ModelInferResponse,
if err != nil {
return nil, fmt.Errorf("Unable to post-process keypoint output: %w", err)
}
case modelPB.ModelInstance_TASK_OCR:
outputs, err = postProcessOcr(inferResponse, modelMetadata.Outputs[0].Name, modelMetadata.Outputs[1].Name)
if err != nil {
return nil, fmt.Errorf("Unable to post-process detection output: %w", err)
}
default:
outputs, err = postProcessUnspecifiedTask(inferResponse, modelMetadata.Outputs)
if err != nil {
Expand Down
37 changes: 37 additions & 0 deletions internal/triton/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,17 @@ func DeserializeFloat32Tensor(encodedTensor []byte) []float32 {
return arr
}

func DeserializeInt32Tensor(encodedTensor []byte) []int32 {
if len(encodedTensor) == 0 {
return []int32{}
}
arr := make([]int32, len(encodedTensor)/4)
for i := 0; i < len(encodedTensor)/4; i++ {
arr[i] = ReadInt32(encodedTensor[i*4 : i*4+4])
}
return arr
}

// TODO: generalise reshape functions by using interface{} arguments and returned values
func Reshape1DArrayStringTo2D(array []string, shape []int64) ([][]string, error) {
if len(array) == 0 {
Expand Down Expand Up @@ -154,6 +165,32 @@ func Reshape1DArrayFloat32To2D(array []float32, shape []int64) ([][]float32, err
return res, nil
}

func Reshape1DArrayInt32To2D(array []int32, shape []int64) ([][]int32, error) {
if len(array) == 0 {
return [][]int32{}, nil
}

if len(shape) != 2 {
return nil, fmt.Errorf("Expected a 2D shape, got %vD shape %v", len(shape), shape)
}

var prod int64 = 1
for _, s := range shape {
prod *= s
}
if prod != int64(len(array)) {
return nil, fmt.Errorf("Cannot reshape array of length %v into shape %v", len(array), shape)
}
res := make([][]int32, shape[0])
for i := int64(0); i < shape[0]; i++ {
res[i] = make([]int32, shape[1])
start := i * shape[1]
end := start + shape[1]
res[i] = array[start:end]
}
return res, nil
}

func GetOutputFromInferResponse(name string, response *inferenceserver.ModelInferResponse) (*inferenceserver.ModelInferResponse_InferOutputTensor, []byte, error) {
for idx, output := range response.Outputs {
if output.Name == name {
Expand Down
2 changes: 2 additions & 0 deletions internal/util/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ var Tasks = map[string]modelPB.ModelInstance_Task{
"TASK_CLASSIFICATION": modelPB.ModelInstance_TASK_CLASSIFICATION,
"TASK_DETECTION": modelPB.ModelInstance_TASK_DETECTION,
"TASK_KEYPOINT": modelPB.ModelInstance_TASK_KEYPOINT,
"TASK_OCR": modelPB.ModelInstance_TASK_OCR,
}

var Tags = map[string]modelPB.ModelInstance_Task{
Expand All @@ -18,6 +19,7 @@ var Tags = map[string]modelPB.ModelInstance_Task{
"IMAGE-CLASSIFICATION": modelPB.ModelInstance_TASK_CLASSIFICATION,
"IMAGE-DETECTION": modelPB.ModelInstance_TASK_DETECTION,
"OBJECT-DETECTION": modelPB.ModelInstance_TASK_DETECTION,
"OCR": modelPB.ModelInstance_TASK_OCR,
}

var Visibility = map[string]modelPB.Model_Visibility{
Expand Down
30 changes: 30 additions & 0 deletions pkg/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,36 @@ func (s *service) ModelInfer(modelInstanceUID uuid.UUID, imgsBytes [][]byte, tas
})

return keypointOutputs, nil
case modelPB.ModelInstance_TASK_OCR:
detResponses := postprocessResponse.(triton.OcrOutput)
batchedOutputDataBboxes := detResponses.Boxes
batchedOutputDataTexts := detResponses.Texts
var ocrOutputs []*modelPB.ModelInstanceOutput
for i := range batchedOutputDataBboxes {
var ocrOutput = modelPB.ModelInstanceOutput{
Output: &modelPB.ModelInstanceOutput_Ocr{
Ocr: &modelPB.OcrOutput{
BoundingBoxes: []*modelPB.BoundingBox{},
},
},
}
for j := range batchedOutputDataBboxes[i] {
box := batchedOutputDataBboxes[i][j]
text := batchedOutputDataTexts[i][j]
// Non-meaningful bboxes were added with coords [-1, -1, -1, -1, -1] and text "" for Triton to be able to batch Tensors
if text != "" && box[0] != -1 {
ocrOutput.GetOcr().BoundingBoxes = append(ocrOutput.GetOcr().BoundingBoxes, &modelPB.BoundingBox{
Left: box[0],
Top: box[1],
Width: box[2],
Height: box[3],
})
ocrOutput.GetOcr().Texts = append(ocrOutput.GetOcr().Texts, text)
}
}
ocrOutputs = append(ocrOutputs, &ocrOutput)
}
return ocrOutputs, nil
default:
outputs := postprocessResponse.([]triton.BatchUnspecifiedTaskOutputs)
var rawOutputs []*modelPB.ModelInstanceOutput
Expand Down

0 comments on commit 7766c6f

Please sign in to comment.