Skip to content

Commit

Permalink
Revert "optimize nhwc reduce (PaddlePaddle#1302)" (PaddlePaddle#1412)
Browse files Browse the repository at this point in the history
This reverts commit dca26ff.
Fast Reverts PaddlePaddle#1302 because the Paddle-CINN-CI is failed and block other PRs
因CI事故,我bypass测试
  • Loading branch information
zhhsplendid authored and jiahy0825 committed May 25, 2023
1 parent 77f10c8 commit 3a1af2e
Show file tree
Hide file tree
Showing 24 changed files with 1,201 additions and 947 deletions.
36 changes: 18 additions & 18 deletions cinn/auto_schedule/post_schedule_rule/cooperative_process_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,23 @@ TEST_F(TestCooperativeProcess, Matmul) {
ScheduleBlock(root)
{
{
serial for (i, 0, 2)
serial for (i_0, 0, 2)
{
serial for (j, 0, 2)
serial for (j_0, 0, 2)
{
serial for (i_0, 0, 8)
serial for (i_1, 0, 8)
{
serial for (j_0, 0, 2)
serial for (j_1, 0, 2)
{
serial for (i_1, 0, 2)
serial for (i_2, 0, 2)
{
serial for (j_1, 0, 8)
serial for (j_2, 0, 8)
{
ScheduleBlock(temp_matmul_out__reduce_init)
{
i0, i1 = axis.bind(((16 * i) + ((2 * i_0) + i_1)), ((16 * j) + ((8 * j_0) + j_1)))
i0, i1 = axis.bind(((16 * i_0) + ((2 * i_1) + i_2)), ((16 * j_0) + ((8 * j_1) + j_2)))
{
temp_matmul_out__reduce_init[((16 * i) + ((2 * i_0) + i_1)), ((16 * j) + ((8 * j_0) + j_1))] = 0.00000000f
temp_matmul_out__reduce_init[((16 * i_0) + ((2 * i_1) + i_2)), ((16 * j_0) + ((8 * j_1) + j_2))] = 0.00000000f
}
}
}
Expand All @@ -120,19 +120,19 @@ TEST_F(TestCooperativeProcess, Matmul) {
}
}
}
thread_bind[blockIdx.x] for (i_j_fused, 0, 4)
thread_bind[blockIdx.x] for (i_0_j_0_fused, 0, 4)
{
thread_bind[threadIdx.x] for (i_0_j_0_fused, 0, 16)
thread_bind[threadIdx.x] for (i_1_j_1_fused, 0, 16)
{
serial for (reduce_k_0, 0, 8)
{
serial for (ax0_0_ax1_0_fused, 0, 2)
serial for (ax0_0_ax1_0_fused_0, 0, 2)
{
thread_bind[threadIdx.x] for (ax0_0_ax1_0_fused_0, 0, 16)
thread_bind[threadIdx.x] for (ax0_0_ax1_0_fused_1, 0, 16)
{
ScheduleBlock(Y_reshape_shared_temp_buffer)
{
v0, v1 = axis.bind(((((16 * ax0_0_ax1_0_fused) + ax0_0_ax1_0_fused_0) / 8) + (4 * reduce_k_0)), ((((16 * ax0_0_ax1_0_fused) + ax0_0_ax1_0_fused_0) % 8) + ((8 * (i_0_j_0_fused % 2)) + (16 * (i_j_fused % 2)))))
v0, v1 = axis.bind(((((16 * ax0_0_ax1_0_fused_0) + ax0_0_ax1_0_fused_1) / 8) + (4 * reduce_k_0)), ((((16 * ax0_0_ax1_0_fused_0) + ax0_0_ax1_0_fused_1) % 8) + ((16 * (i_0_j_0_fused % 2)) + (8 * (i_1_j_1_fused % 2)))))
attrs(compute_at_extra_var:ax0_0,ax1_0)
{
Y_reshape_shared_temp_buffer[v0, v1] = Y_reshape[v0, v1]
Expand All @@ -145,7 +145,7 @@ TEST_F(TestCooperativeProcess, Matmul) {
{
ScheduleBlock(X_reshape_shared_temp_buffer)
{
v0, v1 = axis.bind(((ax0_ax1_fused / 4) + ((2 * (i_0_j_0_fused / 2)) + (16 * (i_j_fused / 2)))), ((ax0_ax1_fused % 4) + (4 * reduce_k_0)))
v0, v1 = axis.bind(((ax0_ax1_fused / 4) + ((2 * (i_1_j_1_fused / 2)) + (16 * (i_0_j_0_fused / 2)))), ((ax0_ax1_fused % 4) + (4 * reduce_k_0)))
attrs(compute_at_extra_var:ax0,ax1)
{
X_reshape_shared_temp_buffer[v0, v1] = X_reshape[v0, v1]
Expand All @@ -155,15 +155,15 @@ TEST_F(TestCooperativeProcess, Matmul) {
__syncthreads()
serial for (reduce_k_1, 0, 4)
{
serial for (i_1, 0, 2)
serial for (i_2, 0, 2)
{
serial for (j_1, 0, 8)
serial for (j_2, 0, 8)
{
ScheduleBlock(temp_matmul_out)
{
i0, i1, i2 = axis.bind(((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1)), ((4 * reduce_k_0) + reduce_k_1))
i0, i1, i2 = axis.bind(((2 * (i_1_j_1_fused / 2)) + ((16 * (i_0_j_0_fused / 2)) + i_2)), ((16 * (i_0_j_0_fused % 2)) + ((8 * (i_1_j_1_fused % 2)) + j_2)), ((4 * reduce_k_0) + reduce_k_1))
{
temp_matmul_out[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))] = (temp_matmul_out[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))] + (X_reshape_shared_temp_buffer[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((4 * reduce_k_0) + reduce_k_1)] * Y_reshape_shared_temp_buffer[((4 * reduce_k_0) + reduce_k_1), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))]))
temp_matmul_out[((2 * (i_1_j_1_fused / 2)) + ((16 * (i_0_j_0_fused / 2)) + i_2)), ((16 * (i_0_j_0_fused % 2)) + ((8 * (i_1_j_1_fused % 2)) + j_2))] = (temp_matmul_out[((2 * (i_1_j_1_fused / 2)) + ((16 * (i_0_j_0_fused / 2)) + i_2)), ((16 * (i_0_j_0_fused % 2)) + ((8 * (i_1_j_1_fused % 2)) + j_2))] + (X_reshape_shared_temp_buffer[((2 * (i_1_j_1_fused / 2)) + ((16 * (i_0_j_0_fused / 2)) + i_2)), ((4 * reduce_k_0) + reduce_k_1)] * Y_reshape_shared_temp_buffer[((4 * reduce_k_0) + reduce_k_1), ((16 * (i_0_j_0_fused % 2)) + ((8 * (i_1_j_1_fused % 2)) + j_2))]))
}
}
}
Expand Down
Loading

0 comments on commit 3a1af2e

Please sign in to comment.