Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR+CINN]Part-2 Pybind IrParser.ParseProgram and Polish UT into check_run #59449

Merged
merged 6 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
#include "paddle/pir/core/block.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/parser/ir_parser.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/core/type.h"
#include "paddle/pir/core/value.h"
Expand Down Expand Up @@ -80,6 +81,7 @@ using paddle::dialect::SelectedRowsType;
using pir::Attribute;
using pir::Block;
using pir::BlockArgument;
using pir::IrParser;
using pir::Operation;
using pir::OpOperand;
using pir::OpResult;
Expand Down Expand Up @@ -250,6 +252,15 @@ void BindProgram(py::module *m) {
});
}

std::shared_ptr<Program> ParseProgram(const std::string &program_str) {
std::stringstream ss(program_str);
pir::IrContext *ctx = pir::IrContext::Instance();
auto program = IrParser(ctx, ss).ParseProgram();
return program;
}

void BindIrParser(py::module *m) { m->def("parse_program", &ParseProgram); }

void RefreshOpStopgradients(Operation *op) {
if (op->num_operands() == 0 || op->isa<pir::ParameterOp>() ||
op->isa<paddle::dialect::UniformOp>()) {
Expand Down Expand Up @@ -1609,6 +1620,7 @@ void BindPir(pybind11::module *module) {
BindControlFlowApi(&ir_module);
auto ops_modules = ir_module.def_submodule("ops");
BindOpsAPI(&ops_modules);
BindIrParser(&ir_module);
}

} // namespace pybind
Expand Down
39 changes: 19 additions & 20 deletions python/paddle/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,25 +428,23 @@ def has_fetch_operations(
that match the info contained in fetch_targets.
"""

fetch_count = 0
mismatch_count = 0
fetch_info = [[], []]
for op in block.ops:
if op.name() == fetch_op:
if op.operand_source(0) not in fetch_targets:
mismatch_count += 1
continue
fetch_count += 1
if mismatch_count > 0:
warnings.warn(
"There are {} fetch ops in Program which are not responsible for the fetch targets that you have passed in fetch_list".format(
mismatch_count
)
)
if fetch_count > 0 and fetch_count != len(fetch_targets):
raise Exception(
"Fetch operations in program do not match 'fetch_targets'"
)
return fetch_count > 0
fetch_info[0].append(op.operand_source(0))
fetch_info[1].append(op.attrs()["name"])

need_fetch_info = []
for i, fetch_var in enumerate(fetch_targets):
if isinstance(fetch_var, str):
if fetch_var not in fetch_info[1]:
raise Exception(
f"Found fetch_target[{i}] is type(str) and doesn't have fetch op."
)
elif fetch_var not in fetch_info[0]:
need_fetch_info.append(fetch_var)

return need_fetch_info
Comment on lines +437 to +447
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段逻辑删除了对之前fetch op数量报错以及警告的检查,改为补充缺失的fetch op,会不会存在这样的问题,用户多次调用executor run,后续run可能和前边的run并没有太大关系,但是fetch op依然滞留到了program里,导致跑后续run的过程中,依然会fetch并不需要fetch的数据,由于fetch会进行copy操作,会造成隐形性能开销,这里是不是还是拦截或者提示一下相关信息比较好

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前的逻辑这里也会给program添加缺失的fetch_ops,这个行为是没有改变的,而且之前的add_pir_fetch_ops是会对所有的fetch_list 添加,即使之前部分fetch_var已经有fetch_op了。
关于多次run且彼此之间的fetch_list不一样的问题,其实在run()接口里应该要先clone program,然后add_feed_fetch_ops,然后缓存起来。后续优先根据feed/fetch 信息来查缓存program,这样每次run就是独立的,互不影响。

这个cache策略是已有的,在上层做的,add_feed_fetch_ops 是不需要关心这个缓存逻辑的



def _add_feed_fetch_ops(
Expand Down Expand Up @@ -519,11 +517,12 @@ def _add_pir_fetch_ops(program, fetch_list, fetch_var_name):

global_block = program.global_block()
fetch_op = "pd_op.fetch"
if not has_fetch_operations(
need_fetch_info = has_fetch_operations(
global_block, fetch_list, fetch_var_name, fetch_op
):
)
if need_fetch_info:
with paddle.static.program_guard(program):
for i, fetch_input in enumerate(fetch_list):
for i, fetch_input in enumerate(need_fetch_info):
assert isinstance(
fetch_input, (OpResult, Value)
), f"Wrong type for fetch_list[{i}]: {type(fetch_input)}"
Expand Down
49 changes: 18 additions & 31 deletions python/paddle/jit/dy2static/export_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from paddle import pir
from paddle.base import core
from paddle.base.dygraph.base import switch_to_static_graph
from paddle.base.framework import Variable, get_flags
from paddle.base.framework import get_flags

__all__ = []

Expand All @@ -42,6 +42,7 @@ def __init__(self, partial_program_layer, program, role):
self.program = program
self.role = role
self.root_dir = get_saving_dir()
self.fetch_col = 0

def save(self):
# step 1: Create subgraph saving path.
Expand All @@ -59,6 +60,9 @@ def _save(self, pir_program, path):
f.write(content)

def parse_inout(self):
"""
Return feed/fetch/intermediate var name list.
"""
raise NotImplementedError("Need to implement parse_inout method")

def translate_into_pir(self):
Expand Down Expand Up @@ -90,9 +94,8 @@ def verify_saving_dir(self, dir_path):

def insert_feed_op(self, intputs, rename_prefix):
global_block = self.program.block(0)

for i, var in enumerate(intputs):
old_name = var.name
intputs.sort()
for i, old_name in enumerate(intputs):
new_name = rename_prefix + str(i)
global_block._rename_var(old_name, new_name)
out = global_block.var(new_name)
Expand All @@ -116,32 +119,20 @@ def insert_fetch_op(self, outputs, rename_prefix):
type=core.VarDesc.VarType.FETCH_LIST,
persistable=False,
)
for i, out in enumerate(outputs):
var = self.get_var(out)
old_name = var.name
outputs.sort()
for i, old_name in enumerate(outputs):
new_name = rename_prefix + str(i)
global_block._rename_var(old_name, new_name)
new_var = global_block.var(new_name)
global_block.append_op(
type="fetch",
inputs={'X': [new_var]},
outputs={'Out': [fetch_var]},
attrs={'col': i},
attrs={'col': self.fetch_col},
)
self.fetch_col += 1
global_block._sync_with_cpp()

def rename_ops(self, ops, new_name, old_name):
for op in ops:
op._rename_input(old_name, new_name)
op._rename_output(old_name, new_name)

def get_var(self, name_or_var):
if isinstance(name_or_var, Variable):
return name_or_var
assert isinstance(name_or_var, str)
global_block = self.program.block(0)
return global_block.var(name_or_var)


class InferExporter(BaseExporter):
def __init__(self, *args, **kwargs):
Expand All @@ -153,12 +144,10 @@ def parse_inout(self):
raw_inputs = self.pp_layer._inputs.tolist() + self.pp_layer._params
raw_outputs = self.pp_layer._outputs.tolist()
for var in raw_inputs:
new_var = global_block.var(var.name)
inputs.append(new_var)
inputs.append(var.name)

for var in raw_outputs:
new_var = global_block.var(var.name)
outputs.append(new_var)
outputs.append(var.name)

return inputs, outputs, []

Expand All @@ -180,14 +169,12 @@ def parse_inout(self):
if self.program.block(0).has_var(name)
}
for var in raw_inputs:
new_var = global_block.var(var.name)
inputs.append(new_var)
inputs.append(var.name)
if var.name in inter_outs:
inter_outs.remove(var.name)

for var in raw_outputs:
new_var = global_block.var(var.name)
outputs.append(new_var)
outputs.append(var.name)
if var.name in inter_outs:
inter_outs.remove(var.name)

Expand All @@ -206,22 +193,22 @@ def parse_inout(self):

for var_name in self.raw_inputs:
if global_block.has_var(var_name):
inputs.append(global_block.var(var_name))
inputs.append(var_name)

# add fill_constant grad_var as input
for var in self.pp_layer._outputs.tolist():
init_grad_name = var.name + "@GRAD"
if init_grad_name not in self.raw_inputs and global_block.has_var(
init_grad_name
):
inputs.append(global_block.var(init_grad_name))
inputs.append(init_grad_name)

for var_name in self.raw_outputs:
if (
global_block.has_var(var_name)
and var_name not in self.raw_inputs
):
outputs.append(global_block.var(var_name))
outputs.append(var_name)

return inputs, outputs, []

Expand Down
1 change: 1 addition & 0 deletions python/paddle/pir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
Program,
Type,
Value,
parse_program,
check_unregistered_ops,
fake_op_result,
is_fake_op_result,
Expand Down
7 changes: 6 additions & 1 deletion test/ir/pir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ set(TEST_IR_SYSTEM_CASES
list(REMOVE_ITEM TEST_INTERP_CASES ${TEST_IR_SYSTEM_CASES})
list(REMOVE_ITEM TEST_INTERP_CASES test_subgraph_exporter)
py_test_modules(
test_subgraph_exporter MODULES test_subgraph_exporter ENVS MIN_GRAPH_SIZE=0
test_subgraph_exporter
MODULES
test_subgraph_exporter
ENVS
MIN_GRAPH_SIZE=0
FLAGS_enable_pir_in_executor=1
FLAGS_pir_subgraph_saving_dir=${CMAKE_CURRENT_SOURCE_DIR})

foreach(target ${TEST_INTERP_CASES})
Expand Down
114 changes: 84 additions & 30 deletions test/ir/pir/test_subgraph_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import shutil
import unittest

import numpy as np

import paddle
from paddle.jit.dy2static.export_subgraph import get_saving_dir

Expand Down Expand Up @@ -50,53 +52,105 @@ def test_export(self):
out = self.net(x)
self.check_export()

def run_program(self, program, feed, fetch_list):
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
outs = exe._run_pir_impl(
program,
feed=feed,
fetch_list=fetch_list,
feed_var_name="feed",
fetch_var_name='fetch',
scope=None,
return_numpy=True,
)
paddle.disable_static()
return outs

def check_export(self):
for prog_file in os.listdir(self.root_dir):
if "forward" in prog_file:
self.check_fwd(prog_file)
return
elif "backward" in prog_file:
self.check_bwd(prog_file)
else:
raise RuntimeError("Not Support.")

def check_fwd(self, prog_file):
prog_info = [
"pt_input_0",
"pt_output_0",
"pt_output_1",
"pt_intermediate_0",
"pt_intermediate_1",
"pt_intermediate_2",
]
path = os.path.join(self.root_dir, prog_file)
with open(path, 'r') as f:
content = f.readlines()
index = 0
for op_str in content:
if "pd_op.data" in op_str or "pd_op.fetch" in op_str:
self.assertIn(prog_info[index], op_str)
index += 1
content = f.read()
program = paddle.pir.parse_program(content)

def check_bwd(self, prog_file):
prog_info = [
"pt_input_6",
"pt_input_5",
"pt_input_4",
"pt_input_3",
"pt_input_2",
"pt_input_1",
"pt_input_0",
pt_input_0 = np.random.random([4, 4]).astype(np.float32)
feed = {"pt_input_0": pt_input_0}
fetch_list = [
'pt_output_0',
'pt_output_1',
'pt_intermediate_0',
'pt_intermediate_1',
'pt_intermediate_2',
]
outs = self.run_program(program, feed, fetch_list)

self.assertEqual(len(outs), 5)
out_shapes = [[4, 4], [], [4, 4], [4, 4], [4, 4]]
for i, out in enumerate(outs):
self.assertListEqual(list(out.shape), out_shapes[i])

def check_bwd(self, prog_file):
path = os.path.join(self.root_dir, prog_file)
with open(path, 'r') as f:
content = f.readlines()
index = 0
for op_str in content:
if "pd_op.data" in op_str or "pd_op.fetch" in op_str:
self.assertIn(prog_info[index], op_str)
index += 1
content = f.read()

program = paddle.pir.parse_program(content)
data = np.random.random([4, 4]).astype(np.float32)
feed = {
"pt_input_6": data,
"pt_input_5": data,
"pt_input_4": data,
"pt_input_3": np.array(0.1).astype(np.float32),
"pt_input_2": data,
"pt_input_1": data,
"pt_input_0": data,
}
fetch_list = []
outs = self.run_program(program, feed, fetch_list)

self.assertEqual(len(outs), 0)


# class TestSaveInferProg(TestSaveFwdBwdProg):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个依赖动转静SOT一个BUG Fix 的PR,待依赖PR合入后,单独打开


# def test_export(self):
# x = paddle.randn([4, 4])
# self.net.eval()
# out = self.net(x)
# self.check_export()

# def check_export(self):
# for prog_file in os.listdir(self.root_dir):
# breakpoint()
# if "infer" in prog_file:
# self.check_infer(prog_file)
# else:
# raise RuntimeError("Not Support.")

# def check_infer(self, prog_file):
# path = os.path.join(self.root_dir, prog_file)
# with open(path, 'r') as f:
# content = f.read()
# program = paddle.pir.parse_program(content)

# pt_input_0 = np.random.random([4,4]).astype(np.float32)
# feed = {"pt_input_0": pt_input_0}
# fetch_list = ['pt_output_0', 'pt_output_1']
# outs = self.run_program(program, feed, fetch_list)

# self.assertEqual(len(outs), 2)
# out_shapes = [[], [4,4]]
# for i, out in enumerate(outs):
# self.assertListEqual(list(out.shape), out_shapes[i])

if __name__ == "__main__":
unittest.main()