Skip to content

Commit

Permalink
[Relax] Implement operators to inspec DLTensor::strides and offset
Browse files Browse the repository at this point in the history
A follow-up PR to apache#16563.  This PR
implements similar operators to inspect the runtime values of
`DLTensor::strides` and `DLTensor::byte_offset`.  In addition, while the
element offset is not explicitly present in the `DLTensor` struct, a
Relax operator is implemented to infer it from the `byte_offset` and
`data_type` fields, for use when interacting with the TIR
`BufferNode::elem_offset` field.
  • Loading branch information
Lunderberg committed Mar 14, 2024
1 parent c3a6aba commit 27101fb
Show file tree
Hide file tree
Showing 7 changed files with 667 additions and 157 deletions.
97 changes: 97 additions & 0 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,33 @@ def shape(self) -> "_DLTensorShapeProxy":
self._check_for_tensor_struct_info()
return _DLTensorShapeProxy(self)

@property
def strides(self) -> "_DLTensorStrideProxy":
"""Returns a proxy object for accessing DLTensor::strides"""
self._check_for_tensor_struct_info()
return _DLTensorStrideProxy(self)

@property
def byte_offset(self) -> "Expr":
"""Returns a proxy object for accessing DLTensor::byte_offset"""
self._check_for_tensor_struct_info()
op = tvm.ir.Op.get("relax.inspect.tensor_byte_offset")
return tvm.relax.Call(op, [self])

@property
def elem_offset(self) -> "Expr":
"""Returns a proxy object for accessing a DLTensor's elem_offset
This parameter is not stored in the DLTensor, but is instead
derived from the DLTensor's byte offset and datatype. This is
exposed in Relax for ease of use, and for translation into the
`tir::BufferNode::elem_offset` field when interacting with TIR
buffers.
"""
self._check_for_tensor_struct_info()
op = tvm.ir.Op.get("relax.inspect.tensor_elem_offset")
return tvm.relax.Call(op, [self])


class _DLTensorDTypeProxy(tvm.runtime.ObjectGeneric):
"""A proxy object for unpacking DLDatatype from DLTensor
Expand Down Expand Up @@ -431,6 +458,76 @@ def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr:
return tvm.relax.Call(op, [self.tensor, axis])


class _DLTensorStrideProxy(tvm.runtime.ObjectGeneric):
"""A proxy object for unpacking the strides from DLTensor
Exposes accessors for the `DLTensor::strides` field. Accessing
these fields will produce `relax.Call` expressions, representing
the field's runtime value. If the datatype of the tensor is known
at compile-time, the `relax.Call` will be normalized into a
`relax.PrimValue`, with no runtime cost.
Parameters
----------
tensor: relax.Expr
The relax tensor (or a variable referring to a relax tensor),
whose runtime strides is being inspected.
"""

def __init__(self, tensor):
self.tensor = tensor

def asobject(self):
"""Provide expected in error message
This method is called when `_DLTensorStrideProxy` is used in a
context that requires a `relax.Expr`. This usage is not
supported, and raising an error here can provide suggested
fixes that are not present in the default error message from
`tvm.runtime.convert_to_object`.
"""
raise TypeError(
f"{self.tensor}.strides cannot be converted to a relax expression, "
f"and should be used as a proxy object to access the runtime strides of the DLTensor. "
f"The DLTensor::ndim field can be accessed as len({self.tensor}), "
f"and the DLTensor::strides array can be accessed as {self.tensor}.strides[i]"
)

def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr:
"""Returns the extent of a tensor axis
Parameters
----------
axis: Union[int, PrimExpr, Expr]
The tensor axis whose extent should be returned. For ease
of use, any python integers or TIR expressions are
converted to `relax.Expr`.
Returns
-------
extent: Expr
The extent of the tensor's axis.
"""

if not isinstance(axis, tvm.relax.Expr):
axis = tvm.relax.PrimValue(axis)

if axis.struct_info_ is not None and not isinstance(
axis.struct_info_, tvm.relax.PrimStructInfo
):
raise TypeError(
f"The index used to access {self.tensor}.strides "
f'must have struct info R.Prim("int64"), '
f"but index {axis} had struct info {axis.struct_info_}."
)

op = tvm.ir.Op.get("relax.inspect.tensor_stride_i")
return tvm.relax.Call(op, [self.tensor, axis])


@tvm._ffi.register_object("relax.expr.Call")
class Call(ExprWithOp):
"""Function call node in Relax.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/transform/legalize_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from . import grad
from . import image
from . import index
from . import inspect_op
from . import linear_algebra
from . import manipulate
from . import nn
Expand Down
128 changes: 128 additions & 0 deletions python/tvm/relax/transform/legalize_ops/inspect_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Legalization functions for DLTensor inspection."""

import enum

from tvm.script import tir as T

from ...block_builder import BlockBuilder
from ...expr import Call, Expr
from .common import register_legalize


class TVMStructFieldKind(enum.IntEnum):
"""Equivalent to tvm::tir::builtin::TVMStructFieldKind
This does not use `enum.auto()` to define the values, because
`enum.auto()` starts from 1, and this must match the C++
definition which starts from 0.
"""

kArrAddr = 0
kArrData = 1
kArrShape = 2
kArrStrides = 3
kArrNDim = 4
kArrTypeCode = 5
kArrTypeBits = 6
kArrTypeLanes = 7
kArrByteOffset = 8
kArrDeviceId = 9
kArrDeviceType = 10
kArrKindBound_ = 11
kTVMValueContent = 12
kTVMValueKindBound_ = 13


@register_legalize("relax.inspect.tensor_stride_i")
def _tensor_stride_i(bb: BlockBuilder, call: Call) -> Expr:
@T.prim_func(private=True)
def _get_tensor_stride_i(dlpack_handle: T.handle, axis: T.int64) -> T.int64:
T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": T.bool(True)})
assert T.int64(0) <= axis, "Specified axis may not be negative"
ndim: T.int32 = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrNDim), "int32"
)
assert axis < T.Cast(
"int64", ndim
), "Specified axis may not be larger than the tensor's dimensionality"
stride_ptr: T.handle("int64") = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrStrides), "handle"
)

if T.isnullptr(stride_ptr):
shape_ptr: T.handle("int64") = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrShape), "handle"
)
shape = T.decl_buffer(ndim, "int64", data=shape_ptr)

product = T.decl_buffer([], "int64")
product[()] = 1

# TODO(Lunderberg): Add a TIR lowering pass to allow
# ranges to start somewhere other than zero. This loop
# could then iterate on `range(axis+1, ndim)`.
for dim_offset in range(ndim - (axis + 1)):
dim = dim_offset + (axis + 1)
product[()] = product[()] * shape[dim]

return product[()]
else:
strides = T.decl_buffer(ndim, "int64", data=stride_ptr)
stride: T.int64 = strides[axis]
return stride

gvar = bb.add_func(_get_tensor_stride_i, "_get_tensor_stride_i")
return Call(gvar, call.args)


@register_legalize("relax.inspect.tensor_byte_offset")
def _tensor_byte_offset(bb: BlockBuilder, call: Call) -> Expr:
@T.prim_func(private=True)
def _get_tensor_byte_offset(dlpack_handle: T.handle) -> T.int64:
T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": T.bool(True)})
byte_offset: T.uint64 = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrByteOffset), "uint64"
)
return byte_offset

gvar = bb.add_func(_get_tensor_byte_offset, "_get_tensor_byte_offset")
return Call(gvar, call.args)


@register_legalize("relax.inspect.tensor_elem_offset")
def _tensor_elem_offset(bb: BlockBuilder, call: Call) -> Expr:
@T.prim_func(private=True)
def _get_tensor_elem_offset(dlpack_handle: T.handle) -> T.int64:
T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": T.bool(True)})
byte_offset: T.uint64 = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrByteOffset), "uint64"
)
scalar_bits: T.uint8 = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrTypeBits), "uint8"
)
lanes: T.uint16 = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrTypeLanes), "uint16"
)
bytes_per_element = T.ceildiv(scalar_bits.astype("uint64") * lanes.astype("uint64"), 8)
elem_offset = byte_offset // bytes_per_element
return elem_offset

gvar = bb.add_func(_get_tensor_elem_offset, "_get_tensor_elem_offset")
return Call(gvar, call.args)
Loading

0 comments on commit 27101fb

Please sign in to comment.