-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How can I run this with wgpu? #3
Comments
burn-wgpu currently doesn't use the full device memory available so llama2 can't run with it just yet but I am working on a solution. Hopefully within the next few days I'll have it working with wgpu. |
Thank you for your effort. |
Could you please explain what exactly is the current limitation and maybe you also if you know if there are plans to solve it in |
I try to some modification type GraphicsApi = AutoGraphicsApi;
type Backend = WgpuBackend<GraphicsApi, Elem, i32>;
let device = WgpuDevice::default(); found some problem:
In Device::create_bind_group
Buffer binding 0 range 524288000 exceeds `max_*_buffer_binding_size` limit 134217728 |
By the way, I just load one layer transformer block because there wasn't enough memory available. |
burn-wgpu has been updated to utilize the full GPU memory so it should now work as long as your GPU has enough memory. |
@Ma-Jian1 how did you fix issue No.1 (" |
I attempted to modify the code directly, but I am unsure if it is correct. I just want to test whether or not it will run on my laptop, without caring about the result. |
I have the same problem. I m using stas/tiny-random-llama-2 the jit of burn has this repeat function pub(crate) fn repeat<R: Runtime, E: JitElement, const D1: usize>(
input: JitTensor<R, E, D1>,
dim: usize,
times: usize,
) -> JitTensor<R, E, D1> {
let mut shape = input.shape.clone();
if shape.dims[dim] != 1 {
panic!("Can only repeat dimension with dim=1");
} @Gadersd could you suggest any fix here? thx |
I want to test this project on my laptop with Intel Iris Xe Graphics, how can I achieve that?
my cpu memory is 16G.
The text was updated successfully, but these errors were encountered: