Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relax] Implement operators to inspec DLTensor::strides and offset #16721

Merged
merged 4 commits into from
Mar 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions python/tvm/relax/expr.py
Original file line number Diff line number Diff line change
@@ -280,6 +280,33 @@ def shape(self) -> "_DLTensorShapeProxy":
self._check_for_tensor_struct_info()
return _DLTensorShapeProxy(self)

@property
def strides(self) -> "_DLTensorStrideProxy":
"""Returns a proxy object for accessing DLTensor::strides"""
self._check_for_tensor_struct_info()
return _DLTensorStrideProxy(self)

@property
def byte_offset(self) -> "Expr":
"""Returns a proxy object for accessing DLTensor::byte_offset"""
self._check_for_tensor_struct_info()
op = tvm.ir.Op.get("relax.inspect.tensor_byte_offset")
return tvm.relax.Call(op, [self])

@property
def elem_offset(self) -> "Expr":
"""Returns a proxy object for accessing a DLTensor's elem_offset
This parameter is not stored in the DLTensor, but is instead
derived from the DLTensor's byte offset and datatype. This is
exposed in Relax for ease of use, and for translation into the
`tir::BufferNode::elem_offset` field when interacting with TIR
buffers.
"""
self._check_for_tensor_struct_info()
op = tvm.ir.Op.get("relax.inspect.tensor_elem_offset")
return tvm.relax.Call(op, [self])


class _DLTensorDTypeProxy(tvm.runtime.ObjectGeneric):
"""A proxy object for unpacking DLDatatype from DLTensor
@@ -431,6 +458,76 @@ def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr:
return tvm.relax.Call(op, [self.tensor, axis])


class _DLTensorStrideProxy(tvm.runtime.ObjectGeneric):
"""A proxy object for unpacking the strides from DLTensor
Exposes accessors for the `DLTensor::strides` field. Accessing
these fields will produce `relax.Call` expressions, representing
the field's runtime value. If the datatype of the tensor is known
at compile-time, the `relax.Call` will be normalized into a
`relax.PrimValue`, with no runtime cost.
Parameters
----------
tensor: relax.Expr
The relax tensor (or a variable referring to a relax tensor),
whose runtime strides is being inspected.
"""

def __init__(self, tensor):
self.tensor = tensor

def asobject(self):
"""Provide expected in error message
This method is called when `_DLTensorStrideProxy` is used in a
context that requires a `relax.Expr`. This usage is not
supported, and raising an error here can provide suggested
fixes that are not present in the default error message from
`tvm.runtime.convert_to_object`.
"""
raise TypeError(
f"{self.tensor}.strides cannot be converted to a relax expression, "
f"and should be used as a proxy object to access the runtime strides of the DLTensor. "
f"The DLTensor::ndim field can be accessed as len({self.tensor}), "
f"and the DLTensor::strides array can be accessed as {self.tensor}.strides[i]"
)

def __getitem__(self, axis: Union[int, PrimExpr, Expr]) -> Expr:
"""Returns the extent of a tensor axis
Parameters
----------
axis: Union[int, PrimExpr, Expr]
The tensor axis whose extent should be returned. For ease
of use, any python integers or TIR expressions are
converted to `relax.Expr`.
Returns
-------
extent: Expr
The extent of the tensor's axis.
"""

if not isinstance(axis, tvm.relax.Expr):
axis = tvm.relax.PrimValue(axis)

if axis.struct_info_ is not None and not isinstance(
axis.struct_info_, tvm.relax.PrimStructInfo
):
raise TypeError(
f"The index used to access {self.tensor}.strides "
f'must have struct info R.Prim("int64"), '
f"but index {axis} had struct info {axis.struct_info_}."
)

op = tvm.ir.Op.get("relax.inspect.tensor_stride_i")
return tvm.relax.Call(op, [self.tensor, axis])


@tvm._ffi.register_object("relax.expr.Call")
class Call(ExprWithOp):
"""Function call node in Relax.
1 change: 1 addition & 0 deletions python/tvm/relax/transform/legalize_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
from . import grad
from . import image
from . import index
from . import inspect_op
from . import linear_algebra
from . import manipulate
from . import nn
128 changes: 128 additions & 0 deletions python/tvm/relax/transform/legalize_ops/inspect_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Legalization functions for DLTensor inspection."""

import enum

from tvm.script import tir as T

from ...block_builder import BlockBuilder
from ...expr import Call, Expr
from .common import register_legalize


class TVMStructFieldKind(enum.IntEnum):
"""Equivalent to tvm::tir::builtin::TVMStructFieldKind
This does not use `enum.auto()` to define the values, because
`enum.auto()` starts from 1, and this must match the C++
definition which starts from 0.
"""

kArrAddr = 0
kArrData = 1
kArrShape = 2
kArrStrides = 3
kArrNDim = 4
kArrTypeCode = 5
kArrTypeBits = 6
kArrTypeLanes = 7
kArrByteOffset = 8
kArrDeviceId = 9
kArrDeviceType = 10
kArrKindBound_ = 11
kTVMValueContent = 12
kTVMValueKindBound_ = 13


@register_legalize("relax.inspect.tensor_stride_i")
def _tensor_stride_i(bb: BlockBuilder, call: Call) -> Expr:
@T.prim_func(private=True)
def _get_tensor_stride_i(dlpack_handle: T.handle, axis: T.int64) -> T.int64:
T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": T.bool(True)})
assert T.int64(0) <= axis, "Specified axis may not be negative"
ndim: T.int32 = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrNDim), "int32"
)
assert axis < T.Cast(
"int64", ndim
), "Specified axis may not be larger than the tensor's dimensionality"
stride_ptr: T.handle("int64") = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrStrides), "handle"
)

if T.isnullptr(stride_ptr):
shape_ptr: T.handle("int64") = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrShape), "handle"
)
shape = T.decl_buffer(ndim, "int64", data=shape_ptr)

product = T.decl_buffer([], "int64")
product[()] = 1

# TODO(Lunderberg): Add a TIR lowering pass to allow
# ranges to start somewhere other than zero. This loop
# could then iterate on `range(axis+1, ndim)`.
for dim_offset in range(ndim - (axis + 1)):
dim = dim_offset + (axis + 1)
product[()] = product[()] * shape[dim]

return product[()]
else:
strides = T.decl_buffer(ndim, "int64", data=stride_ptr)
stride: T.int64 = strides[axis]
return stride

gvar = bb.add_func(_get_tensor_stride_i, "_get_tensor_stride_i")
return Call(gvar, call.args)


@register_legalize("relax.inspect.tensor_byte_offset")
def _tensor_byte_offset(bb: BlockBuilder, call: Call) -> Expr:
@T.prim_func(private=True)
def _get_tensor_byte_offset(dlpack_handle: T.handle) -> T.int64:
T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": T.bool(True)})
byte_offset: T.uint64 = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrByteOffset), "uint64"
)
return byte_offset

gvar = bb.add_func(_get_tensor_byte_offset, "_get_tensor_byte_offset")
return Call(gvar, call.args)


@register_legalize("relax.inspect.tensor_elem_offset")
def _tensor_elem_offset(bb: BlockBuilder, call: Call) -> Expr:
@T.prim_func(private=True)
def _get_tensor_elem_offset(dlpack_handle: T.handle) -> T.int64:
T.func_attr({"tir.is_host": T.bool(True), "tir.is_scheduled": T.bool(True)})
byte_offset: T.uint64 = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrByteOffset), "uint64"
)
scalar_bits: T.uint8 = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrTypeBits), "uint8"
)
lanes: T.uint16 = T.tvm_struct_get(
dlpack_handle, 0, int(TVMStructFieldKind.kArrTypeLanes), "uint16"
)
bytes_per_element = T.ceildiv(scalar_bits.astype("uint64") * lanes.astype("uint64"), 8)
elem_offset = byte_offset // bytes_per_element
return elem_offset

gvar = bb.add_func(_get_tensor_elem_offset, "_get_tensor_elem_offset")
return Call(gvar, call.args)
Loading