-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI][ARM] Improve injective schedule #2801
Conversation
@ajtulloch please review :) |
Yeah. this is very useful like @hlu1 How about if len(s[x].op.axis) >= 5: # it is very useful when we have NCHWxC.
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1], s[x].op.axis[2])
s[x].parallel(fused)
elif len(s[x].op.axis) >= 3:
fused = s[x].fuse(s[x].op.axis[0], s[x].op.axis[1])
s[x].parallel(fused)
else:
s[x].parallel(s[x].op.axis[0])
s[x].vectorize(list(s[x].op.axis)[-1]) |
We also need to consider the case of int8/unit8. For example, when you add two int8 numbers together to produce 1 int16 number, the simd width is 128/16 = 8. I think in general 8 should be a good compromise. |
Ok. Could you add >=5 like we do it in x86? https://github.com/dmlc/tvm/blob/master/topi/python/topi/x86/injective.py#L26. This could help us in NCHWxC layout transform. |
That should have been covered by:
I used |
Oops. I haven't noticed it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
d191369
to
a6e3656
Compare
@hlu1 Could you help to see this discussion on discuss forum? https://discuss.tvm.ai/t/relay-build-target-rasp3b-something-wrong/2195 This issue should be related with this changeset. |
@FrozenGene, thanks for letting me know. Fixed in #3061 |
The generic injective schedule does not have vectorization and is therefore slow on ARM CPU. With vectorization, it can run 2-3x faster. For example, for a upsample_relu layer with 48 x 48 x 48 (C, H, W), the vectorized code runs at 0.003 ms/iter compared to 0.008 ms/iter on raspberry pi.