Skip to content

Commit

Permalink
Fix: Propagate Scalar anywhere (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
hikettei authored Oct 11, 2024
1 parent 0a42af9 commit eb177db
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 41 deletions.
51 changes: 24 additions & 27 deletions source/ajit/memory-planner.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,12 @@ In Caten IR, there is a two way to allocate a variable.
if (eql (node-type ir) :FOR)
collect (newid-from-str group (getattr ir :idx))))
(meta-ids)
(out))
(out)
(candidates
(map
'list
#'(lambda (x) (mp-newid mp x))
(map 'list #'argument-name (kernel-renderer-args kernel)))))
(flet ((cleanup-buffer (buffer)
(setf (buffer-shape buffer) (map 'list #'reveal-buffer (buffer-shape buffer))
(buffer-shape buffer) (loop for s in (buffer-shape buffer)
Expand Down Expand Up @@ -300,11 +305,12 @@ In Caten IR, there is a two way to allocate a variable.
:metadata (make-uconst-buffer))
out)
(push s meta-ids)))))
(setf out (remove-duplicates out :key #'argument-name))
(when (>= (mp-debug mp) 1)
;; TODO: Display memory-size
)
(setf (kernel-renderer-args kernel) out)
(when candidates
(setf out (loop for o in out
if (find (argument-name o) candidates)
collect o)))
(setf out (remove-duplicates out :key #'argument-name)
(kernel-renderer-args kernel) out)
out)))

(defstruct (MemoryBlock
Expand Down Expand Up @@ -392,7 +398,7 @@ MemoryBlock(id) is allocated when t=create, preserved until t become `release`."
(apply-creation time))
I)))

(defmethod memory-plan ((mp MemoryPlanner) &aux (avm (mp-avm mp)))
(defmethod memory-plan ((mp MemoryPlanner) solve &aux (avm (mp-avm mp)))
"This is a toplevel for memory-oriented optimization techniques.
Resourses:
- https://arxiv.org/pdf/2203.00448
Expand Down Expand Up @@ -455,7 +461,7 @@ Lifespan:
(apply #'max (gethash key trace-table)))
:lock (gethash key lock-table))))
;; Minimize the peak memory usage
(solved (greedy-solve-dsa memory-blocks total-time))
(solved (if solve (greedy-solve-dsa memory-blocks total-time) memory-blocks))
;; Retrive the solution. A hash table of OLD_MEMORY_ID -> NEW_MEMORY_ID
(alias-map (mp-alias mp)))
(loop for mb in solved
Expand Down Expand Up @@ -503,32 +509,23 @@ Lifespan:
collect
(group->polyhedral-group group kr)))))
(dolist (p polyhedrons)
;; Final Chance to apply Loop Fusion
;; Kernrels with complicated memory access relations, like Matmul+Transpose, Conv are first fused here
;; Polyhedral Group is a result of splitting a group -> multiple group?
(when (and p (every #'(lambda (x) (typep x 'Polyhedral-Auto-Scheduler)) p))
(affine-fusion p)))

;; Tiling, Vectorizing, Parallelizing(CPU/GPU), Loop Fission here
;; [TODO] Apply the changes to mp-kernerls, mp-groups
))
(dolist (pg p)
(polyhedral-auto-schedule pg)))))

(defmethod retrive-kernels ((mp MemoryPlanner))
"Finalizes the result of memory-planner, retriving the final rendering-graph"
(flet ((prune ()
"Applies the dead code elimination"
(setf (mp-kernels mp) (dead-kernel-elimination (mp-groups mp) (mp-kernels mp) (append (avm-fw-outputs (mp-avm mp)) (avm-bw-outputs (mp-avm mp)))))))
(prune)

;; [Note] Auto_Scheduler is *work in progress*
;; (mp-auto-schedule! mp)
;; (prune)

;; 1. Mutate output buffers as a scalar
(memory-plan mp nil)
(optimize-memory-load mp)
;; 2. Hide Latency Optimization
;; - The arrays should be loaded at once
;; - In the last, storing the result.
(prune)
(setf (mp-id2buffer mp) (make-hash-table))
(memory-plan mp t)
(when (= 1 (ctx:getenv :AUTO_SCHEDULER))
(mp-auto-schedule! mp)
(prune))
;; [TODO] Apply multiexpr in the final fused graph.
;; [TODO] Add dead graph.nodes elimination here. ^ maybe produce unused ops.
(loop for group in (mp-groups mp)
for kernels in (mp-kernels mp)
Expand Down
18 changes: 9 additions & 9 deletions source/ajit/polyhedral-group.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,7 @@ A Polyhedral form of the fused schedule group.
(union-map-union WaR RaW)
WaW)))
(setf (pg-dependencies pg) dependencies)))
(format t "Before~%")
(format t "~%~a~%" (build pg))
(print (pprint-schedule (pg-schedule pg)))
;;(let ((new (schedule pg)))
;; (format t "After~%")
;; (setf (pg-schedule pg) new)
;; (format t "~%~a~%" (build pg)))
)
(setf (pg-schedule pg) (schedule pg)))

(defmethod schedule ((pg Polyhedral-Auto-Scheduler))
(let ((serialize-sccs 0)
Expand Down Expand Up @@ -444,7 +437,7 @@ Reference: https://www.researchgate.net/publication/347152973_PET-to-MLIR_A_poly
(polyhedral-group-base polyhedral-group))
;; ~~ Scheduling Language ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
(defclass Polyhedral-Config ()
nil
((parallelizable-loops))
(:documentation ""))
;; [Design] Only effects on Render-Graph and Pipelining
;; - Matmul+TransposeをFusionしないといけない
Expand All @@ -458,6 +451,13 @@ Reference: https://www.researchgate.net/publication/347152973_PET-to-MLIR_A_poly
;; - そうすれば ConvND < 1 Kernelsができるはず
;; - Assume ^がPrepreq, Embedding/Gemm, Tile, Loop Collapse, Vectorize

(defmethod polyhedral-auto-schedule ((pg Polyhedral-Auto-Scheduler))
;; [TODO] TileOuterBand
;; Working in progress...
)

(defmethod polyhedral-auto-schedule ((pg Polyhedral-Group)))

(defmethod loop-reorder ((pg Polyhedral-Auto-Scheduler) order)

)
Expand Down
5 changes: 2 additions & 3 deletions source/ajit/scheduler.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -352,12 +352,11 @@ DEBUG=4 to debug both DEBUG=3 and DEBUG=4."
(when (group-polyhedron x)
(funcall (compose #'remove-iteration-ir #'poly-pipeline #'group-polyhedron) x)))
groups)
(let* ((1_ (mapc #'post-simplify-multiexpr groups))
(let* ((_ (mapc #'post-simplify-multiexpr groups))
;; Note: (make-instance 'MemoryPlanner ... ) will rewrite the graph of :reduction, it is destructive.
;; Subsequent optimizations do not assume the `graph` is DAG.
;; Graph-Level optimization should be performed just before it.
(mp (make-instance 'MemoryPlanner :avm avm :groups groups :debug debug :device backend))
(2_ (memory-plan mp))
(kernels (retrive-kernels mp))
(blueprints/codes
(loop for group in groups
Expand All @@ -366,7 +365,7 @@ DEBUG=4 to debug both DEBUG=3 and DEBUG=4."
collect
(multiple-value-list (render-to-string backend group (format nil "e~a" nth) avm debug kernel))))
(final-code (%render-program-toplevel backend (with-output-to-string (out) (dolist (c blueprints/codes) (princ (second c) out))))))
(declare (ignore 1_ 2_))
(declare (ignore _))
(when (>= (ctx:getenv :JIT_DEBUG) 2)
(format t "Final JIT Schedule:~%")
(loop for nth upfrom 0
Expand Down
3 changes: 1 addition & 2 deletions source/apis/tensor.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ Create a new lazy tensor.
"make-tensor: Cannot initialize a tensor.~%~%Shape should be specified as an integer (>1), tensor, or symbol.~% Butgot: ~a~% Shape=~a" s shape))
(let ((buff (%internal-make-tensor nil shape :dtype dtype :order order :id id :requires-grad requires-grad :views views)))
(setf (tensor-op buff) (make-instance 'Allocate :buffer buff :initial-element initial-element :from from)
;; (tensor-shape buff) (map 'list #'(lambda (x) (if (tensor-p x) (or (try-fold-constant x) x) x)) (tensor-shape buff))
)
(tensor-shape buff) (map 'list #'(lambda (x) (if (tensor-p x) (or (try-fold-constant x) x) x)) (tensor-shape buff)))
buff))

(defun make-scalar (value &key (dtype *default-float*) (order *default-order*) (id (gensym "SID")) (requires-grad nil))
Expand Down
3 changes: 3 additions & 0 deletions source/common/contextvar.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ Usage:
(:CI
0 :int identity
"Set 1 if the test is running under Github Actions")
(:AUTO_SCHEDULER
0 :int #.(oneof "AUTO_SCHEDULER" 1 `(0 1))
"Set 1 to enable auto-scheduler for JIT")
(:JIT
0 :int identity
"Set 1 to use JIT_BACKEND, 0 to use VM_BACKEND")
Expand Down

0 comments on commit eb177db

Please sign in to comment.