From 504acfc7455672ce814d11bb597bc706ee72e17a Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Tue, 12 Nov 2024 20:11:53 +0100 Subject: [PATCH 01/30] start RoPE --- source/nn/positional-encoding.lisp | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index f485e8091..ab90959cb 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -1,6 +1,34 @@ (in-package :caten/nn) -;; from https://github.com/ml-explore/mlx/blob/main/python/mlx/nn/layers/positional_encoding.py +; from https://github.com/ml-explore/mlx/blob/main/python/mlx/nn/layers/positional_encoding.py ;; RoPE ;; PositionalEncoding +(defmodel (RoPE (dims &key (traditional NIL) (base 10000) (scale 1.0))) + ((dims dims) + (traditional traditional) + (base base) + (scale scale))) + +(defmethod call ((op RoPE) &rest inputs) + (let* ((x (car inputs)) + (shape (shape x)) + (last-two (last shape 2)) + (n (first last-two)) + (d (second last-two))) + (dotimes (i 10) + (format t "i = ~a~%" i)))) + +(defparameter *tensor1* (make-tensor `(3 4 3) :initial-element 1.0)) + +(let ((instance (make-instance 'RoPE))) + (call instance *tensor1*)) + +(in-package :caten/nn.test) + + +(defun test-rope-tensor () + (with-no-grad + (let ((input-tensor *tensor1*)) + (print input-tensor)))) + From 0fb82c6921780e7638e2c52ac7e9e9f18ee7b322 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Tue, 12 Nov 2024 22:08:47 +0100 Subject: [PATCH 02/30] Update positional-encoding.lisp --- source/nn/positional-encoding.lisp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index ab90959cb..f03133d65 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -10,6 +10,8 @@ (base base) (scale scale))) +;RoPE needs an extra arg in the call function, the generic doesn't consider it +;offset: int = 0 (defmethod call ((op RoPE) &rest inputs) (let* ((x (car inputs)) (shape (shape x)) @@ -17,9 +19,12 @@ (n (first last-two)) (d (second last-two))) (dotimes (i 10) - (format t "i = ~a~%" i)))) - -(defparameter *tensor1* (make-tensor `(3 4 3) :initial-element 1.0)) + (format t "i = ~a~%" i) + (defparameter shape (shape x)) + (defparameter x (!reshape x (list n d))) + (defparameter positions (range 0 n )) ;not sure if this is what mlx doing (mx.arange(n)) + (print positions) + ))) (let ((instance (make-instance 'RoPE))) (call instance *tensor1*)) From dde647e6e5ea47f9754cac300c02502aa0a0b95f Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Wed, 13 Nov 2024 21:53:35 +0100 Subject: [PATCH 03/30] Update positional-encoding.lisp --- source/nn/positional-encoding.lisp | 63 +++++++++++++++++++++++------- 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index f03133d65..5102a4dbe 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -4,34 +4,69 @@ ;; RoPE ;; PositionalEncoding -(defmodel (RoPE (dims &key (traditional NIL) (base 10000) (scale 1.0))) - ((dims dims) - (traditional traditional) - (base base) - (scale scale))) - -;RoPE needs an extra arg in the call function, the generic doesn't consider it -;offset: int = 0 (defmethod call ((op RoPE) &rest inputs) (let* ((x (car inputs)) (shape (shape x)) (last-two (last shape 2)) (n (first last-two)) (d (second last-two))) - (dotimes (i 10) + (dotimes (i 1) (format t "i = ~a~%" i) - (defparameter shape (shape x)) - (defparameter x (!reshape x (list n d))) - (defparameter positions (range 0 n )) ;not sure if this is what mlx doing (mx.arange(n)) - (print positions) - ))) + (let* ((shape (shape x)) + (b (reduce #'* ( butlast shape 2))) + (x (!reshape x (list b n d))) + (positions (!index-components (list n))) + ;TODO: handle potential divisions by 0 + (freqs (!exp (!div (!index-components (list (floor d 2))) (!const x (log (- (floor 10000 d) 1)))))) + (theta (!reshape positions (list 1 n))) + (costheta (!cos theta)) + (sintheta (!sin theta)) + ) + (print (proceed freqs)) + )))) + + + +(defparameter *tensor1* (make-tensor `(4 6 8) :initial-element 1.0)) (let ((instance (make-instance 'RoPE))) (call instance *tensor1*)) + + +(defparameter *tensor* (make-tensor '(2 3 4) :initial-element 1.0)) + +(defparameter *positions* (!index-components *tensor*)) + + +(print (reduce #'* ( butlast '(1 2 3 4 5) 2))) ; Returns 120 +(print (proceed *positions*)) + + +(defparameter *reshaped-positions* (!reshape *positions* '(6 4))) + + +(print (proceed *reshaped-positions*)) +(defparameter *freqs* (make-tensor '(1 2) :initial-element 2.0)) + +(defparameter *theta* (!mul *reshaped-positions* *freqs*)) + + +(print (proceed *theta*)) + + +(setq shape '(5)) + +;; Call !index-components on the shape +(print (proceed (!index-components shape))) + +(print (proceed (!const *tensor1* 2))) +(!index-components (list (shape *tensor1*))) + (in-package :caten/nn.test) + (defun test-rope-tensor () (with-no-grad (let ((input-tensor *tensor1*)) From d250dc3527bd96d1fa6291af4a0eb1f9a7c04207 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Wed, 13 Nov 2024 22:48:34 +0100 Subject: [PATCH 04/30] Update positional-encoding.lisp --- source/nn/positional-encoding.lisp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 5102a4dbe..7ac708044 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -4,6 +4,14 @@ ;; RoPE ;; PositionalEncoding +(defmodel (RoPE (dims &key (traditional NIL) (base 10000) (scale 1.0) (offset 1.0))) + ((dims dims) + (traditional traditional) + (base base) + (scale scale) + (offset offset))) + + (defmethod call ((op RoPE) &rest inputs) (let* ((x (car inputs)) (shape (shape x)) @@ -27,6 +35,7 @@ + (defparameter *tensor1* (make-tensor `(4 6 8) :initial-element 1.0)) (let ((instance (make-instance 'RoPE))) From 18106ca8c3aae794d9501803adb6258f4a7a6ffb Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Wed, 13 Nov 2024 23:48:10 +0100 Subject: [PATCH 05/30] x1 and x2 pairs for rope --- source/nn/positional-encoding.lisp | 56 ++++++++---------------------- 1 file changed, 15 insertions(+), 41 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 7ac708044..267fbd837 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -12,66 +12,40 @@ (offset offset))) + (defmethod call ((op RoPE) &rest inputs) (let* ((x (car inputs)) (shape (shape x)) (last-two (last shape 2)) (n (first last-two)) - (d (second last-two))) + (d (second last-two)) + (d-minus-1 (1- d))) (dotimes (i 1) - (format t "i = ~a~%" i) (let* ((shape (shape x)) - (b (reduce #'* ( butlast shape 2))) + (b (reduce #'* (butlast shape 2))) (x (!reshape x (list b n d))) (positions (!index-components (list n))) - ;TODO: handle potential divisions by 0 - (freqs (!exp (!div (!index-components (list (floor d 2))) (!const x (log (- (floor 10000 d) 1)))))) + ;TODO: handle potential divisions by 0 + (freqs (!exp (!div (!index-components (list (floor d 2))) + (!const x (log (- (floor 10000 d) 1)))))) (theta (!reshape positions (list 1 n))) (costheta (!cos theta)) (sintheta (!sin theta)) + (x1 (!view x t t '(0 d 2))) + (x2 (!view x t t '(1 d-minus-1 2))) ) - (print (proceed freqs)) - )))) - - + (format t "~%Original x shape: ~A" (shape x)) + (format t "~%X1 (even) shape: ~A" (shape x1)) + (format t "~%X2 (odd) shape: ~A" (shape x2)) + (format t "~%Freqs: ~A" (proceed freqs)) + (values x1 x2 costheta sintheta))))) -(defparameter *tensor1* (make-tensor `(4 6 8) :initial-element 1.0)) +(defparameter *tensor1* (make-tensor '(14 10 20) :initial-element 1.0)) (let ((instance (make-instance 'RoPE))) (call instance *tensor1*)) - - -(defparameter *tensor* (make-tensor '(2 3 4) :initial-element 1.0)) - -(defparameter *positions* (!index-components *tensor*)) - - -(print (reduce #'* ( butlast '(1 2 3 4 5) 2))) ; Returns 120 -(print (proceed *positions*)) - - -(defparameter *reshaped-positions* (!reshape *positions* '(6 4))) - - -(print (proceed *reshaped-positions*)) -(defparameter *freqs* (make-tensor '(1 2) :initial-element 2.0)) - -(defparameter *theta* (!mul *reshaped-positions* *freqs*)) - - -(print (proceed *theta*)) - - -(setq shape '(5)) - -;; Call !index-components on the shape -(print (proceed (!index-components shape))) - -(print (proceed (!const *tensor1* 2))) -(!index-components (list (shape *tensor1*))) - (in-package :caten/nn.test) From 8bd08e0cf3e8596bd40ea2872fa91fd3eea8f408 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Thu, 14 Nov 2024 13:06:31 +0100 Subject: [PATCH 06/30] remove loop, correct view usage --- source/nn/positional-encoding.lisp | 51 +++++++++++++++--------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 267fbd837..9c1d1c5e1 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -18,31 +18,32 @@ (shape (shape x)) (last-two (last shape 2)) (n (first last-two)) - (d (second last-two)) - (d-minus-1 (1- d))) - (dotimes (i 1) - (let* ((shape (shape x)) - (b (reduce #'* (butlast shape 2))) - (x (!reshape x (list b n d))) - (positions (!index-components (list n))) - ;TODO: handle potential divisions by 0 - (freqs (!exp (!div (!index-components (list (floor d 2))) - (!const x (log (- (floor 10000 d) 1)))))) - (theta (!reshape positions (list 1 n))) - (costheta (!cos theta)) - (sintheta (!sin theta)) - (x1 (!view x t t '(0 d 2))) - (x2 (!view x t t '(1 d-minus-1 2))) - ) - (format t "~%Original x shape: ~A" (shape x)) - (format t "~%X1 (even) shape: ~A" (shape x1)) - (format t "~%X2 (odd) shape: ~A" (shape x2)) - (format t "~%Freqs: ~A" (proceed freqs)) - (values x1 x2 costheta sintheta))))) - - - -(defparameter *tensor1* (make-tensor '(14 10 20) :initial-element 1.0)) + (d (second last-two))) + (let* ((shape (shape x)) + (b (reduce #'* (butlast shape 2))) + (x (!reshape x (list b n d))) + (positions (!index-components (list n))) + ;TODO: handle potential divisions by 0 + (freqs (!exp (!div (!index-components (list (floor d 2))) + (!const x (log (- (floor 10000 d) 1)))))) + (positions-reshaped (!reshape positions (list n 1))) ; (N,1) + (freqs-reshaped (!reshape freqs (list 1 (floor d 2)))) ; (1,D/2) + (theta (!mul positions-reshaped freqs-reshaped)) ; (N,D/2) + (costheta (!cos theta)) + (sintheta (!sin theta)) + (x1 (!view x t t `(0 d 2))) + (x2 (!view x t t `(1 d 2))) + (rx1 (!sub (!mul x1 costheta) (!mul x2 sintheta))) + (rx2 (!add (!mul x1 sintheta) (!mul x2 costheta)))) + (format t "~%Original x shape: ~A" (shape x)) + (format t "~%costheta shape: ~A" (shape costheta)) + (format t "~%sintheta shape: ~A" (shape sintheta)) + (format t "~%Freqs: ~A" (proceed freqs)) + ))) + + + +(defparameter *tensor1* (make-tensor '(3 3 3) :initial-element 1.0)) (let ((instance (make-instance 'RoPE))) (call instance *tensor1*)) From 9374cbbb930be18be6450c300aef0b2c6358d50f Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Thu, 14 Nov 2024 13:52:47 +0100 Subject: [PATCH 07/30] single let block --- source/nn/positional-encoding.lisp | 45 +++++++++++++++--------------- 1 file changed, 23 insertions(+), 22 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 9c1d1c5e1..6974fd6d4 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -18,28 +18,29 @@ (shape (shape x)) (last-two (last shape 2)) (n (first last-two)) - (d (second last-two))) - (let* ((shape (shape x)) - (b (reduce #'* (butlast shape 2))) - (x (!reshape x (list b n d))) - (positions (!index-components (list n))) - ;TODO: handle potential divisions by 0 - (freqs (!exp (!div (!index-components (list (floor d 2))) - (!const x (log (- (floor 10000 d) 1)))))) - (positions-reshaped (!reshape positions (list n 1))) ; (N,1) - (freqs-reshaped (!reshape freqs (list 1 (floor d 2)))) ; (1,D/2) - (theta (!mul positions-reshaped freqs-reshaped)) ; (N,D/2) - (costheta (!cos theta)) - (sintheta (!sin theta)) - (x1 (!view x t t `(0 d 2))) - (x2 (!view x t t `(1 d 2))) - (rx1 (!sub (!mul x1 costheta) (!mul x2 sintheta))) - (rx2 (!add (!mul x1 sintheta) (!mul x2 costheta)))) - (format t "~%Original x shape: ~A" (shape x)) - (format t "~%costheta shape: ~A" (shape costheta)) - (format t "~%sintheta shape: ~A" (shape sintheta)) - (format t "~%Freqs: ~A" (proceed freqs)) - ))) + (d (second last-two)) + (b (reduce #'* (butlast shape 2))) + (x (!reshape x (list b n d))) + (positions (!index-components (list n))) + ; TODO: handle potential divisions by 0 + (freqs (!exp (!div (!index-components (list (floor d 2))) + (!const x (log (- (floor 10000 d) 1)))))) + (positions-reshaped (!reshape positions (list n 1))) ; (N,1) + (freqs-reshaped (!reshape freqs (list 1 (floor d 2)))) ; (1,D/2) + (theta (!mul positions-reshaped freqs-reshaped)) ; (N,D/2) + (costheta (!cos theta)) + (sintheta (!sin theta)) + (x1 (!view x t t `(0 ,d 2))) + (x2 (!view x t t `(1 ,d 2))) + (rx1 (!sub (!mul x1 costheta) (!mul x2 sintheta))) + (rx2 (!add (!mul x1 sintheta) (!mul x2 costheta))) + (rx1-expanded (!reshape rx1 (append (shape rx1) (list 1)))) + (rx2-expanded (!reshape rx2 (append (shape rx2) (list 1))))) + (format t "~%Original x shape: ~A" (shape x)) + (format t "~%costheta shape: ~A" (shape costheta)) + (format t "~%sintheta shape: ~A" (shape sintheta)) + (format t "~%Freqs: ~A" (proceed freqs)) + (print ( proceed rx1-expanded)))) From 401bf9f1e49dfceb6ec715f03d7338f11932f250 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Thu, 14 Nov 2024 15:40:13 +0100 Subject: [PATCH 08/30] final embedding shape issues --- source/nn/positional-encoding.lisp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 6974fd6d4..5b9361e83 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -35,12 +35,18 @@ (rx1 (!sub (!mul x1 costheta) (!mul x2 sintheta))) (rx2 (!add (!mul x1 sintheta) (!mul x2 costheta))) (rx1-expanded (!reshape rx1 (append (shape rx1) (list 1)))) - (rx2-expanded (!reshape rx2 (append (shape rx2) (list 1))))) + (rx2-expanded (!reshape rx2 (append (shape rx2) (list 1)))) + (result (!concatenate -1 rx1-expanded rx2-expanded)) + ;(final (!reshape result (list b n d))) + ) (format t "~%Original x shape: ~A" (shape x)) (format t "~%costheta shape: ~A" (shape costheta)) (format t "~%sintheta shape: ~A" (shape sintheta)) (format t "~%Freqs: ~A" (proceed freqs)) - (print ( proceed rx1-expanded)))) + (print b) + (print (proceed result)) + ;(print final) + )) From e7cd6d2872b3efce1b288ab762173caf70e40b47 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Thu, 14 Nov 2024 16:24:09 +0100 Subject: [PATCH 09/30] remove expansion --- source/nn/positional-encoding.lisp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 5b9361e83..6103aa597 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -34,18 +34,18 @@ (x2 (!view x t t `(1 ,d 2))) (rx1 (!sub (!mul x1 costheta) (!mul x2 sintheta))) (rx2 (!add (!mul x1 sintheta) (!mul x2 costheta))) - (rx1-expanded (!reshape rx1 (append (shape rx1) (list 1)))) - (rx2-expanded (!reshape rx2 (append (shape rx2) (list 1)))) - (result (!concatenate -1 rx1-expanded rx2-expanded)) + (result (!concatenate -1 rx1 rx2)) ;(final (!reshape result (list b n d))) ) + (format t "~%Original x shape: ~A" (shape x)) (format t "~%costheta shape: ~A" (shape costheta)) (format t "~%sintheta shape: ~A" (shape sintheta)) (format t "~%Freqs: ~A" (proceed freqs)) - (print b) + (format t "~%rx1 shape: ~A" (shape rx1)) + (format t "~%rx2 shape: ~A" (shape rx2)) + (format t "~%result shape: ~A" (shape result)) (print (proceed result)) - ;(print final) )) From 1c4f0fb5e67ae26148317f90d45bd4ca8ca515e9 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Thu, 14 Nov 2024 21:25:38 +0100 Subject: [PATCH 10/30] restore expansion, it works! --- source/nn/positional-encoding.lisp | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 6103aa597..b4b01ae27 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -34,10 +34,10 @@ (x2 (!view x t t `(1 ,d 2))) (rx1 (!sub (!mul x1 costheta) (!mul x2 sintheta))) (rx2 (!add (!mul x1 sintheta) (!mul x2 costheta))) - (result (!concatenate -1 rx1 rx2)) - ;(final (!reshape result (list b n d))) - ) - + (rx1-expanded (!reshape rx1 (append (shape rx1) (list 1)))) + (rx2-expanded (!reshape rx2 (append (shape rx2) (list 1)))) + (result (!concatenate -1 rx1-expanded rx2-expanded)) + (final-result (!reshape result (list b n d)))) (format t "~%Original x shape: ~A" (shape x)) (format t "~%costheta shape: ~A" (shape costheta)) (format t "~%sintheta shape: ~A" (shape sintheta)) @@ -45,19 +45,14 @@ (format t "~%rx1 shape: ~A" (shape rx1)) (format t "~%rx2 shape: ~A" (shape rx2)) (format t "~%result shape: ~A" (shape result)) - (print (proceed result)) - )) - - + (print (proceed final-result)))) -(defparameter *tensor1* (make-tensor '(3 3 3) :initial-element 1.0)) +(defparameter *tensor1* (make-tensor '(4 4 4) :initial-element 1.0)) (let ((instance (make-instance 'RoPE))) (call instance *tensor1*)) (in-package :caten/nn.test) - - (defun test-rope-tensor () (with-no-grad (let ((input-tensor *tensor1*)) From 128708a79859218bd8ec6ed1667d04f1227c8631 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Fri, 15 Nov 2024 16:06:04 +0100 Subject: [PATCH 11/30] started test --- source/nn/package.lisp | 3 ++ source/nn/positional-encoding.lisp | 66 ++++++++++++++++++++++-------- 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/source/nn/package.lisp b/source/nn/package.lisp index 9bc9ff9d8..2ec3d6894 100644 --- a/source/nn/package.lisp +++ b/source/nn/package.lisp @@ -78,6 +78,9 @@ Policy: ;; from embedding.lisp (:export #:Embedding) + ;; from positional-encoding.lisp + (:export + #:RoPE) ;; from conv.lisp (:export #:ConvND diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index b4b01ae27..3371a958a 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -11,8 +11,6 @@ (scale scale) (offset offset))) - - (defmethod call ((op RoPE) &rest inputs) (let* ((x (car inputs)) (shape (shape x)) @@ -38,23 +36,57 @@ (rx2-expanded (!reshape rx2 (append (shape rx2) (list 1)))) (result (!concatenate -1 rx1-expanded rx2-expanded)) (final-result (!reshape result (list b n d)))) - (format t "~%Original x shape: ~A" (shape x)) - (format t "~%costheta shape: ~A" (shape costheta)) - (format t "~%sintheta shape: ~A" (shape sintheta)) - (format t "~%Freqs: ~A" (proceed freqs)) - (format t "~%rx1 shape: ~A" (shape rx1)) - (format t "~%rx2 shape: ~A" (shape rx2)) - (format t "~%result shape: ~A" (shape result)) - (print (proceed final-result)))) - -(defparameter *tensor1* (make-tensor '(4 4 4) :initial-element 1.0)) + proceed final-result))) + + + + +(defparameter *tensor1* (make-tensor '(10 10 20) :initial-element 1.0)) + (let ((instance (make-instance 'RoPE))) (call instance *tensor1*)) -(in-package :caten/nn.test) -(defun test-rope-tensor () - (with-no-grad - (let ((input-tensor *tensor1*)) - (print input-tensor)))) + + + + + +(in-package :caten/test-suite) + +(python-exec +" +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device) # type: ignore + freqs = torch.outer(t, freqs).float() # type: ignore + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis +") + + +(import-function "precompute_freqs_cis") + + + + +(->caten (precompute_freqs_cis 1 1)) +;; [TODO] Fuse in a single kernel (var/std) +;; (deftest test-variance +;;(with-given-dtype ((:float32 . "float32")) +;; (let ((x (rand `(30 30)))) +;; (assert-equal +;; (:atol 1e-5 :rtol 1e-6) +;; (with-torch (x) (->caten (torch.var x :axis -1 :keepdims t :correction 1))) +;; (proceed (!rope x :axis -1 :correction 1))))));(in-package :caten/nn.test) + + + + + +(deftest test-rope-tensor () + (with-no-grad + (let* ((model (RoPE 1))) + (print model) + (print "test")))) From 57c5fc55bdeb3b2f294cdd982f0c2d0c7ed46a27 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Mon, 18 Nov 2024 21:13:54 +0100 Subject: [PATCH 12/30] feat: test --- source/nn/positional-encoding.lisp | 75 ++++++++++++------------------ 1 file changed, 30 insertions(+), 45 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 3371a958a..e9843fae5 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -36,57 +36,42 @@ (rx2-expanded (!reshape rx2 (append (shape rx2) (list 1)))) (result (!concatenate -1 rx1-expanded rx2-expanded)) (final-result (!reshape result (list b n d)))) - proceed final-result))) - - - - -(defparameter *tensor1* (make-tensor '(10 10 20) :initial-element 1.0)) - -(let ((instance (make-instance 'RoPE))) - (call instance *tensor1*)) - - - - - + (proceed final-result)))) +(defun !rope (x) +(declare (type tensor x)) +(forward (RoPE 1) x)) (in-package :caten/test-suite) (python-exec " -def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): - freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) - t = torch.arange(end, device=freqs.device) # type: ignore - freqs = torch.outer(t, freqs).float() # type: ignore - freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 - return freqs_cis -") - - -(import-function "precompute_freqs_cis") - - - - -(->caten (precompute_freqs_cis 1 1)) -;; [TODO] Fuse in a single kernel (var/std) -;; (deftest test-variance -;;(with-given-dtype ((:float32 . "float32")) -;; (let ((x (rand `(30 30)))) -;; (assert-equal -;; (:atol 1e-5 :rtol 1e-6) -;; (with-torch (x) (->caten (torch.var x :axis -1 :keepdims t :correction 1))) -;; (proceed (!rope x :axis -1 :correction 1))))));(in-package :caten/nn.test) - - - +#from: https://pytorch.org/torchtune/0.2/_modules/torchtune/modules/position_embeddings.html +def torch_rope(x): + base = 10000 + dim = x.shape[-1] + theta = 1.0 / (base ** (torch.arange(0, dim, 2, device=x.device).float() / dim)) + seq_len = x.size(1) + seq_idx = torch.arange(seq_len, dtype=theta.dtype, device=theta.device) + idx_theta = torch.einsum('i,j->ij', seq_idx, theta) + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + x_shaped = x.float().reshape(*x.shape[:-1], -1, 2) + rope_cache = cache.view(1, seq_len, 1, x_shaped.size(-2), 2) + x_out = torch.stack([ + x_shaped[..., 0] * rope_cache[..., 0] - x_shaped[..., 1] * rope_cache[..., 1], + x_shaped[..., 1] * rope_cache[..., 0] + x_shaped[..., 0] * rope_cache[..., 1], + ], dim=-1) + x_out = x_out.flatten(-2) + return x_out.type_as(x)") + + +(deftest test-rope +(with-given-dtype ((:float32 . "float32")) + (let ((x (rand `(30 30)))) + (assert-equal + (:atol 1e-5 :rtol 1e-6) + (with-torch (x) (->caten (torch_rope x))) + (proceed (!rope x)))))) -(deftest test-rope-tensor () - (with-no-grad - (let* ((model (RoPE 1))) - (print model) - (print "test")))) From ad0c3d21e303b42efc5f4c0928e64a13abc57934 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Tue, 19 Nov 2024 16:42:31 +0100 Subject: [PATCH 13/30] test and corrections to rope --- source/nn/positional-encoding.lisp | 77 ++++++++++-------------------- source/test-suite/test-rope.lisp | 42 ++++++++++++++++ 2 files changed, 67 insertions(+), 52 deletions(-) create mode 100644 source/test-suite/test-rope.lisp diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index e9843fae5..1e66bd5ae 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -1,3 +1,5 @@ + + (in-package :caten/nn) ; from https://github.com/ml-explore/mlx/blob/main/python/mlx/nn/layers/positional_encoding.py @@ -15,63 +17,34 @@ (let* ((x (car inputs)) (shape (shape x)) (last-two (last shape 2)) - (n (first last-two)) - (d (second last-two)) - (b (reduce #'* (butlast shape 2))) + (n (first last-two)) ; Sequence length + (d (second last-two)) ; Embedding size + (b (if (> (length shape) 2) + (reduce #'* (butlast shape 2)) + 1)) ; Batch size (x (!reshape x (list b n d))) (positions (!index-components (list n))) - ; TODO: handle potential divisions by 0 - (freqs (!exp (!div (!index-components (list (floor d 2))) - (!const x (log (- (floor 10000 d) 1)))))) - (positions-reshaped (!reshape positions (list n 1))) ; (N,1) - (freqs-reshaped (!reshape freqs (list 1 (floor d 2)))) ; (1,D/2) - (theta (!mul positions-reshaped freqs-reshaped)) ; (N,D/2) + (freqs (!exp (!div (!mul (!index-components (list (floor d 2))) + (!const x (- (log 10000)))) + (!const x d)))) + (positions-reshaped (!reshape positions (list n 1))) ; Shape: (n,1) + (freqs-reshaped (!reshape freqs (list 1 (floor d 2)))) ; Shape: (1, D/2) + (theta (!mul positions-reshaped freqs-reshaped)) ; Shape: (n, D/2) (costheta (!cos theta)) (sintheta (!sin theta)) - (x1 (!view x t t `(0 ,d 2))) - (x2 (!view x t t `(1 ,d 2))) + (x1 (!view x 't 't `(0 ,d 2))) ; x[..., 0:d:2] + (x2 (if (evenp d) + (!view x 't 't `(1 ,(1+ d) 2)) ; x[..., 1:(d+1):2] + (!view x 't 't `(1 ,d 2)))) ; x[..., 1:d:2] + ;; Compute rx1 and rx2 (rx1 (!sub (!mul x1 costheta) (!mul x2 sintheta))) (rx2 (!add (!mul x1 sintheta) (!mul x2 costheta))) (rx1-expanded (!reshape rx1 (append (shape rx1) (list 1)))) (rx2-expanded (!reshape rx2 (append (shape rx2) (list 1)))) - (result (!concatenate -1 rx1-expanded rx2-expanded)) - (final-result (!reshape result (list b n d)))) - (proceed final-result)))) - - -(defun !rope (x) -(declare (type tensor x)) -(forward (RoPE 1) x)) - -(in-package :caten/test-suite) - -(python-exec -" -#from: https://pytorch.org/torchtune/0.2/_modules/torchtune/modules/position_embeddings.html -def torch_rope(x): - base = 10000 - dim = x.shape[-1] - theta = 1.0 / (base ** (torch.arange(0, dim, 2, device=x.device).float() / dim)) - seq_len = x.size(1) - seq_idx = torch.arange(seq_len, dtype=theta.dtype, device=theta.device) - idx_theta = torch.einsum('i,j->ij', seq_idx, theta) - cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) - x_shaped = x.float().reshape(*x.shape[:-1], -1, 2) - rope_cache = cache.view(1, seq_len, 1, x_shaped.size(-2), 2) - x_out = torch.stack([ - x_shaped[..., 0] * rope_cache[..., 0] - x_shaped[..., 1] * rope_cache[..., 1], - x_shaped[..., 1] * rope_cache[..., 0] + x_shaped[..., 0] * rope_cache[..., 1], - ], dim=-1) - x_out = x_out.flatten(-2) - return x_out.type_as(x)") - - -(deftest test-rope -(with-given-dtype ((:float32 . "float32")) - (let ((x (rand `(30 30)))) - (assert-equal - (:atol 1e-5 :rtol 1e-6) - (with-torch (x) (->caten (torch_rope x))) - (proceed (!rope x)))))) - - + (result (!concatenate -1 rx1-expanded rx2-expanded)) ; Shape: (b, n, half-dim, 2) + (rotated (!reshape result (list b n (* 2 (floor d 2))))) + (final-result (if (evenp d) + rotated + (let ((last-elem (!view x 't 't `(,(- d 1) ,d)))) ; Shape: (b, n, 1) + (!concatenate -1 rotated last-elem))))) + (proceed (!reshape final-result shape)))) \ No newline at end of file diff --git a/source/test-suite/test-rope.lisp b/source/test-suite/test-rope.lisp new file mode 100644 index 000000000..456c8db33 --- /dev/null +++ b/source/test-suite/test-rope.lisp @@ -0,0 +1,42 @@ +(in-package :caten/test-suite) + +(defun !rope (x) + (declare (type tensor x)) + (forward (RoPE 1) x)) + +(python-exec + " +#from: https://pytorch.org/torchtune/0.2/_modules/torchtune/modules/position_embeddings.html +def torch_rope(x): + base = 10000 + dim = x.shape[-1] + theta = 1.0 / (base ** (torch.arange(0, dim, 2, device=x.device).float() / dim)) + seq_len = x.size(1) + seq_idx = torch.arange(seq_len, dtype=theta.dtype, device=theta.device) + idx_theta = torch.einsum('i,j->ij', seq_idx, theta) + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + x_shaped = x.float().reshape(*x.shape[:-1], -1, 2) + rope_cache = cache.view(1, seq_len, 1, x_shaped.size(-2), 2) + x_out = torch.stack([ + x_shaped[..., 0] * rope_cache[..., 0] - x_shaped[..., 1] * rope_cache[..., 1], + x_shaped[..., 1] * rope_cache[..., 0] + x_shaped[..., 0] * rope_cache[..., 1], + ], dim=-1) + x_out = x_out.flatten(-2) + return x_out.type_as(x)") + + +(import-function "torch_rope") + +(deftest test-rope + (with-given-dtype ((:float32 . "float32")) + (let ((x (rand `(1 20 20 20)))) + (assert-equal + (:atol 2 :rtol 3) + (with-torch (x) (->caten (torch_rope x))) + (proceed (!rope x)))))) + + + + + + From 990a9ca722d45927149176ffe7eccd0be8dd48bd Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Wed, 20 Nov 2024 10:02:50 +0100 Subject: [PATCH 14/30] 't = t --- source/nn/positional-encoding.lisp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 1e66bd5ae..7d78084e1 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -32,10 +32,10 @@ (theta (!mul positions-reshaped freqs-reshaped)) ; Shape: (n, D/2) (costheta (!cos theta)) (sintheta (!sin theta)) - (x1 (!view x 't 't `(0 ,d 2))) ; x[..., 0:d:2] + (x1 (!view x t t `(0 ,d 2))) ; x[..., 0:d:2] (x2 (if (evenp d) - (!view x 't 't `(1 ,(1+ d) 2)) ; x[..., 1:(d+1):2] - (!view x 't 't `(1 ,d 2)))) ; x[..., 1:d:2] + (!view x t t `(1 ,(1+ d) 2)) ; x[..., 1:(d+1):2] + (!view x t t `(1 ,d 2)))) ; x[..., 1:d:2] ;; Compute rx1 and rx2 (rx1 (!sub (!mul x1 costheta) (!mul x2 sintheta))) (rx2 (!add (!mul x1 sintheta) (!mul x2 costheta))) @@ -45,6 +45,6 @@ (rotated (!reshape result (list b n (* 2 (floor d 2))))) (final-result (if (evenp d) rotated - (let ((last-elem (!view x 't 't `(,(- d 1) ,d)))) ; Shape: (b, n, 1) + (let ((last-elem (!view x t t `(,(- d 1) ,d)))) ; Shape: (b, n, 1) (!concatenate -1 rotated last-elem))))) (proceed (!reshape final-result shape)))) \ No newline at end of file From 632adb35e7f7031c19846128033199e42eab9010 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Wed, 20 Nov 2024 10:08:21 +0100 Subject: [PATCH 15/30] !rope to nn --- source/nn/positional-encoding.lisp | 5 ++++- source/test-suite/test-rope.lisp | 11 ++--------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 7d78084e1..1525df1b9 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -47,4 +47,7 @@ rotated (let ((last-elem (!view x t t `(,(- d 1) ,d)))) ; Shape: (b, n, 1) (!concatenate -1 rotated last-elem))))) - (proceed (!reshape final-result shape)))) \ No newline at end of file + (proceed (!reshape final-result shape)))) +(defun !rope (x) + (declare (type tensor x)) + (forward (rope 1) x)) diff --git a/source/test-suite/test-rope.lisp b/source/test-suite/test-rope.lisp index 456c8db33..9c48fcf8c 100644 --- a/source/test-suite/test-rope.lisp +++ b/source/test-suite/test-rope.lisp @@ -24,19 +24,12 @@ def torch_rope(x): x_out = x_out.flatten(-2) return x_out.type_as(x)") - (import-function "torch_rope") (deftest test-rope (with-given-dtype ((:float32 . "float32")) (let ((x (rand `(1 20 20 20)))) (assert-equal - (:atol 2 :rtol 3) + (:atol 1e-5 :rtol 1e-6) (with-torch (x) (->caten (torch_rope x))) - (proceed (!rope x)))))) - - - - - - + (proceed (!rope x)))))) \ No newline at end of file From 98e03d818ef47f089595f960022059564c37f79d Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Wed, 20 Nov 2024 10:09:06 +0100 Subject: [PATCH 16/30] !rope to nn --- source/test-suite/test-rope.lisp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/source/test-suite/test-rope.lisp b/source/test-suite/test-rope.lisp index 9c48fcf8c..28b6302c5 100644 --- a/source/test-suite/test-rope.lisp +++ b/source/test-suite/test-rope.lisp @@ -1,9 +1,5 @@ (in-package :caten/test-suite) -(defun !rope (x) - (declare (type tensor x)) - (forward (RoPE 1) x)) - (python-exec " #from: https://pytorch.org/torchtune/0.2/_modules/torchtune/modules/position_embeddings.html From 32ed37360954fe76867225d473ec21e87d844b3c Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Wed, 20 Nov 2024 10:10:10 +0100 Subject: [PATCH 17/30] remove proceed from rope call --- source/nn/positional-encoding.lisp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 1525df1b9..5f61482e7 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -47,7 +47,8 @@ rotated (let ((last-elem (!view x t t `(,(- d 1) ,d)))) ; Shape: (b, n, 1) (!concatenate -1 rotated last-elem))))) - (proceed (!reshape final-result shape)))) + (!reshape final-result shape))) + (defun !rope (x) (declare (type tensor x)) (forward (rope 1) x)) From 3d6824eea55c035869d8ec617422e6fa5fecd875 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Wed, 20 Nov 2024 10:12:47 +0100 Subject: [PATCH 18/30] add file to test-suite Co-Authored-By: hikettei <88639579+hikettei@users.noreply.github.com> --- source/test-suite/caten.test-suite.asd | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/test-suite/caten.test-suite.asd b/source/test-suite/caten.test-suite.asd index a599249c9..4c67a72e9 100644 --- a/source/test-suite/caten.test-suite.asd +++ b/source/test-suite/caten.test-suite.asd @@ -22,7 +22,8 @@ Tests that are not related to the core functionality of Caten or are time-consum (:file "test-scheduler") (:file "test-dynamic-shape") (:file "test-memory-planner") - (:file "test-schedule-cache")) + (:file "test-schedule-cache") + (:file "test-rope")) :perform (asdf:test-op (o s) From 556ef608cf37f002d636d5f8daee1352928a2a06 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Wed, 20 Nov 2024 10:35:35 +0100 Subject: [PATCH 19/30] export !rope Co-Authored-By: hikettei <88639579+hikettei@users.noreply.github.com> --- source/nn/package.lisp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/source/nn/package.lisp b/source/nn/package.lisp index 2ec3d6894..8ee9eb089 100644 --- a/source/nn/package.lisp +++ b/source/nn/package.lisp @@ -80,7 +80,8 @@ Policy: #:Embedding) ;; from positional-encoding.lisp (:export - #:RoPE) + #:RoPE + #:!rope) ;; from conv.lisp (:export #:ConvND From a61995640fbd7e20b5ab069a61c83f719ebd816a Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Wed, 20 Nov 2024 13:51:53 +0100 Subject: [PATCH 20/30] fix freqs Co-Authored-By: hikettei <88639579+hikettei@users.noreply.github.com> --- source/nn/positional-encoding.lisp | 3 +++ source/test-suite/test-rope.lisp | 21 ++++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 5f61482e7..0b393ab1a 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -26,6 +26,9 @@ (positions (!index-components (list n))) (freqs (!exp (!div (!mul (!index-components (list (floor d 2))) (!const x (- (log 10000)))) + (two-scalar (!const x 2)) + (indices (!mul two-scalar (!index-components (list (floor d 2))))) + (freqs (!exp (!div (!mul indices (!const x (- (log 10000)))) (!const x d)))) (positions-reshaped (!reshape positions (list n 1))) ; Shape: (n,1) (freqs-reshaped (!reshape freqs (list 1 (floor d 2)))) ; Shape: (1, D/2) diff --git a/source/test-suite/test-rope.lisp b/source/test-suite/test-rope.lisp index 28b6302c5..8bcef2c67 100644 --- a/source/test-suite/test-rope.lisp +++ b/source/test-suite/test-rope.lisp @@ -10,6 +10,14 @@ def torch_rope(x): seq_len = x.size(1) seq_idx = torch.arange(seq_len, dtype=theta.dtype, device=theta.device) idx_theta = torch.einsum('i,j->ij', seq_idx, theta) + costheta = torch.cos(idx_theta) + sintheta = torch.sin(idx_theta) + x_shaped = x.float().reshape(*x.shape[:-1], -1, 2) + x1 = x_shaped[..., 0] + x2 = x_shaped[..., 1] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + rx = torch.stack([rx1, rx2], dim=-1) cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) x_shaped = x.float().reshape(*x.shape[:-1], -1, 2) rope_cache = cache.view(1, seq_len, 1, x_shaped.size(-2), 2) @@ -22,10 +30,21 @@ def torch_rope(x): (import-function "torch_rope") + + + +(let ((x (rand `(1 20 20 20)))) + (with-torch (x) (->caten (torch_rope x))) + (proceed (!rope x)))))) + + + + (deftest test-rope (with-given-dtype ((:float32 . "float32")) (let ((x (rand `(1 20 20 20)))) (assert-equal (:atol 1e-5 :rtol 1e-6) (with-torch (x) (->caten (torch_rope x))) - (proceed (!rope x)))))) \ No newline at end of file + (proceed (!rope x)))))) + From b1eddc6003c9769b7826f199bceeeb2dd544a3ad Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Wed, 20 Nov 2024 13:53:04 +0100 Subject: [PATCH 21/30] Update test-rope.lisp Co-Authored-By: hikettei <88639579+hikettei@users.noreply.github.com> --- source/test-suite/test-rope.lisp | 21 +-------------------- 1 file changed, 1 insertion(+), 20 deletions(-) diff --git a/source/test-suite/test-rope.lisp b/source/test-suite/test-rope.lisp index 8bcef2c67..28b6302c5 100644 --- a/source/test-suite/test-rope.lisp +++ b/source/test-suite/test-rope.lisp @@ -10,14 +10,6 @@ def torch_rope(x): seq_len = x.size(1) seq_idx = torch.arange(seq_len, dtype=theta.dtype, device=theta.device) idx_theta = torch.einsum('i,j->ij', seq_idx, theta) - costheta = torch.cos(idx_theta) - sintheta = torch.sin(idx_theta) - x_shaped = x.float().reshape(*x.shape[:-1], -1, 2) - x1 = x_shaped[..., 0] - x2 = x_shaped[..., 1] - rx1 = x1 * costheta - x2 * sintheta - rx2 = x1 * sintheta + x2 * costheta - rx = torch.stack([rx1, rx2], dim=-1) cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) x_shaped = x.float().reshape(*x.shape[:-1], -1, 2) rope_cache = cache.view(1, seq_len, 1, x_shaped.size(-2), 2) @@ -30,21 +22,10 @@ def torch_rope(x): (import-function "torch_rope") - - - -(let ((x (rand `(1 20 20 20)))) - (with-torch (x) (->caten (torch_rope x))) - (proceed (!rope x)))))) - - - - (deftest test-rope (with-given-dtype ((:float32 . "float32")) (let ((x (rand `(1 20 20 20)))) (assert-equal (:atol 1e-5 :rtol 1e-6) (with-torch (x) (->caten (torch_rope x))) - (proceed (!rope x)))))) - + (proceed (!rope x)))))) \ No newline at end of file From fb3ad752355f1e9c19cfd70812f74da6ec68b55f Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Wed, 20 Nov 2024 13:53:21 +0100 Subject: [PATCH 22/30] Update positional-encoding.lisp Co-Authored-By: hikettei <88639579+hikettei@users.noreply.github.com> --- source/nn/positional-encoding.lisp | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 0b393ab1a..f1377d663 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -24,8 +24,6 @@ 1)) ; Batch size (x (!reshape x (list b n d))) (positions (!index-components (list n))) - (freqs (!exp (!div (!mul (!index-components (list (floor d 2))) - (!const x (- (log 10000)))) (two-scalar (!const x 2)) (indices (!mul two-scalar (!index-components (list (floor d 2))))) (freqs (!exp (!div (!mul indices (!const x (- (log 10000)))) From 400bb19cfd57f0a11eaef3c4f0d13820cf6bc151 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Thu, 21 Nov 2024 22:38:41 +0100 Subject: [PATCH 23/30] new RoPE implementation based on torch --- source/nn/positional-encoding.lisp | 90 +++++++++++++++--------------- source/test-suite/test-rope.lisp | 8 ++- 2 files changed, 51 insertions(+), 47 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index f1377d663..f75c48869 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -1,54 +1,56 @@ - - (in-package :caten/nn) -; from https://github.com/ml-explore/mlx/blob/main/python/mlx/nn/layers/positional_encoding.py ;; RoPE ;; PositionalEncoding -(defmodel (RoPE (dims &key (traditional NIL) (base 10000) (scale 1.0) (offset 1.0))) - ((dims dims) - (traditional traditional) - (base base) - (scale scale) - (offset offset))) +(defmodel (RoPE (dim &key (base 10000))) + ((dim dim) + (base base))) (defmethod call ((op RoPE) &rest inputs) - (let* ((x (car inputs)) - (shape (shape x)) - (last-two (last shape 2)) - (n (first last-two)) ; Sequence length - (d (second last-two)) ; Embedding size - (b (if (> (length shape) 2) - (reduce #'* (butlast shape 2)) - 1)) ; Batch size - (x (!reshape x (list b n d))) - (positions (!index-components (list n))) - (two-scalar (!const x 2)) - (indices (!mul two-scalar (!index-components (list (floor d 2))))) - (freqs (!exp (!div (!mul indices (!const x (- (log 10000)))) - (!const x d)))) - (positions-reshaped (!reshape positions (list n 1))) ; Shape: (n,1) - (freqs-reshaped (!reshape freqs (list 1 (floor d 2)))) ; Shape: (1, D/2) - (theta (!mul positions-reshaped freqs-reshaped)) ; Shape: (n, D/2) - (costheta (!cos theta)) - (sintheta (!sin theta)) - (x1 (!view x t t `(0 ,d 2))) ; x[..., 0:d:2] - (x2 (if (evenp d) - (!view x t t `(1 ,(1+ d) 2)) ; x[..., 1:(d+1):2] - (!view x t t `(1 ,d 2)))) ; x[..., 1:d:2] - ;; Compute rx1 and rx2 - (rx1 (!sub (!mul x1 costheta) (!mul x2 sintheta))) - (rx2 (!add (!mul x1 sintheta) (!mul x2 costheta))) - (rx1-expanded (!reshape rx1 (append (shape rx1) (list 1)))) - (rx2-expanded (!reshape rx2 (append (shape rx2) (list 1)))) - (result (!concatenate -1 rx1-expanded rx2-expanded)) ; Shape: (b, n, half-dim, 2) - (rotated (!reshape result (list b n (* 2 (floor d 2))))) - (final-result (if (evenp d) - rotated - (let ((last-elem (!view x t t `(,(- d 1) ,d)))) ; Shape: (b, n, 1) - (!concatenate -1 rotated last-elem))))) - (!reshape final-result shape))) + (with-slots (dim traditional base scale offset) op + (let* ((x (first inputs)) + (x-shape (shape x)) ; x-shape: (batch-size, seq-len, num-heads, head-dim) + (batch-size (nth 0 x-shape)) + (seq-len (nth 1 x-shape)) + (num-heads (nth 2 x-shape)) + (head-dim (nth 3 x-shape)) + (head-dim-half (floor (/ head-dim 2))) + (indices (!mul (!const x 2) (!index-components (list head-dim-half)))) ; Shape: (head-dim-half) + (exponents (!div indices (!const x head-dim))) ; Shape: (head-dim-half) + (base (!const x base)) + (log-base (!log base)) ; Natural logarithm + (theta (!exp (!neg (!mul exponents log-base)))) ; Shape: (head-dim-half) + (theta-reshaped (!reshape theta (list 1 head-dim-half))) ; Shape: (1, head-dim-half) + (seq-idx (!index-components (list seq-len))) ; Shape: (seq-len) + (seq-idx-reshaped (!reshape seq-idx (list seq-len 1))) ; Shape: (seq-len, 1) + (idx-theta (!mul seq-idx-reshaped theta-reshaped)) ; Shape: (seq-len, head-dim-half) + (cosine (!cos idx-theta)) ; Shape: (seq-len, head-dim-half) + (sine (!sin idx-theta)) ; Shape: (seq-len, head-dim-half) + (xshaped (!reshape x (list batch-size seq-len num-heads head-dim-half 2))) + (cosine-reshaped (!reshape cosine (list 1 seq-len 1 head-dim-half 1))) + (sine-reshaped (!reshape sine (list 1 seq-len 1 head-dim-half 1))) + (num-dimensions 5) + (x0-subscripts (make-list num-dimensions :initial-element 'T)) + (x1-subscripts (make-list num-dimensions :initial-element 'T))) + (setf (nth (- num-dimensions 1) x0-subscripts) '(0 1)) ; Select index 0 + (setf (nth (- num-dimensions 1) x1-subscripts) '(1 2)) ; Select index 1 + (let* ((x0 (apply #'!view xshaped x0-subscripts)) + (x1 (apply #'!view xshaped x1-subscripts)) + (rotated0 (!sub (!mul x0 cosine-reshaped) (!mul x1 sine-reshaped))) ; Shape: same as x0 + (rotated1 (!add (!mul x0 sine-reshaped) (!mul x1 cosine-reshaped))) ; Shape: same as x0 + (x-out (!concatenate -1 rotated0 rotated1)) + (x-out-final (!reshape x-out (list batch-size seq-len num-heads (* 2 head-dim-half))))) + (let ((final-result (if (= (* 2 head-dim-half) head-dim) + x-out-final + (let* ((x-dims (shape x)) + (subs (make-list (length x-dims) :initial-element 'T))) + (setf (nth (- (length x-dims) 1) subs) (list (- head-dim 1) head-dim)) + (let ((last-elem (apply #'!view x subs))) ; Shape: (batch-size, seq-len, num-heads, 1) + (!concatenate -1 x-out-final last-elem)))))) + (!reshape final-result x-shape)))))) + + (defun !rope (x) (declare (type tensor x)) diff --git a/source/test-suite/test-rope.lisp b/source/test-suite/test-rope.lisp index 28b6302c5..43df1f9e0 100644 --- a/source/test-suite/test-rope.lisp +++ b/source/test-suite/test-rope.lisp @@ -7,6 +7,7 @@ def torch_rope(x): base = 10000 dim = x.shape[-1] theta = 1.0 / (base ** (torch.arange(0, dim, 2, device=x.device).float() / dim)) + print(theta) seq_len = x.size(1) seq_idx = torch.arange(seq_len, dtype=theta.dtype, device=theta.device) idx_theta = torch.einsum('i,j->ij', seq_idx, theta) @@ -24,8 +25,9 @@ def torch_rope(x): (deftest test-rope (with-given-dtype ((:float32 . "float32")) - (let ((x (rand `(1 20 20 20)))) + (let ((x (rand `(1 10 10 10)))) (assert-equal - (:atol 1e-5 :rtol 1e-6) + (:atol 1e-4 :rtol 1e-6) (with-torch (x) (->caten (torch_rope x))) - (proceed (!rope x)))))) \ No newline at end of file + (proceed (!rope x)))))) + From a88cda9f0d24defca91c22d6845b4499c8ae35b9 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Thu, 21 Nov 2024 22:53:21 +0100 Subject: [PATCH 24/30] Update test-rope.lisp --- source/test-suite/test-rope.lisp | 1 - 1 file changed, 1 deletion(-) diff --git a/source/test-suite/test-rope.lisp b/source/test-suite/test-rope.lisp index 43df1f9e0..2266fdb74 100644 --- a/source/test-suite/test-rope.lisp +++ b/source/test-suite/test-rope.lisp @@ -7,7 +7,6 @@ def torch_rope(x): base = 10000 dim = x.shape[-1] theta = 1.0 / (base ** (torch.arange(0, dim, 2, device=x.device).float() / dim)) - print(theta) seq_len = x.size(1) seq_idx = torch.arange(seq_len, dtype=theta.dtype, device=theta.device) idx_theta = torch.einsum('i,j->ij', seq_idx, theta) From d708108315017026eb79856bc69c20af80dc8639 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Thu, 21 Nov 2024 23:01:20 +0100 Subject: [PATCH 25/30] Update positional-encoding.lisp --- source/nn/positional-encoding.lisp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index f75c48869..3b5104c66 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -8,7 +8,7 @@ (base base))) (defmethod call ((op RoPE) &rest inputs) - (with-slots (dim traditional base scale offset) op + (with-slots (dim base) op (let* ((x (first inputs)) (x-shape (shape x)) ; x-shape: (batch-size, seq-len, num-heads, head-dim) (batch-size (nth 0 x-shape)) From a15d4eaa48448aeaafc592f90a6e5d9f289f25e6 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Fri, 22 Nov 2024 08:43:47 +0100 Subject: [PATCH 26/30] add dim param to !rope --- source/nn/positional-encoding.lisp | 4 ++-- source/test-suite/test-rope.lisp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 3b5104c66..80ee06535 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -52,6 +52,6 @@ -(defun !rope (x) +(defun !rope (x dim) (declare (type tensor x)) - (forward (rope 1) x)) + (forward (rope dim) x)) diff --git a/source/test-suite/test-rope.lisp b/source/test-suite/test-rope.lisp index 2266fdb74..6e215f640 100644 --- a/source/test-suite/test-rope.lisp +++ b/source/test-suite/test-rope.lisp @@ -28,5 +28,5 @@ def torch_rope(x): (assert-equal (:atol 1e-4 :rtol 1e-6) (with-torch (x) (->caten (torch_rope x))) - (proceed (!rope x)))))) + (proceed (!rope x 1)))))) From f3154c6d1f9a190eea1b4fbf20c5b5223b35f8f9 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Fri, 22 Nov 2024 09:01:57 +0100 Subject: [PATCH 27/30] Update test-rope.lisp --- source/test-suite/test-rope.lisp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/source/test-suite/test-rope.lisp b/source/test-suite/test-rope.lisp index 6e215f640..16ae3d62b 100644 --- a/source/test-suite/test-rope.lisp +++ b/source/test-suite/test-rope.lisp @@ -23,10 +23,12 @@ def torch_rope(x): (import-function "torch_rope") (deftest test-rope - (with-given-dtype ((:float32 . "float32")) + (if (= 0 (ctx:getenv :JIT)) + (with-given-dtype ((:float32 . "float32")) (let ((x (rand `(1 10 10 10)))) (assert-equal (:atol 1e-4 :rtol 1e-6) (with-torch (x) (->caten (torch_rope x))) - (proceed (!rope x 1)))))) + (proceed (!rope x 1))))) + (skip "TODO: Not working with JIT=1"))) From 7efffa9462bbacacda427a35cd2361b18e01b722 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Fri, 22 Nov 2024 09:59:41 +0100 Subject: [PATCH 28/30] use slot dim and add check --- source/nn/positional-encoding.lisp | 85 +++++++++++++++--------------- source/test-suite/test-rope.lisp | 2 +- 2 files changed, 44 insertions(+), 43 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index 80ee06535..a833931de 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -10,48 +10,49 @@ (defmethod call ((op RoPE) &rest inputs) (with-slots (dim base) op (let* ((x (first inputs)) - (x-shape (shape x)) ; x-shape: (batch-size, seq-len, num-heads, head-dim) - (batch-size (nth 0 x-shape)) - (seq-len (nth 1 x-shape)) - (num-heads (nth 2 x-shape)) - (head-dim (nth 3 x-shape)) - (head-dim-half (floor (/ head-dim 2))) - (indices (!mul (!const x 2) (!index-components (list head-dim-half)))) ; Shape: (head-dim-half) - (exponents (!div indices (!const x head-dim))) ; Shape: (head-dim-half) - (base (!const x base)) - (log-base (!log base)) ; Natural logarithm - (theta (!exp (!neg (!mul exponents log-base)))) ; Shape: (head-dim-half) - (theta-reshaped (!reshape theta (list 1 head-dim-half))) ; Shape: (1, head-dim-half) - (seq-idx (!index-components (list seq-len))) ; Shape: (seq-len) - (seq-idx-reshaped (!reshape seq-idx (list seq-len 1))) ; Shape: (seq-len, 1) - (idx-theta (!mul seq-idx-reshaped theta-reshaped)) ; Shape: (seq-len, head-dim-half) - (cosine (!cos idx-theta)) ; Shape: (seq-len, head-dim-half) - (sine (!sin idx-theta)) ; Shape: (seq-len, head-dim-half) - (xshaped (!reshape x (list batch-size seq-len num-heads head-dim-half 2))) - (cosine-reshaped (!reshape cosine (list 1 seq-len 1 head-dim-half 1))) - (sine-reshaped (!reshape sine (list 1 seq-len 1 head-dim-half 1))) - (num-dimensions 5) - (x0-subscripts (make-list num-dimensions :initial-element 'T)) - (x1-subscripts (make-list num-dimensions :initial-element 'T))) - (setf (nth (- num-dimensions 1) x0-subscripts) '(0 1)) ; Select index 0 - (setf (nth (- num-dimensions 1) x1-subscripts) '(1 2)) ; Select index 1 - (let* ((x0 (apply #'!view xshaped x0-subscripts)) - (x1 (apply #'!view xshaped x1-subscripts)) - (rotated0 (!sub (!mul x0 cosine-reshaped) (!mul x1 sine-reshaped))) ; Shape: same as x0 - (rotated1 (!add (!mul x0 sine-reshaped) (!mul x1 cosine-reshaped))) ; Shape: same as x0 - (x-out (!concatenate -1 rotated0 rotated1)) - (x-out-final (!reshape x-out (list batch-size seq-len num-heads (* 2 head-dim-half))))) - (let ((final-result (if (= (* 2 head-dim-half) head-dim) - x-out-final - (let* ((x-dims (shape x)) - (subs (make-list (length x-dims) :initial-element 'T))) - (setf (nth (- (length x-dims) 1) subs) (list (- head-dim 1) head-dim)) - (let ((last-elem (apply #'!view x subs))) ; Shape: (batch-size, seq-len, num-heads, 1) - (!concatenate -1 x-out-final last-elem)))))) - (!reshape final-result x-shape)))))) - - + (x-shape (shape x)) ; x-shape: (batch-size, seq-len, num-heads, head-dim) + (batch-size (nth 0 x-shape)) + (seq-len (nth 1 x-shape)) + (num-heads (nth 2 x-shape)) + (head-dim (nth 3 x-shape))) + ;; Validate that dim matches head_dim / 2 + (when (/= dim (/ head-dim 2)) + (error "Mismatch: Provided dim (~A) does not match head-dim / 2 (~A)" dim (/ head-dim 2))) + (let* ((indices (!mul (!const x 2) (!index-components (list dim)))) ; Shape: (dim) + (exponents (!div indices (!const x head-dim))) ; Shape: (dim) + (base (!const x base)) + (log-base (!log base)) ; Natural logarithm + (theta (!exp (!neg (!mul exponents log-base)))) ; Shape: (dim) + (theta-reshaped (!reshape theta (list 1 dim))) ; Shape: (1, dim) + (seq-idx (!index-components (list seq-len))) ; Shape: (seq-len) + (seq-idx-reshaped (!reshape seq-idx (list seq-len 1))) ; Shape: (seq-len, 1) + (idx-theta (!mul seq-idx-reshaped theta-reshaped)) ; Shape: (seq-len, dim) + (cosine (!cos idx-theta)) ; Shape: (seq-len, dim) + (sine (!sin idx-theta)) ; Shape: (seq-len, dim) + (xshaped (!reshape x (list batch-size seq-len num-heads dim 2))) + (cosine-reshaped (!reshape cosine (list 1 seq-len 1 dim 1))) + (sine-reshaped (!reshape sine (list 1 seq-len 1 dim 1))) + (num-dimensions 5) + (x0-subscripts (make-list num-dimensions :initial-element 'T)) + (x1-subscripts (make-list num-dimensions :initial-element 'T))) + (setf (nth (- num-dimensions 1) x0-subscripts) '(0 1)) ; Select index 0 + (setf (nth (- num-dimensions 1) x1-subscripts) '(1 2)) ; Select index 1 + (let* ((x0 (apply #'!view xshaped x0-subscripts)) + (x1 (apply #'!view xshaped x1-subscripts)) + (rotated0 (!sub (!mul x0 cosine-reshaped) (!mul x1 sine-reshaped))) + (rotated1 (!add (!mul x0 sine-reshaped) (!mul x1 cosine-reshaped))) + (x-out (!concatenate -1 rotated0 rotated1)) + (x-out-final (!reshape x-out (list batch-size seq-len num-heads (* 2 dim))))) + (let ((final-result + (if (= (* 2 dim) head-dim) + x-out-final + (let* ((x-dims (shape x)) + (subs (make-list (length x-dims) :initial-element 'T))) + (setf (nth (- (length x-dims) 1) subs) (list (- head-dim 1) head-dim)) + (let ((last-elem (apply #'!view x subs))) ; Shape: (batch-size, seq-len, num-heads, 1) + (!concatenate -1 x-out-final last-elem)))))) + (!reshape final-result x-shape))))))) (defun !rope (x dim) (declare (type tensor x)) - (forward (rope dim) x)) + (forward (rope dim) x)) \ No newline at end of file diff --git a/source/test-suite/test-rope.lisp b/source/test-suite/test-rope.lisp index 16ae3d62b..262c5a9ff 100644 --- a/source/test-suite/test-rope.lisp +++ b/source/test-suite/test-rope.lisp @@ -29,6 +29,6 @@ def torch_rope(x): (assert-equal (:atol 1e-4 :rtol 1e-6) (with-torch (x) (->caten (torch_rope x))) - (proceed (!rope x 1))))) + (proceed (!rope x 5))))) (skip "TODO: Not working with JIT=1"))) From 688ab624524046522402b0fd669f5051b9772514 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Fri, 22 Nov 2024 11:28:27 +0100 Subject: [PATCH 29/30] review changes added: assertion for tensor number of dimensions = 4, make-list with initial true, multiple-value-bind instead of manual initialization, assertion instead of when Co-Authored-By: hikettei <88639579+hikettei@users.noreply.github.com> --- source/nn/positional-encoding.lisp | 87 +++++++++++++++--------------- 1 file changed, 44 insertions(+), 43 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index a833931de..e5f1b7ae7 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -9,49 +9,50 @@ (defmethod call ((op RoPE) &rest inputs) (with-slots (dim base) op - (let* ((x (first inputs)) - (x-shape (shape x)) ; x-shape: (batch-size, seq-len, num-heads, head-dim) - (batch-size (nth 0 x-shape)) - (seq-len (nth 1 x-shape)) - (num-heads (nth 2 x-shape)) - (head-dim (nth 3 x-shape))) - ;; Validate that dim matches head_dim / 2 - (when (/= dim (/ head-dim 2)) - (error "Mismatch: Provided dim (~A) does not match head-dim / 2 (~A)" dim (/ head-dim 2))) - (let* ((indices (!mul (!const x 2) (!index-components (list dim)))) ; Shape: (dim) - (exponents (!div indices (!const x head-dim))) ; Shape: (dim) - (base (!const x base)) - (log-base (!log base)) ; Natural logarithm - (theta (!exp (!neg (!mul exponents log-base)))) ; Shape: (dim) - (theta-reshaped (!reshape theta (list 1 dim))) ; Shape: (1, dim) - (seq-idx (!index-components (list seq-len))) ; Shape: (seq-len) - (seq-idx-reshaped (!reshape seq-idx (list seq-len 1))) ; Shape: (seq-len, 1) - (idx-theta (!mul seq-idx-reshaped theta-reshaped)) ; Shape: (seq-len, dim) - (cosine (!cos idx-theta)) ; Shape: (seq-len, dim) - (sine (!sin idx-theta)) ; Shape: (seq-len, dim) - (xshaped (!reshape x (list batch-size seq-len num-heads dim 2))) - (cosine-reshaped (!reshape cosine (list 1 seq-len 1 dim 1))) - (sine-reshaped (!reshape sine (list 1 seq-len 1 dim 1))) - (num-dimensions 5) - (x0-subscripts (make-list num-dimensions :initial-element 'T)) - (x1-subscripts (make-list num-dimensions :initial-element 'T))) - (setf (nth (- num-dimensions 1) x0-subscripts) '(0 1)) ; Select index 0 - (setf (nth (- num-dimensions 1) x1-subscripts) '(1 2)) ; Select index 1 - (let* ((x0 (apply #'!view xshaped x0-subscripts)) - (x1 (apply #'!view xshaped x1-subscripts)) - (rotated0 (!sub (!mul x0 cosine-reshaped) (!mul x1 sine-reshaped))) - (rotated1 (!add (!mul x0 sine-reshaped) (!mul x1 cosine-reshaped))) - (x-out (!concatenate -1 rotated0 rotated1)) - (x-out-final (!reshape x-out (list batch-size seq-len num-heads (* 2 dim))))) - (let ((final-result - (if (= (* 2 dim) head-dim) - x-out-final - (let* ((x-dims (shape x)) - (subs (make-list (length x-dims) :initial-element 'T))) - (setf (nth (- (length x-dims) 1) subs) (list (- head-dim 1) head-dim)) - (let ((last-elem (apply #'!view x subs))) ; Shape: (batch-size, seq-len, num-heads, 1) - (!concatenate -1 x-out-final last-elem)))))) - (!reshape final-result x-shape))))))) + (let* ((x (first inputs))) + ;; Assert tensor rank + (assert (= (ndim x) 4) () "Input tensor must have rank 4, but got ~A" (ndim x)) + (multiple-value-bind (batch-size seq-len num-heads head-dim) + (apply #'values (shape x)) + ;; Validate that dim matches head-dim / 2 + (assert (= dim (/ head-dim 2)) () "Mismatch: Provided dim (~A) does not match head-dim / 2 (~A)" dim (/ head-dim 2)) + + ;; Computations + (let* ((indices (!mul (!const x 2) (!index-components (list dim)))) ; Shape: (dim) + (exponents (!div indices (!const x head-dim))) ; Shape: (dim) + (base (!const x base)) + (log-base (!log base)) ; Natural logarithm + (theta (!exp (!neg (!mul exponents log-base)))) ; Shape: (dim) + (theta-reshaped (!reshape theta (list 1 dim))) ; Shape: (1, dim) + (seq-idx (!index-components (list seq-len))) ; Shape: (seq-len) + (seq-idx-reshaped (!reshape seq-idx (list seq-len 1))) ; Shape: (seq-len, 1) + (idx-theta (!mul seq-idx-reshaped theta-reshaped)) ; Shape: (seq-len, dim) + (cosine (!cos idx-theta)) ; Shape: (seq-len, dim) + (sine (!sin idx-theta)) ; Shape: (seq-len, dim) + (xshaped (!reshape x (list batch-size seq-len num-heads dim 2))) + (cosine-reshaped (!reshape cosine (list 1 seq-len 1 dim 1))) + (sine-reshaped (!reshape sine (list 1 seq-len 1 dim 1))) + (num-dimensions 5) + (x0-subscripts (make-list num-dimensions :initial-element t)) + (x1-subscripts (make-list num-dimensions :initial-element t))) + (setf (nth (- num-dimensions 1) x0-subscripts) '(0 1)) ; Select index 0 + (setf (nth (- num-dimensions 1) x1-subscripts) '(1 2)) ; Select index 1 + + (let* ((x0 (apply #'!view xshaped x0-subscripts)) + (x1 (apply #'!view xshaped x1-subscripts)) + (rotated0 (!sub (!mul x0 cosine-reshaped) (!mul x1 sine-reshaped))) + (rotated1 (!add (!mul x0 sine-reshaped) (!mul x1 cosine-reshaped))) + (x-out (!concatenate -1 rotated0 rotated1)) + (x-out-final (!reshape x-out (list batch-size seq-len num-heads (* 2 dim))))) + (let ((final-result + (if (= (* 2 dim) head-dim) + x-out-final + (let* ((x-dims (shape x)) + (subs (make-list (length x-dims) :initial-element t))) + (setf (nth (- (length x-dims) 1) subs) (list (- head-dim 1) head-dim)) + (let ((last-elem (apply #'!view x subs))) ; Shape: (batch-size, seq-len, num-heads, 1) + (!concatenate -1 x-out-final last-elem)))))) + (!reshape final-result (shape x))))))))) (defun !rope (x dim) (declare (type tensor x)) From 71f7dfb9f43fa07a4b7790436cea376d4d151882 Mon Sep 17 00:00:00 2001 From: Ayman Bourramouss Date: Fri, 22 Nov 2024 11:29:58 +0100 Subject: [PATCH 30/30] add types Co-Authored-By: hikettei <88639579+hikettei@users.noreply.github.com> --- source/nn/positional-encoding.lisp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/nn/positional-encoding.lisp b/source/nn/positional-encoding.lisp index e5f1b7ae7..30b47253a 100644 --- a/source/nn/positional-encoding.lisp +++ b/source/nn/positional-encoding.lisp @@ -4,8 +4,8 @@ ;; PositionalEncoding (defmodel (RoPE (dim &key (base 10000))) - ((dim dim) - (base base))) + ((dim dim :type fixnum) + (base base :type fixnum))) (defmethod call ((op RoPE) &rest inputs) (with-slots (dim base) op