Skip to content

Commit

Permalink
Add forward action to the training service
Browse files Browse the repository at this point in the history
  • Loading branch information
thodkatz committed Dec 11, 2024
1 parent 6d2a078 commit c4db94a
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 14 deletions.
25 changes: 24 additions & 1 deletion tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import h5py
import numpy as np
import pytest
import xarray as xr

from tiktorch.converters import pb_state_to_trainer, trainer_state_to_pb
from tiktorch.converters import pb_state_to_trainer, pb_tensor_to_xarray, trainer_state_to_pb, xarray_to_pb_tensor
from tiktorch.proto import training_pb2, training_pb2_grpc
from tiktorch.server.device_pool import TorchDevicePool
from tiktorch.server.grpc import training_servicer
Expand Down Expand Up @@ -473,6 +474,28 @@ def test_close_trainer_session_twice(self, grpc_stub):
grpc_stub.CloseTrainerSession(training_session_id)
assert "Unknown session" in excinfo.value.details()

def test_forward(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)

grpc_stub.Start(training_session_id)

in_channels_unet2d = 3 # dependent on the config
out_channels_unet2d = 2
shape = (in_channels_unet2d, 128, 128) # c, y, x
data = np.random.rand(*shape).astype(np.float32)
xarray_data = xr.DataArray(data, dims=("c", "y", "x"))
pb_tensor = xarray_to_pb_tensor(tensor_id="", array=xarray_data)
predict_request = training_pb2.PredictRequest(sessionId=training_session_id, tensors=[pb_tensor])

response = grpc_stub.Predict(predict_request)

predicted_tensors = [pb_tensor_to_xarray(pb_tensor) for pb_tensor in response.tensors]
assert len(predicted_tensors) == 1
predicted_tensor = predicted_tensors[0]
assert predicted_tensor.dims == ("c", "y", "x")
assert predicted_tensor.shape == (out_channels_unet2d, 128, 128)

def test_close_session(self, grpc_stub):
"""
Test closing a training session.
Expand Down
38 changes: 31 additions & 7 deletions tiktorch/server/grpc/training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from typing import Callable, List

import grpc
import torch

from tiktorch.converters import trainer_state_to_pb
from tiktorch.proto import training_pb2, training_pb2_grpc
from tiktorch.converters import pb_tensor_to_numpy, trainer_state_to_pb
from tiktorch.proto import training_pb2, training_pb2_grpc, utils_pb2
from tiktorch.server.device_pool import IDevicePool
from tiktorch.server.session.process import start_trainer_process
from tiktorch.server.session.rpc_interface import IRPCTrainer
Expand Down Expand Up @@ -47,7 +48,7 @@ def Init(self, request: training_pb2.TrainingConfig, context):

return training_pb2.TrainingSessionId(id=session.id)

def Start(self, request, context):
def Start(self, request: training_pb2.TrainingSessionId, context):
session = self._getTrainerSession(context, request.id)
session.client.start_training()
return training_pb2.Empty()
Expand All @@ -63,17 +64,40 @@ def Pause(self, request: training_pb2.TrainingSessionId, context):
return training_pb2.Empty()

def Save(self, request: training_pb2.TrainingSessionId, context):
session = self._getTrainerSession(context, request.modelSessionId)
session = self._getTrainerSession(context, request.id)
session.client.save()
return training_pb2.Empty()

def Export(self, request: training_pb2.TrainingSessionId, context):
session = self._getTrainerSession(context, request.modelSessionId)
session = self._getTrainerSession(context, request.id)
session.client.export()
return training_pb2.Empty()

def Predict(self, request: training_pb2.TrainingSessionId, context):
raise NotImplementedError
def Predict(self, request: training_pb2.PredictRequest, context):
session = self._getTrainerSession(context, request.sessionId.id)
tensors = [torch.tensor(pb_tensor_to_numpy(pb_tensor)) for pb_tensor in request.tensors]
self._check_tensors_shape(tensors)
predictions = session.client.forward(tensors).result()
return training_pb2.PredictResponse(tensors=[self._tensor_to_pb(prediction) for prediction in predictions])

def _tensor_to_pb(self, tensor: torch.Tensor):
self._check_tensors_shape([tensor])
dims = ("c", "y", "x") if tensor.ndim == 3 else ("c", "z", "y", "x")
shape = [utils_pb2.NamedInt(size=dim, name=i) for i, dim in zip(dims, tensor.shape)]

np_array = tensor.numpy()

proto_tensor = utils_pb2.Tensor(
tensorId="", dtype=str(np_array.dtype), shape=shape, buffer=np_array.tobytes() # not used currently
)
return proto_tensor

def _check_tensors_shape(self, tensors: List[torch.Tensor]):
for tensor in tensors:
if tensor.ndim != 3 and tensor.ndim != 4:
raise ValueError(
f"Tensor dims should be 3 (c, y, x) or 4 (c, z, y, x) but got {tensor.ndim} dimensions"
)

def StreamUpdates(self, request: training_pb2.TrainingSessionId, context):
raise NotImplementedError
Expand Down
4 changes: 3 additions & 1 deletion tiktorch/server/session/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import logging
from abc import ABC
from concurrent.futures import Future
from typing import List

import torch
from bioimageio.core import PredictionPipeline

from tiktorch.configkeys import TRAINING, VALIDATION
Expand Down Expand Up @@ -61,7 +63,7 @@ def __init__(self, trainer: Trainer):
supervisor = TrainerSupervisor(trainer)
super().__init__(supervisor)

def forward(self, input_tensors):
def forward(self, input_tensors: List[torch.Tensor]):
res = Future()
self._queue_tasks.send_command(commands.ForwardPass(res, input_tensors))
return res
Expand Down
3 changes: 2 additions & 1 deletion tiktorch/server/session/backend/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ def shutdown(self):

def forward(self, input_tensors):
self.pause()
self._trainer.forward(input_tensors)
res = self._trainer.forward(input_tensors)
self.resume()
return res

def save(self):
raise NotImplementedError
Expand Down
3 changes: 2 additions & 1 deletion tiktorch/server/session/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from multiprocessing.connection import Connection
from typing import List, Optional, Tuple, Type, TypeVar, Union

import torch
from bioimageio.core import PredictionPipeline, Tensor, create_prediction_pipeline
from bioimageio.spec import InvalidDescr, load_description
from bioimageio.spec.model import v0_5
Expand Down Expand Up @@ -125,7 +126,7 @@ def init(self, trainer_yaml_config: str):
trainer = parser.parse()
self._worker = base.TrainerSessionBackend(trainer)

def forward(self, input_tensors) -> Future:
def forward(self, input_tensors: List[torch.Tensor]) -> Future:
res = self.worker.forward(input_tensors)
return res

Expand Down
4 changes: 3 additions & 1 deletion tiktorch/server/session/rpc_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import List

import torch

from tiktorch.converters import Sample
from tiktorch.rpc import RPCInterface, exposed
from tiktorch.rpc.exceptions import Shutdown
Expand Down Expand Up @@ -56,7 +58,7 @@ def init(self, trainer_yaml_config: str):
raise NotImplementedError

@exposed
def forward(self, input_tensors: Sample):
def forward(self, input_tensors: List[torch.Tensor]):
raise NotImplementedError

@exposed
Expand Down
10 changes: 8 additions & 2 deletions tiktorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class Trainer(UNetTrainer):
def __init__(
self,
model,
device,
optimizer,
lr_scheduler,
loss_criterion,
Expand Down Expand Up @@ -138,6 +139,7 @@ def __init__(
pre_trained=pre_trained,
**kwargs,
)
self._device = device
self.logs_callbacks: LogsCallbacks = BaseCallbacks()
self.should_stop_callbacks: Callbacks = ShouldStopCallbacks()

Expand All @@ -150,10 +152,13 @@ def train(self):
def validate(self):
return super().validate()

def forward(self, input_tensors):
def forward(self, input_tensors: List[torch.Tensor]):
self.model.eval()
with torch.no_grad():
self.model(input_tensors)
batched_tensor = torch.stack(input_tensors, dim=0).to(self._device)
predictions = self.model(batched_tensor)
predictions_list = list(predictions.unbind(dim=0))
return predictions_list

def should_stop(self) -> bool:
"""
Expand Down Expand Up @@ -228,6 +233,7 @@ def parse(self) -> Trainer:
pre_trained = trainer_config.pop("pre_trained", None)

return Trainer(
device=config["device"],
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
Expand Down

0 comments on commit c4db94a

Please sign in to comment.