-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relax] Implement operators to inspec DLTensor::strides and offset #16721
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding ways to access this information; it will be useful for expanding support for in-place computation in Relax. I had some questions about some of the changes included in this PR, as I wasn't sure how the fit in with the stated purpose.
@@ -32,21 +32,21 @@ | |||
namespace tvm { | |||
namespace tir { | |||
|
|||
class CollectUnmanagedAllocations : public StmtExprVisitor { | |||
class CollectManagedAllocations : public StmtExprVisitor { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these changes relate to the stated purpose of the PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. This was a bug that I encountered and fixed during implementation. The TIR build pipeline made some simplifying assumptions about the TIR that it would receive, which the new legalization functions do not follow.
I've spun it out into an independent PR#16726.
@@ -38,6 +38,19 @@ namespace tir { | |||
// These information are needed during codegen. | |||
class BuiltinLower : public StmtExprMutator { | |||
public: | |||
static PrimFunc Build(PrimFunc func) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this also related to the stated purpose of the PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was another bug I ran into. The LowerTVMBuiltin
pass would look for each Allocate
node, and would replace it with a call to the PackedFunc TVMBackendAllocWorkspace
, using the exact device type on which the allocation should occur. This change allowed it to search in the function attributes for the device type, and not just in the AttrStmt
.
Partly, this is due to the expanded role of the "tir.is_host_func"
attribute. Initially, this attribute was used to tag the host/device portions after SplitHostDevice
. However, it also has been useful for tagging a function as being on the host, when in a portion of the lowering flow that doesn't yet know the exact targets for the device/host (example). This is useful functionality, so I expanding LowerTVMBuiltin
to support it makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've spun this change out into an independent PR #16727.
// As of 2024-03-14, Relax does not have an explicit | ||
// representation for striding in `TensorStructInfo`. The | ||
// `FLegalize` function for most operators is implemented in terms | ||
// of `topi`, and is then converted from TE to `tir::PrimFunc` | ||
// using `tvm::tir::CreatePrimFunc`. The `te::Tensor` is | ||
// converted to a `tir::Buffer` in `RewriteStageToBlock`, and uses | ||
// the default empty list for the strides. The empty strides | ||
// represent a compact data array. | ||
// | ||
// Therefore, while Relax does not explicitly represent the | ||
// striding of a tensor, it implicitly requires compact striding | ||
// for any legalizable Tensor. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be worth discussing whether this is the specification we would want. Perhaps it would be good to have a helper function to assert/convert tensors into the expected format. Right now, tensors that don't conform to the expectations about stride/offset are completely "invisible" to Relax other than using these newly added functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. This was more to state what the current implementation requires. The upgrade path that I'd imagine would be with a sequence of incremental changes:
- Provide ability to read the strides/offset.
- Insert assert statements (probably while lowering
R.match_cast
) that the user-provided buffers meet the implicit requirements. - Update the
TensorStructInfo
to specify the byte offset and striding. (I'd default to a contiguous striding, but an arbitrary offset. The offset is the more useful one, and could be applied before delegating to an external kernel.) - Update the assert statements to match the more flexible
TensorStructInfo
.
This PR would be step (1) in this sequence.
If an allocation occurs within a host function, it may not have a device/host split.
27101fb
to
e9ffd2f
Compare
This PR branch has been updated to be based on top of the two spun-out PRs. |
A follow-up PR to apache#16563. This PR implements similar operators to inspect the runtime values of `DLTensor::strides` and `DLTensor::byte_offset`. In addition, while the element offset is not explicitly present in the `DLTensor` struct, a Relax operator is implemented to infer it from the `byte_offset` and `data_type` fields, for use when interacting with the TIR `BufferNode::elem_offset` field.
e9ffd2f
to
3af29e5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for spinning off changes to other PRs that can be reviewed separately. I think the main concept is sound. We may want to revisit the default inferred PrimStructInfo for some of these calls in the future, namely if we handle offsets/strides more systematically later, though the approach here is correct for the present.
Sounds like a plan. I think the biggest use of |
…pache#16721) * [TIR] LowerTVMBuiltin may use device_type from PrimFunc annotation If an allocation occurs within a host function, it may not have a device/host split. * lint fix * [Relax] Implement operators to inspec DLTensor::strides and offset A follow-up PR to apache#16563. This PR implements similar operators to inspect the runtime values of `DLTensor::strides` and `DLTensor::byte_offset`. In addition, while the element offset is not explicitly present in the `DLTensor` struct, a Relax operator is implemented to infer it from the `byte_offset` and `data_type` fields, for use when interacting with the TIR `BufferNode::elem_offset` field.
A follow-up PR to #16563. This PR implements similar operators to inspect the runtime values of
DLTensor::strides
andDLTensor::byte_offset
. In addition, while the element offset is not explicitly present in theDLTensor
struct, a Relax operator is implemented to infer it from thebyte_offset
anddata_type
fields, for use when interacting with the TIRBufferNode::elem_offset
field.