-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Index in triton #974
Comments
Yeah, on-chip indexing through shared memory isn't supported yet. It's on the roadmap though, but it's a pretty advanced feature so we haven't come up with a specific timeline yet. |
Thanks for the reply! Although I am still curious that if we store the values back and use indexes as ptrs to load them, will this be slow ? |
It's supposed to be slow since you store values on the global memory. Though in some cases you will go through the cache. |
Triton just raises an assertion error when trying to index a local tensor. I suppose it is related to this issue. Are there any workarounds? |
Error code reference in L0 docs / ze_api.h is in base16
Any updates on this? Is there still no way to do indexing in a Triton kernel? |
https://github.com/facebookresearch/xformers/blob/main/xformers/ops/_triton/k_index_select_cat.py There’s this in xformers seems similar to indexing into a sparse tensor |
Yes but it goes through global memory which is slow as mentioned by @Jokeren. |
I have a similar issue but I only want to index different blocks such as (to compute a spline function up to a certain order):
I get similar kind of compiler error but this issue could be easily fixed by creating 4 different shared memory blocks (each with a specific name). In that case, iterating over these blocks with a for loop becomes the issue. I think I can unroll and name everything to overcome the problem but that would produce unmaintainable code. Is there a known trick to get this to work other than going through global memory? |
Are there any updates on indexing through shared memory? |
#5262 adds support for this as output = tl.gather(x, idx, axis=0) |
We'd like to do some indexing in triton kernels,
say we have x_ptr, idx_ptr, out_ptr
we have:
1.
it works
2.
it reports errors. (error message is put at last)
**we want to know:
we didn't see a lot in the docs , so are there any other ways of doing indexing ?**
we using Triton Version: 2.0.0.dev20221120, python 3.8.0 and run on A100
error logs of approach 2:
The text was updated successfully, but these errors were encountered: