Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Rotatory Positional Encoding #215

Merged
merged 35 commits into from
Nov 22, 2024
Merged

Conversation

abourramouss
Copy link
Contributor

this pr is for the implementation of RoPE

#195

@abourramouss abourramouss marked this pull request as draft November 12, 2024 19:14
source/nn/positional-encoding.lisp Outdated Show resolved Hide resolved
source/nn/positional-encoding.lisp Outdated Show resolved Hide resolved
@hikettei hikettei mentioned this pull request Nov 15, 2024
28 tasks
@abourramouss
Copy link
Contributor Author

test-rope
  dtype=FLOAT32
    ✓ Shapes match
    ✓ Satisfying (atol=1.9999694) <= 2
    ✓ Satisfying (rtol=2.4570446) <= 3
T
NIL

Ideally the error should be on the e-5 scale, the issue is that the tensors aren't identical, some values are different.

@hikettei
Copy link
Owner

Your code looks very clean! LGTM
With regard to the accuracy issue, let me make some experiments (yeah it should be fit in 1e-5 scale) this might be due to the compiler issue

@hikettei
Copy link
Owner

hikettei commented Nov 20, 2024

the code looks very good. will merge this after fixing the error in 1e-5 scale

@hikettei
Copy link
Owner

hikettei commented Nov 20, 2024

Hmm with JIT it should be scheduled to a single kernel? (don't worry about that I will take this!)

  • I might have to a new fusion pattern for RoPE (without relying on :reduction=t)
CATEN/TEST-SUITE> (call (Rope `(10)) (make-tensor `(10 10)))
{Tensor[float32] :shape (10 10) :id TID1450187
  :buffer nil
  :op #<RESHAPE {700BF75D83}>
  :requires-grad NIL
  :variables (TID1450176)
  :tracker #<TRACKER :order={row(0 1)} :shape=(10 10) :contiguous-p=T>}
CATEN/TEST-SUITE> (proceed *)
[graph-schedule] Schedule Graph:

FastGraph[seen=NIL, outputs=(val_39)] {
    { Allocate } : [ val_34 <- (1 10 5 2) where lowered-p=nil ]
    { Allocate } : [ val_32 <- (1 10 5) where lowered-p=nil ]
    {  KERNEL  } : [ val_35 <- val_23, val_22, val_21, val_32, val_34 where lowered-p=nil :name=FUSED_CONCATENATENODE1463528]
    { Allocate } : [ val_0 <- (1 10 5) where lowered-p=nil ]
    { Allocate } : [ val_23 <- (10 10) where lowered-p=nil ]
    { Allocate } : [ val_17 <- (10 5) where lowered-p=nil ]
    { Allocate } : [ val_15 <- (10) where lowered-p=nil ]
    { Allocate } : [ val_9 <- (5) where lowered-p=nil ]
    { Allocate } : [ val_7 <- NIL where lowered-p=nil ]
    { Allocate } : [ val_5 <- NIL where lowered-p=nil ]
    { Allocate } : [ val_3 <- NIL where lowered-p=nil ]
    { Allocate } : [ val_1 <- NIL where lowered-p=nil ]
    {  KERNEL  } : [ val_22, val_21, val_36 <- val_1, val_3, val_5, val_7, val_9, val_15, val_17, val_23, val_0, val_35 where lowered-p=nil :name=FUSED_CONCATENATENODE_COSNODE1463526]
    { Allocate } : [ val_37 <- (1 10 5 2) where lowered-p=nil ]
    {  KERNEL  } : [ val_38 <- val_37, val_36 where lowered-p=nil :name=FUSED_MOVE1463522]
    {   VMOP   } : [ val_39 <- val_38 where lowered-p=nil :name=FUSED_BACKWARD1463520]
}

[14:21:53, 11/20/2024 (GMT+9)] : JIT Compilation Start (AVM=MAIN1450194)

* (1/3) FUSED_CONCATENATENODE1463528
=====> Lowering to blueprint
{
  for (int _gid0=0;(_gid0<1);_gid0+=1) {
    for (int _gid1=0;(_gid1<10);_gid1+=1) {
      for (int _gid2=0;(_gid2<5);_gid2+=1) {
        val_27 = -((val_23[((_gid0+(10*_gid1))+(1+(2*_gid2)))]*val_22[((_gid0+(5*_gid1))+_gid2)]));
        val_35[(((100*_gid0)+(10*_gid1))+(2*_gid2))] = ((val_23[((_gid0+(10*_gid1))+(2*_gid2))]*val_21[((_gid0+(5*_gid1))+_gid2)])+val_27);
      } // _gid2
    } // _gid1
  } // _gid0
}
Compilation Time : 0.030271(sec)
* (2/3) FUSED_CONCATENATENODE_COSNODE1463526
=====> Lowering to blueprint
{
  for (int _gid0=0;(_gid0<1);_gid0+=1) {
    for (int _gid1=0;(_gid1<10);_gid1+=1) {
      for (int _gid2=0;(_gid2<5);_gid2+=1) {
        val_19 = (_gid1*exp2((((_gid2*-9.2103405)*0.1)*1.442695)));
        val_22[((5*_gid1)+_gid2)] = sin(val_19);
        val_21[((5*_gid1)+_gid2)] = sin((val_19+1.5707964));
        val_29 = (val_23[((_gid0+(10*_gid1))+(1+(2*_gid2)))]*val_21[((_gid0+(5*_gid1))+_gid2)]);
        val_36[((((100*_gid0)+(10*_gid1))+(2*_gid2))+1)] = ((val_23[((_gid0+(10*_gid1))+(2*_gid2))]*val_22[((_gid0+(5*_gid1))+_gid2)])+val_29);
      } // _gid2
    } // _gid1
  } // _gid0
}
Compilation Time : 0.04597(sec)
* (3/3) FUSED_MOVE1463522 

=====> Lowering to blueprint
{
  for (int _gid0=0;(_gid0<100);_gid0+=1) {
    val_38[_gid0] = val_36[_gid0];
  } // _gid0
}
Compilation Time : 0.001554(sec)

@hikettei
Copy link
Owner

#245 will it fix something for JIT=0?

@abourramouss
Copy link
Contributor Author

Been doing some tests in torch and mlx, the issue is that mlx and pytorch yield different results, and my implementation is basically inspired by mlx, while the test is using the torch implementation

This is the code i've been using:

import torch
from torchtune.modules import RotaryPositionalEmbeddings
from mlx.nn import RoPE
import mlx.core as mx
import numpy as np

# Parameters
seq_len, num_heads, head_dim = 30, 30, 30
dim = head_dim
max_seq_len = seq_len
theta = 10000.0

x = torch.rand(1, seq_len, num_heads, head_dim)
rope_torch = RotaryPositionalEmbeddings(dim=head_dim)

rope_mlx = RoPE(dims=head_dim)

numpy_array = x.detach().cpu().numpy()

x_mlx = mx.array(numpy_array)

mx_output = rope_mlx(x_mlx)

print("Output MLX:")
mlx_tensor = torch.tensor(numpy_array)
print(type(mlx_tensor))

print("Output PyTorch:")
torch_tensor = rope_torch(x)
print(type(torch_tensor))


absolute_diff = torch.sum(torch.abs(torch_tensor - mlx_tensor))

relative_diff = absolute_diff / torch.sum(torch.abs(mlx_tensor))

# Output results
print(f"Total Absolute Difference: {absolute_diff.item()}")
print(f"Total Relative Difference: {relative_diff.item()}")

and rtol atol:

Output MLX:
<class 'torch.Tensor'>
Output PyTorch:
<class 'torch.Tensor'>
Total Absolute Difference: 6711.056640625
Total Relative Difference: 0.49696406722068787

I will reimplement the call function using the torch implementation in order to lower the rtol and atol to the required values.

@abourramouss abourramouss marked this pull request as ready for review November 22, 2024 09:12
source/nn/positional-encoding.lisp Outdated Show resolved Hide resolved
source/nn/positional-encoding.lisp Outdated Show resolved Hide resolved
source/nn/positional-encoding.lisp Outdated Show resolved Hide resolved
abourramouss and others added 3 commits November 22, 2024 11:28
added: assertion for tensor number of dimensions = 4, make-list with initial true, multiple-value-bind instead of manual initialization, assertion instead of when
Co-Authored-By: hikettei <[email protected]>
Co-Authored-By: hikettei <[email protected]>
hikettei added a commit to abourramouss/Caten that referenced this pull request Nov 22, 2024
@hikettei hikettei merged commit 8594f87 into hikettei:main Nov 22, 2024
1 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Development

Successfully merging this pull request may close these issues.

2 participants