Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Merge pull request #80 from hpcaitech/feature/trt
Browse files Browse the repository at this point in the history
update metaconfig
  • Loading branch information
MaruyamaAya authored May 26, 2022
2 parents 107e48f + 2db112b commit e41fcc6
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 101 deletions.
103 changes: 22 additions & 81 deletions energonai/cli/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,103 +7,44 @@
from energonai.context import mcfg


def launches(model_class=None,
model_type=None,
engine_server=None,
max_batch_size=32,
tp_init_size=1,
pp_init_size=1,
host="127.0.0.1",
port=29500,
half=True,
checkpoint=None,
tokenizer_path=None,
server_host="127.0.0.1",
server_port=8005,
log_level="critical",
backend="nccl",
rm_padding=False):
click.echo(f'*** Energon Init Configurations: *** \n'
f'Model Name: {model_class} \n'
f'Model Type: {model_type} \n'
f'Engine Server: {engine_server} \n'
f'Max Batch Size: {max_batch_size} \n'
f'Tensor Parallelism Size: {tp_init_size} \n'
f'Pipeline Parallelism Size: {pp_init_size} \n'
f'Communication Host: {host} \n'
f'Communication Port: {port} \n'
f'Is Half: {half} \n'
f'Checkpoint Path: {checkpoint} \n'
f'Tokenizer Path: {tokenizer_path}'
f'Worker Server Host: {server_host} \n'
f'Worker Server Port: {server_port} \n'
f'Unvicorn Log Level: {log_level} \n'
f'Remove padding: {rm_padding} \n')
@click.group()
def service():
pass

if half:
dtype = torch.half
else:
dtype = torch.float

world_size = tp_init_size * pp_init_size
@service.command()
@click.option("--config_file", type=str)
def init(config_file):

mcfg.load_config(config_file)

click.echo(f'*** Energon Init Configurations: ***')
for k in mcfg:
click.echo(f'{k}\t:\t{mcfg[k]}')

# prepare context for master and worker
world_size = mcfg['tp_init_size'] * mcfg['pp_init_size']
num_worker = world_size - 1

engine_port = server_port
worker_port = server_port + 1
worker_port = mcfg['server_port'] + 1
worker_rank = 1 # start from 1

# launch each worker
process_list = []
mp.set_start_method('spawn')
for i in range(num_worker):
p = mp.Process(target=server.launch_worker,
args=(host, port, tp_init_size, pp_init_size, "nccl", 1024, True, worker_rank + i, worker_rank + i,
server_host, worker_port + i, log_level))
args=(config_file, worker_rank + i, worker_rank + i, mcfg['server_host'], worker_port + i))
p.start()
process_list.append(p)

sig_server = inspect.signature(engine_server)
# launch the master
sig_server = inspect.signature(mcfg['engine_server'])
parameters_server = sig_server.parameters

cfg = {
'model_class': model_class,
'model_type': model_type,
'max_batch_size': max_batch_size,
'tp_init_size': tp_init_size,
'pp_init_size': pp_init_size,
'host': host,
'port': port,
'dtype': dtype,
'checkpoint': checkpoint,
'tokenizer_path': tokenizer_path,
'server_host': server_host,
'server_port': engine_port,
'log_level': log_level,
'rm_padding': rm_padding
}

argv = dict()
for name, _ in parameters_server.items():
if name in cfg:
argv[name] = cfg[name]

engine_server(**argv)


@click.group()
def service():
pass


@service.command()
@click.option("--config_file", type=str)
def init(config_file):
mcfg.load_config(config_file)

sig = inspect.signature(launches)
parameters = sig.parameters

argv = dict()
for name, _ in parameters.items():
if name in mcfg:
argv[name] = mcfg[name]
launches(**argv)

mcfg['engine_server'](**argv)
42 changes: 40 additions & 2 deletions energonai/context/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,34 @@

import inspect
import sys
import torch
from typing import Union
from importlib.machinery import SourceFileLoader
from pathlib import Path
from energonai.logging import get_dist_logger


nec_args = {
'model_class': None,
'model_type': None,
'max_batch_size': 32,
'tp_init_size': 1,
'pp_init_size': 1,
'host': "127.0.0.1",
'port': 29500,
'dtype': torch.float,
'checkpoint': None,
'tokenizer_path': None,
'server_host': "127.0.0.1",
'server_port': 8005,
'log_level': "critical",
'backend':"nccl",
'rm_padding': False,
'seed' : 1024,
'verbose' : True
}


class Config(dict):
"""This is a wrapper class for dict objects so that values of which can be
accessed as attributes.
Expand Down Expand Up @@ -118,7 +140,13 @@ def __iter__(self):
return self._config.__iter__()

def __getitem__(self, key):
return self._config[key]
if key in self._config.keys():
return self._config[key]
else:
return None

def __setitem__(self, key, value):
self._config[key] = value

def load_config(self, config: Union[dict, str]):
"""Loads the configuration from either a dict or a file.
Expand All @@ -134,5 +162,15 @@ def load_config(self, config: Union[dict, str]):
self._config = Config(config)
else:
raise TypeError("Invalid type for config, only dictionary or string is supported")

for k,v in nec_args.items():
if k not in self._config:
self._config[k] = v

if mcfg['half']:
mcfg['dtype'] = torch.half
else:
mcfg['dtype'] = torch.float

mcfg = MetaConfig()

mcfg = MetaConfig()
7 changes: 5 additions & 2 deletions energonai/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
from .rpc_worker import RPCWorker
from .pipeline_msg_dict import CircleInt

# for TP
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from energonai.initialize import launch_from_multiprocess
from colossalai.logging import get_dist_logger

from energonai.initialize import launch_from_multiprocess
from energonai.utils import ensure_directory_exists
from colossalai.logging import get_dist_logger



logger = get_dist_logger('energonai')

Expand Down
15 changes: 13 additions & 2 deletions energonai/engine/rpc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
import inspect
import torch.distributed.rpc as rpc
import sys

from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.logging import get_dist_logger

from .rpc_utils import remote_cls_method, sync_cls_method, async_cls_method
from .pipeline_wrapper import PipelineCommWrapper
from .vit_pipeline_wrapper import ViTPipelineCommWrapper
from colossalai.logging import get_dist_logger

# from torch2trt import torch2trt

logger = get_dist_logger('energonai')

Expand Down Expand Up @@ -48,6 +52,8 @@ def __init__(self, model_class, model_config, model_type, dtype, max_batch_size:
self.model = None # call the model
self.rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}')

# self.trt_sample = None
self._init_self()
self.return_dict = ReturnDict()

Expand All @@ -58,8 +64,13 @@ def _init_self(self):
self.model = self.model_class(**self.model_config).cuda().half()
else:
self.model = self.model_class(**self.model_config).cuda()

self.model.eval()

# if trt_sample is not None and gpc.get_world_size(ParallelMode.MODEL) > 1:
# logger.error("Tensor Parallelism does not support TensorRT convert")
# elif trt_sample is not None and gpc.get_world_size(ParallelMode.MODEL) == 1:
# model = torch2trt(model, [self.trt_sample])

try:
self.model = pipe_wrapper[self.model_type](model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype)
Expand Down
24 changes: 11 additions & 13 deletions energonai/server/worker_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.distributed.rpc as rpc
from energonai.initialize import launch_from_multiprocess
from colossalai.logging import get_dist_logger
from energonai.context import mcfg

logger = get_dist_logger('energonai')

Expand All @@ -21,23 +22,20 @@ async def shutdown():
await server.shutdown()


def launch_worker(host="127.0.0.1",
port=29500,
tp_init_size=1,
pp_init_size=1,
backend="nccl",
seed=1024,
verbose=True,
def launch_worker(config_file,
rank=0,
local_rank=0,
server_host="127.0.0.1",
server_port=8005,
log_level="info"):
server_port=8005):
mcfg.load_config(config_file)

world_size = mcfg['tp_init_size'] * mcfg['pp_init_size']

world_size = tp_init_size * pp_init_size
launch_from_multiprocess(mcfg['tp_init_size'], mcfg['pp_init_size'], mcfg['backend'],
mcfg['seed'], mcfg['verbose'], rank, local_rank, world_size,
mcfg['host'], mcfg['port'])

launch_from_multiprocess(tp_init_size, pp_init_size, backend, seed, verbose, rank, local_rank, world_size, host,
port)

WORKER_NAME = "wok{}"
rpc_backend_options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=16,
# _transports=["uv"] TODO: potentially a bug
Expand All @@ -47,6 +45,6 @@ def launch_worker(host="127.0.0.1",
logger.info(f'RPC STATUS: RPC of Rank: {rank} is initialized.')

global server
config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level)
config = uvicorn.Config(app, host=server_host, port=server_port, log_level=mcfg['log_level'])
server = uvicorn.Server(config=config)
server.run()
2 changes: 1 addition & 1 deletion examples/hf_gpt2/hf_gpt2_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def launch_engine(model_class,
model_config = {'dtype': dtype, 'checkpoint': True, 'checkpoint_path': checkpoint}
else:
model_config = {'dtype': dtype}

global engine
engine = InferenceEngine(model_class,
model_config,
Expand Down

0 comments on commit e41fcc6

Please sign in to comment.