Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

vit example #73

Merged
merged 1 commit into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Here FasterTransformer is adopted in comparison and it does not support the redu
#### Batching
Here FIFO batching is selected in comparison.
<div align="center">
<img src="https://user-images.githubusercontent.com/12018307/169729579-8735c905-30ed-44f9-af4e-275e021f4266.png" width = "600" height = "300" alt="Architecture" align=center />
<img src="https://user-images.githubusercontent.com/12018307/169729579-8735c905-30ed-44f9-af4e-275e021f4266.png" width = "400" height = "130" alt="Architecture" align=center />
</div>

### Contributing
Expand Down
7 changes: 5 additions & 2 deletions energon/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from energon.initialize import launch_from_multiprocess

from energon.utils import ensure_directory_exists
from energon.logging import get_dist_logger
from colossalai.logging import get_dist_logger

logger = get_dist_logger('energon')


class InferenceEngine(Module):
Expand Down Expand Up @@ -78,10 +80,11 @@ def _init_dist_rpc(self):
rank=0,
world_size=self.global_world_size,
rpc_backend_options=rpc_backend_options)
logger.info(f'RPC STATUS: RPC of Rank: 0 is initialized.')

def _init_model(self):
for i in range(self.global_world_size):
print(f'[INFO] rank{self.rank} calls rank{i} to init.')
logger.info(f'RPC STATUS: rank: {self.rank} calls rank: {i} to init model.')
ob_info = rpc.get_worker_info(self.WORKER_NAME.format(i))
self.rrefs.append(
rpc.remote(ob_info,
Expand Down
25 changes: 18 additions & 7 deletions energon/engine/rpc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@
from colossalai.context import ParallelMode
from .rpc_utils import remote_cls_method, sync_cls_method, async_cls_method
from .pipeline_wrapper import PipelineCommWrapper
from .vit_pipeline_wrapper import ViTPipelineCommWrapper
from colossalai.logging import get_dist_logger

logger = get_dist_logger('energon')

pipe_wrapper = {
'vit': ViTPipelineCommWrapper,
'bert': PipelineCommWrapper,
'gpt': PipelineCommWrapper
}


class ReturnDict:
Expand All @@ -32,30 +42,31 @@ def __init__(self, model_class, model_config, model_type, dtype, max_batch_size:
self.model_config = model_config
self.dtype = dtype
self.max_batch_size = max_batch_size
self.model_type = model_type

self.WORKER_NAME = "wok{}"
self.model = None # call the model
self.rank = gpc.get_local_rank(ParallelMode.GLOBAL)

torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}')
self._init_self()

self.return_dict = ReturnDict()

def _init_self(self):
print("[INFO] init model in rank {}".format(self.rank))
logger.info("[INFO] init model in rank {}".format(self.rank))

if self.dtype == torch.half:
self.model = self.model_class(**self.model_config).cuda().half()
else:
self.model = self.model_class(**self.model_config).cuda()
# print("Pass")

self.model.eval()

self.model = PipelineCommWrapper(model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype)

try:
self.model = pipe_wrapper[self.model_type](model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype)
except:
logger.error(f'Only {pipe_wrapper.keys()} pipeline wrapper are supported.')

def run(self, key, inputs):
# print("key: {}".format(key), flush=True)
torch.cuda.set_device(f'cuda:{gpc.get_local_rank(ParallelMode.GLOBAL)}')
for k, v in inputs.items():
if v is not None:
Expand Down
123 changes: 123 additions & 0 deletions energon/engine/vit_pipeline_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import inspect
import threading

import torch
import torch.nn as nn
import torch.distributed as dist
from typing import List, Tuple, Union

from energon.communication import send_forward, recv_forward, send_tensor_meta, recv_tensor_meta
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc

from .pipeline_meta import PipelineMeta
from .pipeline_msg_dict import PipelineMsgDict, CircleInt # PipelineMsgPriorityQueue


# The Wrapper is only for Transformer Model.
class ViTPipelineCommWrapper:

def __init__(self, model: nn.Module, max_batch_size: int = 1, dtype=torch.float) -> None:
self.model = model
self.dtype = dtype

self.tensor_dim = 0
self.hidden_shape = 0
self.max_batch_size = max_batch_size

if gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
img = torch.rand((max_batch_size,3,224,224), dtype = dtype).cuda()
sample = dict(img=img)
self._init_tensor_meta(sample)

self.pipe_msg_queue = PipelineMsgDict()
self.lock = threading.Lock()
self.key = CircleInt()

def _init_tensor_meta(self, sample):

with torch.inference_mode():
recv_tensor_shape = None
if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(x=sample['img'])
send_tensor_meta(output)
send_forward(output)
self.tensor_dim = output.dim()
self.hidden_shape = output.size()
elif gpc.is_last_rank(ParallelMode.PIPELINE):
recv_tensor_shape = recv_tensor_meta(recv_tensor_shape)
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_shape = input_tensor.size()
else:
recv_tensor_shape = recv_tensor_meta(recv_tensor_shape)
input_tensor = recv_forward(recv_tensor_shape, dtype=self.dtype) # only a tensor now
self.tensor_dim = input_tensor.dim()
self.hidden_shape = input_tensor.size()
output = self.model(x=input_tensor)
send_tensor_meta(output)
send_forward(output)

def run(self, key, inputs):
if gpc.is_initialized(ParallelMode.PIPELINE):
return self.run_with_pp(key, inputs)
else:
return self.run_without_pp(key, inputs)

def run_without_pp(self, key, inputs):
pipe_meta = None
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

self.lock.acquire()

cur_key = self.key.val
sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
self.key.addOne()
with torch.inference_mode():
output = self.model(x=sample['img'])
self.lock.release()

return output, cur_key

'''
hidden_size : ([32, 512, 1600])
For different model type, fill_meta_tensor is different
'''

def fill_meta_tensor(self, inputs, pipe_meta):
pipe_meta.get_meta_tensor()[0] = inputs['img'].shape[0]
pipe_meta.get_meta_tensor()[1] = inputs['img'].shape[0]
pipe_meta.get_meta_tensor()[2] = self.hidden_shape[1]
pipe_meta.get_meta_tensor()[3] = self.hidden_shape[2]
pipe_meta.update_meta()

def run_with_pp(self, key, inputs):
pipe_meta = PipelineMeta(self.tensor_dim, self.max_batch_size)
self.fill_meta_tensor(inputs, pipe_meta)
self.pipe_msg_queue.enqueue(key, inputs, pipe_meta)

self.lock.acquire()
cur_key = self.key.val
sample, pipe_meta = self.pipe_msg_queue.top(cur_key)
self.key.addOne()

with torch.inference_mode():

if gpc.is_first_rank(ParallelMode.PIPELINE):
output = self.model(x=sample['img'])
send_forward(output)
self.lock.release()
return None

if gpc.is_last_rank(ParallelMode.PIPELINE):
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(x=input_tensor)
self.lock.release()
return output, cur_key

else:
input_tensor = recv_forward(pipe_meta.get_tensor_shapes(), dtype=self.dtype)
output = self.model(x=input_tensor)
send_forward(output)
self.lock.release()
return None
1 change: 0 additions & 1 deletion energon/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ def launch_from_multiprocess(tp_size: int = 1,
os.environ['MASTER_PORT'] = f'{port}'

config = dict(parallel=dict(pipeline=dict(size=pp_size), tensor=dict(size=tp_size, mode='1d')))
gpc.load_config(config)

launch(config=config,
local_rank=local_rank,
Expand Down
26 changes: 6 additions & 20 deletions energon/server/worker_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,16 @@
from fastapi import FastAPI
import torch.distributed.rpc as rpc
from energon.initialize import launch_from_multiprocess
from colossalai.logging import get_dist_logger

app = FastAPI() # 创建 api 对象
logger = get_dist_logger('energon')


@app.get("/") # 根路由
app = FastAPI()
@app.get("/")
def root():
return {"200"}


# @app.get("/start/{tp_size}")
# def init(tp_size: int, pp_size: int, backend: str, seed: int, verbose: bool, rank: int, local_rank: int, host: str, port: int):
# # http://127.0.0.1:8005/start/1?pp_size=1&backend=nccl&seed=1024&verbose=true&rank=0&local_rank=0&host=localhost&port=29500
# # http://127.0.0.1:8005/start/1?pp_size=1&backend=nccl&seed=1024&verbose=true&rank=0&local_rank=0&host=localhost&port=29500
# world_size = tp_size * pp_size

# os.environ['MASTER_ADDR'] = host
# os.environ['MASTER_PORT'] = f'{port}'
# launch_from_multiprocess(tp_size, pp_size, backend, seed, verbose, rank, local_rank, world_size, host, port)
# WORKER_NAME = "wok{}"
# rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
# num_worker_threads=16)
# rpc.init_rpc(WORKER_NAME.format(rank), rank=rank, world_size=world_size, rpc_backend_options=rpc_backend_options)
# rpc.shutdown()
# # print(f'{WORKER_NAME.format(rank)} Start!')
# return {f'{WORKER_NAME.format(rank)} Start!'}


@app.get("/shutdown")
async def shutdown():
Expand Down Expand Up @@ -60,6 +44,8 @@ def launch_worker(host="127.0.0.1",
)
rpc.init_rpc(WORKER_NAME.format(rank), rank=rank, world_size=world_size, rpc_backend_options=rpc_backend_options)

logger.info(f'RPC STATUS: RPC of Rank: {rank} is initialized.')

global server
config = uvicorn.Config(app, host=server_host, port=server_port, log_level=log_level)
server = uvicorn.Server(config=config)
Expand Down
Binary file added examples/vit/dataset/n01667114_9985.JPEG
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 35 additions & 0 deletions examples/vit/proc_img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from typing import Any
from PIL import Image
import torchvision.transforms as transforms
import torchvision.datasets as datasets


default_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, "rb") as f:
img = Image.open(f)
return img.convert("RGB")

def accimage_loader(path: str) -> Any:
import accimage

try:
return accimage.Image(path)
except OSError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)

def proc_img(path: str, size: int=224, normalize=default_normalize, loader=pil_loader) -> torch.Tensor:
img = loader(path)
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
img = transform(img)
return img
57 changes: 57 additions & 0 deletions examples/vit/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from vit import vit_large_patch32_384, vit_base_patch16_224,vit_lite_depth7_patch4_32,vit_large_patch32_224
from colossalai import launch_from_torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
import torch
from typing import Any

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from PIL import Image

config = dict(parallel=dict(pipeline=dict(size=2), tensor=dict(size=1, mode='1d')))

launch_from_torch(config)

def accimage_loader(path: str) -> Any:
import accimage

try:
return accimage.Image(path)
except OSError:
# Potentially a decoding problem, fall back to PIL.Image
return pil_loader(path)

def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, "rb") as f:
img = Image.open(f)
return img.convert("RGB")

img = pil_loader('/home/lcdjs/ColossalAI-Inference/examples/vit/dataset/n01667114_9985.JPEG')

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
img = transform(img)



img = torch.unsqueeze(img, 0).half().cuda()

model = vit_large_patch32_224(dtype=torch.half).cuda()

# print(model)

if gpc.is_first_rank(ParallelMode.PIPELINE):
output = model(img)
print(type(output))

# print(model)
# print(torch.cuda.memory_allocated())
Loading