Skip to content

Commit

Permalink
Move PredictRequest and ModelSession to utils.proto
Browse files Browse the repository at this point in the history
Since both inference and training servicers have common the concept of
id, the training session id was replaced with the model session one used
for inference. This model session protobuf interfaced moved to a
separate utils proto file.

The PredictRequest being common, can be leveraged for abstraction.
  • Loading branch information
thodkatz committed Dec 21, 2024
1 parent 50b0944 commit 4798dbe
Show file tree
Hide file tree
Showing 13 changed files with 221 additions and 264 deletions.
14 changes: 1 addition & 13 deletions proto/inference.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ service Inference {


message CreateDatasetDescriptionRequest {
string modelSessionId = 1;
ModelSession modelSessionId = 1;
double mean = 3;
double stddev = 4;
}
Expand Down Expand Up @@ -53,9 +53,6 @@ message NamedFloats {
}


message ModelSession {
string id = 1;
}

message LogEntry {
enum Level {
Expand All @@ -73,15 +70,6 @@ message LogEntry {
}


message PredictRequest {
string modelSessionId = 1;
string datasetId = 2;
repeated Tensor tensors = 3;
}

message PredictResponse {
repeated Tensor tensors = 1;
}


service FlightControl {
Expand Down
35 changes: 12 additions & 23 deletions proto/training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,33 +9,33 @@ import "utils.proto";
service Training {
rpc ListDevices(Empty) returns (Devices) {}

rpc Init(TrainingConfig) returns (TrainingSessionId) {}
rpc Init(TrainingConfig) returns (ModelSession) {}

rpc Start(TrainingSessionId) returns (Empty) {}
rpc Start(ModelSession) returns (Empty) {}

rpc Resume(TrainingSessionId) returns (Empty) {}
rpc Resume(ModelSession) returns (Empty) {}

rpc Pause(TrainingSessionId) returns (Empty) {}
rpc Pause(ModelSession) returns (Empty) {}

rpc StreamUpdates(TrainingSessionId) returns (stream StreamUpdateResponse) {}
rpc StreamUpdates(ModelSession) returns (stream StreamUpdateResponse) {}

rpc GetLogs(TrainingSessionId) returns (GetLogsResponse) {}
rpc GetLogs(ModelSession) returns (GetLogsResponse) {}

rpc Save(TrainingSessionId) returns (Empty) {}

rpc Export(TrainingSessionId) returns (Empty) {}
rpc Export(ModelSession) returns (Empty) {}

rpc Predict(PredictRequest) returns (PredictResponse) {}

rpc GetStatus(TrainingSessionId) returns (GetStatusResponse) {}
rpc GetStatus(ModelSession) returns (GetStatusResponse) {}

rpc CloseTrainerSession(TrainingSessionId) returns (Empty) {}
rpc CloseTrainerSession(ModelSession) returns (Empty) {}
}

message TrainingSessionId {

message GetBestModelIdxResponse {
string id = 1;
}


message Logs {
enum ModelPhase {
Train = 0;
Expand All @@ -59,17 +59,6 @@ message GetLogsResponse {
}



message PredictRequest {
repeated Tensor tensors = 1;
TrainingSessionId sessionId = 2;
}


message PredictResponse {
repeated Tensor tensors = 1;
}

message ValidationResponse {
double validation_score_average = 1;
}
Expand Down
13 changes: 13 additions & 0 deletions proto/utils.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,19 @@ syntax = "proto3";

message Empty {}

message ModelSession {
string id = 1;
}

message PredictRequest {
ModelSession modelSessionId = 1;
repeated Tensor tensors = 2;
}

message PredictResponse {
repeated Tensor tensors = 1;
}

message NamedInt {
uint32 size = 1;
string name = 2;
Expand Down
43 changes: 22 additions & 21 deletions tests/test_server/test_grpc/test_inference_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def test_model_session_creation_using_non_existent_upload(self, grpc_stub):

def test_predict_call_fails_without_specifying_model_session_id(self, grpc_stub):
with pytest.raises(grpc.RpcError) as e:
grpc_stub.Predict(inference_pb2.PredictRequest())
grpc_stub.Predict(utils_pb2.PredictRequest())

assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code()
assert "model-session-id has not been provided" in e.value.details()
assert "model-session with id doesn't exist" in e.value.details()

def test_model_init_failed_close_session(self, bioimage_model_explicit_add_one_siso_v5, grpc_stub):
"""
Expand Down Expand Up @@ -169,17 +169,18 @@ def test_returns_ack_message(self, bioimage_model_explicit_add_one_siso_v5, grpc
class TestForwardPass:
def test_call_fails_with_unknown_model_session_id(self, grpc_stub):
with pytest.raises(grpc.RpcError) as e:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId="myid1"))
model_id = utils_pb2.ModelSession(id="myid")
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id))
assert grpc.StatusCode.FAILED_PRECONDITION == e.value.code()
assert "model-session with id myid1 doesn't exist" in e.value.details()
assert "model-session with id myid doesn't exist" in e.value.details()

def test_call_predict_valid_explicit(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y"))
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
res = grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert len(res.tensors) == 1
pb_tensor = res.tensors[0]
assert pb_tensor.tensorId == "output"
Expand All @@ -188,11 +189,11 @@ def test_call_predict_valid_explicit(self, grpc_stub, bioimage_model_explicit_ad

def test_call_predict_valid_explicit_v4(self, grpc_stub, bioimage_model_add_one_v4):
model_bytes = bioimage_model_add_one_v4
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y"))
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
res = grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert len(res.tensors) == 1
pb_tensor = res.tensors[0]
assert pb_tensor.tensorId == "output"
Expand All @@ -201,16 +202,16 @@ def test_call_predict_valid_explicit_v4(self, grpc_stub, bioimage_model_add_one_

def test_call_predict_invalid_shape_explicit(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(32 * 32).reshape(1, 1, 32, 32), dims=("batch", "channel", "x", "y"))
input_tensors = [converters.xarray_to_pb_tensor("input", arr)]
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Incompatible axis")

def test_call_predict_multiple_inputs_with_reference(self, grpc_stub, bioimage_model_add_one_miso_v5):
model_bytes = bioimage_model_add_one_miso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))

arr1 = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y"))
input_tensor_id1 = "input1"
Expand All @@ -227,8 +228,8 @@ def test_call_predict_multiple_inputs_with_reference(self, grpc_stub, bioimage_m
converters.xarray_to_pb_tensor(tensor_id, arr) for tensor_id, arr in zip(input_tensor_ids, tensors_arr)
]

res = grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.CloseModelSession(model)
res = grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
grpc_stub.CloseModelSession(model_id)
assert len(res.tensors) == 1
pb_tensor = res.tensors[0]
assert pb_tensor.tensorId == "output"
Expand All @@ -238,33 +239,33 @@ def test_call_predict_multiple_inputs_with_reference(self, grpc_stub, bioimage_m
@pytest.mark.parametrize("shape", [(1, 2, 10, 20), (1, 2, 12, 20), (1, 2, 10, 23), (1, 2, 12, 23)])
def test_call_predict_valid_shape_parameterized(self, grpc_stub, shape, bioimage_model_param_add_one_siso_v5):
model_bytes = bioimage_model_param_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("batch", "channel", "x", "y"))
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))

@pytest.mark.parametrize(
"shape",
[(1, 1, 10, 20), (1, 2, 8, 20), (1, 2, 11, 20), (1, 2, 10, 21)],
)
def test_call_predict_invalid_shape_parameterized(self, grpc_stub, shape, bioimage_model_param_add_one_siso_v5):
model_bytes = bioimage_model_param_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(np.prod(shape)).reshape(*shape), dims=("batch", "channel", "x", "y"))
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Incompatible axis")

def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=("batch", "channel", "x", "y"))
input_tensors = [converters.xarray_to_pb_tensor("invalidTensorName", arr)]
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Spec 'invalidTensorName' doesn't exist")

@pytest.mark.parametrize(
Expand All @@ -278,10 +279,10 @@ def test_call_predict_invalid_tensor_ids(self, grpc_stub, bioimage_model_explici
)
def test_call_predict_invalid_axes(self, grpc_stub, axes, bioimage_model_explicit_add_one_siso_v5):
model_bytes = bioimage_model_explicit_add_one_siso_v5
model = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
model_id = grpc_stub.CreateModelSession(valid_model_request(model_bytes))
arr = xr.DataArray(np.arange(2 * 10 * 20).reshape(1, 2, 10, 20), dims=axes)
input_tensor_id = "input"
input_tensors = [converters.xarray_to_pb_tensor(input_tensor_id, arr)]
with pytest.raises(grpc.RpcError) as error:
grpc_stub.Predict(inference_pb2.PredictRequest(modelSessionId=model.id, tensors=input_tensors))
grpc_stub.Predict(utils_pb2.PredictRequest(modelSessionId=model_id, tensors=input_tensors))
assert error.value.details().startswith("Exception calling application: Incompatible axes names")
32 changes: 16 additions & 16 deletions tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def test_init_failed_then_devices_are_released(self, grpc_stub):

# attempt to init with the same device
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
response = training_pb2.TrainingSessionId(id=init_response.id)
response = utils_pb2.ModelSession(id=init_response.id)
assert response.id is not None

def test_start_training_success(self):
Expand Down Expand Up @@ -271,7 +271,7 @@ def test_concurrent_state_transitions(self, grpc_stub):
The test should exit gracefully without hanging processes or threads.
"""
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)

threads = []
for _ in range(2):
Expand All @@ -290,7 +290,7 @@ def assert_state(state_to_check):
self.assert_state(grpc_stub, training_session_id, state_to_check)

init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)

grpc_stub.Start(training_session_id)
assert_state(TrainerState.RUNNING)
Expand All @@ -304,7 +304,7 @@ def assert_state(state_to_check):

def test_error_handling_on_invalid_state_transitions_after_training_started(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)

# Attempt to start again while already running
grpc_stub.Start(training_session_id)
Expand All @@ -326,7 +326,7 @@ def test_error_handling_on_invalid_state_transitions_after_training_started(self

def test_error_handling_on_invalid_state_transitions_before_training_started(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)

# Attempt to resume before start
with pytest.raises(grpc.RpcError) as excinfo:
Expand All @@ -345,7 +345,7 @@ def test_start_training_without_init(self, grpc_stub):
with pytest.raises(grpc.RpcError) as excinfo:
grpc_stub.Start(utils_pb2.Empty())
assert excinfo.value.code() == grpc.StatusCode.FAILED_PRECONDITION
assert "trainer-session with id doesn't exist" in excinfo.value.details()
assert "model-session with id doesn't exist" in excinfo.value.details()

def test_recover_training_failed(self):
class MockedExceptionTrainer:
Expand Down Expand Up @@ -439,25 +439,25 @@ def init(self, trainer_yaml_config: str = ""):

def test_graceful_shutdown_after_init(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)
grpc_stub.CloseTrainerSession(training_session_id)

def test_graceful_shutdown_after_start(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)
grpc_stub.Start(training_session_id)
grpc_stub.CloseTrainerSession(training_session_id)

def test_graceful_shutdown_after_pause(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)
grpc_stub.Start(training_session_id)
grpc_stub.Pause(training_session_id)
grpc_stub.CloseTrainerSession(training_session_id)

def test_graceful_shutdown_after_resume(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)
grpc_stub.Start(training_session_id)
grpc_stub.Pause(training_session_id)
grpc_stub.Resume(training_session_id)
Expand All @@ -466,7 +466,7 @@ def test_graceful_shutdown_after_resume(self, grpc_stub):
def test_close_trainer_session_twice(self, grpc_stub):
# Attempt to close the session twice
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)
grpc_stub.CloseTrainerSession(training_session_id)

# The second attempt should raise an error
Expand All @@ -476,7 +476,7 @@ def test_close_trainer_session_twice(self, grpc_stub):

def test_forward_while_running(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)

grpc_stub.Start(training_session_id)

Expand All @@ -487,7 +487,7 @@ def test_forward_while_running(self, grpc_stub):
data = np.random.rand(*shape).astype(np.float32)
xarray_data = xr.DataArray(data, dims=("b", "c", "z", "y", "x"))
pb_tensor = xarray_to_pb_tensor(tensor_id="", array=xarray_data)
predict_request = training_pb2.PredictRequest(sessionId=training_session_id, tensors=[pb_tensor])
predict_request = utils_pb2.PredictRequest(modelSessionId=training_session_id, tensors=[pb_tensor])

response = grpc_stub.Predict(predict_request)

Expand All @@ -502,7 +502,7 @@ def test_forward_while_running(self, grpc_stub):

def test_forward_while_paused(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)

grpc_stub.Start(training_session_id)

Expand All @@ -513,7 +513,7 @@ def test_forward_while_paused(self, grpc_stub):
data = np.random.rand(*shape).astype(np.float32)
xarray_data = xr.DataArray(data, dims=("b", "c", "z", "y", "x"))
pb_tensor = xarray_to_pb_tensor(tensor_id="", array=xarray_data)
predict_request = training_pb2.PredictRequest(sessionId=training_session_id, tensors=[pb_tensor])
predict_request = utils_pb2.PredictRequest(modelSessionId=training_session_id, tensors=[pb_tensor])

grpc_stub.Pause(training_session_id)

Expand All @@ -533,7 +533,7 @@ def test_close_session(self, grpc_stub):
Test closing a training session.
"""
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)
training_session_id = utils_pb2.ModelSession(id=init_response.id)
grpc_stub.CloseTrainerSession(training_session_id)

# attempt to perform an operation while session is closed
Expand Down
Loading

0 comments on commit 4798dbe

Please sign in to comment.