-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PIR+CINN]Part-2 Pybind IrParser.ParseProgram and Polish UT into check_run #59449
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
85eb60c
[PIR+CINN]Support SubGraph Exporter for Unittest Platform
Aurelius84 ade735b
fix conflict
Aurelius84 bb3e81e
remove print
Aurelius84 77e36ee
fix UT
Aurelius84 0c4ad9f
add list.sort to fix random
Aurelius84 664afa8
Merge branch 'develop' into ir_cinn_test1
Aurelius84 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ | |
Program, | ||
Type, | ||
Value, | ||
parse_program, | ||
check_unregistered_ops, | ||
fake_op_result, | ||
is_fake_op_result, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,8 @@ | |
import shutil | ||
import unittest | ||
|
||
import numpy as np | ||
|
||
import paddle | ||
from paddle.jit.dy2static.export_subgraph import get_saving_dir | ||
|
||
|
@@ -50,53 +52,105 @@ def test_export(self): | |
out = self.net(x) | ||
self.check_export() | ||
|
||
def run_program(self, program, feed, fetch_list): | ||
paddle.enable_static() | ||
exe = paddle.static.Executor(paddle.CPUPlace()) | ||
outs = exe._run_pir_impl( | ||
program, | ||
feed=feed, | ||
fetch_list=fetch_list, | ||
feed_var_name="feed", | ||
fetch_var_name='fetch', | ||
scope=None, | ||
return_numpy=True, | ||
) | ||
paddle.disable_static() | ||
return outs | ||
|
||
def check_export(self): | ||
for prog_file in os.listdir(self.root_dir): | ||
if "forward" in prog_file: | ||
self.check_fwd(prog_file) | ||
return | ||
elif "backward" in prog_file: | ||
self.check_bwd(prog_file) | ||
else: | ||
raise RuntimeError("Not Support.") | ||
|
||
def check_fwd(self, prog_file): | ||
prog_info = [ | ||
"pt_input_0", | ||
"pt_output_0", | ||
"pt_output_1", | ||
"pt_intermediate_0", | ||
"pt_intermediate_1", | ||
"pt_intermediate_2", | ||
] | ||
path = os.path.join(self.root_dir, prog_file) | ||
with open(path, 'r') as f: | ||
content = f.readlines() | ||
index = 0 | ||
for op_str in content: | ||
if "pd_op.data" in op_str or "pd_op.fetch" in op_str: | ||
self.assertIn(prog_info[index], op_str) | ||
index += 1 | ||
content = f.read() | ||
program = paddle.pir.parse_program(content) | ||
|
||
def check_bwd(self, prog_file): | ||
prog_info = [ | ||
"pt_input_6", | ||
"pt_input_5", | ||
"pt_input_4", | ||
"pt_input_3", | ||
"pt_input_2", | ||
"pt_input_1", | ||
"pt_input_0", | ||
pt_input_0 = np.random.random([4, 4]).astype(np.float32) | ||
feed = {"pt_input_0": pt_input_0} | ||
fetch_list = [ | ||
'pt_output_0', | ||
'pt_output_1', | ||
'pt_intermediate_0', | ||
'pt_intermediate_1', | ||
'pt_intermediate_2', | ||
] | ||
outs = self.run_program(program, feed, fetch_list) | ||
|
||
self.assertEqual(len(outs), 5) | ||
out_shapes = [[4, 4], [], [4, 4], [4, 4], [4, 4]] | ||
for i, out in enumerate(outs): | ||
self.assertListEqual(list(out.shape), out_shapes[i]) | ||
|
||
def check_bwd(self, prog_file): | ||
path = os.path.join(self.root_dir, prog_file) | ||
with open(path, 'r') as f: | ||
content = f.readlines() | ||
index = 0 | ||
for op_str in content: | ||
if "pd_op.data" in op_str or "pd_op.fetch" in op_str: | ||
self.assertIn(prog_info[index], op_str) | ||
index += 1 | ||
content = f.read() | ||
|
||
program = paddle.pir.parse_program(content) | ||
data = np.random.random([4, 4]).astype(np.float32) | ||
feed = { | ||
"pt_input_6": data, | ||
"pt_input_5": data, | ||
"pt_input_4": data, | ||
"pt_input_3": np.array(0.1).astype(np.float32), | ||
"pt_input_2": data, | ||
"pt_input_1": data, | ||
"pt_input_0": data, | ||
} | ||
fetch_list = [] | ||
outs = self.run_program(program, feed, fetch_list) | ||
|
||
self.assertEqual(len(outs), 0) | ||
|
||
|
||
# class TestSaveInferProg(TestSaveFwdBwdProg): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个依赖动转静SOT一个BUG Fix 的PR,待依赖PR合入后,单独打开 |
||
|
||
# def test_export(self): | ||
# x = paddle.randn([4, 4]) | ||
# self.net.eval() | ||
# out = self.net(x) | ||
# self.check_export() | ||
|
||
# def check_export(self): | ||
# for prog_file in os.listdir(self.root_dir): | ||
# breakpoint() | ||
# if "infer" in prog_file: | ||
# self.check_infer(prog_file) | ||
# else: | ||
# raise RuntimeError("Not Support.") | ||
|
||
# def check_infer(self, prog_file): | ||
# path = os.path.join(self.root_dir, prog_file) | ||
# with open(path, 'r') as f: | ||
# content = f.read() | ||
# program = paddle.pir.parse_program(content) | ||
|
||
# pt_input_0 = np.random.random([4,4]).astype(np.float32) | ||
# feed = {"pt_input_0": pt_input_0} | ||
# fetch_list = ['pt_output_0', 'pt_output_1'] | ||
# outs = self.run_program(program, feed, fetch_list) | ||
|
||
# self.assertEqual(len(outs), 2) | ||
# out_shapes = [[], [4,4]] | ||
# for i, out in enumerate(outs): | ||
# self.assertListEqual(list(out.shape), out_shapes[i]) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段逻辑删除了对之前fetch op数量报错以及警告的检查,改为补充缺失的fetch op,会不会存在这样的问题,用户多次调用executor run,后续run可能和前边的run并没有太大关系,但是fetch op依然滞留到了program里,导致跑后续run的过程中,依然会fetch并不需要fetch的数据,由于fetch会进行copy操作,会造成隐形性能开销,这里是不是还是拦截或者提示一下相关信息比较好
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
之前的逻辑这里也会给program添加缺失的fetch_ops,这个行为是没有改变的,而且之前的add_pir_fetch_ops是会对所有的fetch_list 添加,即使之前部分fetch_var已经有fetch_op了。
关于多次run且彼此之间的fetch_list不一样的问题,其实在run()接口里应该要先clone program,然后add_feed_fetch_ops,然后缓存起来。后续优先根据feed/fetch 信息来查缓存program,这样每次run就是独立的,互不影响。
这个cache策略是已有的,在上层做的,add_feed_fetch_ops 是不需要关心这个缓存逻辑的