Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: defcall #268

Merged
merged 2 commits into from
Nov 29, 2024
Merged

Feat: defcall #268

merged 2 commits into from
Nov 29, 2024

Conversation

hikettei
Copy link
Owner

@hikettei hikettei commented Nov 28, 2024

  • Replace defmethod call with defcall
  • Documentation
(defmethod call ((model Transformer) &rest inputs)
  (multiple-value-bind (tokens start-pos) (apply #'values inputs)
    (assert (and tokens start-pos))
    (st "Tokens[batch sentence_length] Start_Pos[] -> Tokens[batch sentence_length]" (tokens start-pos))
    ;; (assert (numberp (second (shape tokens))) () "The second dimension of the tensor (seq_len) must be a fixnum, getting ~a" (shape tokens))
    (with-slots ((wte wte) (wpe wpe) (h h) (ln-f ln-f) (lm-head lm-head)) model
      (let* ((token-emb (forward wte tokens))
	     (pos-emb   (forward wpe (!cast (!add start-pos (!index-components `(1 ,(second (shape tokens))))) (dtype-of tokens))))
	     (hi (!add token-emb pos-emb))
             (seq-len (iconst (second (shape tokens))))
	     (mask (!triu (!full `(1 1 ,seq-len ,(!+ start-pos (iconst seq-len))) (-inf)) :diagonal (!+ (iconst 1) start-pos)))
	     (_ (dolist (hn h) (setf hi (forward hn hi mask start-pos))))
	     (logits (forward lm-head (forward ln-f hi))))
	(declare (ignore _))
	;; (!argmax (!view logits t -1 t))
        (!argmax logits)))))

===> replace

(defcall (model Transformer) (Tokens[Batch Seq-Len] Start-Pos[])
  (with-slots ((wte wte) (wpe wpe) (h h) (ln-f ln-f) (lm-head lm-head)) model
    (let* ((token-emb (forward wte tokens))
	   (pos-emb   (forward wpe (!cast (!add start-pos (!index-components `(1 ,seq-len))) (dtype-of tokens))))
	   (hi (!add token-emb pos-emb))
	   (mask (!triu (!full `(1 1 ,seq-len ,(!+ start-pos (iconst seq-len))) (-inf)) :diagonal (!+ (iconst 1) start-pos)))
	   (_ (dolist (hn h) (setf hi (forward hn hi mask start-pos))))
	   (logits (forward lm-head (forward ln-f hi))))
      (declare (ignore _))
      ;; (!argmax (!view logits t -1 t))
      (!argmax logits))))

@hikettei hikettei marked this pull request as ready for review November 29, 2024 04:49
@hikettei hikettei merged commit 962a37d into main Nov 29, 2024
6 checks passed
@hikettei hikettei deleted the defcall branch November 29, 2024 05:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant