Skip to content

Commit

Permalink
[INTERPRETER][NFC] Rename tensor_shape -> block_shape in interpre…
Browse files Browse the repository at this point in the history
…ter (#5195)

`tensor_shape` is a confusing name and doesn't match block pointer's
semantic.
`block_shape` is much clearer.
  • Loading branch information
Jokeren authored Nov 20, 2024
1 parent aaf64d6 commit 54c840b
Showing 1 changed file with 12 additions and 13 deletions.
25 changes: 12 additions & 13 deletions python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, data, dtype):
'''
data: numpy array
dtype: triton type, either pointer_type or scalar_type.
we don't store block_type here because the shape information is already availale in the data field
we don't store block_type here because the shape information is already available in the data field
attr: a dictionary of attributes
'''
self.data = data
Expand All @@ -46,24 +46,23 @@ def set_attr(self, key, value):

class BlockPointerHandle:

def __init__(self, base, shape, strides, offsets, tensor_shape, order):
def __init__(self, base, shape, strides, offsets, block_shape, order):
self.base = base
self.shape = shape
self.strides = strides
self.offsets = offsets
self.tensor_shape = tensor_shape
self.block_shape = block_shape
self.order = order

def materialize_pointers(self, boundary_check):
dtype_tt = self.base.get_element_ty()
n_bytes = dtype_tt.primitive_bitwidth // 8
tensor_shape = self.tensor_shape
ptrs = np.broadcast_to(self.base.data, self.tensor_shape)
masks = np.ones(self.tensor_shape, dtype=bool)
for dim in range(len(tensor_shape)):
bcast_dims = [1] * len(tensor_shape)
bcast_dims[dim] = tensor_shape[dim]
off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims)
ptrs = np.broadcast_to(self.base.data, self.block_shape)
masks = np.ones(self.block_shape, dtype=bool)
for dim in range(len(self.block_shape)):
bcast_dims = [1] * len(self.block_shape)
bcast_dims[dim] = self.block_shape[dim]
off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
if dim in boundary_check:
masks = np.logical_and(masks, off < self.shape[dim].data)
Expand Down Expand Up @@ -655,17 +654,17 @@ def create_barrier(self):
# Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter
pass

def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order):
def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order):
# Create new offsets to avoid modifying the original
new_offsets = [offset.clone() for offset in offsets]
return BlockPointerHandle(base, shape, strides, new_offsets, tensor_shape, order)
return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order)

def create_advance(self, ptr, offsets):
if len(ptr.offsets) != len(offsets):
raise ValueError("len(ptr.offsets) != len(offsets)")
# Create new offsets to avoid modifying the original
new_offsets = [offset.clone() for offset in ptr.offsets]
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.tensor_shape, ptr.order)
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order)
for i in range(len(offsets)):
ret.offsets[i].data += offsets[i].data
return ret
Expand Down

0 comments on commit 54c840b

Please sign in to comment.