diff --git a/tests/test_server/test_grpc/test_training_servicer.py b/tests/test_server/test_grpc/test_training_servicer.py index f40a8fcc..2da060bb 100644 --- a/tests/test_server/test_grpc/test_training_servicer.py +++ b/tests/test_server/test_grpc/test_training_servicer.py @@ -8,8 +8,9 @@ import h5py import numpy as np import pytest +import xarray as xr -from tiktorch.converters import pb_state_to_trainer, trainer_state_to_pb +from tiktorch.converters import pb_state_to_trainer, pb_tensor_to_xarray, trainer_state_to_pb, xarray_to_pb_tensor from tiktorch.proto import training_pb2, training_pb2_grpc from tiktorch.server.device_pool import TorchDevicePool from tiktorch.server.grpc import training_servicer @@ -473,6 +474,28 @@ def test_close_trainer_session_twice(self, grpc_stub): grpc_stub.CloseTrainerSession(training_session_id) assert "Unknown session" in excinfo.value.details() + def test_forward(self, grpc_stub): + init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment())) + training_session_id = training_pb2.TrainingSessionId(id=init_response.id) + + grpc_stub.Start(training_session_id) + + in_channels_unet2d = 3 # dependent on the config + out_channels_unet2d = 2 + shape = (in_channels_unet2d, 128, 128) # c, y, x + data = np.random.rand(*shape).astype(np.float32) + xarray_data = xr.DataArray(data, dims=("c", "y", "x")) + pb_tensor = xarray_to_pb_tensor(tensor_id="", array=xarray_data) + predict_request = training_pb2.PredictRequest(sessionId=training_session_id, tensors=[pb_tensor]) + + response = grpc_stub.Predict(predict_request) + + predicted_tensors = [pb_tensor_to_xarray(pb_tensor) for pb_tensor in response.tensors] + assert len(predicted_tensors) == 1 + predicted_tensor = predicted_tensors[0] + assert predicted_tensor.dims == ("c", "y", "x") + assert predicted_tensor.shape == (out_channels_unet2d, 128, 128) + def test_close_session(self, grpc_stub): """ Test closing a training session. diff --git a/tiktorch/server/grpc/training_servicer.py b/tiktorch/server/grpc/training_servicer.py index 4dac6b18..14806d0c 100644 --- a/tiktorch/server/grpc/training_servicer.py +++ b/tiktorch/server/grpc/training_servicer.py @@ -5,9 +5,10 @@ from typing import Callable, List import grpc +import torch -from tiktorch.converters import trainer_state_to_pb -from tiktorch.proto import training_pb2, training_pb2_grpc +from tiktorch.converters import pb_tensor_to_numpy, trainer_state_to_pb +from tiktorch.proto import training_pb2, training_pb2_grpc, utils_pb2 from tiktorch.server.device_pool import IDevicePool from tiktorch.server.session.process import start_trainer_process from tiktorch.server.session.rpc_interface import IRPCTrainer @@ -47,7 +48,7 @@ def Init(self, request: training_pb2.TrainingConfig, context): return training_pb2.TrainingSessionId(id=session.id) - def Start(self, request, context): + def Start(self, request: training_pb2.TrainingSessionId, context): session = self._getTrainerSession(context, request.id) session.client.start_training() return training_pb2.Empty() @@ -63,17 +64,40 @@ def Pause(self, request: training_pb2.TrainingSessionId, context): return training_pb2.Empty() def Save(self, request: training_pb2.TrainingSessionId, context): - session = self._getTrainerSession(context, request.modelSessionId) + session = self._getTrainerSession(context, request.id) session.client.save() return training_pb2.Empty() def Export(self, request: training_pb2.TrainingSessionId, context): - session = self._getTrainerSession(context, request.modelSessionId) + session = self._getTrainerSession(context, request.id) session.client.export() return training_pb2.Empty() - def Predict(self, request: training_pb2.TrainingSessionId, context): - raise NotImplementedError + def Predict(self, request: training_pb2.PredictRequest, context): + session = self._getTrainerSession(context, request.sessionId.id) + tensors = [torch.tensor(pb_tensor_to_numpy(pb_tensor)) for pb_tensor in request.tensors] + self._check_tensors_shape(tensors) + predictions = session.client.forward(tensors).result() + return training_pb2.PredictResponse(tensors=[self._tensor_to_pb(prediction) for prediction in predictions]) + + def _tensor_to_pb(self, tensor: torch.Tensor): + self._check_tensors_shape([tensor]) + dims = ("c", "y", "x") if tensor.ndim == 3 else ("c", "z", "y", "x") + shape = [utils_pb2.NamedInt(size=dim, name=i) for i, dim in zip(dims, tensor.shape)] + + np_array = tensor.numpy() + + proto_tensor = utils_pb2.Tensor( + tensorId="", dtype=str(np_array.dtype), shape=shape, buffer=np_array.tobytes() # not used currently + ) + return proto_tensor + + def _check_tensors_shape(self, tensors: List[torch.Tensor]): + for tensor in tensors: + if tensor.ndim != 3 and tensor.ndim != 4: + raise ValueError( + f"Tensor dims should be 3 (c, y, x) or 4 (c, z, y, x) but got {tensor.ndim} dimensions" + ) def StreamUpdates(self, request: training_pb2.TrainingSessionId, context): raise NotImplementedError diff --git a/tiktorch/server/session/backend/base.py b/tiktorch/server/session/backend/base.py index 471c2d71..091c64c0 100644 --- a/tiktorch/server/session/backend/base.py +++ b/tiktorch/server/session/backend/base.py @@ -3,7 +3,9 @@ import logging from abc import ABC from concurrent.futures import Future +from typing import List +import torch from bioimageio.core import PredictionPipeline from tiktorch.configkeys import TRAINING, VALIDATION @@ -61,7 +63,7 @@ def __init__(self, trainer: Trainer): supervisor = TrainerSupervisor(trainer) super().__init__(supervisor) - def forward(self, input_tensors): + def forward(self, input_tensors: List[torch.Tensor]): res = Future() self._queue_tasks.send_command(commands.ForwardPass(res, input_tensors)) return res diff --git a/tiktorch/server/session/backend/supervisor.py b/tiktorch/server/session/backend/supervisor.py index fa04e599..a49ae19a 100644 --- a/tiktorch/server/session/backend/supervisor.py +++ b/tiktorch/server/session/backend/supervisor.py @@ -127,8 +127,9 @@ def shutdown(self): def forward(self, input_tensors): self.pause() - self._trainer.forward(input_tensors) + res = self._trainer.forward(input_tensors) self.resume() + return res def save(self): raise NotImplementedError diff --git a/tiktorch/server/session/process.py b/tiktorch/server/session/process.py index 3a1b949e..81aa6b82 100644 --- a/tiktorch/server/session/process.py +++ b/tiktorch/server/session/process.py @@ -7,6 +7,7 @@ from multiprocessing.connection import Connection from typing import List, Optional, Tuple, Type, TypeVar, Union +import torch from bioimageio.core import PredictionPipeline, Tensor, create_prediction_pipeline from bioimageio.spec import InvalidDescr, load_description from bioimageio.spec.model import v0_5 @@ -125,7 +126,7 @@ def init(self, trainer_yaml_config: str): trainer = parser.parse() self._worker = base.TrainerSessionBackend(trainer) - def forward(self, input_tensors) -> Future: + def forward(self, input_tensors: List[torch.Tensor]) -> Future: res = self.worker.forward(input_tensors) return res diff --git a/tiktorch/server/session/rpc_interface.py b/tiktorch/server/session/rpc_interface.py index db714cb5..36146b1e 100644 --- a/tiktorch/server/session/rpc_interface.py +++ b/tiktorch/server/session/rpc_interface.py @@ -1,5 +1,7 @@ from typing import List +import torch + from tiktorch.converters import Sample from tiktorch.rpc import RPCInterface, exposed from tiktorch.rpc.exceptions import Shutdown @@ -56,7 +58,7 @@ def init(self, trainer_yaml_config: str): raise NotImplementedError @exposed - def forward(self, input_tensors: Sample): + def forward(self, input_tensors: List[torch.Tensor]): raise NotImplementedError @exposed diff --git a/tiktorch/trainer.py b/tiktorch/trainer.py index 79b371f9..44eada34 100644 --- a/tiktorch/trainer.py +++ b/tiktorch/trainer.py @@ -96,6 +96,7 @@ class Trainer(UNetTrainer): def __init__( self, model, + device, optimizer, lr_scheduler, loss_criterion, @@ -138,6 +139,7 @@ def __init__( pre_trained=pre_trained, **kwargs, ) + self._device = device self.logs_callbacks: LogsCallbacks = BaseCallbacks() self.should_stop_callbacks: Callbacks = ShouldStopCallbacks() @@ -150,10 +152,13 @@ def train(self): def validate(self): return super().validate() - def forward(self, input_tensors): + def forward(self, input_tensors: List[torch.Tensor]): self.model.eval() with torch.no_grad(): - self.model(input_tensors) + batched_tensor = torch.stack(input_tensors, dim=0).to(self._device) + predictions = self.model(batched_tensor) + predictions_list = list(predictions.unbind(dim=0)) + return predictions_list def should_stop(self) -> bool: """ @@ -228,6 +233,7 @@ def parse(self) -> Trainer: pre_trained = trainer_config.pop("pre_trained", None) return Trainer( + device=config["device"], model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,