Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Pointer Adress to caten/aasm, and revisit the semantic of :MOVE #160

Merged
merged 6 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions source/aasm/attrs.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ out <- x ^ y (if integer)
out <- move(x, y)
where move(x, y) is x = y
```
- _jit_dont_render_me[boolean] (TODO)
"
:slots ((_jit_dont_render_me :initform nil)))
")

(defnode (:BinaryOps :MAX) (BinaryOps JITAble)
"Computes the maximum value of two tensors in read, writing the result to the first write.
Expand Down
4 changes: 1 addition & 3 deletions source/apis/iseq.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@
(push final-node (graph-nodes graph))
(session/assign session grad-id final-node))
(let* ((subgrad (session/read session (car rest-grads)))
;; [TODO] Remove _jit_dont_render_me option (should only used in: ./ajit/graph.lisp, air->expr)
;; JIT do not want to render the ir below.
(final-node (make-node :BinaryOps :MOVE (list grad-id) (list final-grad-id (node->id subgrad)) :_jit_dont_render_me t)))
(final-node (make-node :BinaryOps :MOVE (list grad-id) (list final-grad-id (node->id subgrad)))))
(push final-node (graph-nodes graph))
(session/assign session grad-id final-node))))))
(session-grad->grads session)))
Expand Down
42 changes: 33 additions & 9 deletions source/codegen/blueprint.lisp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
(defpackage :caten/codegen/blueprint
(:documentation "The `Blueprint` represents a transformed computation graph format of `caten/AASM` that incorporates loop information. The `lower-schedule-item` method infers loop boundaries based on `Schedule-item` and performs lowering into a format that includes :FOR/:ENDFOR nodes.
The `Blueprint` is a data structure closer to the `Renderer` than AASM, and it is used for loop optimization and by the Renderer.")
(:use :cl :caten/air :caten/codegen/expr :alexandria)
(:use :cl :caten/air :caten/codegen/expr :alexandria :caten/codegen/expr-cache)
(:import-from
:caten/codegen/shape-inference
#:read-type-relay
Expand Down Expand Up @@ -41,7 +41,10 @@ The `Blueprint` is a data structure closer to the `Renderer` than AASM, and it i
(:import-from
:caten/codegen/exprify
#:graph-exprify
#:graph-scalarify)
#:graph-scalarify
#:expr-set-iterations
#:graph-propagate-pointer-id-type
#:expr-rewrite-edge-with-pointer-id)
(:export
#:lower-schedule-item
#:print-blueprint))
Expand Down Expand Up @@ -459,7 +462,7 @@ Depends=~a Reduce=~a Users=~a
(node-reads node))
nil)))

(defmethod schedule-item-infer-io-buffers ((node Node) (bp-items list))
(defmethod schedule-item-infer-io-buffers ((node Node) (bp-items list) rewrite-map)
(assert (eql (node-type node) :Schedule-Item))
(let ((seen) (read-items) (write-items))
(loop for item in bp-items
Expand All @@ -477,10 +480,27 @@ Depends=~a Reduce=~a Users=~a
do (push write seen)))
(setf (node-reads node) (map 'list #'car read-items)
(node-writes node) (map 'list #'car write-items)
(getattr node :read-types) (map 'list #'cdr read-items)
(getattr node :write-types) (map 'list #'cdr write-items)
(getattr node :storage-id-src) (map 'list #'car read-items)
(getattr node :storage-id-dst) (map 'list #'car write-items))))
seen nil)
(labels ((reduce-address-of (id)
(if (gethash id rewrite-map)
(reduce-address-of (gethash id rewrite-map))
id))
(address-of (id)
(read-ptrid (reduce-address-of id)))
(only-unseen (list)
(loop for l in list
for id = (address-of (car l))
if (null (find id seen))
do (push id seen) and collect l))
(make-pair (list)
(only-unseen (remove-duplicates list :key #'(lambda (x) (address-of (car x)))))))
(multiple-value-bind (write-items read-items)
(values (make-pair write-items) (make-pair read-items))
(setf
(getattr node :read-types) (map 'list #'cdr read-items)
(getattr node :write-types) (map 'list #'cdr write-items)
(getattr node :storage-id-src) (map 'list (compose #'reduce-address-of #'car) read-items)
(getattr node :storage-id-dst) (map 'list (compose #'reduce-address-of #'car) write-items))))))

(defmethod schedule-item-gather-dynamic-shapes ((node Node) base-graph)
(flet ((is-dynamic-shape-p (val)
Expand Down Expand Up @@ -540,8 +560,12 @@ Depends=~a Reduce=~a Users=~a
(setf (ctx-blueprint ctx) (simplify-blueprint (ctx-blueprint ctx))
(ctx-blueprint ctx) (graph-scalarify (ctx-blueprint ctx) node scheduled-graph)
(ctx-blueprint ctx) (graph-exprify (ctx-blueprint ctx) node scheduled-graph))
(expr-set-iterations (ctx-blueprint ctx))
(multiple-value-bind (new-bp id-rewrite-map) (graph-propagate-pointer-id-type (ctx-blueprint ctx) scheduled-graph)
(setf (ctx-blueprint ctx) new-bp)
;; Infer the input/output buffers again, they can be removed during the op fusion.
(schedule-item-infer-io-buffers node (ctx-blueprint ctx) id-rewrite-map)
(expr-rewrite-edge-with-pointer-id (ctx-blueprint ctx) id-rewrite-map))
(when (and (>= (ctx:getenv :JIT_DEBUG) 2) (null (getattr node :cache-name)))
(print-blueprint (ctx-blueprint ctx) t))
;; Infer the input/output buffers again, they can be removed during the op fusion.
(schedule-item-infer-io-buffers node (ctx-blueprint ctx))
(setf (getattr node :blueprint) (ctx-blueprint ctx)))))
13 changes: 10 additions & 3 deletions source/codegen/expr-cache.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
#:stash-expr
#:restore-expr
#:expr-cache-reduce-alias
#:read-newid))
#:cache-pointer-map
#:read-newid
#:read-ptrid))

(in-package :caten/codegen/expr-cache)

Expand All @@ -24,11 +26,12 @@
((cache :type hash-table :initform (make-hash-table :test 'equal) :accessor cache-table)
(id2expr :type hash-table :initform (make-hash-table :test 'equal) :accessor id2expr-table)
(global-counter :initform 0 :type fixnum :accessor global-counter)
(pointer-map :type hash-table :accessor cache-pointer-map :initarg :pointer-map)
(global-reduce-alias :type hash-table :initform (make-hash-table :test 'equal) :accessor expr-cache-reduce-alias))
(:documentation "Creates a cached object for (scalar) EXPR graph."))

(defmacro with-expr-cache (() &body body)
`(let ((*expr-cache* (make-instance 'Expr-Cache)))
(defmacro with-expr-cache ((&key (pointer-map (make-hash-table))) &body body)
`(let ((*expr-cache* (make-instance 'Expr-Cache :pointer-map ,pointer-map)))
,@body))

(defun expr-id (num) (format nil "_expr_id_~a" num))
Expand Down Expand Up @@ -66,3 +69,7 @@
(defun read-newid (id)
(assert *expr-cache*)
(or (gethash id (expr-cache-reduce-alias *expr-cache*)) id))

(defun read-ptrid (id)
(assert *expr-cache*)
(or (gethash id (cache-pointer-map *expr-cache*)) id))
76 changes: 51 additions & 25 deletions source/codegen/exprify.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
(:export
#:graph-scalarify
#:graph-exprify
#:graph-propagete-reduction))
#:graph-propagate-pointer-id-type
#:expr-set-iterations
#:expr-rewrite-edge-with-pointer-id))

(in-package :caten/codegen/exprify)

Expand Down Expand Up @@ -210,7 +212,7 @@
(loop for item in (graph-nodes (expr-graph (getattr expr :expr)))
append (node-writes item)))

(defun expr-only-leaf-are-arguments (nodes)
(defun expr-only-leaf-are-arguments (nodes schedule-graph)
"Rewrites the expr in nodes, to have only the leaf nodes as an argument."
(loop for node in nodes
if (eql (node-type node) :EXPR) do
Expand All @@ -219,7 +221,10 @@
for node in (graph-nodes expr)
;; See renderer.lisp, MOVE first argument is not rendered for example.
;; [Note] Add more nodes if you found an argument which is actually rendered but not used in the rendered kernel.
if (find (node-type node) `(:MOVE :CAST :!= :< :INDEX-COMPONENTS :LOAD :STORE))
if (and
(find (node-type node) `(:MOVE :CAST :!= :< :INDEX-COMPONENTS :LOAD :STORE))
(id->value schedule-graph (car (node-reads node)))
(getattr (id->value schedule-graph (car (node-reads node))) :allocate-p))
collect (car (node-reads node))
if (not (eql (node-type node) :Aref))
collect (car (node-writes node))))
Expand Down Expand Up @@ -316,27 +321,12 @@
;; C <- A + B
;; =>
;; A += B
(expr-only-leaf-are-arguments (graph-propagate-reduction (rewriter 0 (length new-bp)) replaceable))))))

(defun rewrite-expr-aref (expr replace)
(declare (type graph expr))
(dolist (n (graph-nodes expr))
(if (eql (node-type n) :AREF)
(multiple-value-bind (id type is) (funcall replace (car (node-writes n)) (getattr n :buffer) (getattr n :space))
(setf (node-writes n) (list id)
(getattr n :buffer) type
(getattr n :space) is))
(setf (node-reads n)
(map
'list
#'(lambda (x)
(let ((id (funcall replace x nil nil)))
(or id x)))
(node-reads n))))))

(defmethod graph-propagate-reduction (blueprint replaceable)
(rewriter 0 (length new-bp))))))
;; [TODO] Clean up this function!
(defun graph-propagate-pointer-id-type (blueprint schedule-graph)
(assert *expr-cache*)
(let ((id->tgt (expr-cache-reduce-alias *expr-cache*))) ;; id -> (list new_id new_type new_is)
(let ((rewrite-map (make-hash-table))
(id->tgt (expr-cache-reduce-alias *expr-cache*))) ;; id -> (list new_id new_type new_is)
(loop for bp in blueprint
if (and (eql (node-type bp) :EXPR) (getattr bp :reduction))
do (setf (gethash (car (node-writes bp)) id->tgt)
Expand All @@ -349,7 +339,9 @@
(let ((new-id (final-new-id id)))
(if (null new-id)
(values id type is)
(apply #'values new-id)))))
(progn
(setf (gethash id rewrite-map) (car new-id))
(values id (second new-id) (third new-id)))))))
(macrolet ((updt (expr &key (reader) (ireader) (treader))
`(loop for nth upfrom 0
for read in (,reader ,expr)
Expand All @@ -366,7 +358,39 @@
do (updt bp :reader node-reads :ireader relay-read-iters :treader relay-reads)
(updt bp :reader node-writes :ireader relay-write-iters :treader relay-writes)
(rewrite-expr-aref (expr-graph (getattr bp :expr)) #'new))
(expr-set-iterations blueprint)))))
(values (expr-only-leaf-are-arguments blueprint schedule-graph) rewrite-map)))))

(defun rewrite-expr-aref (expr replace)
(declare (type graph expr))
(dolist (n (graph-nodes expr))
(if (eql (node-type n) :AREF)
(multiple-value-bind (id type is) (funcall replace (car (node-writes n)) (getattr n :buffer) (getattr n :space))
(setf (node-writes n) (list id)
(getattr n :buffer) (or type (getattr n :buffer))
(getattr n :space) (or is (getattr n :space))))
(setf (node-reads n)
(map
'list
#'(lambda (x)
(let ((id (funcall replace x nil nil)))
(or id x)))
(node-reads n))))))

(defun expr-rewrite-edge-with-pointer-id (blueprint map)
(labels ((newid (id &optional _ __) (declare (ignore _ __))
(if (gethash id map)
(newid (gethash id map))
id)))
(macrolet ((updt (expr &key (reader))
`(loop for nth upfrom 0
for read in (,reader ,expr)
if (and read (symbolp read)) do
(setf (nth nth (,reader ,expr)) (newid read)))))
(loop for bp in blueprint
if (eql (node-type bp) :EXPR) do
(updt bp :reader node-reads)
(updt bp :reader node-writes)
(rewrite-expr-aref (expr-graph (getattr bp :expr)) #'newid)))))

(defun expr-set-iterations (blueprint)
(loop with gids = nil
Expand All @@ -387,3 +411,5 @@
(when is
(setf (getattr bp :iterations) (ensure-iteration-space-length is (getattr bp :iterations))))))
blueprint)


85 changes: 45 additions & 40 deletions source/codegen/jit.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ caten/codegen overview:
(:import-from
:caten/codegen/rewriting-rules
#:apply-rewriting-rules
#:graph-infer-pointer-address
#:schedule-item-write-define-global)
(:import-from
:caten/codegen/scheduler
Expand All @@ -74,7 +75,8 @@ caten/codegen overview:
#:print-progress)
(:import-from
:caten/codegen/expr-cache
#:with-expr-cache)
#:with-expr-cache
#:read-ptrid)
(:import-from
:caten/codegen/memory-planner
#:run-memory-planner)
Expand Down Expand Up @@ -136,7 +138,7 @@ caten/codegen overview:
(append
(getattr si :storage-id-dst) ;; optimized by memory-planner
(map 'list #'car (getattr si :dynamic-shapes))
(node-reads si))
(getattr si :storage-id-src))
:output-buffer-n (length (node-writes si))
:kernel-info (make-compiled-kernel-from-si si graph)))

Expand Down Expand Up @@ -203,7 +205,6 @@ caten/codegen overview:

(defun schedule-graph->vmop (avm graph &aux (map (id->output-map graph)))
(declare (type Graph graph))
(verify-graph graph)
(let ((nodes) (allocated))
(flet ((merge-id (id)
(multiple-value-bind (deps new-seen) (get-subgraph (avm-graph avm) id allocated)
Expand Down Expand Up @@ -325,17 +326,18 @@ caten/codegen overview:
(remove-duplicates
(loop for node in (graph-nodes (avm-graph avm))
if (and (eql (node-type node) :LOAD) (symbolp (getattr node :value)))
collect (getattr node :value)))))
collect (getattr node :value))))
(pointer-map (graph-infer-pointer-address (avm-graph avm))))
(declare (type Graph schedule-graph))
;; 5. Minifying the number of schedules, (reuse kernels)
(minify-equivalent-schedule schedule-graph)
;; 6. Start JIT Compilation. (Performing by group)
(let ((total-kernels (count-if #'(lambda (x) (getattr x :jitable)) (graph-nodes schedule-graph))))
(when (>= (ctx:getenv :JIT_DEBUG) 2)
(print-info "JIT Compilation Start (AVM=~a)" (avm-name avm)))
;; [TODO] mapc is pmapc
(with-progress (total-kernels :debug (if (>= (ctx:getenv :JIT_DEBUG) 2) 1 -1) :timeit nil)
(with-expr-cache () ;; Initialize a cache to treat (EXPR: a*b) as a symbolic and make symbolic collapsed loops as an affine loop.
(with-expr-cache (:pointer-map pointer-map) ;; Initialize a cache to treat (EXPR: a*b) as a symbolic and make symbolic collapsed loops as an affine loop.
;; 5. Minifying the number of schedules, (reuse kernels)
(minify-equivalent-schedule schedule-graph)
;; 6. Start JIT Compilation. (Performing by group)
(let ((total-kernels (count-if #'(lambda (x) (getattr x :jitable)) (graph-nodes schedule-graph))))
(when (>= (ctx:getenv :JIT_DEBUG) 2)
(print-info "JIT Compilation Start (AVM=~a)" (avm-name avm)))
;; [TODO] mapc is pmapc
(with-progress (total-kernels :debug (if (>= (ctx:getenv :JIT_DEBUG) 2) 1 -1) :timeit nil)
(mapc
#'(lambda (x &aux (start (get-internal-real-time)))
(when (and (getattr x :jitable) (getattr x :cache-name))
Expand Down Expand Up @@ -366,30 +368,33 @@ caten/codegen overview:
(format t "=====> Optimized kernel~%")
(print-blueprint (getattr x :blueprint) t)))
(when (>= (ctx:getenv :JIT_DEBUG) 2)
(format t "Compilation Time : ~A(sec)" (float (/ (- (get-internal-real-time) start) internal-time-units-per-second)))))))
(graph-nodes schedule-graph)))))
;; 10. Running memory-planner, update the storage-id
(when (>= (ctx:getenv :JIT_DEBUG) 2)
(fresh-line)
(print-info "Running the memory planner..."))
(verify-graph schedule-graph)
(setf schedule-graph (->graph schedule-graph))
;;(run-memory-planner schedule-graph) disable until fixing weirdness in Padding2D/AutoDiff
(when (>= (ctx:getenv :JIT_DEBUG) 2)
(fresh-line)
(print-info "Rendering ..."))
(dolist (s (graph-nodes schedule-graph))
(when (and (getattr s :jitable) (getattr s :blueprint))
(schedule-item-write-define-global s)
(setf (getattr s :rendered-object) (%render-kernel renderer s))))
;; 11. Complete (Render by the renderer)
(when (>= (ctx:getenv :JIT_DEBUG) 2)
(fresh-line)
(print-info "Compiling ..."))
(%compile-kernel renderer (graph-nodes schedule-graph) dir)
(let ((new-graph (schedule-graph->vmop avm schedule-graph)))
(setf (avm-graph avm) new-graph
(avm-tape-length avm) (length (graph-nodes new-graph))
(avm-pc avm) 0
(avm-variables avm) (make-hash-table)))
avm))
(format t "Compilation Time : ~A(sec)" (float (/ (- (get-internal-real-time) start) internal-time-units-per-second)))))
(schedule-item-write-define-global x)))
(graph-nodes schedule-graph)))
;; 10. Running memory-planner, update the storage-id
(setf schedule-graph (->graph schedule-graph))
(verify-graph schedule-graph)
(when (>= (ctx:getenv :JIT_DEBUG) 2)
(fresh-line)
(print-info "Running the memory planner..."))
(dolist (item (graph-nodes schedule-graph))
(setf (getattr item :storage-id-src) (map 'list #'read-ptrid (getattr item :storage-id-src))
(getattr item :storage-id-dst) (map 'list #'read-ptrid (getattr item :storage-id-dst))))
;; (run-memory-planner schedule-graph) disable until fixing weirdness in Padding2D/AutoDiff
(when (>= (ctx:getenv :JIT_DEBUG) 2)
(fresh-line)
(print-info "Rendering ..."))
(dolist (s (graph-nodes schedule-graph))
(when (and (getattr s :jitable) (getattr s :blueprint))
(setf (getattr s :rendered-object) (%render-kernel renderer s))))
;; 11. Complete (Render by the renderer)
(when (>= (ctx:getenv :JIT_DEBUG) 2)
(fresh-line)
(print-info "Compiling ..."))
(%compile-kernel renderer (graph-nodes schedule-graph) dir)
(let ((new-graph (schedule-graph->vmop avm schedule-graph)))
(setf (avm-graph avm) new-graph
(avm-tape-length avm) (length (graph-nodes new-graph))
(avm-pc avm) 0
(avm-variables avm) (make-hash-table)))
avm))))
Loading
Loading