Skip to content

Commit

Permalink
add graph operation fusion with env option
Browse files Browse the repository at this point in the history
  • Loading branch information
CyCle1024 committed Dec 27, 2024
1 parent 94fd42d commit 9062c6b
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 4 deletions.
21 changes: 20 additions & 1 deletion dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <algorithm>
#include <fstream>

#include "torch_npu/csrc/framework/OpCommand.h"

#include "ops/operation_creator.h"
#include "utils/config.h"
#include "utils/log.h"
Expand Down Expand Up @@ -139,6 +141,8 @@ atb::Tensor Model::CreateInternalTensorFromDesc(const atb::TensorDesc& tensorDes
}

Model::Model(const std::string& modelId, const std::string& modelPath) : modelId_(modelId), modelPath_(modelPath) {
const char *envStr = std::getenv("DICP_USE_TORCH_NPU_LAUNCHER");
UseTorchNpuLauncher_ = (envStr != nullptr && std::string(envStr) == "1");
auto st = BuildGraph();
DICP_LOG_IF(st != atb::NO_ERROR, ERROR) << modelId_ << " init graph:\n" << graph_.ToString();
graph_.Init();
Expand Down Expand Up @@ -261,7 +265,22 @@ atb::Status Model::ExecuteNode(int nodeId) {

DICP_LOG(INFO) << modelId_ << "execute node[" << nodeId << "] start";

st = node.operation->Execute(node.variantPack, (uint8_t*)(node.workspace), node.workspaceSize, context_);
if (UseTorchNpuLauncher_) {
at_npu::native::OpCommand cmd;
std::string taskName = "DicpDecoderModel_" + modelId_ + std::to_string(nodeId);
std::function<int()> task = [&]() {
atb::Status tmp_st = node.operation->Execute(node.variantPack, (uint8_t*)(node.workspace), node.workspaceSize, context_);
if (tmp_st != 0) {
DICP_LOG(ERROR) << "op command execute node[" << nodeId << "] fail, error code: " << st;
}
return 0;
};
cmd.Name(taskName);
cmd.SetCustomHandler(task);
cmd.Run();
} else {
st = node.operation->Execute(node.variantPack, (uint8_t*)(node.workspace), node.workspaceSize, context_);
}
if (st != 0) {
DICP_LOG(ERROR) << "execute node[" << nodeId << "] fail, error code: " << st;
}
Expand Down
1 change: 1 addition & 0 deletions dlinfer/graph/dicp/vendor/AtbGraph/codegen/runtime/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class Model {
void SetupInferShape(const nlohmann::json& inferShape, atb::InferShapeFunc& inferShapeFunc);

private:
bool UseTorchNpuLauncher_;
std::string modelId_;
std::string modelPath_;
Graph graph_;
Expand Down
7 changes: 7 additions & 0 deletions dlinfer/graph/dicp/vendor/AtbGraph/compile_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ def __init__(self, source_code) -> None:
def _compile(self):
try:
if not hasattr(torch.classes.DICPModel, "DICPModel"):
if os.getenv("DICP_USE_TORCH_NPU_LAUNCHER", "0") != "0":
os.environ["ATB_CONTEXT_HOSTTILING_RING"] = "1"
os.environ["ATB_CONTEXT_HOSTTILING_SIZE"] = "102400"
os.environ["ATB_WORKSPACE_MEM_ALLOC_GLOBAL"] = "1"
os.environ["ATB_USE_TILING_COPY_STREAM"] = "0"
os.environ["ATB_OPSRUNNER_KERNEL_CACHE_LOCAL_COUNT"] = "1"
os.environ["ATB_OPSRUNNER_KERNEL_CACHE_GLOABL_COUNT"] = "16"
current_dir = os.path.dirname(__file__)
lib_path = os.path.join(current_dir, "codegen/libdicp_model.so")
torch.classes.load_library(lib_path)
Expand Down
31 changes: 28 additions & 3 deletions dlinfer/graph/dicp/vendor/AtbGraph/conversion.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import os
import functools
import operator
import torch
import math

import torch.fx
from torch.fx.immutable_collections import immutable_dict
from collections import OrderedDict
import torch.fx.traceback as fx_traceback
from dlinfer.graph.dicp.vendor.AtbGraph import atb_op

Expand Down Expand Up @@ -80,6 +83,18 @@ class AtenToAtbTransformer(SingleOpTransformer):
def __init__(self, gm):
super().__init__(gm, conversions)
self._register_binary_ops()
self.use_torch_npu_launcher = os.getenv("DICP_USE_TORCH_NPU_LAUNCHER", "0") != "0"
self.graph_op_group = None

def get_proxy(self, target, args, kwargs=immutable_dict()):
proxy = super().get_proxy(target, args, kwargs)
if self.use_torch_npu_launcher:
if target == atb_op.Graph:
return proxy
if isinstance(self.graph_op_group, OrderedDict):
assert id(proxy) not in self.graph_op_group
self.graph_op_group[id(proxy)] = proxy
return proxy

@register_conversion(torch.ops.atb.linear.default)
def linear(self, a, b, bias, trans_a, trans_b):
Expand All @@ -95,6 +110,7 @@ def identity(self, x, idx):

@register_conversion("torch.ops.dlinfer.rms_norm.default")
def npu_rms_norm(self, x, w, eps=1e-6):
self.graph_op_group = OrderedDict()
rms_norm = self.get_proxy(atb_op.RmsNorm, (x, w, eps))
return rms_norm

Expand Down Expand Up @@ -341,14 +357,20 @@ def silu_and_mul(self, gate_up, dim):
up = self.get_proxy(atb_op.GetItem, (split, 1))
act = self.get_proxy(atb_op.Swish, (gate,))
mul = self.get_proxy(atb_op.Mul, (act, up))
graph = self.get_proxy(
atb_op.Graph, (split, gate, up, act, mul), {"output": mul}
)
# graph = self.get_proxy(
# atb_op.Graph, (split, gate, up, act, mul), {"output": mul}
# )
return mul

@register_conversion("torch.ops.dlinfer.add_rms_norm.default")
def dlinfer_add_rms_norm(self, x1, x2, gamma, epsilon):
add = self.get_proxy(atb_op.Add, (x1, x2))
if self.use_torch_npu_launcher and len(self.graph_op_group) > 0:
op_tuple = tuple(self.graph_op_group.values())
graph = self.get_proxy(atb_op.Graph,
op_tuple,
{"output": add})
self.graph_op_group = OrderedDict()
norm = self.get_proxy(atb_op.RmsNorm, (add, gamma, epsilon))
# FIXME(tangzhiyi11): Temporarily disable graph op for MOE precision issues
# graph = self.get_proxy(
Expand Down Expand Up @@ -546,6 +568,9 @@ def alias(self, x):
def dlinfer_linear(self, x, weight, bias, all_reduce):
if all_reduce == False:
return self.get_proxy(atb_op.Linear, (x, weight, bias, False, True))
# else:
# linear_res = self.get_proxy(atb_op.Linear, (x, weight, bias, False, True))
# return self.get_proxy(atb_op.AllReduce, (linear_res, "sum"))
return self.get_proxy(atb_op.LinearAllReduce, (x, weight, bias))

@register_conversion(torch.ops.aten.index.Tensor)
Expand Down

0 comments on commit 9062c6b

Please sign in to comment.