Skip to content

Commit

Permalink
Feat: defcall (#268)
Browse files Browse the repository at this point in the history
  • Loading branch information
hikettei authored Nov 29, 2024
1 parent 9df37b3 commit 962a37d
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 25 deletions.
34 changes: 14 additions & 20 deletions external/llm/layers.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@
((c-fc (Linear dim hidden-dim))
(c-proj (Linear hidden-dim dim))))

(defmethod call ((model FeedForward) &rest inputs)
(multiple-value-bind (x) (apply #'values inputs)
(with-slots ((c-fc c-fc) (c-proj c-proj)) model
(forward c-proj (!gelu (forward c-fc x))))))
(defcall (model FeedForward) (X[~])
(with-slots ((c-fc c-fc) (c-proj c-proj)) model
(forward c-proj (!gelu (forward c-fc x)))))

(defmodel (TransformerBlock (dim n-heads &key (norm-eps 1e-5) (max-seq-len 1024)))
((attn (Attention dim n-heads max-seq-len))
Expand All @@ -79,19 +78,14 @@
(ln-f (LayerNorm `(,dim) :eps norm-eps))
(lm-head (Linear dim vocab-size :bias nil))))

(defmethod call ((model Transformer) &rest inputs)
(multiple-value-bind (tokens start-pos) (apply #'values inputs)
(assert (and tokens start-pos))
(st "Tokens[batch sentence_length] Start_Pos[] -> Tokens[batch sentence_length]" (tokens start-pos))
;; (assert (numberp (second (shape tokens))) () "The second dimension of the tensor (seq_len) must be a fixnum, getting ~a" (shape tokens))
(with-slots ((wte wte) (wpe wpe) (h h) (ln-f ln-f) (lm-head lm-head)) model
(let* ((token-emb (forward wte tokens))
(pos-emb (forward wpe (!cast (!add start-pos (!index-components `(1 ,(second (shape tokens))))) (dtype-of tokens))))
(hi (!add token-emb pos-emb))
(seq-len (iconst (second (shape tokens))))
(mask (!triu (!full `(1 1 ,seq-len ,(!+ start-pos (iconst seq-len))) (-inf)) :diagonal (!+ (iconst 1) start-pos)))
(_ (dolist (hn h) (setf hi (forward hn hi mask start-pos))))
(logits (forward lm-head (forward ln-f hi))))
(declare (ignore _))
;; (!argmax (!view logits t -1 t))
(!argmax logits)))))
(defcall (model Transformer) (Tokens[Batch Seq-Len] Start-Pos[])
(with-slots ((wte wte) (wpe wpe) (h h) (ln-f ln-f) (lm-head lm-head)) model
(let* ((token-emb (forward wte tokens))
(pos-emb (forward wpe (!cast (!add start-pos (!index-components `(1 ,seq-len))) (dtype-of tokens))))
(hi (!add token-emb pos-emb))
(mask (!triu (!full `(1 1 ,seq-len ,(!+ start-pos (iconst seq-len))) (-inf)) :diagonal (!+ (iconst 1) start-pos)))
(_ (dolist (hn h) (setf hi (forward hn hi mask start-pos))))
(logits (forward lm-head (forward ln-f hi))))
(declare (ignore _))
;; (!argmax (!view logits t -1 t))
(!argmax logits))))
2 changes: 1 addition & 1 deletion source/apis/documentation.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ Compute the backward pass of the compiled computational graph (AVM). Note that t
(docs:define-page ("Models" "packages/caten.apis.models.md")
(docs:title "Models")
(docs:body "TODO")
)
(docs:doc/macro "defcall" 'defcall))

(docs:define-page ("AOT" "packages/caten.apis.aot.md")
(docs:title "AOT")
Expand Down
42 changes: 42 additions & 0 deletions source/apis/model.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,45 @@
,@(loop for slot-name in slot-names
collect `(setf (slot-value ,name ',slot-name) ,slot-name))
,@body)))))))

(defmacro defcall ((model-bind model) (&rest inputs) &body body)
"
```
(defcall (model-bind model) (&rest inputs) body0
```
A macro to write `defmethod call` in a more concise way.
### Example
```lisp
(defcall (model Transformer) (Tokens[Batch Seq-Len] Start-Pos[])
(with-slots ((wte wte) (wpe wpe) (h h) (ln-f ln-f) (lm-head lm-head)) model
(let* ((token-emb (forward wte tokens))
(pos-emb (forward wpe (!cast (!add start-pos (!index-components `(1 ,seq-len))) (dtype-of tokens))))
(hi (!add token-emb pos-emb))
(mask (!triu (!full `(1 1 ,seq-len ,(!+ start-pos (iconst seq-len))) (-inf)) :diagonal (!+ (iconst 1) start-pos)))
(_ (dolist (hn h) (setf hi (forward hn hi mask start-pos))))
(logits (forward lm-head (forward ln-f hi))))
(declare (ignore _))
;; (!argmax (!view logits t -1 t))
(!argmax logits))))
```
"
(let* ((where (princ-to-string inputs))
(where (when inputs (subseq where 1 (1- (length where)))))
(wt (when where (%parse-st (format nil "~a -> ~a" where where))))
(wt (if where wt (make-st "" nil nil)))
(keys (when wt (remove-duplicates (flatten (map 'list #'at-shape (st-bf wt)))))))
(with-gensyms (inputs solved)
`(defmethod call ((,model-bind ,model) &rest ,inputs)
(assert (= ,(length (st-bf wt)) (length ,inputs))
()
"(call ~a &rest inputs): The number of inputs does not match defined inputs ~a~%call is defined as: ~a" ',model ,inputs ,where)
,(when where `(st ,(format nil "~a -> ~a" where where) (,inputs)))
(multiple-value-bind (,@(map 'list (compose #'intern #'princ-to-string #'at-name) (st-bf wt))) (apply #'values ,inputs)
(let ((,solved ,(when where `(%solve-st (%parse-st ,(format nil "~a -> ~a" where where)) nil nil :tensors ,inputs :return-solved t))))
(declare (ignorable ,solved))
(let (,@(loop for key in keys collect `(,(intern (princ-to-string key)) (gethash ,key ,solved))))
(declare (ignorable ,@(map 'list (compose #'intern #'princ-to-string) keys)))
,@body)))))))
2 changes: 1 addition & 1 deletion source/apis/package.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
#:float-type-of
)
;; from model.lisp
(:export #:defmodel #:call)
(:export #:defmodel #:call #:defcall)
;; from conditions.lisp
(:export
#:caten-forward-error
Expand Down
7 changes: 4 additions & 3 deletions source/apis/shape-tracker.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
else
collect tns))))))

(defun %solve-st (st lazy-solve allow-broadcast &rest tensors &aux (tensors (flatten tensors)))
(defun %solve-st (st lazy-solve allow-broadcast &key tensors (return-solved nil) &aux (tensors (flatten tensors)))
"lazy-solve = (symbol . value)"
(declare (type ShapeTracker st)
(type list tensors)
Expand Down Expand Up @@ -224,6 +224,7 @@
(at-shape (find-at (at-name at) :key #'st-aft)))
(tensor-tr base)
nil))))))
(when return-solved (return-from %solve-st solved))
(apply #'values (map 'list #'make-new-tensor (st-aft st))))))
(defun parse-where (where)
"Verifies the where form"
Expand Down Expand Up @@ -267,14 +268,14 @@ TODO: Add LazyAssertion which applies shape check even for symbols
"
(declare (type string st-notation))
(let ((st (%st->list (%parse-st st-notation))))
`(%solve-st ,st ,(parse-where where) nil ,@input-tensors)))
`(%solve-st ,st ,(parse-where where) nil :tensors (list ,@input-tensors))))

(defmacro bc (st-notation (&rest input-tensors) &rest where)
"## [macro] bc
Perform the same operation as `st`, but also doing broadcasting.
It calls !reshape and !view inside, therefore, it must not be used inside the forward method."
(declare (type string st-notation))
(let ((st (%st->list (%parse-st st-notation))))
`(%solve-st ,st ,(parse-where where) t ,@input-tensors)))
`(%solve-st ,st ,(parse-where where) t :tensors (list ,@input-tensors))))

(defun broadcast-elwise (a b) (multiple-value-list (bc "A[~] B[~] -> A[~] B[~]" (a b))))

0 comments on commit 962a37d

Please sign in to comment.