Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BugFix: MHA in rtol<1e-5 #270

Merged
merged 21 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion source/air/viz.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,28 @@ Visualizes the graph using graphviz(requirement). Set open=t to open the resulti
(node (node-id node) (node-name node) (helper/color :node) "filled, solid"))
(:Module
(node (node-id node) (render-attrs (node-name node) node (getattrs node)) (helper/color :module) "filled, solid"))
(:Graph
(if (eql (node-type node) :Schedule-Item)
(if (getattr node :allocate-p)
(let ((alloc (car (getattr node :items))))
(assert alloc)
(if (getattr alloc :from)
(node
(node-id node)
(format nil "Input[~a] ~a" (car (node-writes alloc)) (subseq (node-reads alloc) 0 (getattr alloc :nrank)))
(helper/color :input)
"filled, solid")
(node
(node-id node)
(format nil "TmpAlloc[~a]" (subseq (node-reads alloc) 0 (getattr alloc :nrank)))
(helper/color :chain)
"filled, solid")))
(if (getattr node :jitable)
(node (node-id node) (getattr node :name) (helper/color :node) "filled, solid")
(let ((node (car (getattr node :items))))
(assert node)
(node (node-id node) (format nil "[VMOP] ~a" (node-type node)) (helper/color :chain) "filled, solid"))))
(node (node-id node) (node-name node) (helper/color :movement) "filled, solid")))
(otherwise
(node (node-id node) (node-name node) (helper/color :movement) "filled, solid")))))
(dolist (node (graph-nodes graph))
Expand All @@ -125,6 +147,16 @@ Visualizes the graph using graphviz(requirement). Set open=t to open the resulti
(with-open-file (stream htmlpath :direction :output :if-exists :supersede :if-does-not-exist :create)
(format stream "<html><p><b><font size=\"5\">~a</b></p><body><img src=\"~a.png\"></body></html>" title pathname)
(uiop:launch-program (list "open" htmlpath) :output t)))))

(defun compute-n-children (graph id &aux (seen nil) (count 0))
(labels ((explore (id)
(when (find id seen) (return-from explore))
(let ((val (id->value graph id)))
(when val
(push id seen) (incf count)
(mapc #'explore (node-reads val))))))
(explore id)
count))
;; [TODO] optimize screen-width automatically
(defparameter *indent* 0)
(defun pprint-graph (graph &key (screen-width 140) (stream t)
Expand Down Expand Up @@ -189,6 +221,12 @@ The function `pprint-graph` prints the graph in a tree-like structure. `screen-w
(let ((item (format nil "~a~a~a~%" (indent lastp) (if (zerop *indent*) "" " ") (princ-node node))))
(princ item out)
(> (length item) screen-width)))
(child-weights (node lastp-map)
(let* ((weights (map 'list #'(lambda (x) (compute-n-children graph x)) (node-reads node)))
(paired (map 'list #'list weights (node-reads node) lastp-map)))
;; Small children first
;; argsort
(map 'list #'cdr (sort paired #'< :key #'car))))
(explore (id &optional (lastp nil))
(when (find id seen)
(let ((node (id->value graph id)))
Expand All @@ -210,7 +248,9 @@ The function `pprint-graph` prints the graph in a tree-like structure. `screen-w
(let ((*indent* (+ 2 *indent*))
(lastp-map (make-list (length (node-reads node)))))
(when lastp-map (setf (car lastp-map) t))
(mapc #'explore (node-reads node) (nreverse lastp-map)))))))))
(loop for pair in (child-weights node (reverse lastp-map))
for nth upfrom 0
do (explore (car pair) (second pair))))))))))
(setf stashed (copy-list (graph-outputs graph)))
(dotimes (i screen-width) (princ "=" out))
(format out "~%")
Expand Down
3 changes: 2 additions & 1 deletion source/codegen/jit.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,8 @@ caten/codegen overview:
(let ((new-graph (schedule-graph->avm-graph base-graph schedule-graph)))
(when (>= (ctx:getenv :JIT_DEBUG) 4)
(print-info "Final Scheduling Graph:")
(print schedule-graph)
(pprint-graph schedule-graph)
(when (= (ctx:getenv :DOT) 2) (->dot schedule-graph :title "Schedule Graph (Final)"))
(print-info "Final VM Graph:")
(print new-graph))
(setf (avm-graph avm) new-graph
Expand Down
50 changes: 43 additions & 7 deletions source/codegen/memory-planner.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,20 @@ MemoryBlock(id) is allocated when t=create, preserved until t become `release`."
(dolist (k (cdr s2)) (setf stack (remove k stack :test #'eql)))
(null stack)))))
;; Paper: Best-Fit Heuristic https://arxiv.org/pdf/1804.10001
(defun greedy-solve-dsa (I total-time)
(defun greedy-solve-dsa (I total-time black-lists)
"A greedy solver for minimizing `peak_mem`"
(declare (type list I))
(let ((locked))
(labels ((choose-from-fragments (mb time &aux (candidates nil))
(loop for candidate in I
if (and (null (find (memoryblock-id candidate) locked))
(freed-p candidate time)
(null (find (memoryblock-id candidate) (gethash (memoryblock-id mb) black-lists)))
(not (= -1 (buffer-nrank (memoryblock-type mb))))
(not (= -1 (buffer-nrank (memoryblock-type candidate))))
(buffer-shape (memoryblock-type mb)) ;; <=> assure the memory-block is a tensor
(buffer-size-eq (memoryblock-type candidate) (memoryblock-type mb))
(equal (buffer-dtype (memoryblock-type candidate))
(buffer-dtype (memoryblock-type mb)))
(equal (buffer-dtype (memoryblock-type candidate)) (buffer-dtype (memoryblock-type mb)))
;; [TODO] If offsets were created but size are equivalent; they are not cached right?
(equal (buffer-views (memoryblock-type candidate)) (buffer-views (memoryblock-type mb))))
do (push candidate candidates))
Expand Down Expand Up @@ -144,21 +144,58 @@ MemoryBlock(id) is allocated when t=create, preserved until t become `release`."
append (getattr node :blueprint)
else
collect node))
(exprs (apply #'make-graph (loop for n in nodes if (eql (node-type n) :EXPR) collect n)))
(total-time (length nodes))
(trace-table (make-hash-table))
(id2type (make-hash-table))
(lock-table (make-hash-table))
(outputs ;; a list of buffers that do no changed by the memory-planner
(append ;; If the output were read by other kernels, it should be optimized by the global memory-planner.
(graph-outputs schedule-graph)
symbolics)))
symbolics))
(id->depend-loops (make-hash-table))
(black-list-table (make-hash-table)))
(loop with stacks = nil
for node in nodes
if (eql (node-type node) :FOR) do
(push node stacks)
else if (eql (node-type node) :ENDFOR) do
(setf stacks (remove (getattr node :idx) stacks :key #'(lambda (x) (getattr x :idx))))
else
do (setf (gethash (node-id node) id->depend-loops) stacks))
(dolist (s symbolics) (setf (gethash s lock-table) t))
;; Creating a timestamp table for each node and variable.
(loop for node in nodes
for nth upfrom 0
if (eql (node-type node) :EXPR) do
;; Some hacks to keep the dependency of reduce ops
;; for ...
;; for ...
;; for ...
;; x = a * b + c
;; out = x
;; Here, out should not be mutated as x, a, b, and c.
;; Enumerate such pairs and record then to the black-list-table.
(loop with node-loops = (gethash (node-id node) id->depend-loops)
for read in (node-reads node)
for val = (loop for e in (graph-nodes exprs) if (find read (node-writes e)) collect e) do
(loop for r in val
for parent-loops = (gethash (node-id r) id->depend-loops)
if (and r (not (eql (node-id r) (node-id node)))
(getattr r :reduction :allow-undefined t)
;; reduction after elementwise will never a solution
(intersection node-loops parent-loops :key #'node-id) ;; intersects
(not (= (length node-loops) (length parent-loops)))) ;; but partially
do (dolist (read-id (node-reads r))
(dolist (write-id (node-writes node))
;; Explicit the mutation from W to R is invaild.
(push read-id (gethash write-id black-list-table))))))
if (eql (node-type node) :Schedule-Item) ; Optimization for non-jitable instructions (like: foreign kernel calls, allocation, pause/backward)
do (assert (= (length (getattr node :storage-id-src)) (length (getattr node :read-types))))
(assert (= (length (getattr node :storage-id-dst)) (length (getattr node :write-types))))
;; Lock the allocation (its the minimum requirement for running the graph)
(when (getattr node :allocate-p)
(setf (gethash (car (node-writes node)) lock-table) t))
(loop for val in (getattr node :storage-id-src)
for typ in (getattr node :read-types)
for time = `(,nth ,@(gethash val trace-table))
Expand Down Expand Up @@ -187,8 +224,7 @@ MemoryBlock(id) is allocated when t=create, preserved until t become `release`."
;; ID2Type -> the variable name and its type
;; TraceTable -> the variable name and timestamps of the variable (when it's used)
;; LockTable -> Set T to lock (never become in-place)
do (setf (gethash val id2type) typ
(gethash val trace-table) (list nth))))
do (setf (gethash val id2type) typ (gethash val trace-table) (list nth))))
(let* ((memory-blocks
(loop for key in (alexandria:hash-table-keys trace-table)
for typ = (gethash key id2type)
Expand All @@ -204,7 +240,7 @@ MemoryBlock(id) is allocated when t=create, preserved until t become `release`."
(apply #'max (gethash key trace-table)))
:lock (gethash key lock-table))))
;; Minimize the peak memory usage
(solved (greedy-solve-dsa memory-blocks total-time))
(solved (greedy-solve-dsa memory-blocks total-time black-list-table))
;; Retrive the solution. A hash table of OLD_MEMORY_ID -> NEW_MEMORY_ID
(alias-map (make-hash-table)))
(loop for mb in solved
Expand Down
17 changes: 11 additions & 6 deletions source/codegen/scheduler.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,9 @@ Otherwise, returns NIL. (= not fusable)"
if (group-reduce-dims child-group)
do (return-from group-chase-down-reduction-p t)))
nil))

;; [TODO] Simplify the algorithm. at least :permute in apply-view-fusor may not be necessary.
(defmethod groups-rewrite-views-in-the-same-space ((parent-group Group) (tgt-group Group) view)
"If you want to schedule parent-group and tgt-group in the same space, this method fixes the view to be compatible in the same kernel. If failed, returns NIL."
(symbol-macrolet ((->ng (return-from groups-rewrite-views-in-the-same-space nil))
(->ok (return-from groups-rewrite-views-in-the-same-space t)))
(let* ((r1 (group-rank tgt-group)) (r2 (group-rank parent-group))
Expand Down Expand Up @@ -616,8 +617,7 @@ Returns T if merging parent-group and tgt-group is possible. Sometime rewrites v
(when (groups-reduce-permute-p tgt-group parent-group) ->ng) ;; Reduce -> Permute -> Reduce is not fusable.
(if (= r1 r2)
(when (buffer-mergeable-p (ctx-graph ctx) (group-get-type tgt-group) (group-get-type parent-group))
;; View Rewriting here?
;; Permute rewriting here?
;; [TODO] apply-view-fusor here?
->ok)
(let ((optimal-p
(or (eql pattern :case1)
Expand All @@ -631,9 +631,14 @@ Returns T if merging parent-group and tgt-group is possible. Sometime rewrites v
;; 1. Injective + Same Ranks
(when (= r1 r2)
(when (buffer-mergeable-p (ctx-graph ctx) (group-get-type parent-group) (group-get-type tgt-group))
;; Reject :MOVE+:MOVE fusion if view is provided.
(when (and view (= (length (group-items parent-group)) (length (group-items tgt-group)) 1)
(eql :MOVE (node-type (car (group-items parent-group))))
(eql :MOVE (node-type (car (group-items tgt-group)))))
->ng)
(let ((permute (and view (getattr view :permute))))
(when permute (apply-view-fusor r1 (loop repeat r1 collect nil) parent-group :permute permute))
->ok))
(when permute (apply-view-fusor (length permute) (loop repeat (length permute) collect nil) parent-group :permute permute)))
->ok)
->ng)
;; 2. Injective + Different Ranks
(when (group-chase-down-reduction-p ctx tgt-group restart-point)->ng) ; If tgt-group is used by the reduction => better to merge it with that.
Expand Down Expand Up @@ -798,7 +803,7 @@ Creates a schedule-graph(FastGraph) from the given `graph`."
(schedule-graph (apply #'make-graph (map 'list #'(lambda (x) (group->schedule-item x ctx)) groups))))
(setf (graph-outputs schedule-graph) (graph-outputs graph) schedule-graph (->fast-graph schedule-graph)) ; Convert the schedule graph into FastGraph
(mapc #'verify-group groups)
(apply-move-after-reduction schedule-graph)
(apply-move-after-reduction schedule-graph) ;; :reduction T cannot be an output of schedule item.
(when (>= (the fixnum (ctx:getenv :JIT_DEBUG)) 3)
(format t "[graph-schedule] scheduled graph:~%")
(pprint-graph schedule-graph))
Expand Down
11 changes: 9 additions & 2 deletions source/test-suite/helpers.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
dtype)
(buffer-shape (tensor-buffer tensor))))

(defun ->torch (tensor &key (dtype *torch-dtype*)) (remote-objects (torch.from_numpy (->numpy tensor :dtype dtype))))
(defun sync-visible-size (tensor)
(when (arrayp (buffer-value (tensor-buffer tensor)))
(when (not (= (apply #'* (buffer-shape (tensor-buffer tensor))) (array-total-size (buffer-value (tensor-buffer tensor)))))
(setf tensor (proceed (!copy tensor)))))
tensor)

(defun ->torch (tensor &key (dtype *torch-dtype*))
(remote-objects (torch.from_numpy (->numpy (sync-visible-size tensor) :dtype dtype))))

(defun torch-shape (tensor) (remote-objects* (py.list (chain tensor (size)))))

Expand Down Expand Up @@ -159,7 +166,7 @@
(defmacro assert-equal ((&key (rtol 1e-7) (atol 0.0)) torch-form lisp-form)
`(let ((torch ,torch-form) (lisp ,lisp-form))
(ok (equal (shape torch) (shape lisp)) "Shapes match")
(multiple-value-bind (atol1 rtol1) (compute-rtol-atol (buffer-value (tensor-buffer torch)) (buffer-value (tensor-buffer lisp)))
(multiple-value-bind (atol1 rtol1) (compute-rtol-atol (buffer-value (tensor-buffer (sync-visible-size torch))) (buffer-value (tensor-buffer (sync-visible-size lisp))))
(ok (<= atol1 ,atol) (format nil "Satisfying (atol=~a) <= ~a" atol1 ,atol))
(ok (<= rtol1 ,rtol) (format nil "Satisfying (rtol=~a) <= ~a" rtol1 ,rtol)))))

Expand Down
Loading
Loading