-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SVE][TOPI] Add conv2d NHWC hybrid SVE schedule for arm_cpu
#16899
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work @Anndrey24, I didn't finish reviewing yet, but I had a few sporadic comments about the scheduling
This commit adds an `arm_cpu` conv2d NHWC schedule which generates SVE instructions by extending the hybrid GeMM approach implemented in apache#16106 to use scalable expressions as splitting factors. Various vscale-related fixes needed to implement the schedule are also included, such as: - adding vscale bounds in the `ConstIntBoundAnalyzer` and `IntervalSetEvaluator` - simplifying `MinNode` and `MaxNode` that have scalable expression operands in `RewriteSimplifier`, which would appear when defining the shape of a buffer padded to be a multiple of vscale and in its respective buffer access indices (e.g. `C_1 = T.Buffer((1024 * (T.vscale() * 16 + 256 - 16 % T.vscale() * 16),), data=C)` instead of `C_1 = T.Buffer((1024 * (T.max(255, T.vscale() * 16 + 255 - 16 % T.vscale() * 16) + 1),), data=C)`) The correctness of the new schedule is checked using a TOPI test, while the presence of generated SVE instructions is verified by a codegen test. The new `rewrite_simplify` rules are also covered by additional test cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@@ -369,6 +370,8 @@ class ConstIntBoundAnalyzer::Impl | |||
return VisitLeftShift(op); | |||
} else if (op->op.same_as(tir::builtin::bitwise_and())) { | |||
return VisitBitwiseAnd(op); | |||
} else if (op->op.same_as(tir::builtin::vscale()) && TargetHasSVE()) { | |||
return MakeBound(1, 16); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we could make the upper bound the length of kAArch64VScaleValues
incase more values are added in the future. Happy for this to be added in a later patch though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Anndrey24, looks great! 🚀
Thanks @Anndrey24, @ekalda! |
Addresses a nitpick comment mentioned here: apache#16899 (comment) Change-Id: I5b3dbe2b08dbf3b498b55fb89d9bfc112049baa4
…pu` targets This patch partly reverts the unification of scalable and non-scalable scheduling of conv2d NHWC for `arm_cpu` targets introduced in apache#16899. The non-scalable schedule for float32 splits the N axis (corresponding to number of output channels) by 16 in both the unified and the nonunified schedule versions, and then additionally splits the inner partitions by 4 in only the nonunified version to which this patch is reverting (first added in apache#16106). The two versions' behaviour would be equivalent if none of the padding on the N axis was removed during lowering, however we allow for that to happen as it proved to increase performance for very small convolutions. As it stands, there seems to be a regression in cases where the datatype is float32 and the number of output channels is greater than 16, a multiple of 4, and not a multiple of 16, because even with the removed padding the nonunified schedule is able to vectorise over 4 elements, while the unified version cannot vectorise over 16 elements anymore. Since all of the conv2d NHWC hybrid topi test cases used numbers of output channels either less than 16 or divisible by 16, this patch also adds a new case which falls in the aforementioned regression area.
…pu` targets (#16951) This patch partly reverts the unification of scalable and non-scalable scheduling of conv2d NHWC for `arm_cpu` targets introduced in #16899. The non-scalable schedule for float32 splits the N axis (corresponding to number of output channels) by 16 in both the unified and the nonunified schedule versions, and then additionally splits the inner partitions by 4 in only the nonunified version to which this patch is reverting (first added in #16106). The two versions' behaviour would be equivalent if none of the padding on the N axis was removed during lowering, however we allow for that to happen as it proved to increase performance for very small convolutions. As it stands, there seems to be a regression in cases where the datatype is float32 and the number of output channels is greater than 16, a multiple of 4, and not a multiple of 16, because even with the removed padding the nonunified schedule is able to vectorise over 4 elements, while the unified version cannot vectorise over 16 elements anymore. Since all of the conv2d NHWC hybrid topi test cases used numbers of output channels either less than 16 or divisible by 16, this patch also adds a new case which falls in the aforementioned regression area.
This commit adds an
arm_cpu
conv2d NHWC schedule which generates SVE instructions by extending the hybrid GeMM approach implemented in #16106 to use scalable expressions as splitting factors.Various vscale-related fixes needed to implement the schedule are also included, such as:
ConstIntBoundAnalyzer
andIntervalSetEvaluator
MinNode
andMaxNode
that have scalable expression operands inRewriteSimplifier
, which would appear when defining the shape of a buffer padded to be a multiple of vscale and in its respective buffer access indices (e.g.C_1 = T.Buffer((1024 * (T.vscale() * 16 + 256 - 16 % T.vscale() * 16),), data=C)
instead ofC_1 = T.Buffer((1024 * (T.max(255, T.vscale() * 16 + 255 - 16 % T.vscale() * 16) + 1),), data=C)
)The correctness of the new schedule is checked using a TOPI test, while the presence of generated SVE instructions is verified by a codegen_aarch64 test. The new rewrite_simplify rules are also covered by additional test cases.
cc @ekalda @lhutton1 @Lunderberg