Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

update metaconfig #80

Merged
merged 1 commit into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 22 additions & 81 deletions energonai/cli/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,103 +7,44 @@
from energonai.context import mcfg


def launches(model_class=None,
model_type=None,
engine_server=None,
max_batch_size=32,
tp_init_size=1,
pp_init_size=1,
host="127.0.0.1",
port=29500,
half=True,
checkpoint=None,
tokenizer_path=None,
server_host="127.0.0.1",
server_port=8005,
log_level="critical",
backend="nccl",
rm_padding=False):
click.echo(f'*** Energon Init Configurations: *** \n'
f'Model Name: {model_class} \n'
f'Model Type: {model_type} \n'
f'Engine Server: {engine_server} \n'
f'Max Batch Size: {max_batch_size} \n'
f'Tensor Parallelism Size: {tp_init_size} \n'
f'Pipeline Parallelism Size: {pp_init_size} \n'
f'Communication Host: {host} \n'
f'Communication Port: {port} \n'
f'Is Half: {half} \n'
f'Checkpoint Path: {checkpoint} \n'
f'Tokenizer Path: {tokenizer_path}'
f'Worker Server Host: {server_host} \n'
f'Worker Server Port: {server_port} \n'
f'Unvicorn Log Level: {log_level} \n'
f'Remove padding: {rm_padding} \n')
@click.group()
def service():
pass

if half:
dtype = torch.half
else:
dtype = torch.float

world_size = tp_init_size * pp_init_size
@service.command()
@click.option("--config_file", type=str)
def init(config_file):

mcfg.load_config(config_file)

click.echo(f'*** Energon Init Configurations: ***')
for k in mcfg:
click.echo(f'{k}\t:\t{mcfg[k]}')

# prepare context for master and worker
world_size = mcfg['tp_init_size'] * mcfg['pp_init_size']
num_worker = world_size - 1

engine_port = server_port
worker_port = server_port + 1
worker_port = mcfg['server_port'] + 1
worker_rank = 1 # start from 1

# launch each worker
process_list = []
mp.set_start_method('spawn')
for i in range(num_worker):
p = mp.Process(target=server.launch_worker,
args=(host, port, tp_init_size, pp_init_size, "nccl", 1024, True, worker_rank + i, worker_rank + i,
server_host, worker_port + i, log_level))
args=(config_file, worker_rank + i, worker_rank + i, mcfg['server_host'], worker_port + i))
p.start()
process_list.append(p)

sig_server = inspect.signature(engine_server)
# launch the master
sig_server = inspect.signature(mcfg['engine_server'])
parameters_server = sig_server.parameters

cfg = {
'model_class': model_class,
'model_type': model_type,
'max_batch_size': max_batch_size,
'tp_init_size': tp_init_size,
'pp_init_size': pp_init_size,
'host': host,
'port': port,
'dtype': dtype,
'checkpoint': checkpoint,
'tokenizer_path': tokenizer_path,
'server_host': server_host,
'server_port': engine_port,
'log_level': log_level,
'rm_padding': rm_padding
}

argv = dict()
for name, _ in parameters_server.items():
if name in cfg:
argv[name] = cfg[name]

engine_server(**argv)


@click.group()
def service():
pass


@service.command()
@click.option("--config_file", type=str)
def init(config_file):
mcfg.load_config(config_file)

sig = inspect.signature(launches)
parameters = sig.parameters

argv = dict()
for name, _ in parameters.items():
if name in mcfg:
argv[name] = mcfg[name]
launches(**argv)

mcfg['engine_server'](**argv)
42 changes: 40 additions & 2 deletions energonai/context/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,34 @@

import inspect
import sys
import torch
from typing import Union
from importlib.machinery import SourceFileLoader
from pathlib import Path
from energonai.logging import get_dist_logger


nec_args = {
'model_class': None,
'model_type': None,
'max_batch_size': 32,
'tp_init_size': 1,
'pp_init_size': 1,
'host': "127.0.0.1",
'port': 29500,
'dtype': torch.float,
'checkpoint': None,
'tokenizer_path': None,
'server_host': "127.0.0.1",
'server_port': 8005,
'log_level': "critical",
'backend':"nccl",
'rm_padding': False,
'seed' : 1024,
'verbose' : True
}


class Config(dict):
"""This is a wrapper class for dict objects so that values of which can be
accessed as attributes.
Expand Down Expand Up @@ -118,7 +140,13 @@ def __iter__(self):
return self._config.__iter__()

def __getitem__(self, key):
return self._config[key]
if key in self._config.keys():
return self._config[key]
else:
return None

def __setitem__(self, key, value):
self._config[key] = value

def load_config(self, config: Union[dict, str]):
"""Loads the configuration from either a dict or a file.
Expand All @@ -134,5 +162,15 @@ def load_config(self, config: Union[dict, str]):
self._config = Config(config)
else:
raise TypeError("Invalid type for config, only dictionary or string is supported")

for k,v in nec_args.items():
if k not in self._config:
self._config[k] = v

if mcfg['half']:
mcfg['dtype'] = torch.half
else:
mcfg['dtype'] = torch.float

mcfg = MetaConfig()

mcfg = MetaConfig()
7 changes: 5 additions & 2 deletions energonai/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
from .rpc_worker import RPCWorker
from .pipeline_msg_dict import CircleInt

# for TP
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from energonai.initialize import launch_from_multiprocess
from colossalai.logging import get_dist_logger

from energonai.initialize import launch_from_multiprocess
from energonai.utils import ensure_directory_exists
from colossalai.logging import get_dist_logger



logger = get_dist_logger('energonai')

Expand Down
15 changes: 13 additions & 2 deletions energonai/engine/rpc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
import inspect
import torch.distributed.rpc as rpc
import sys

from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.logging import get_dist_logger

from .rpc_utils import remote_cls_method, sync_cls_method, async_cls_method
from .pipeline_wrapper import PipelineCommWrapper
from .vit_pipeline_wrapper import ViTPipelineCommWrapper
from colossalai.logging import get_dist_logger

# from torch2trt import torch2trt

logger = get_dist_logger('energonai')

Expand Down Expand Up @@ -48,6 +52,8 @@ def __init__(self, model_class, model_config, model_type, dtype, max_batch_size:
self.model = None # call the model
self.rank = gpc.get_local_rank(ParallelMode.GLOBAL)
torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}')

# self.trt_sample = None
self._init_self()
self.return_dict = ReturnDict()

Expand All @@ -58,8 +64,13 @@ def _init_self(self):
self.model = self.model_class(**self.model_config).cuda().half()
else:
self.model = self.model_class(**self.model_config).cuda()

self.model.eval()

# if trt_sample is not None and gpc.get_world_size(ParallelMode.MODEL) > 1:
# logger.error("Tensor Parallelism does not support TensorRT convert")
# elif trt_sample is not None and gpc.get_world_size(ParallelMode.MODEL) == 1:
# model = torch2trt(model, [self.trt_sample])

try:
self.model = pipe_wrapper[self.model_type](model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype)
Expand Down
24 changes: 11 additions & 13 deletions energonai/server/worker_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch.distributed.rpc as rpc
from energonai.initialize import launch_from_multiprocess
from colossalai.logging import get_dist_logger
from energonai.context import mcfg

logger = get_dist_logger('energonai')

Expand All @@ -21,23 +22,20 @@ async def shutdown():
await server.shutdown()


def launch_worker(host="127.0.0.1",
port=29500,
tp_init_size=1,
pp_init_size=1,
backend="nccl",
seed=1024,
verbose=True,
def launch_worker(config_file,
rank=0,
local_rank=0,
server_host="127.0.0.1",
server_port=8005,
log_level="info"):
server_port=8005):
mcfg.load_config(config_file)

world_size = mcfg['tp_init_size'] * mcfg['pp_init_size']

world_size = tp_init_size * pp_init_size
launch_from_multiprocess(mcfg['tp_init_size'], mcfg['pp_init_size'], mcfg['backend'],
mcfg['seed'], mcfg['verbose'], rank, local_rank, world_size,
mcfg['host'], mcfg['port'])

launch_from_multiprocess(tp_init_size, pp_init_size, backend, seed, verbose, rank, local_rank, world_size, host,
port)

WORKER_NAME = "wok{}"
rpc_backend_options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=16,
# _transports=["uv"] TODO: potentially a bug
Expand All @@ -47,6 +45,6 @@ def launch_worker(host="127.0.0.1",
logger.info(f'RPC STATUS: RPC of Rank: {rank} is initialized.')

global server
config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level)
config = uvicorn.Config(app, host=server_host, port=server_port, log_level=mcfg['log_level'])
server = uvicorn.Server(config=config)
server.run()
2 changes: 1 addition & 1 deletion examples/hf_gpt2/hf_gpt2_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def launch_engine(model_class,
model_config = {'dtype': dtype, 'checkpoint': True, 'checkpoint_path': checkpoint}
else:
model_config = {'dtype': dtype}

global engine
engine = InferenceEngine(model_class,
model_config,
Expand Down