-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredictorpool.go
63 lines (54 loc) · 1.58 KB
/
predictorpool.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
package paddle
import (
"context"
"runtime"
"sync/atomic"
"github.com/jackc/puddle/v2"
paddle "github.com/paddlepaddle/paddle/paddle/fluid/inference/goapi"
)
// PredictorPool is a predictor pool for concurrent inferences.
//
// See also:
// - https://github.com/PaddlePaddle/Paddle/issues/17288
// - https://www.paddlepaddle.org.cn/inference/master/guides/performance_tuning/multi_thread.html
type PredictorPool struct {
pool *puddle.Pool[*paddle.Predictor]
}
func NewPredictorPool(config *paddle.Config, size int) *PredictorPool {
if size < 1 {
size = runtime.NumCPU()
}
var first int64
mainPredictor := paddle.NewPredictor(config)
pool, err := puddle.NewPool(&puddle.Config[*paddle.Predictor]{
Constructor: func(context.Context) (*paddle.Predictor, error) {
if atomic.CompareAndSwapInt64(&first, 0, 1) {
return mainPredictor, nil
} else {
return mainPredictor.Clone(), nil
}
},
Destructor: func(value *paddle.Predictor) {},
MaxSize: int32(size),
})
if err != nil {
panic(err)
}
// Pre-create all predictors since they can not be created on demand.
//
// The root cause is that predictor.Clone() is not concurrency-safe, see
// https://github.com/PaddlePaddle/Paddle/issues/24887.
for i := 0; i < size; i++ {
if err = pool.CreateResource(context.Background()); err != nil {
panic(err)
}
}
return &PredictorPool{pool: pool}
}
func (p *PredictorPool) Get() (predictor *paddle.Predictor, put func()) {
resource, err := p.pool.Acquire(context.Background())
if err != nil {
panic(err)
}
return resource.Value(), func() { resource.Release() }
}