Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Implement operators to inspec DLTensor::strides and offset #16721

Merged
merged 4 commits into from
Mar 26, 2024

Conversation

Lunderberg
Copy link
Contributor

A follow-up PR to #16563. This PR implements similar operators to inspect the runtime values of DLTensor::strides and DLTensor::byte_offset. In addition, while the element offset is not explicitly present in the DLTensor struct, a Relax operator is implemented to infer it from the byte_offset and data_type fields, for use when interacting with the TIR BufferNode::elem_offset field.

Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding ways to access this information; it will be useful for expanding support for in-place computation in Relax. I had some questions about some of the changes included in this PR, as I wasn't sure how the fit in with the stated purpose.

@@ -32,21 +32,21 @@
namespace tvm {
namespace tir {

class CollectUnmanagedAllocations : public StmtExprVisitor {
class CollectManagedAllocations : public StmtExprVisitor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these changes relate to the stated purpose of the PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. This was a bug that I encountered and fixed during implementation. The TIR build pipeline made some simplifying assumptions about the TIR that it would receive, which the new legalization functions do not follow.

I've spun it out into an independent PR#16726.

@@ -38,6 +38,19 @@ namespace tir {
// These information are needed during codegen.
class BuiltinLower : public StmtExprMutator {
public:
static PrimFunc Build(PrimFunc func) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this also related to the stated purpose of the PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was another bug I ran into. The LowerTVMBuiltin pass would look for each Allocate node, and would replace it with a call to the PackedFunc TVMBackendAllocWorkspace, using the exact device type on which the allocation should occur. This change allowed it to search in the function attributes for the device type, and not just in the AttrStmt.

Partly, this is due to the expanded role of the "tir.is_host_func" attribute. Initially, this attribute was used to tag the host/device portions after SplitHostDevice. However, it also has been useful for tagging a function as being on the host, when in a portion of the lowering flow that doesn't yet know the exact targets for the device/host (example). This is useful functionality, so I expanding LowerTVMBuiltin to support it makes sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've spun this change out into an independent PR #16727.

Comment on lines +375 to +386
// As of 2024-03-14, Relax does not have an explicit
// representation for striding in `TensorStructInfo`. The
// `FLegalize` function for most operators is implemented in terms
// of `topi`, and is then converted from TE to `tir::PrimFunc`
// using `tvm::tir::CreatePrimFunc`. The `te::Tensor` is
// converted to a `tir::Buffer` in `RewriteStageToBlock`, and uses
// the default empty list for the strides. The empty strides
// represent a compact data array.
//
// Therefore, while Relax does not explicitly represent the
// striding of a tensor, it implicitly requires compact striding
// for any legalizable Tensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be worth discussing whether this is the specification we would want. Perhaps it would be good to have a helper function to assert/convert tensors into the expected format. Right now, tensors that don't conform to the expectations about stride/offset are completely "invisible" to Relax other than using these newly added functions

Copy link
Contributor Author

@Lunderberg Lunderberg Mar 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. This was more to state what the current implementation requires. The upgrade path that I'd imagine would be with a sequence of incremental changes:

  1. Provide ability to read the strides/offset.
  2. Insert assert statements (probably while lowering R.match_cast) that the user-provided buffers meet the implicit requirements.
  3. Update the TensorStructInfo to specify the byte offset and striding. (I'd default to a contiguous striding, but an arbitrary offset. The offset is the more useful one, and could be applied before delegating to an external kernel.)
  4. Update the assert statements to match the more flexible TensorStructInfo.

This PR would be step (1) in this sequence.

If an allocation occurs within a host function, it may not have a
device/host split.
@Lunderberg Lunderberg force-pushed the relax_inspect_dltensor branch from 27101fb to e9ffd2f Compare March 15, 2024 13:11
@Lunderberg
Copy link
Contributor Author

This PR branch has been updated to be based on top of the two spun-out PRs.

A follow-up PR to apache#16563.  This PR
implements similar operators to inspect the runtime values of
`DLTensor::strides` and `DLTensor::byte_offset`.  In addition, while the
element offset is not explicitly present in the `DLTensor` struct, a
Relax operator is implemented to infer it from the `byte_offset` and
`data_type` fields, for use when interacting with the TIR
`BufferNode::elem_offset` field.
@Lunderberg Lunderberg force-pushed the relax_inspect_dltensor branch from e9ffd2f to 3af29e5 Compare March 18, 2024 20:24
Copy link
Contributor

@slyubomirsky slyubomirsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for spinning off changes to other PRs that can be reviewed separately. I think the main concept is sound. We may want to revisit the default inferred PrimStructInfo for some of these calls in the future, namely if we handle offsets/strides more systematically later, though the approach here is correct for the present.

@Lunderberg Lunderberg merged commit 8274d14 into apache:main Mar 26, 2024
19 checks passed
@Lunderberg
Copy link
Contributor Author

We may want to revisit the default inferred PrimStructInfo for some of these calls in the future, namely if we handle offsets/strides more systematically later, though the approach here is correct for the present.

Sounds like a plan. I think the biggest use of strides would be in exposing a view of a tensor to a compute kernel, without requiring the entire tensor to be exposed. (e.g. Improved R.split legalization) That said, there's enough kernels that assume contiguous tensors, as are currently provided by Relax, that for now I'd want to keep that requirement.

@Lunderberg Lunderberg deleted the relax_inspect_dltensor branch March 26, 2024 13:57
thaisacs pushed a commit to thaisacs/tvm that referenced this pull request Apr 3, 2024
…pache#16721)

* [TIR] LowerTVMBuiltin may use device_type from PrimFunc annotation

If an allocation occurs within a host function, it may not have a
device/host split.

* lint fix

* [Relax] Implement operators to inspec DLTensor::strides and offset

A follow-up PR to apache#16563.  This PR
implements similar operators to inspect the runtime values of
`DLTensor::strides` and `DLTensor::byte_offset`.  In addition, while the
element offset is not explicitly present in the `DLTensor` struct, a
Relax operator is implemented to infer it from the `byte_offset` and
`data_type` fields, for use when interacting with the TIR
`BufferNode::elem_offset` field.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants