Skip to content

Commit

Permalink
Plug ThreadPoolFactory into MultiInference request handling logic.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 306856194
  • Loading branch information
lilao authored and tensorflow-copybara committed Apr 16, 2020
1 parent e862e59 commit 9a2db1d
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 67 deletions.
37 changes: 19 additions & 18 deletions tensorflow_serving/model_servers/prediction_service_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ int DeadlineToTimeoutMillis(const gpr_timespec deadline) {
gpr_now(GPR_CLOCK_MONOTONIC)));
}

thread::ThreadPoolOptions GetThreadPoolOptions(
ThreadPoolFactory *thread_pool_factory) {
thread::ThreadPoolOptions thread_pool_options;
if (thread_pool_factory != nullptr) {
thread_pool_options.inter_op_threadpool =
thread_pool_factory->GetInterOpThreadPool();
thread_pool_options.intra_op_threadpool =
thread_pool_factory->GetIntraOpThreadPool();
}
return thread_pool_options;
}

} // namespace

::grpc::Status PredictionServiceImpl::Predict(::grpc::ServerContext *context,
Expand Down Expand Up @@ -79,16 +91,10 @@ ::grpc::Status PredictionServiceImpl::Classify(
run_options.set_timeout_in_ms(
DeadlineToTimeoutMillis(context->raw_deadline()));
}
thread::ThreadPoolOptions thread_pool_options;
if (thread_pool_factory_ != nullptr) {
thread_pool_options.inter_op_threadpool =
thread_pool_factory_->GetInterOpThreadPool();
thread_pool_options.intra_op_threadpool =
thread_pool_factory_->GetIntraOpThreadPool();
}
const ::grpc::Status status =
ToGRPCStatus(TensorflowClassificationServiceImpl::Classify(
run_options, core_, thread_pool_options, *request, response));
run_options, core_, GetThreadPoolOptions(thread_pool_factory_),
*request, response));
if (!status.ok()) {
VLOG(1) << "Classify request failed: " << status.error_message();
}
Expand All @@ -104,16 +110,10 @@ ::grpc::Status PredictionServiceImpl::Regress(::grpc::ServerContext *context,
run_options.set_timeout_in_ms(
DeadlineToTimeoutMillis(context->raw_deadline()));
}
thread::ThreadPoolOptions thread_pool_options;
if (thread_pool_factory_ != nullptr) {
thread_pool_options.inter_op_threadpool =
thread_pool_factory_->GetInterOpThreadPool();
thread_pool_options.intra_op_threadpool =
thread_pool_factory_->GetIntraOpThreadPool();
}
const ::grpc::Status status =
ToGRPCStatus(TensorflowRegressionServiceImpl::Regress(
run_options, core_, thread_pool_options, *request, response));
run_options, core_, GetThreadPoolOptions(thread_pool_factory_),
*request, response));
if (!status.ok()) {
VLOG(1) << "Regress request failed: " << status.error_message();
}
Expand All @@ -129,8 +129,9 @@ ::grpc::Status PredictionServiceImpl::MultiInference(
run_options.set_timeout_in_ms(
DeadlineToTimeoutMillis(context->raw_deadline()));
}
const ::grpc::Status status = ToGRPCStatus(
RunMultiInferenceWithServerCore(run_options, core_, *request, response));
const ::grpc::Status status = ToGRPCStatus(RunMultiInferenceWithServerCore(
run_options, core_, GetThreadPoolOptions(thread_pool_factory_), *request,
response));
if (!status.ok()) {
VLOG(1) << "MultiInference request failed: " << status.error_message();
}
Expand Down
16 changes: 8 additions & 8 deletions tensorflow_serving/servables/tensorflow/multi_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ Status TensorFlowMultiInferenceRunner::Infer(
int num_examples;
TF_RETURN_IF_ERROR(PerformOneShotTensorComputation(
run_options, request.input(), input_tensor_name, output_tensor_names,
session_, &outputs, &num_examples));
session_, &outputs, &num_examples, thread_pool_options_));
RecordRequestExampleCount(model_name, num_examples);

TRACELITERAL("PostProcessResults");
Expand Down Expand Up @@ -129,15 +129,15 @@ Status TensorFlowMultiInferenceRunner::Infer(
return Status::OK();
}

Status RunMultiInference(const RunOptions& run_options,
const MetaGraphDef& meta_graph_def,
const optional<int64>& servable_version,
Session* session, const MultiInferenceRequest& request,
MultiInferenceResponse* response) {
Status RunMultiInference(
const RunOptions& run_options, const MetaGraphDef& meta_graph_def,
const optional<int64>& servable_version, Session* session,
const MultiInferenceRequest& request, MultiInferenceResponse* response,
const tensorflow::thread::ThreadPoolOptions& thread_pool_options) {
TRACELITERAL("RunMultiInference");

TensorFlowMultiInferenceRunner inference_runner(session, &meta_graph_def,
servable_version);
TensorFlowMultiInferenceRunner inference_runner(
session, &meta_graph_def, servable_version, thread_pool_options);
return inference_runner.Infer(run_options, request, response);
}

Expand Down
26 changes: 16 additions & 10 deletions tensorflow_serving/servables/tensorflow/multi_inference.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_MULTI_INFERENCE_H_

#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/threadpool_options.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow_serving/apis/inference.pb.h"
Expand All @@ -32,14 +33,17 @@ class TensorFlowMultiInferenceRunner {
TensorFlowMultiInferenceRunner(Session* session,
const MetaGraphDef* meta_graph_def)
: TensorFlowMultiInferenceRunner(session, meta_graph_def,
/*servable_version*/ {}) {}
/*servable_version=*/{}) {}

TensorFlowMultiInferenceRunner(Session* session,
const MetaGraphDef* meta_graph_def,
optional<int64> servable_version)
TensorFlowMultiInferenceRunner(
Session* session, const MetaGraphDef* meta_graph_def,
optional<int64> servable_version,
const thread::ThreadPoolOptions& thread_pool_options =
thread::ThreadPoolOptions())
: session_(session),
meta_graph_def_(meta_graph_def),
servable_version_(servable_version) {}
servable_version_(servable_version),
thread_pool_options_(thread_pool_options) {}

// Run inference and return the inference results in the same order as the
// InferenceTasks in the request.
Expand All @@ -55,14 +59,16 @@ class TensorFlowMultiInferenceRunner {
// If available, servable_version is used to set the ModelSpec version in the
// InferenceResults of the MultiInferenceResponse.
const optional<int64> servable_version_;
const tensorflow::thread::ThreadPoolOptions thread_pool_options_;
};

// Creates TensorFlowMultiInferenceRunner and calls Infer on it.
Status RunMultiInference(const RunOptions& run_options,
const MetaGraphDef& meta_graph_def,
const optional<int64>& servable_version,
Session* session, const MultiInferenceRequest& request,
MultiInferenceResponse* response);
Status RunMultiInference(
const RunOptions& run_options, const MetaGraphDef& meta_graph_def,
const optional<int64>& servable_version, Session* session,
const MultiInferenceRequest& request, MultiInferenceResponse* response,
const tensorflow::thread::ThreadPoolOptions& thread_pool_options =
tensorflow::thread::ThreadPoolOptions());

} // namespace serving
} // namespace tensorflow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,26 @@ const ModelSpec& GetModelSpecFromRequest(const MultiInferenceRequest& request) {

} // namespace

Status RunMultiInferenceWithServerCore(const RunOptions& run_options,
ServerCore* core,
const MultiInferenceRequest& request,
MultiInferenceResponse* response) {
Status RunMultiInferenceWithServerCore(
const RunOptions& run_options, ServerCore* core,
const tensorflow::thread::ThreadPoolOptions& thread_pool_options,
const MultiInferenceRequest& request, MultiInferenceResponse* response) {
return RunMultiInferenceWithServerCoreWithModelSpec(
run_options, core, GetModelSpecFromRequest(request), request, response);
run_options, core, thread_pool_options, GetModelSpecFromRequest(request),
request, response);
}

Status RunMultiInferenceWithServerCoreWithModelSpec(
const RunOptions& run_options, ServerCore* core,
const tensorflow::thread::ThreadPoolOptions& thread_pool_options,
const ModelSpec& model_spec, const MultiInferenceRequest& request,
MultiInferenceResponse* response) {
ServableHandle<SavedModelBundle> bundle;
TF_RETURN_IF_ERROR(core->GetServableHandle(model_spec, &bundle));

return RunMultiInference(run_options, bundle->meta_graph_def,
bundle.id().version, bundle->session.get(),
request, response);
bundle.id().version, bundle->session.get(), request,
response, thread_pool_options);
}

} // namespace serving
Expand Down
10 changes: 6 additions & 4 deletions tensorflow_serving/servables/tensorflow/multi_inference_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_MULTI_INFERENCE_HELPER_H_

#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/threadpool_options.h"
#include "tensorflow_serving/apis/inference.pb.h"
#include "tensorflow_serving/model_servers/server_core.h"
#include "tensorflow_serving/util/optional.h"
Expand All @@ -25,15 +26,16 @@ namespace tensorflow {
namespace serving {

// Runs MultiInference
Status RunMultiInferenceWithServerCore(const RunOptions& run_options,
ServerCore* core,
const MultiInferenceRequest& request,
MultiInferenceResponse* response);
Status RunMultiInferenceWithServerCore(
const RunOptions& run_options, ServerCore* core,
const thread::ThreadPoolOptions& thread_pool_options,
const MultiInferenceRequest& request, MultiInferenceResponse* response);

// Like RunMultiInferenceWithServerCore(), but uses 'model_spec' instead of the
// one(s) embedded in 'request'.
Status RunMultiInferenceWithServerCoreWithModelSpec(
const RunOptions& run_options, ServerCore* core,
const thread::ThreadPoolOptions& thread_pool_options,
const ModelSpec& model_spec, const MultiInferenceRequest& request,
MultiInferenceResponse* response);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ TEST_F(MultiInferenceTest, MissingInputTest) {
PopulateTask("regress_x_to_y", kRegressMethodName, -1, request.add_tasks());

MultiInferenceResponse response;
ExpectStatusError(RunMultiInferenceWithServerCore(RunOptions(),
GetServerCore(),
request, &response),
ExpectStatusError(RunMultiInferenceWithServerCore(
RunOptions(), GetServerCore(),
thread::ThreadPoolOptions(), request, &response),
tensorflow::error::INVALID_ARGUMENT, "Input is empty");
}

Expand All @@ -134,9 +134,9 @@ TEST_F(MultiInferenceTest, UndefinedSignatureTest) {
request.add_tasks());

MultiInferenceResponse response;
ExpectStatusError(RunMultiInferenceWithServerCore(RunOptions(),
GetServerCore(),
request, &response),
ExpectStatusError(RunMultiInferenceWithServerCore(
RunOptions(), GetServerCore(),
thread::ThreadPoolOptions(), request, &response),
tensorflow::error::INVALID_ARGUMENT, "signature not found");
}

Expand All @@ -156,9 +156,9 @@ TEST_F(MultiInferenceTest, InconsistentModelSpecsInRequestTest) {
task->set_method_name(kRegressMethodName);

MultiInferenceResponse response;
ExpectStatusError(RunMultiInferenceWithServerCore(RunOptions(),
GetServerCore(),
request, &response),
ExpectStatusError(RunMultiInferenceWithServerCore(
RunOptions(), GetServerCore(),
thread::ThreadPoolOptions(), request, &response),
tensorflow::error::INVALID_ARGUMENT,
"must access the same model name");
}
Expand All @@ -171,9 +171,9 @@ TEST_F(MultiInferenceTest, EvaluateDuplicateSignaturesTest) {
PopulateTask("regress_x_to_y", kRegressMethodName, -1, request.add_tasks());

MultiInferenceResponse response;
ExpectStatusError(RunMultiInferenceWithServerCore(RunOptions(),
GetServerCore(),
request, &response),
ExpectStatusError(RunMultiInferenceWithServerCore(
RunOptions(), GetServerCore(),
thread::ThreadPoolOptions(), request, &response),
tensorflow::error::INVALID_ARGUMENT,
"Duplicate evaluation of signature: regress_x_to_y");
}
Expand All @@ -184,9 +184,9 @@ TEST_F(MultiInferenceTest, UsupportedSignatureTypeTest) {
PopulateTask("serving_default", kPredictMethodName, -1, request.add_tasks());

MultiInferenceResponse response;
ExpectStatusError(RunMultiInferenceWithServerCore(RunOptions(),
GetServerCore(),
request, &response),
ExpectStatusError(RunMultiInferenceWithServerCore(
RunOptions(), GetServerCore(),
thread::ThreadPoolOptions(), request, &response),
tensorflow::error::UNIMPLEMENTED, "Unsupported signature");
}

Expand All @@ -197,9 +197,9 @@ TEST_F(MultiInferenceTest, SignaturesWithDifferentInputsTest) {
PopulateTask("regress_x2_to_y3", kRegressMethodName, -1, request.add_tasks());

MultiInferenceResponse response;
ExpectStatusError(RunMultiInferenceWithServerCore(RunOptions(),
GetServerCore(),
request, &response),
ExpectStatusError(RunMultiInferenceWithServerCore(
RunOptions(), GetServerCore(),
thread::ThreadPoolOptions(), request, &response),
tensorflow::error::INVALID_ARGUMENT,
"Input tensor must be the same");
}
Expand All @@ -220,6 +220,7 @@ TEST_F(MultiInferenceTest, ValidSingleSignatureTest) {

MultiInferenceResponse response;
TF_ASSERT_OK(RunMultiInferenceWithServerCore(RunOptions(), GetServerCore(),
thread::ThreadPoolOptions(),
request, &response));
EXPECT_THAT(response, test_util::EqualsProto(expected_response));
}
Expand Down Expand Up @@ -252,6 +253,7 @@ TEST_F(MultiInferenceTest, MultipleValidRegressSignaturesTest) {

MultiInferenceResponse response;
TF_ASSERT_OK(RunMultiInferenceWithServerCore(RunOptions(), GetServerCore(),
thread::ThreadPoolOptions(),
request, &response));
EXPECT_THAT(response, test_util::EqualsProto(expected_response));
}
Expand Down Expand Up @@ -282,6 +284,7 @@ TEST_F(MultiInferenceTest, RegressAndClassifySignaturesTest) {

MultiInferenceResponse response;
TF_ASSERT_OK(RunMultiInferenceWithServerCore(RunOptions(), GetServerCore(),
thread::ThreadPoolOptions(),
request, &response));
EXPECT_THAT(response, test_util::EqualsProto(expected_response));
}
Expand All @@ -299,15 +302,46 @@ TEST_F(MultiInferenceTest, ModelSpecOverride) {
MultiInferenceResponse response;
EXPECT_NE(tensorflow::error::NOT_FOUND,
RunMultiInferenceWithServerCore(RunOptions(), GetServerCore(),
thread::ThreadPoolOptions(),
request, &response)
.code());
EXPECT_EQ(tensorflow::error::NOT_FOUND,
RunMultiInferenceWithServerCoreWithModelSpec(
RunOptions(), GetServerCore(), model_spec_override, request,
&response)
RunOptions(), GetServerCore(), thread::ThreadPoolOptions(),
model_spec_override, request, &response)
.code());
}

TEST_F(MultiInferenceTest, ThreadPoolOptions) {
MultiInferenceRequest request;
AddInput({{"x", 2}}, &request);
PopulateTask("regress_x_to_y", kRegressMethodName, servable_version_,
request.add_tasks());

MultiInferenceResponse expected_response;
auto* inference_result = expected_response.add_results();
auto* model_spec = inference_result->mutable_model_spec();
*model_spec = request.tasks(0).model_spec();
model_spec->mutable_version()->set_value(servable_version_);
auto* regression_result = inference_result->mutable_regression_result();
regression_result->add_regressions()->set_value(3.0);

test_util::CountingThreadPool inter_op_threadpool(Env::Default(), "InterOp",
/*num_threads=*/1);
test_util::CountingThreadPool intra_op_threadpool(Env::Default(), "IntraOp",
/*num_threads=*/1);
thread::ThreadPoolOptions thread_pool_options;
thread_pool_options.inter_op_threadpool = &inter_op_threadpool;
thread_pool_options.intra_op_threadpool = &intra_op_threadpool;
MultiInferenceResponse response;
TF_ASSERT_OK(RunMultiInferenceWithServerCore(
RunOptions(), GetServerCore(), thread_pool_options, request, &response));
EXPECT_THAT(response, test_util::EqualsProto(expected_response));

// The intra_op_threadpool doesn't have anything scheduled.
ASSERT_GE(inter_op_threadpool.NumScheduled(), 1);
}

} // namespace
} // namespace serving
} // namespace tensorflow
Loading

0 comments on commit 9a2db1d

Please sign in to comment.