Skip to content

Commit

Permalink
Add save and export to the training service
Browse files Browse the repository at this point in the history
  • Loading branch information
thodkatz committed Dec 21, 2024
1 parent 4798dbe commit 146305e
Show file tree
Hide file tree
Showing 10 changed files with 357 additions and 51 deletions.
14 changes: 13 additions & 1 deletion proto/training.proto
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ service Training {

rpc GetLogs(ModelSession) returns (GetLogsResponse) {}

rpc Export(ModelSession) returns (Empty) {}
rpc Save(SaveRequest) returns (Empty) {}

rpc Export(ExportRequest) returns (Empty) {}

rpc Predict(PredictRequest) returns (PredictResponse) {}

Expand Down Expand Up @@ -58,6 +60,16 @@ message GetLogsResponse {
repeated Logs logs = 1;
}

message SaveRequest {
ModelSession modelSessionId = 1;
string filePath = 2;
}


message ExportRequest {
ModelSession modelSessionId = 1;
string filePath = 2;
}

message ValidationResponse {
double validation_score_average = 1;
Expand Down
69 changes: 60 additions & 9 deletions tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import threading
import time
from pathlib import Path
from typing import Callable
from typing import Callable, Optional

import grpc
import h5py
Expand Down Expand Up @@ -41,8 +41,11 @@ def grpc_stub_cls():
return training_pb2_grpc.TrainingStub


def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: str = "cpu"):
return f"""
def unet2d_config_path(
checkpoint_dir: Path, train_data_dir: str, val_data_path: str, resume: Optional[str] = None, device: str = "cpu"
):
# todo: upsampling makes model torchscript incompatible
base = f"""
device: {device} # Use CPU for faster test execution, change to 'cuda' if GPU is available and necessary
model:
name: UNet2D
Expand All @@ -53,13 +56,14 @@ def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: st
num_groups: 4
final_sigmoid: false
is_segmentation: true
upsample: default
trainer:
checkpoint_dir: {checkpoint_dir}
resume: null
validate_after_iters: 2
validate_after_iters: 250
log_after_iters: 2
max_num_epochs: 1000
max_num_iterations: 10000
max_num_epochs: 10000
max_num_iterations: 100000
eval_score_higher_is_better: True
optimizer:
learning_rate: 0.0002
Expand Down Expand Up @@ -149,6 +153,9 @@ def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: st
- name: ToTensor
expand_dims: false
"""
if resume:
return f"resume: {resume}{base}"
return base


def create_random_dataset(shape, channel_per_class):
Expand All @@ -171,15 +178,22 @@ def create_random_dataset(shape, channel_per_class):
return tmp.name


def prepare_unet2d_test_environment(device: str = "cpu") -> str:
def prepare_unet2d_test_environment(resume: Optional[str] = None, device: str = "cpu") -> str:
checkpoint_dir = Path(tempfile.mkdtemp())

shape = (3, 1, 128, 128)
in_channel = 3
z = 1 # 2d
y = 128
x = 128
shape = (in_channel, z, y, x)
binary_loss = False
train = create_random_dataset(shape, binary_loss)
val = create_random_dataset(shape, binary_loss)

return unet2d_config_path(checkpoint_dir=checkpoint_dir, train_data_dir=train, val_data_path=val, device=device)
config = unet2d_config_path(
resume=resume, checkpoint_dir=checkpoint_dir, train_data_dir=train, val_data_path=val, device=device
)
return config


class TestTrainingServicer:
Expand Down Expand Up @@ -528,6 +542,43 @@ def test_forward_while_paused(self, grpc_stub):
assert predicted_tensor.dims == ("b", "c", "z", "y", "x")
assert predicted_tensor.shape == (batch, out_channels_unet2d, 1, 128, 128)

def test_save(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = utils_pb2.ModelSession(id=init_response.id)

grpc_stub.Start(training_session_id)

with tempfile.TemporaryDirectory() as model_checkpoint_dir:
model_checkpoint_file = Path(model_checkpoint_dir) / "model.pth"
save_request = training_pb2.SaveRequest(modelSessionId=training_session_id, filePath=str(model_checkpoint_file))
grpc_stub.Save(save_request)
assert model_checkpoint_file.exists()

# assume stopping training to release devices
grpc_stub.CloseTrainerSession(training_session_id)

# attempt to init a new model with the new checkpoint and start training
init_response = grpc_stub.Init(
training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment(resume=model_checkpoint_file))
)
training_session_id = utils_pb2.ModelSession(id=init_response.id)
grpc_stub.Start(training_session_id)

def test_export(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = utils_pb2.ModelSession(id=init_response.id)

grpc_stub.Start(training_session_id)

with tempfile.TemporaryDirectory() as model_checkpoint_dir:
model_export_file = Path(model_checkpoint_dir) / "bioimageio.zip"
export_request = training_pb2.ExportRequest(modelSessionId=training_session_id, filePath=str(model_export_file))
grpc_stub.Export(export_request)
assert model_export_file.exists()

# assume stopping training since model is exported
grpc_stub.CloseTrainerSession(training_session_id)

def test_close_session(self, grpc_stub):
"""
Test closing a training session.
Expand Down
30 changes: 17 additions & 13 deletions tiktorch/proto/training_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 36 additions & 3 deletions tiktorch/proto/training_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ def __init__(self, channel):
request_serializer=utils__pb2.ModelSession.SerializeToString,
response_deserializer=training__pb2.GetLogsResponse.FromString,
)
self.Save = channel.unary_unary(
'/training.Training/Save',
request_serializer=training__pb2.SaveRequest.SerializeToString,
response_deserializer=utils__pb2.Empty.FromString,
)
self.Export = channel.unary_unary(
'/training.Training/Export',
request_serializer=utils__pb2.ModelSession.SerializeToString,
request_serializer=training__pb2.ExportRequest.SerializeToString,
response_deserializer=utils__pb2.Empty.FromString,
)
self.Predict = channel.unary_unary(
Expand Down Expand Up @@ -117,6 +122,12 @@ def GetLogs(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def Save(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def Export(self, request, context):
"""Missing associated documentation comment in .proto file."""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
Expand Down Expand Up @@ -179,9 +190,14 @@ def add_TrainingServicer_to_server(servicer, server):
request_deserializer=utils__pb2.ModelSession.FromString,
response_serializer=training__pb2.GetLogsResponse.SerializeToString,
),
'Save': grpc.unary_unary_rpc_method_handler(
servicer.Save,
request_deserializer=training__pb2.SaveRequest.FromString,
response_serializer=utils__pb2.Empty.SerializeToString,
),
'Export': grpc.unary_unary_rpc_method_handler(
servicer.Export,
request_deserializer=utils__pb2.ModelSession.FromString,
request_deserializer=training__pb2.ExportRequest.FromString,
response_serializer=utils__pb2.Empty.SerializeToString,
),
'Predict': grpc.unary_unary_rpc_method_handler(
Expand Down Expand Up @@ -328,6 +344,23 @@ def GetLogs(request,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def Save(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/training.Training/Save',
training__pb2.SaveRequest.SerializeToString,
utils__pb2.Empty.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def Export(request,
target,
Expand All @@ -340,7 +373,7 @@ def Export(request,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/training.Training/Export',
utils__pb2.ModelSession.SerializeToString,
training__pb2.ExportRequest.SerializeToString,
utils__pb2.Empty.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
Expand Down
13 changes: 9 additions & 4 deletions tiktorch/server/session/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from abc import ABC
from concurrent.futures import Future
from pathlib import Path
from typing import List

import torch
Expand Down Expand Up @@ -83,11 +84,15 @@ def start_training(self) -> None:
self._queue_tasks.send_command(start_cmd.awaitable)
start_cmd.awaitable.wait()

def save(self) -> None:
raise NotImplementedError
def save(self, file_path: Path) -> None:
save_cmd = commands.SaveTrainingCmd(file_path)
self._queue_tasks.send_command(save_cmd.awaitable)
save_cmd.awaitable.wait()

def export(self) -> None:
raise NotImplementedError
def export(self, file_path: Path) -> None:
export_cmd = commands.ExportTrainingCmd(file_path)
self._queue_tasks.send_command(export_cmd.awaitable)
export_cmd.awaitable.wait()

def get_state(self) -> TrainerState:
return self._supervisor.get_state()
19 changes: 19 additions & 0 deletions tiktorch/server/session/backend/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import threading
import typing
from dataclasses import dataclass, field
from pathlib import Path
from typing import Generic, Type, TypeVar

from tiktorch.trainer import TrainerAction, TrainerState
Expand Down Expand Up @@ -131,6 +132,24 @@ def execute(self, ctx: Context) -> None:
pass


class ExportTrainingCmd(ICommand):
def __init__(self, file_path: Path):
super().__init__()
self._file_path = file_path

def execute(self, ctx: Context[TrainerSupervisor]) -> None:
ctx.session.export(self._file_path)


class SaveTrainingCmd(ICommand):
def __init__(self, file_path: Path):
super().__init__()
self._file_path = file_path

def execute(self, ctx: Context[TrainerSupervisor]) -> None:
ctx.session.save(self._file_path)


class ShutdownWithTeardownCmd(ShutdownCmd):
def execute(self, ctx: Context[Supervisors]) -> None:
ctx.session.shutdown()
Expand Down
12 changes: 8 additions & 4 deletions tiktorch/server/session/backend/supervisor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import threading
from pathlib import Path
from typing import Generic, Set, TypeVar, Union

from bioimageio.core import PredictionPipeline, Sample
Expand Down Expand Up @@ -134,11 +135,14 @@ def forward(self, input_tensors):
self.resume()
return res

def save(self):
raise NotImplementedError
def save(self, file_path: Path):
self.pause()
self._trainer.save_state_dict(file_path)
self.resume()

def export(self):
raise NotImplementedError
def export(self, file_path: Path):
self.pause()
self._trainer.export(file_path)

def _should_stop(self):
return self._pause_triggered
Expand Down
12 changes: 6 additions & 6 deletions tiktorch/server/session/process.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
import multiprocessing as _mp
import pathlib
import tempfile
import uuid
from concurrent.futures import Future
from multiprocessing.connection import Connection
from pathlib import Path
from typing import List, Optional, Tuple, Type, TypeVar, Union

import torch
Expand Down Expand Up @@ -139,11 +139,11 @@ def start_training(self):
def pause_training(self):
self.worker.pause_training()

def save(self):
self.worker.save()
def save(self, file_path: Path):
self.worker.save(file_path)

def export(self):
self.worker.export()
def export(self, file_path: Path):
self.worker.export(file_path)

def get_state(self):
return self.worker.get_state()
Expand Down Expand Up @@ -210,7 +210,7 @@ def _get_prediction_pipeline_from_model_bytes(model_bytes: bytes, devices: List[
def _get_model_descr_from_model_bytes(model_bytes: bytes) -> v0_5.ModelDescr:
with tempfile.NamedTemporaryFile(suffix=".zip", delete=False) as _tmp_file:
_tmp_file.write(model_bytes)
temp_file_path = pathlib.Path(_tmp_file.name)
temp_file_path = Path(_tmp_file.name)
model_descr = load_description(temp_file_path, format_version="latest")
if isinstance(model_descr, InvalidDescr):
raise ValueError(f"Failed to load valid model descriptor {model_descr.validation_summary}")
Expand Down
Loading

0 comments on commit 146305e

Please sign in to comment.