Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Integrate the local and global memory-planner #193

Merged
merged 28 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions source/apis/iseq.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@
;; (Forward Mode) First, Simplify the forward graph in :Module/:Func level
(dolist (f external-simplifiers) (funcall f forward-graph))
;; Second, lower an :module into a list of :func
(lower-all forward-graph) ;; lower-all is O(N)
(lower-all forward-graph) ;; SLOW
;; (Backward Mode) First, create a reverse-mode backward tape from the sorted forward graph.
;; the tapes consequent after the allocation of prev-grad.
(when (null no-grad) (setf iseq-bw (%make-graph-backward session iseq :iseq-bw iseq-bw)))
Expand Down Expand Up @@ -268,7 +268,7 @@
(let ((merged-graph (->fast-graph merged-graph)))
(lower-all merged-graph)
;; Function-level whole optimization
(dolist (f external-simplifiers)
(dolist (f external-simplifiers) ;; Slow but O(n)
(funcall f merged-graph :debug-opt (= 1 (the fixnum (ctx:getenv :PROFILE_SIMPLIFIER)))))
;; verify and complete
(verify-graph merged-graph)
Expand Down
1 change: 1 addition & 0 deletions source/codegen/jit.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ caten/codegen overview:
(push alloc nodes)
(when view (push view nodes)))
(push w allocated))
(dolist (w (node-writes node)) (push w allocated))
(push (make-compiled-kernel-node node graph) nodes)
;; Merging view after the JIT_KERNEL invocation
(loop for w in (node-writes node)
Expand Down
221 changes: 98 additions & 123 deletions source/codegen/memory-planner.lisp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
(defpackage :caten/codegen/memory-planner
(:use :cl :caten/air :caten/avm :caten/codegen/shape-inference :caten/codegen/expr)
(:documentation "`Memory Planner` is a data structure that abstracts the allocation and freeing of memory over time.
It is responsible for optimizing memory allocation by overlapping allocation to minimize the maximum memory usage (heap_size) required for all the time `t`.")
(:use :cl :caten/air :caten/avm :caten/codegen/shape-inference :caten/codegen/expr :alexandria)
(:export
#:run-memory-planner))
(in-package :caten/codegen/memory-planner)
Expand Down Expand Up @@ -80,97 +82,93 @@ MemoryBlock(id) is allocated when t=create, preserved until t become `release`."
(let ((node (id->value graph id)))
(when (and node (eql (node-type node) :Allocate))
(when (getattr node :from)
;; If :from is specified => the input should not be destructed.
;; Memory Planner is not allowed to destruct the input. (like: having a weight/parameter)
t))))

(defmethod run-memory-planner-global ((schedule-graph Graph) (symbolics list) (base-graph Graph))
"write_1, write_2 = f(write_suite_1, write_suite_2, *[dynamic_shape + read_buffers])
The goal of run-memory-planner is to reduce the number of :allocate-p object in schedule-graph, by rewriting write_suite1 and write_suite2."
(let* ((trace-table (make-hash-table))
(id2type (make-hash-table))
(lock-table (make-hash-table))
(total-time (length (graph-nodes schedule-graph)))
(outputs (append (graph-outputs schedule-graph) symbolics))
(constants))
(loop for node in (graph-nodes schedule-graph)
for nth upfrom 0
for lock-p = (null (getattr node :jitable)) do
(loop for val in (getattr node :storage-id-src)
for typ in (getattr node :read-types)
for time = `(,nth ,@(gethash val trace-table))
if (id-is-input-p val base-graph) do (push val outputs)
if (and (symbolp val) (null (find val constants)))
do (setf (gethash val id2type) typ (gethash val trace-table) time)) ;; (incf consume)
(loop for val in (getattr node :storage-id-dst)
for typ in (getattr node :write-types)
if (id-is-input-p val base-graph) do (push val outputs)
if (and (symbolp val) (null (gethash val trace-table)))
;; ID2Type -> the variable name and its type
;; TraceTable -> the variable name and timestamps of the variable (when it's used)
;; LockTable -> Set T to lock (never become in-place)
do (setf (gethash val id2type) typ
(gethash val trace-table) (list nth)
(gethash val lock-table) lock-p)))
(let* ((memory-blocks
(loop for key in (alexandria:hash-table-keys trace-table)
for typ = (gethash key id2type)
collect
;; [Note] A memory block lives in the range of [min{t}, max{t})
;; Plus, If the same task (e.g.: T0(x) -> T1(x) -> T0(x+1)) is scheduled, the memory block lives from 0 to 2.
(make-memoryblock
key typ
(apply #'min (gethash key trace-table))
;; Set the longest time for the output variables (not to destruct it, and users can see the result)
(if (find key outputs)
total-time
(apply #'max (gethash key trace-table)))
:lock (gethash key lock-table))))
;; Minimize the peak memory usage
(solved (greedy-solve-dsa memory-blocks total-time))
;; Retrive the solution. A hash table of OLD_MEMORY_ID -> NEW_MEMORY_ID
(alias-map (make-hash-table)))
(loop for mb in solved
do (setf (gethash (memoryblock-id mb) alias-map) (or (memoryblock-answer mb) (memoryblock-id mb))))
(flet ((newid (id) (or (gethash id alias-map) id)))
(dolist (node (graph-nodes schedule-graph))
(when (getattr node :jitable)
(setf (getattr node :storage-id-dst) (map 'list #'newid (getattr node :storage-id-dst)))))))))
(defun rewrite-bp-with-newid (item newid)
"Rewrites the given schedule item with newid"
(dolist (bp (getattr item :blueprint))
(setf (node-writes bp) (map 'list newid (node-writes bp))
(node-reads bp) (map 'list newid (node-reads bp)))
(when (eql (node-type bp) :EXPR)
(dolist (item (graph-nodes (expr-graph (getattr bp :EXPR))))
(when (eql (node-type item) :AREF)
(setf (getattr item :storage-id) (funcall newid (car (node-writes item))))))))
;; Remove Duplicated :DEFINE_GLOBAL
(setf (getattr item :blueprint)
(loop with seen = nil
for item in (getattr item :blueprint)
if (or (not (eql (node-type item) :DEFINE-GLOBAL))
(null (find (car (node-writes item)) seen)))
collect item
if (eql (node-type item) :DEFINE-GLOBAL)
do (push (car (node-writes item)) seen)))
(let* ((reads (map 'list #'cons (getattr item :storage-id-src) (getattr item :read-types)))
(writes (map 'list #'cons (getattr item :storage-id-dst) (getattr item :write-types)))
(reads (remove-duplicates reads :key (compose newid #'car)))
(writes (remove-duplicates writes :key (compose newid #'car)))
(seen))
(flet ((only-unseen (items)
(loop for (id . type) in items
if (null (find (funcall newid id) seen))
do (push (funcall newid id) seen) and collect (cons id type))))
(multiple-value-bind (writes reads) (values (only-unseen writes) (only-unseen reads))
(setf (getattr item :storage-id-src) (map 'list (compose newid #'car) reads)
(getattr item :storage-id-dst) (map 'list (compose newid #'car) writes)
(getattr item :read-types) (map 'list #'cdr reads)
(getattr item :write-types) (map 'list #'cdr writes))))))

(defun run-memory-planner-local (item schedule-graph symbolics base-graph)
"Minimizes the number of allocation buffers that are only used in the item."
(declare (type node item) (type graph schedule-graph))
(assert (eql (node-type item) :Schedule-Item))
(let* ((blueprint (getattr item :blueprint))
(defun apply-memory-planner (schedule-graph symbolics base-graph)
(declare (type graph schedule-graph))
(let* ((nodes
(loop for node in (graph-nodes schedule-graph)
if (getattr node :jitable)
append (getattr node :blueprint)
else
collect node))
(total-time (length nodes))
(trace-table (make-hash-table))
(id2type (make-hash-table))
(lock-table (make-hash-table))
(total-time (length blueprint))
(outputs ;; a list of buffers that do no changed by the memory-planner
(append ;; If the output were read by other kernels, it should be optimized by the global memory-planner.
(graph-outputs schedule-graph)
symbolics
(loop for node in (graph-nodes schedule-graph)
if (not (eql (node-id node) (node-id item)))
append (node-reads node))))
(constants))
(loop for node in blueprint
symbolics)))
(dolist (s symbolics) (setf (gethash s lock-table) t))
;; Creating a timestamp table for each node and variable.
(loop for node in nodes
for nth upfrom 0
if (not (eql (node-class node) :Render)) do
(loop for val in (node-reads node)
for typ in (relay-reads (read-type-relay node))
for time = `(,nth ,@(gethash val trace-table))
if (id-is-input-p val base-graph) do (push val outputs)
if (and (symbolp val) (null (find val constants)))
do (setf (gethash val id2type) typ (gethash val trace-table) time)) ;; (incf consume)
(loop for val in (node-writes node)
for typ in (relay-writes (read-type-relay node))
if (id-is-input-p val base-graph) do (push val outputs)
if (and (symbolp val) (null (gethash val trace-table)))
;; ID2Type -> the variable name and its type
;; TraceTable -> the variable name and timestamps of the variable (when it's used)
;; LockTable -> Set T to lock (never become in-place)
do (setf (gethash val id2type) typ
(gethash val trace-table) (list nth))))
if (eql (node-type node) :Schedule-Item) ; Optimization for non-jitable instructions (like: foreign kernel calls, allocation, pause/backward)
do (loop for val in (getattr node :storage-id-src)
for typ in (getattr node :read-types)
for time = `(,nth ,@(gethash val trace-table))
if (id-is-input-p val base-graph) do (push val outputs)
if (symbolp val)
do (setf (gethash val id2type) typ (gethash val trace-table) time))
(loop for val in (getattr node :storage-id-dst)
for typ in (getattr node :write-types)
for time = `(,nth ,@(gethash val trace-table))
if (id-is-input-p val base-graph) do (push val outputs)
if (and (symbolp val) (null (gethash val trace-table)))
do (setf (gethash val id2type) typ) (gethash val trace-table) (list nth))
if (and
(not (eql (node-type node) :Schedule-Item)) ; For jitable and lowered instructions
(not (eql (node-class node) :Render)))
do (loop for val in (node-reads node)
for typ in (relay-reads (read-type-relay node))
for time = `(,nth ,@(gethash val trace-table))
if (id-is-input-p val base-graph) do (push val outputs)
if (symbolp val)
do (setf (gethash val id2type) typ (gethash val trace-table) time))
(loop for val in (node-writes node)
for typ in (relay-writes (read-type-relay node))
if (id-is-input-p val base-graph) do (push val outputs)
if (and (symbolp val) (null (gethash val trace-table)))
;; ID2Type -> the variable name and its type
;; TraceTable -> the variable name and timestamps of the variable (when it's used)
;; LockTable -> Set T to lock (never become in-place)
do (setf (gethash val id2type) typ
(gethash val trace-table) (list nth))))
(let* ((memory-blocks
(loop for key in (alexandria:hash-table-keys trace-table)
for typ = (gethash key id2type)
Expand All @@ -191,41 +189,25 @@ The goal of run-memory-planner is to reduce the number of :allocate-p object in
(alias-map (make-hash-table)))
(loop for mb in solved
do (setf (gethash (memoryblock-id mb) alias-map) (or (memoryblock-answer mb) (memoryblock-id mb))))
(flet ((newid (id) (or (gethash id alias-map) id)))
(dolist (bp (getattr item :blueprint))
(setf (node-writes bp) (map 'list #'newid (node-writes bp))
(node-reads bp) (map 'list #'newid (node-reads bp)))
(when (eql (node-type bp) :EXPR)
(dolist (item (graph-nodes (expr-graph (getattr bp :EXPR))))
(when (eql (node-type item) :AREF)
(setf (getattr item :storage-id) (newid (car (node-writes item))))))))
;; remove duplicated define-global
(setf (getattr item :blueprint)
(loop with seen = nil
for item in (getattr item :blueprint)
if (or (not (eql (node-type item) :DEFINE-GLOBAL))
(null (find (car (node-writes item)) seen)))
collect item
if (eql (node-type item) :DEFINE-GLOBAL)
do (push (car (node-writes item)) seen)))
(let* ((reads (map 'list #'cons (getattr item :storage-id-src) (getattr item :read-types)))
(writes (map 'list #'cons (getattr item :storage-id-dst) (getattr item :write-types)))
(reads (remove-duplicates reads :key (alexandria:compose #'newid #'car)))
(writes (remove-duplicates writes :key (alexandria:compose #'newid #'car)))
(seen))
(flet ((only-unseen (items)
(loop for (id . type) in items
if (null (find (newid id) seen))
do (push (newid id) seen) and collect (cons id type))))
(multiple-value-bind (reads writes) (values (only-unseen reads) (only-unseen writes))
(setf (getattr item :storage-id-src) (map 'list (alexandria:compose #'newid #'car) reads)
(getattr item :storage-id-dst) (map 'list (alexandria:compose #'newid #'car) writes)
(getattr item :read-types) (map 'list #'cdr reads)
(getattr item :write-types) (map 'list #'cdr writes))))))
alias-map)))
;; Note(hikettei): is this recursively applied? especially for schedule cached and big graph.
;; As of this writing(2024/11/10), i am unsure if this is correct. Should be tested by GPT2 in the next pr.
(labels ((newid (id)
(if (gethash id alias-map)
(if (eql (gethash id alias-map) id)
id
(newid (gethash id alias-map)))
id)))
(when (>= (ctx:getenv :JIT_DEBUG) 4)
(format t "[DEBUG] MemoryPlanner: minimized alias-map~%")
(maphash
#'(lambda (k v)
(format t " | newid(~a) = ~a, alias-map[~a] = ~a~%" k (newid k) k v))
alias-map))
(dolist (node (graph-nodes schedule-graph))
(rewrite-bp-with-newid node #'newid))))))

(defun buffer-sizeof (buffer)
"Returns the size of the buffer in bits"
"Computes the size of the buffer in bits."
(assert (every #'numberp (buffer-shape buffer)))
(* (apply #'* (buffer-shape buffer)) (caten/common.dtype:dtype/size-of (buffer-dtype buffer))))

Expand Down Expand Up @@ -271,14 +253,7 @@ The goal of run-memory-planner is to reduce the number of :allocate-p object in
(let ((static-graph-p (null symbolics)))
(multiple-value-bind (before-count before-size)
(when (>= (ctx:getenv :JIT_DEBUG) 2) (evaluate schedule-graph static-graph-p))
;; First, applying the memory-planner kernel by kernel.
;; The goal is to reduce the number of arguments in the kernel.
(dolist (item (graph-nodes schedule-graph))
(when (and (getattr item :jitable) (getattr item :blueprint))
(run-memory-planner-local item schedule-graph symbolics base-graph)))
;; Second, applying the memory-planner in the schedule-graph level
;; The goal here is to reduce the number of :allocate-p object in schedule-graph.
(run-memory-planner-global schedule-graph symbolics base-graph)
(apply-memory-planner schedule-graph symbolics base-graph)
(mapc #'remove-extra-node-writes-to (graph-nodes schedule-graph))
(multiple-value-bind (after-count after-size)
(when (>= (ctx:getenv :JIT_DEBUG) 2) (evaluate schedule-graph static-graph-p))
Expand Down
Loading