Skip to content

Commit

Permalink
various improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Mar 21, 2019
1 parent 0695dc3 commit 004a5c0
Show file tree
Hide file tree
Showing 14 changed files with 1,035 additions and 156 deletions.
Empty file added tests/__init__.py
Empty file.
48 changes: 48 additions & 0 deletions tests/handler/dummyserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import logging
import torch

from torch.multiprocessing import Pipe
from typing import Union, Tuple

from tiktorch.rpc_interface import INeuralNetworkAPI, IFlightControl
from tiktorch.handler import HandlerProcess
from tiktorch.handler.constants import SHUTDOWN, SHUTDOWN_ANSWER
from tiktorch.types import NDArray, NDArrayBatch

logger = logging.getLogger(__name__)

class DummyServer(INeuralNetworkAPI, IFlightControl):
def __init__(self, **kwargs):
self.handler_conn, server_conn = Pipe()
self.handler = HandlerProcess(server_conn=server_conn, **kwargs)
self.handler.start()

def forward(self, batch: NDArrayBatch) -> None:
self.handler_conn.send(
(
"forward",
{"keys": [a.id for a in batch], "data": torch.stack([torch.from_numpy(a.as_numpy()) for a in batch])},
)
)

def active_children(self):
self.handler_conn.send(("active_children", {}))

def listen(self, timeout: float = 10) -> Union[None, Tuple[str, dict]]:
if self.handler_conn.poll(timeout=timeout):
answer = self.handler_conn.recv()
logger.debug("got answer: %s", answer)
return answer
else:
return None

def shutdown(self):
self.handler_conn.send(SHUTDOWN)
got_shutdown_answer = False
while self.handler.is_alive():
if self.handler_conn.poll(timeout=2):
answer = self.handler_conn.recv()
if answer == SHUTDOWN_ANSWER:
got_shutdown_answer = True

assert got_shutdown_answer
44 changes: 44 additions & 0 deletions tests/handler/test_dryrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
import logging.config
from importlib import import_module
from functools import partial

from tests.handler.dummyserver import DummyServer
from tiktorch.handler.dryrun import DryRunProcess, in_subproc

# from tests.data.tiny_models import TinyConvNet2d
# model = TinyConvNet2d()

logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"handlers": {"default": {"level": "DEBUG", "class": "logging.StreamHandler", "stream": "ext://sys.stdout"}},
"loggers": {"": {"handlers": ["default"], "level": "DEBUG", "propagate": True}},
}
)


def test_minimal_device_test():
assert DryRunProcess.minimal_device_test(torch.device("cpu"))


def test_minimal_device_test_in_subproc():
ret = in_subproc(DryRunProcess.minimal_device_test, torch.device("cpu"))
assert ret.recv()


def test_confirm_training_shape(tiny_model_2d):
tiny_model_2d['config'].update({"training_shape": (15, 15)})
ts = DummyServer(**tiny_model_2d)
try:
ts.active_children()
assert ts.listen(timeout=10) is not None
ts.handler_conn.send(('set_devices', {"device_names": ["cpu"]}))
answer = ts.listen(timeout=10)
# todo: fix this test! Where is the answer?
assert answer is not None
except Exception:
raise
finally:
ts.shutdown()
59 changes: 2 additions & 57 deletions tests/handler/test_handler.py
Original file line number Diff line number Diff line change
@@ -1,74 +1,19 @@
import os
import logging
import numpy
import time
import torch

from tiktorch.rpc_interface import INeuralNetworkAPI, IFlightControl
from tiktorch.handler import HandlerProcess
from tiktorch.handler.constants import SHUTDOWN, SHUTDOWN_ANSWER
from tiktorch.types import NDArray, NDArrayBatch
from torch.multiprocessing import Pipe
from typing import Union, Tuple
from tests.handler.dummyserver import DummyServer

logger = logging.getLogger(__name__)


class DummyServer(INeuralNetworkAPI, IFlightControl):
def __init__(self, **kwargs):
self.handler_conn, server_conn = Pipe()
self.handler = HandlerProcess(server_conn=server_conn, **kwargs)
self.handler.start()

@property
def devices(self):
return self.handler.devices

@devices.setter
def devices(self, devices: list):
self.handler.devices = devices

def dry_run_on_device(self, device, upper_bound):
return self.handler.dry_run_on_device(device, upper_bound)

def forward(self, batch: NDArrayBatch) -> None:
self.handler_conn.send(
(
"forward",
{"keys": [a.id for a in batch], "data": torch.stack([torch.from_numpy(a.as_numpy()) for a in batch])},
)
)

def active_children(self):
self.handler_conn.send(("active_children", {}))

def listen(self, timeout: float = 10) -> Union[None, Tuple[str, dict]]:
if self.handler_conn.poll(timeout=timeout):
answer = self.handler_conn.recv()
logger.debug("got answer: %s", answer)
return answer
else:
return None

def shutdown(self):
self.handler_conn.send(SHUTDOWN)
got_shutdown_answer = False
while self.handler.is_alive():
if self.handler_conn.poll(timeout=2):
answer = self.handler_conn.recv()
if answer == SHUTDOWN_ANSWER:
got_shutdown_answer = True

assert got_shutdown_answer


def test_initialization(tiny_model):
ts = DummyServer(**tiny_model)
ts.active_children()
active_children = ts.listen()
ts.shutdown()
assert active_children is not None
assert len(active_children) == 2
assert len(active_children) in (2, 3)
assert "TrainingProcess" in active_children
assert "InferenceProcess" in active_children

Expand Down
2 changes: 1 addition & 1 deletion tiktorch/device_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
import numpy as np

from tiktorch.utils import DynamicShape, assert_, to_list, define_patched_model
from tiktorch.utils import assert_, define_patched_model
from tiktorch.blockinator import Blockinator, th_pad
from tiktorch.trainy import Trainer

Expand Down
2 changes: 1 addition & 1 deletion tiktorch/handler/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .base import HandlerProcess
from .handler import HandlerProcess
Loading

0 comments on commit 004a5c0

Please sign in to comment.