-
-
Notifications
You must be signed in to change notification settings - Fork 880
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RWKV-4 169m/430m in browser with ORT Web / TF.js / tfjs-tflite? #7
Comments
Exporting to ONNX is something that I've been tinkering with and I can report that the 169m RWKV-4 model does run in browser. Here's my code: https://github.com/AXKuhta/RWKV-LM/tree/onnx There's two things missing:
Running |
Great work :) Did you get this error with webgl? cannot resolve operator 'Max' with opsets: ai.onnx v13 You can remove RWKV_HEAD_QK and RWKV-ffnPre which are not required for Pile models, and probably that will fix it. p.s. upgrade onnxruntime to latest version and then you can test CUDAExecutionProvider in python. I think you might be using an older onnxruntime because all new versions require explicitly setting providers when initializing InferenceSession(). |
@AXKuhta Nice! I got a web demo going here (for 169m and 430m):
But it seems like something is going wrong - the model isn't "coherent" in using the context. For example, if you prompt the 430m model with "The capital of France is" it continues with "first of the, the city of Paris". I checked that the tokenizer is working properly, so I think it's something to do with the inference / context-handling code. Some other random notes:
|
That seems familiar!
It looks like you display the outputs during the prompt-feeding stage, which happens one token at a time. That should fix it: let token = greedySampling(results.x.data);
if (promptTokens.length == 0) {
+ if(streamingCallback) streamingCallback(token);
ctx.push( token );
} else {
ctx.push( promptTokens.shift() );
}
-
- if(streamingCallback) streamingCallback(token);
feeds.xx_att = results.xx_att_r;
feeds.aa_att = results.aa_att_r; |
@josephrocca I had to host the demo locally because huggingface keeps terminating the model downloads for some reason, but otherwise I can confirm that it works on my machine. Good job with getting the tokenizer and the quantization working!
Yeah, that's what's happening. RWKV-v4 is bf16 which can't be losslessly converted to fp16, so fp32 is the next best option. The fp32-converted model also compresses really well since half the bytes in it are zero. |
take a look at src/model_run.py. moreover, use https://github.com/daquexian/onnx-simplifier to optimizer the onnx model |
And the onnx version might work for AMD & Intel gpus. The DirectML backend supports them (on win10). I tried that for RWKV-1. |
You can loseless "transform" bf16 to fp16, and the idea is to use the same raw binary value. The float value will be totally different, but you can do an inverse transform in JS to loselessly recover the original bf16. |
@AXKuhta Could have sworn I replied here earlier, sorry - apparently I didn't click send. I fixed the demo according to your comment soon after you posted it - thanks for your help!! Strange that huggingface is terminating the download for you... 🤔 @BlinkDL Thanks for the tips! I'll look into the stuff you've mentioned. |
The python code for RWKV-2 weight conversion to .bin (for tf.js):
You can gradually port it to RWKV-4 by matching the outputs for each layer. The Chinese RWKV-2 has a better UI: https://github.com/BlinkDL/AI-Writer/blob/main/docs/index.html The English RWKV-2: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng |
Add top-p top-k and temperature and then it's very usable :) |
It looks like the webgl backend has a lot of limitations. I did some testing by stripping out different parts of the model in order to see if I can get anything at all to work on the webgl backend. I think I got like four different error messages with different combinations. The bottom line is that I can't even get a matmul to work. https://github.com/AXKuhta/RWKV-LM/blob/matmul/RWKV-v4/matmul_test.py It does work on the wasm backend! EDIT: It actually works on webgl if you do this: AXKuhta@75ad160 |
I have been able to force the full model to run on webgl, but it doesn't produce anything coherent, so something's still broken: https://github.com/AXKuhta/RWKV-LM/tree/onnx_webgl @BlinkDL The "cannot resolve operator 'Max' with opsets: ai.onnx v13" error was caused by |
That's great. Could you check whether https://github.com/daquexian/onnx-simplifier can help? Use https://github.com/lutzroeder/netron to visualize models. And then you can print() the outputs of interesting layers to find the culprit... gradually matching the results of webgl vs wasm. |
@BlinkDL After some painstaking debugging I got it to produce coherent output on webgl. The fix was really bizarre: add Here's the changes: https://github.com/AXKuhta/RWKV-LM/commits/onnx_webgl
I did try onnx-simplifier with RWKV-3, but it didn't find much to simplify. The graph was almost unchanged. I will retest with RWKV-4 though. |
@AXKuhta Nice! Can you upload the webgl-compatible 169m/430m models to hugging face so I can add them to the web demo? Also, I wonder if the +0.0 bug is something that would be worth reporting to the ONNX runtime team? |
@josephrocca I think it's better to keep all the web models in one place so I made two PRs in your huggingface repository. Oh, and by the way, I also improved my initial index.html a little to not create new tensors inside the loop and to remove leading_pad(). I think you should integrate these changes into your demo too. I ran some performance tests with the hardware that I have available:
These numbers are not very impressive 😹 I didn't try in on a real GPU with a wide memory bus, but I suspect it won't perform massively better. There are three different webgl bug reports to be made to onnxruntime:
|
@AXKuhta Maybe there are some hidden bottlenecks :) Check the time consumption of all major functions and code fragments. |
@AXKuhta Thanks! Great work. I've always struggled with the WebGL backend - I'm guessing that it doesn't get as much attention as wasm because it isn't a port of C++, but must be written from scratch IIUC. I'm hoping that WebGPU will change that situation and we'll get really serious GPU ML on the web. Another factor RE performance could be relevant here is that wasm can just be faster for some models, but I'd have thought that this would only be the case for models that are very small. Some discussion in this article about tf.js: https://blog.tensorflow.org/2020/09/supercharging-tensorflowjs-webassembly.html |
@BlinkDL The final [768, 50277] matmul is the slowest component. It's almost as slow as the entire model on WASM, which is kind of surprising, considering that GPUs are supposed to be good at matmul. It may be caused by the fact that it doesn't fit under the texture size limit of 16384 so onnxruntime does some magic to remap it into a 6214x6214 texture instead, possibly making it slow.
@josephrocca Yeah, I think it's better to wait for WebGPU instead of pursuing WebGL any further. It seems to work well for graphics, but not so much for compute. |
Probably can try tf.js for the final matmul and see if its faster |
@AXKuhta @josephrocca And actually you can skip the final matmul when scanning the prompt (because we just need the hidden states). I will provide some more efficient code soon to quickly generate the initial hidden states from prompt. |
Oh and please check the speed of onnxruntime in pytorch :) I wonder if it will be faster. You can actually install pytorch in Android too. |
@BlinkDL Ooh, somehow I didn't think of that before! There is a "only_execute_path_to_fetches" switch in onnxruntime that can be used to make this work even with existing .onnx files. It looks like they forgot to expose it to JavaScript, so I had to make a custom build of ort-wasm-simd.wasm with that flag toggled in the source. I found that it actually works:
I put the custom-built ort-wasm-simd.wasm and the index.html updated with fetches logic here if anyone wants to try this too. I think it should be possible to pack both the RNN-style model and the GPT-style model into a single .onnx graph. Since the weights are shared between the two, there would only be a minimal increase in file size. I'll wait for the new GPT code (The current one doesn't run without CUDA btw). |
Here's some performance numbers for RWKV-4 with pytorch and native onnxruntime:
But I think I made a bit of a mistake by not excluding sample_logits() from the pytorch version. It seems to take somewhere about ~10ms too. I need to rerun those tests with more caution. EDIT: I totally forgot that my test_onnx.py had sample_logits() too, so these comparisons are fair after all. |
Finally tested the webgl backend on a real GPU:
As seen above, the 430m model also works on webgl now. It turns out my state store/restore code was breaking it: with a 24 layer model, it would attempt to stack 24 tensors at once, which would exceed the 16 input textures limit in WebGL. I worked around this by stacking 12 tensors at a time, twice, then using torch.cat() to glue two stacks. The stacking code can be removed, but then the 430m model will have 120 individual inputs/outputs for state, which sound scary. I guess this kind of vindicates the webgl backend? It does outperform wasm when used on a real GPU, and it can also run the non-quantized 430m model, while wasm can't. Of course, it is still significantly slower than native. @josephrocca I opened two new PRs in your huggingface repo, one with the updated 430m webgl model and the other removing the outdated model. |
@AXKuhta Thanks! I've accepted the pull request and updated the demo.
Note that the wasm runtime should be able to run the non-quantized, 1.7GB model with no problems if it had enough memory available. There's currently an arbitrary 2GB limit that needs to be raised: microsoft/onnxruntime#10957 (comment) The memory limits should be gone completely once we get Memory64: https://github.com/WebAssembly/memory64 |
So there is work ongoing to lift that limit. That's good to know 👍 |
Please try the raw binary BF16 trick too :) #7 (comment) And please show the progress (1/32 etc.) in the webpage |
@AXKuhta Another idea: the w.emb.weight shall be a simple Float32Array on CPU. |
@BlinkDL That's totally doable. Placing it into a separate file seems like a reasonable way to accomplish this. It may fix the NaN problem with webgl on Snapdragon iGPUs, which happened exactly in w.emb.weight[ctx[-1]]. I think it was caused by a 4096x4096 texture size limit in Adreno GL ES drivers, unlike 16384x16384 on AMD/Nvidia/Intel. A 50277x768 tensor represented as a 6214x6214 texture thus fails to fit on Adreno. Final matmul is probably broken on Adreno too because of this.
I don't quite understand the idea here. Do you mean storing bf16 weights in files and then converting them to fp32 or fp16 at runtime? |
Yeah remove it from the ONNX model. The model will directly use the embedded vector as input. Saves VRAM and shall be much faster.
Unfortunately you will still have this problem when doing the head (output) matmul... But I think you can split 50277 into chunks.
Yeah storing bf16 weight as 16bit binary files. Then decode them at runtime in JS when loading the model. See https://github.com/BlinkDL/AI-Writer/blob/main/docs/eng/index.html#L231 for loading binary weights |
Added support for lr_final
Merge pull request BlinkDL#7 from PicoCreator/dev-infctx-lr-final
Hi, really exciting project! I'm wondering if you've published the model conversion script that you used to create the js_models files from the
.pth
model file? It would be awesome to see how the larger and newer models like RWKV-4 169m/430m perform in the browser! I think the inference speed of RWKV opens up many new possibilities for language models on the web.The text was updated successfully, but these errors were encountered: