Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RWKV-4 169m/430m in browser with ORT Web / TF.js / tfjs-tflite? #7

Open
josephrocca opened this issue Aug 20, 2022 · 32 comments
Open

Comments

@josephrocca
Copy link

Hi, really exciting project! I'm wondering if you've published the model conversion script that you used to create the js_models files from the .pth model file? It would be awesome to see how the larger and newer models like RWKV-4 169m/430m perform in the browser! I think the inference speed of RWKV opens up many new possibilities for language models on the web.

@AXKuhta
Copy link

AXKuhta commented Aug 20, 2022

Exporting to ONNX is something that I've been tinkering with and I can report that the 169m RWKV-4 model does run in browser. Here's my code: https://github.com/AXKuhta/RWKV-LM/tree/onnx

There's two things missing:

  • JavaScript implementation of the tokenizer
  • JavaScript implementation of sample_logits().

Running python -i -u export_onnx.py and then rnn_export() will export the model as rwkw.onnx, which can then be tested with test_onnx.py and loaded from index.html. The demo in index.html uses greedy sampling and you just sorta have to visit https://goose.ai/tokenizer in order to encode/decode the text. It works using the wasm backend, but unfortunately throws an error if you try the webgl backend.

@BlinkDL
Copy link
Owner

BlinkDL commented Aug 20, 2022

Exporting to ONNX is something that I've been tinkering with and I can report that the 169m RWKV-4 model does run in browser. Here's my code: https://github.com/AXKuhta/RWKV-LM/tree/onnx

Great work :)

Did you get this error with webgl? cannot resolve operator 'Max' with opsets: ai.onnx v13

You can remove RWKV_HEAD_QK and RWKV-ffnPre which are not required for Pile models, and probably that will fix it.

p.s. upgrade onnxruntime to latest version and then you can test CUDAExecutionProvider in python. I think you might be using an older onnxruntime because all new versions require explicitly setting providers when initializing InferenceSession().

@josephrocca
Copy link
Author

josephrocca commented Aug 21, 2022

@AXKuhta Nice! I got a web demo going here (for 169m and 430m):

But it seems like something is going wrong - the model isn't "coherent" in using the context. For example, if you prompt the 430m model with "The capital of France is" it continues with "first of the, the city of Paris". I checked that the tokenizer is working properly, so I think it's something to do with the inference / context-handling code.

Some other random notes:

  • The models were twice their size when porting to ONNX - e.g. 169m model goes from 339MB to 679MB. I quantized down to 171MB, but that makes inference half the speed (~5 tokens/sec for quantized versus ~13 tokens/sec for original). I'm guessing the non-quantized version have been converted from f16 to f32, hence the size doubling? The demo includes both the normal and quantized versions.

  • @BlinkDL Yes, I got TypeError: cannot resolve operator 'Max' with opsets: ai.onnx v13 when trying to use the WebGL backend. How would I go about removing RWKV_HEAD_QK and RWKV-ffnPre? I made a conversion notebook here: https://colab.research.google.com/github/josephrocca/rwkv-v4-web/blob/main/RWKV_v4_ONNX_conversion.ipynb Is it as simple as adding a few commands to that, or is there more work involved?

  • The WebGL backend doesn't work with quantized models. It gives this error: TypeError: cannot resolve operator 'DequantizeLinear' with opsets: ai.onnx v13, com.microsoft.experimental v1, ai.onnx.preview.training v1, ai.onnx.training v1, com.ms.internal.nhwc v17, org.pytorch.aten v1, com.microsoft.nchwc v1, ai.onnx.ml v3, com.microsoft v1

  • I used a very overkill approach to getting the tokenizer working... https://github.com/josephrocca/tokenizers-pyodide I haven't looked into how different the tokenizer is from gpt 2/3, but if it's similar, then I guess it shouldn't be too hard to make an edited version of this https://github.com/josephrocca/gpt-2-3-tokenizer ?

@AXKuhta
Copy link

AXKuhta commented Aug 21, 2022

For example, if you prompt the 430m model with "The capital of France is" it continues with "first of the, the city of Paris"

That seems familiar!

The => first
The capital => of
The capital of => the
The capital of France => ,
The capital of France is => the

It looks like you display the outputs during the prompt-feeding stage, which happens one token at a time.

That should fix it:

         let token = greedySampling(results.x.data);

         if (promptTokens.length == 0) {
+          if(streamingCallback) streamingCallback(token);
           ctx.push( token );
         } else {
           ctx.push( promptTokens.shift() );
         }
-
-        if(streamingCallback) streamingCallback(token);

         feeds.xx_att = results.xx_att_r;
         feeds.aa_att = results.aa_att_r;

@AXKuhta
Copy link

AXKuhta commented Aug 21, 2022

@josephrocca I had to host the demo locally because huggingface keeps terminating the model downloads for some reason, but otherwise I can confirm that it works on my machine. Good job with getting the tokenizer and the quantization working!

I'm guessing the non-quantized version have been converted from f16 to f32, hence the size doubling?

Yeah, that's what's happening. RWKV-v4 is bf16 which can't be losslessly converted to fp16, so fp32 is the next best option. The fp32-converted model also compresses really well since half the bytes in it are zero.

@BlinkDL
Copy link
Owner

BlinkDL commented Aug 21, 2022

take a look at src/model_run.py.
for the pile model, self.model_type == 'RWKV' and RWKV_HEAD_QK_DIM = 0 so you can remove some useless code.

moreover, use https://github.com/daquexian/onnx-simplifier to optimizer the onnx model

@BlinkDL
Copy link
Owner

BlinkDL commented Aug 21, 2022

And the onnx version might work for AMD & Intel gpus. The DirectML backend supports them (on win10).

I tried that for RWKV-1.

@BlinkDL
Copy link
Owner

BlinkDL commented Aug 21, 2022

Yeah, that's what's happening. RWKV-v4 is bf16 which can't be losslessly converted to fp16, so fp32 is the next best option. The fp32-converted model also compresses really well since half the bytes in it are zero.

You can loseless "transform" bf16 to fp16, and the idea is to use the same raw binary value. The float value will be totally different, but you can do an inverse transform in JS to loselessly recover the original bf16.

@josephrocca
Copy link
Author

josephrocca commented Aug 21, 2022

@AXKuhta Could have sworn I replied here earlier, sorry - apparently I didn't click send. I fixed the demo according to your comment soon after you posted it - thanks for your help!! Strange that huggingface is terminating the download for you... 🤔

@BlinkDL Thanks for the tips! I'll look into the stuff you've mentioned.

@BlinkDL
Copy link
Owner

BlinkDL commented Aug 21, 2022

Hi, really exciting project! I'm wondering if you've published the model conversion script that you used to create the js_models files from the .pth model file? It would be awesome to see how the larger and newer models like RWKV-4 169m/430m perform in the browser! I think the inference speed of RWKV opens up many new possibilities for language models on the web.

The python code for RWKV-2 weight conversion to .bin (for tf.js):

w = torch.load(MODEL_NAME + '.pth')
for x in w.keys():
	if 'copy_mask' in x: # this is for headQK which is not used in pile models
		continue
	print(x, w[x].shape)
	
	# we are doing some pre-computations here. change them to match RWKV-4. or you can just skip all of them and do everything in js first.
	if '.time_' in x: 
		w[x] = w[x].squeeze()
	if '.time_decay' in x:
		w[x] = torch.exp(-torch.exp(w[x]))
	if '.time_first' in x:
		w[x] = torch.exp(w[x])
	
	w[x].numpy().tofile(f'20220425/{x}.bin')

You can gradually port it to RWKV-4 by matching the outputs for each layer.

The Chinese RWKV-2 has a better UI: https://github.com/BlinkDL/AI-Writer/blob/main/docs/index.html

The English RWKV-2: https://github.com/BlinkDL/AI-Writer/tree/main/docs/eng

@BlinkDL
Copy link
Owner

BlinkDL commented Aug 21, 2022

@AXKuhta Could have sworn I replied here earlier, sorry - apparently I didn't click send. I fixed the demo according to your comment soon after you posted it

Add top-p top-k and temperature and then it's very usable :)

@AXKuhta
Copy link

AXKuhta commented Aug 23, 2022

It looks like the webgl backend has a lot of limitations. I did some testing by stripping out different parts of the model in order to see if I can get anything at all to work on the webgl backend. I think I got like four different error messages with different combinations. The bottom line is that I can't even get a matmul to work.

matmul

https://github.com/AXKuhta/RWKV-LM/blob/matmul/RWKV-v4/matmul_test.py
https://github.com/AXKuhta/RWKV-LM/blob/matmul/RWKV-v4/index.html

It does work on the wasm backend!

EDIT: It actually works on webgl if you do this: AXKuhta@75ad160

@AXKuhta
Copy link

AXKuhta commented Aug 23, 2022

I have been able to force the full model to run on webgl, but it doesn't produce anything coherent, so something's still broken:

https://github.com/AXKuhta/RWKV-LM/tree/onnx_webgl

@BlinkDL The "cannot resolve operator 'Max' with opsets: ai.onnx v13" error was caused by torch.maximum(pp, ww) and I was able to suppress it by using torch.max(torch.stack([pp, ww]), 0).values instead. I also had to add a bunch of .view([768,1]) around matmul operations and then fix layer_norm() from producing NaNs. Now it looks like self.FF() always produces zeroes on webgl, but I'm not sure yet as to why nope, self.FF() does produce something.

@BlinkDL
Copy link
Owner

BlinkDL commented Aug 23, 2022

I have been able to force the full model to run on webgl, but it doesn't produce anything coherent, so something's still broken:

https://github.com/AXKuhta/RWKV-LM/tree/onnx_webgl

@BlinkDL The "cannot resolve operator 'Max' with opsets: ai.onnx v13" error was caused by torch.maximum(pp, ww) and I was able to suppress it by using torch.max(torch.stack([pp, ww]), 0).values instead. I also had to add a bunch of .view([768,1]) around matmul operations and then fix layer_norm() from producing NaNs. Now it looks like self.FF() always produces zeroes on webgl, but I'm not sure yet as to why nope, self.FF() does produce something.

That's great. Could you check whether https://github.com/daquexian/onnx-simplifier can help? Use https://github.com/lutzroeder/netron to visualize models.

And then you can print() the outputs of interesting layers to find the culprit... gradually matching the results of webgl vs wasm.

@AXKuhta
Copy link

AXKuhta commented Aug 24, 2022

@BlinkDL After some painstaking debugging I got it to produce coherent output on webgl. The fix was really bizarre: add + 0.0 in a bunch of places. Some nodes on the ONNX graph that follow matmul+reshape operations kept getting bugged inputs that looked like a single value across all 768 elements. Performing +0.0 with the bugged input fixes it.

Here's the changes: https://github.com/AXKuhta/RWKV-LM/commits/onnx_webgl

Could you check whether https://github.com/daquexian/onnx-simplifier can help?

I did try onnx-simplifier with RWKV-3, but it didn't find much to simplify. The graph was almost unchanged. I will retest with RWKV-4 though.

@josephrocca
Copy link
Author

josephrocca commented Aug 24, 2022

@AXKuhta Nice! Can you upload the webgl-compatible 169m/430m models to hugging face so I can add them to the web demo?

Also, I wonder if the +0.0 bug is something that would be worth reporting to the ONNX runtime team?

@AXKuhta
Copy link

AXKuhta commented Aug 25, 2022

@josephrocca I think it's better to keep all the web models in one place so I made two PRs in your huggingface repository. Oh, and by the way, I also improved my initial index.html a little to not create new tensors inside the loop and to remove leading_pad(). I think you should integrate these changes into your demo too.

I ran some performance tests with the hardware that I have available:

All tests performed in Chromium
169m model

========= WASM =========
Intel Core i7 2760QM:			280ms per token
Intel Core i7 6650U:			204ms per token
AMD A10-7800:				331ms per token
Snapdragon 865:				233ms per token

========= WebGL =========
Intel Core i7 2760QM iGPU 		600ms per token
Nvidia GeForce 520MX 			305ms per token
Intel Core i7 6650U iGPU 		192ms per token
AMD A10-7800 iGPU 			232ms per token
Snapdragon 865 iGPU:			Produces NaNs

These numbers are not very impressive 😹

I didn't try in on a real GPU with a wide memory bus, but I suspect it won't perform massively better.

There are three different webgl bug reports to be made to onnxruntime:

  • Matmuls like [768, 768] @ [768] complain about dimension mismatch, must be converted to [768, 768] @ [768, 1]
  • NaNs produced by layer_norm() if there are negative inputs
  • This strange +0.0 stuff if I can reproduce it in a standalone fashion

@BlinkDL
Copy link
Owner

BlinkDL commented Aug 25, 2022

@AXKuhta Maybe there are some hidden bottlenecks :) Check the time consumption of all major functions and code fragments.

@josephrocca
Copy link
Author

@AXKuhta Thanks! Great work. I've always struggled with the WebGL backend - I'm guessing that it doesn't get as much attention as wasm because it isn't a port of C++, but must be written from scratch IIUC. I'm hoping that WebGPU will change that situation and we'll get really serious GPU ML on the web.

Another factor RE performance could be relevant here is that wasm can just be faster for some models, but I'd have thought that this would only be the case for models that are very small. Some discussion in this article about tf.js: https://blog.tensorflow.org/2020/09/supercharging-tensorflowjs-webassembly.html

image

@AXKuhta
Copy link

AXKuhta commented Aug 26, 2022

@BlinkDL The final [768, 50277] matmul is the slowest component. It's almost as slow as the entire model on WASM, which is kind of surprising, considering that GPUs are supposed to be good at matmul. It may be caused by the fact that it doesn't fit under the texture size limit of 16384 so onnxruntime does some magic to remap it into a 6214x6214 texture instead, possibly making it slow.

Gradually removing parts of the model until there is nothing left except input->output passthrough
Nvidia GeForce 520MX
169m model

Baseline full model			344ms		N/A
Removed state store/restore		326ms		-18ms
Removed final matmul 			145ms 		-181ms
Removed 12 x self.FF() 			60ms 		-85ms
Removed 12 x self.SA() 			30ms 		-30ms
Removed 26 x self.LN() 			16ms 		-14ms
Removed w.emb.weight[ctx[-1]] 		0.7ms 		-15.3ms

@josephrocca Yeah, I think it's better to wait for WebGPU instead of pursuing WebGL any further. It seems to work well for graphics, but not so much for compute.

@BlinkDL
Copy link
Owner

BlinkDL commented Aug 26, 2022

@BlinkDL The final [768, 50277] matmul is the slowest component. It's almost as slow as the entire model on WASM, which is kind of surprising, considering that GPUs are supposed to be good at matmul. It may be caused by the fact that it doesn't fit under the texture size limit of 16384 so onnxruntime does some magic to remap it into a 6214x6214 texture instead, possibly making it slow.

Probably can try tf.js for the final matmul and see if its faster

@BlinkDL
Copy link
Owner

BlinkDL commented Aug 26, 2022

@AXKuhta @josephrocca And actually you can skip the final matmul when scanning the prompt (because we just need the hidden states).

I will provide some more efficient code soon to quickly generate the initial hidden states from prompt.

@BlinkDL
Copy link
Owner

BlinkDL commented Aug 26, 2022

Oh and please check the speed of onnxruntime in pytorch :) I wonder if it will be faster.

You can actually install pytorch in Android too.

@AXKuhta
Copy link

AXKuhta commented Aug 27, 2022

And actually you can skip the final matmul when scanning the prompt

@BlinkDL Ooh, somehow I didn't think of that before!

There is a "only_execute_path_to_fetches" switch in onnxruntime that can be used to make this work even with existing .onnx files. It looks like they forgot to expose it to JavaScript, so I had to make a custom build of ort-wasm-simd.wasm with that flag toggled in the source. I found that it actually works:

Intel Core i7 2760QM
169m model
WASM only_execute_path_to_fetches = true

Don't want the x output 	158ms per token
Want the x output 		258ms per token

I put the custom-built ort-wasm-simd.wasm and the index.html updated with fetches logic here if anyone wants to try this too.

I think it should be possible to pack both the RNN-style model and the GPT-style model into a single .onnx graph. Since the weights are shared between the two, there would only be a minimal increase in file size. I'll wait for the new GPT code (The current one doesn't run without CUDA btw).

@AXKuhta
Copy link

AXKuhta commented Aug 27, 2022

Oh and please check the speed of onnxruntime in pytorch

Here's some performance numbers for RWKV-4 with pytorch and native onnxruntime:

Native pytorch + onnxruntime
169m model

Intel Core i7 2760QM 	Pytorch 	79.3 ms/token
Intel Core i7 2760QM 	ONNX 		152 ms/token 	Note: ONNX forced to use 8 threads to hit full CPU utilization

Intel Core i7 6650U 	Pytorch 	62.1 ms/token
Intel Core i7 6650U 	ONNX 		129 ms/token 	Note: ONNX forced to use 4 threads to hit full CPU utilization

Snapdragon 865 		Pytorch 	71.0 ms/token
Snapdragon 865		ONNX 		180ms/token

But I think I made a bit of a mistake by not excluding sample_logits() from the pytorch version. It seems to take somewhere about ~10ms too. I need to rerun those tests with more caution.

EDIT: I totally forgot that my test_onnx.py had sample_logits() too, so these comparisons are fair after all.

@AXKuhta
Copy link

AXKuhta commented Aug 28, 2022

Finally tested the webgl backend on a real GPU:

GTX 1060 6GB
webgl

169m model		68.6 ms/token
430m model 		119 ms/token

As seen above, the 430m model also works on webgl now. It turns out my state store/restore code was breaking it: with a 24 layer model, it would attempt to stack 24 tensors at once, which would exceed the 16 input textures limit in WebGL. I worked around this by stacking 12 tensors at a time, twice, then using torch.cat() to glue two stacks.

The stacking code can be removed, but then the 430m model will have 120 individual inputs/outputs for state, which sound scary.

I guess this kind of vindicates the webgl backend? It does outperform wasm when used on a real GPU, and it can also run the non-quantized 430m model, while wasm can't. Of course, it is still significantly slower than native.

@josephrocca I opened two new PRs in your huggingface repo, one with the updated 430m webgl model and the other removing the outdated model.

@josephrocca
Copy link
Author

@AXKuhta Thanks! I've accepted the pull request and updated the demo.

it can also run the non-quantized 430m model, while wasm can't

Note that the wasm runtime should be able to run the non-quantized, 1.7GB model with no problems if it had enough memory available. There's currently an arbitrary 2GB limit that needs to be raised: microsoft/onnxruntime#10957 (comment)

The memory limits should be gone completely once we get Memory64: https://github.com/WebAssembly/memory64

@AXKuhta
Copy link

AXKuhta commented Aug 29, 2022

The memory limits should be gone completely once we get Memory64

So there is work ongoing to lift that limit. That's good to know 👍

@BlinkDL
Copy link
Owner

BlinkDL commented Aug 30, 2022

@AXKuhta Thanks! I've accepted the pull request and updated the demo.

Please try the raw binary BF16 trick too :) #7 (comment)

And please show the progress (1/32 etc.) in the webpage

@BlinkDL
Copy link
Owner

BlinkDL commented Sep 2, 2022

@AXKuhta Another idea: the w.emb.weight shall be a simple Float32Array on CPU.

@AXKuhta
Copy link

AXKuhta commented Sep 3, 2022

Another idea: the w.emb.weight shall be a simple Float32Array on CPU.

@BlinkDL That's totally doable.

Placing it into a separate file seems like a reasonable way to accomplish this.

It may fix the NaN problem with webgl on Snapdragon iGPUs, which happened exactly in w.emb.weight[ctx[-1]]. I think it was caused by a 4096x4096 texture size limit in Adreno GL ES drivers, unlike 16384x16384 on AMD/Nvidia/Intel. A 50277x768 tensor represented as a 6214x6214 texture thus fails to fit on Adreno. Final matmul is probably broken on Adreno too because of this.

Please try the raw binary BF16 trick too :) #7 (comment)

I don't quite understand the idea here. Do you mean storing bf16 weights in files and then converting them to fp32 or fp16 at runtime?

@BlinkDL
Copy link
Owner

BlinkDL commented Sep 3, 2022

Placing it into a separate file seems like a reasonable way to accomplish this.

Yeah remove it from the ONNX model. The model will directly use the embedded vector as input. Saves VRAM and shall be much faster.

A 50277x768 tensor represented as a 6214x6214 texture thus fails to fit on Adreno

Unfortunately you will still have this problem when doing the head (output) matmul... But I think you can split 50277 into chunks.

Do you mean storing bf16 weights in files and then converting them to fp32 or fp16 at runtime?

Yeah storing bf16 weight as 16bit binary files. Then decode them at runtime in JS when loading the model.

See https://github.com/BlinkDL/AI-Writer/blob/main/docs/eng/index.html#L231 for loading binary weights

harrisonvanderbyl pushed a commit to harrisonvanderbyl/RWKV-LM that referenced this issue Jul 15, 2023
harrisonvanderbyl pushed a commit to harrisonvanderbyl/RWKV-LM that referenced this issue Jul 15, 2023
Merge pull request BlinkDL#7 from PicoCreator/dev-infctx-lr-final
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants