Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Index in triton #974

Closed
jiangzzsss opened this issue Dec 10, 2022 · 10 comments
Closed

Index in triton #974

jiangzzsss opened this issue Dec 10, 2022 · 10 comments

Comments

@jiangzzsss
Copy link

We'd like to do some indexing in triton kernels,
say we have x_ptr, idx_ptr, out_ptr

x = tl.load(x_ptr + offsets, mask = mask)
idx = tl.load(idx_ptr + offsets, mask = mask)

we have:
1.

idx = idx.to(tl.int32)
output = tl.load(x_ptr + idx)

it works
2.

output = tl.zeros([BLOCK_SIZE, ], dtype=tl.float32)
for i in range(0, BLOCK_SIZE):
      output[i] = x[idx[i]]

it reports errors. (error message is put at last)
**we want to know:

  1. if we using approach 1, is it memory efficient ? since we use load.
  2. if we try x[0], it also errors: "TypeError: 'constexpr' object is not iterable"
    we didn't see a lot in the docs , so are there any other ways of doing indexing ?**

we using Triton Version: 2.0.0.dev20221120, python 3.8.0 and run on A100
error logs of approach 2:

Traceback (most recent call last):
  File "<string>", line 21, in tri_index_kernel
KeyError: ('2-.-0-.-0-1e8410f206c822547fb50e2ea86e45a6-2b0c5161c53c71b37ae20a9996ee4bb8-3aa563e00c5c695dd945e23b09a86848-42648570729a4835b21c1c18cebedbfe-ff946bd4b3b4a4cbdf8cedc6e1c658e0-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float32, torch.float32, torch.float32, 'i32'), (64,), (True, True, True, (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 838, in make_triton_ir
    generator.visit(fn.parse())
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
    return visitor(node)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 260, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 368, in generic_visit
    self.visit(item)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
    return visitor(node)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 320, in visit_FunctionDef
    has_ret = self.visit_compound_statement(node.body)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 254, in visit_compound_statement
    self.last_ret = self.visit(stmt)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
    return visitor(node)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 648, in visit_For
    self.visit_compound_statement(node.body)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 254, in visit_compound_statement
    self.last_ret = self.visit(stmt)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
    return visitor(node)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 364, in visit_Assign
    _names += [self.visit(target)]
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 771, in visit
    return super().visit(node)
  File "/mnt/cache/share/spring/conda_envs/miniconda3/envs/s0.3.6_py38/lib/python3.8/ast.py", line 360, in visit
    return visitor(node)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 576, in visit_Subscript
    assert node.ctx.__class__.__name__ == "Load"
AssertionError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "ztest.py", line 51, in <module>
    output = tri_index(x, idx)
  File "ztest.py", line 44, in tri_index
    tri_index_kernel[grid](x, idx, output, n_elements, BLOCK_SIZE=64)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "<string>", line 41, in tri_index_kernel
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 1256, in compile
    asm, shared, kernel_name = _compile(fn, signature, device, constants, configs[0], num_warps, num_stages,
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 892, in _compile
    module, _ = make_triton_ir(fn, signature, specialization, constants)
  File "/mnt/cache/jiangzhen/.local/lib/python3.8/site-packages/triton/compiler.py", line 843, in make_triton_ir
    raise CompilationError(fn.src, node) from e
triton.compiler.CompilationError: at 23:8:
def tri_index_kernel(
    x_ptr,  # *Pointer* to first input vector
    idx_ptr,  # *Pointer* to second input vector
    output_ptr,  # *Pointer* to output vector
    n_elements,  # Size of the vector
    BLOCK_SIZE: tl.constexpr,  # Number of elements each program should process
                 # NOTE: `constexpr` so it can be used as a shape value
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    x = tl.load(x_ptr + offsets, mask = mask)
    idx = tl.load(idx_ptr + offsets, mask = mask)
    output = tl.zeros([BLOCK_SIZE, ], dtype=tl.float32)
    min_off = tl.min(offsets, axis=0)
    max_off = tl.max(offsets, axis=0)
    # idx //= 1
    idx = idx.to(tl.int32)
    output = tl.load(x_ptr + idx)
    for i in range(0, BLOCK_SIZE):
        output[i] = x[idx[i]]
        ^
@ptillet
Copy link
Collaborator

ptillet commented Dec 10, 2022

Yeah, on-chip indexing through shared memory isn't supported yet. It's on the roadmap though, but it's a pretty advanced feature so we haven't come up with a specific timeline yet.

@jiangzzsss
Copy link
Author

Thanks for the reply! Although I am still curious that if we store the values back and use indexes as ptrs to load them, will this be slow ?

@Jokeren
Copy link
Contributor

Jokeren commented Dec 21, 2022

Thanks for the reply! Although I am still curious that if we store the values back and use indexes as ptrs to load them, will this be slow ?

It's supposed to be slow since you store values on the global memory. Though in some cases you will go through the cache.

@nlgranger
Copy link

Triton just raises an assertion error when trying to index a local tensor. I suppose it is related to this issue. Are there any workarounds?

ZzEeKkAa pushed a commit to ZzEeKkAa/triton that referenced this issue Aug 5, 2024
Error code reference in L0 docs / ze_api.h is in base16
@marcelroed
Copy link

Any updates on this? Is there still no way to do indexing in a Triton kernel?

@jselvam11
Copy link

https://github.com/facebookresearch/xformers/blob/main/xformers/ops/_triton/k_index_select_cat.py

There’s this in xformers seems similar to indexing into a sparse tensor

@nlgranger
Copy link

https://github.com/facebookresearch/xformers/blob/main/xformers/ops/_triton/k_index_select_cat.py

There’s this in xformers seems similar to indexing into a sparse tensor

Yes but it goes through global memory which is slow as mentioned by @Jokeren.

@cagrikymk
Copy link

I have a similar issue but I only want to index different blocks such as (to compute a spline function up to a certain order):

data = tl.zeros((4, BLOCK_SIZE))
data[0] = w
data[1] = 1 - w
.....

I get similar kind of compiler error but this issue could be easily fixed by creating 4 different shared memory blocks (each with a specific name). In that case, iterating over these blocks with a for loop becomes the issue.

I think I can unroll and name everything to overcome the problem but that would produce unmaintainable code. Is there a known trick to get this to work other than going through global memory?

@webstorms
Copy link

Are there any updates on indexing through shared memory?

@peterbell10
Copy link
Contributor

#5262 adds support for this as

output = tl.gather(x, idx, axis=0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

9 participants