Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various enhancements and refactorings on caten/ajit #110

Merged
merged 72 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
8a88278
wip
hikettei Sep 20, 2024
999c36e
needs refactor for isl ...
hikettei Sep 20, 2024
443f6a8
still misunderstanding waw/war/raw?
hikettei Sep 20, 2024
49ae1b1
wip
hikettei Sep 20, 2024
658e328
refactor: get 4d matmul working
hikettei Sep 20, 2024
c64f8eb
fixed?
hikettei Sep 20, 2024
5c7688d
serial t
hikettei Sep 20, 2024
7aa5044
a little tweak for double free?
hikettei Sep 20, 2024
349b96e
Embedding in a single kernel
hikettei Sep 20, 2024
4501a5c
fix syntax error
hikettei Sep 20, 2024
f8f5e41
update
hikettei Sep 20, 2024
8ce740f
a lil tweak (normal, randn still wont work)
hikettei Sep 20, 2024
8103234
cannot permute assertion
hikettei Sep 20, 2024
85faa26
spacing
hikettei Sep 21, 2024
4eb1060
fix around scalar handling
hikettei Sep 21, 2024
495ca35
.
hikettei Sep 21, 2024
61333e7
hotfix: fix typo
hikettei Sep 21, 2024
d86a183
fuse w/ prev-rank
hikettei Sep 21, 2024
d1c6300
disable manual loop fusion for now
hikettei Sep 21, 2024
74be6d3
progresses on refactoring
hikettei Sep 22, 2024
5297644
wip
hikettei Sep 22, 2024
7433c11
hmmmmm
hikettei Sep 22, 2024
f8f2d97
hmm
hikettei Sep 22, 2024
b62935d
rem
hikettei Sep 22, 2024
57a898b
Enhancement: EXPR and EXPR Simplifier
hikettei Sep 22, 2024
2e655de
MultiExpr Simplify
hikettei Sep 22, 2024
232e234
update
hikettei Sep 22, 2024
efb5a24
simplify pttn
hikettei Sep 22, 2024
82b9e77
wip: refactor
hikettei Sep 23, 2024
dfd6762
doc
hikettei Sep 23, 2024
7490cb0
refac
hikettei Sep 23, 2024
d10fbac
work in progress
hikettei Sep 23, 2024
fa4f1ac
enhancemnts on the renderer
hikettei Sep 23, 2024
06df3af
wip
hikettei Sep 23, 2024
246afde
wip
hikettei Sep 23, 2024
250f065
wip: new initial scheduler algorithm
hikettei Sep 23, 2024
dee8cba
wip: fuse symbol as const
hikettei Sep 23, 2024
6cdccd5
wip: todo, rewriting
hikettei Sep 23, 2024
3898095
add: graph-weakly-connected-p
hikettei Sep 23, 2024
b7376c0
.
hikettei Sep 23, 2024
53ae339
avoid fusing unrelated dimensions
hikettei Sep 23, 2024
b0ef773
fix: randn
hikettei Sep 23, 2024
3b61afe
WIP: IndexComponent Fusion
hikettei Sep 24, 2024
434018a
wip
hikettei Sep 24, 2024
670ba1a
fix: scalar scheduling
hikettei Sep 24, 2024
a2ff240
passing view tests
hikettei Sep 24, 2024
af16154
revisit scalar mutation
hikettei Sep 24, 2024
36abc34
passes on jit
hikettei Sep 24, 2024
5920a5b
18 and 1 3 are not mergeable
hikettei Sep 24, 2024
e9ebab7
skip ConvND
hikettei Sep 24, 2024
b500d41
Skip ConvND
hikettei Sep 24, 2024
7af3f71
skip?
hikettei Sep 24, 2024
c338ad2
padding schedule idx
hikettei Sep 24, 2024
f09747e
logsoftmax is not a inplace?
hikettei Sep 24, 2024
9fcafc5
2d+1 padiding
hikettei Sep 24, 2024
e74e87d
unroll anywhere
hikettei Sep 24, 2024
f546d7e
i think we should abandon this branch ...
hikettei Sep 24, 2024
ccd5f02
i think we should abandon this branch ...
hikettei Sep 24, 2024
ff0d5ac
wip
hikettei Sep 24, 2024
0580760
Fix: memory-planner (different views are not merged)
hikettei Sep 25, 2024
943bbcc
fix memory corruption
hikettei Sep 25, 2024
6562656
wip: no transpose?
hikettei Sep 25, 2024
e5a9cb3
del comment
hikettei Sep 25, 2024
014c605
Fix: Unrolling Softmax
hikettei Sep 25, 2024
f821c15
remove context-handle
hikettei Sep 25, 2024
190fea4
Fix: Unrollment especially when transposed
hikettei Sep 25, 2024
f05f29c
regression tests
hikettei Sep 25, 2024
813c6c5
giving up index-component folding for a now
hikettei Sep 25, 2024
d05bcfa
clean up codes
hikettei Sep 25, 2024
1113308
update test
hikettei Sep 25, 2024
af18f26
no padding?
hikettei Sep 25, 2024
4674966
need pads
hikettei Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/aasm/attrs.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
(_reads_old_for_multiexpr :initarg :_reads_old_for_multiexpr :initform nil)
(_reads :initarg :_reads)
(_writes :initarg :_writes)
(declare-type :initarg :declare-type))
(declare-type :initarg :declare-type :initform nil))
(:documentation "This node is jitable.
- declare-type[boolean] When this option is set to T, it is necessary to declare the types of the variables included in. e.g.:
```
Expand Down
10 changes: 10 additions & 0 deletions source/air/graph.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,13 @@ To sort the graph properly, resolve the following isolated graph dependencies.
if (or (find (node-id node) valid-write-ids) ;; node exists in a valid path
(special-p (node-class node)))
collect node)))))

(defmethod graph-weakly-connected-p ((Graph graph) from to &aux (seen))
"Returns T if exploring the graph from `from`, and `to` was found. (i.e.: from is wealky depends on to)"
(labels ((explore (id &aux (node (id->value graph id)))
(when (and node (null (find id seen)))
(push id seen)
(when (find to (node-reads node)) (return-from graph-weakly-connected-p t))
(mapc #'explore (node-reads node)))))
(explore from)
nil))
2 changes: 1 addition & 1 deletion source/air/package.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
(:export #:make-node #:copy-node)
(:export #:lower #:mutate)
(:export #:Graph #:FastGraph #:make-graph #:copy-graph #:graph-p #:graph-seen #:graph-outputs #:Graph-nodes #:id->value #:id->users #:remnode #:verify-graph
#:insert-nodes #:->graph #:->fast-graph #:%graph-nodes-table)
#:insert-nodes #:->graph #:->fast-graph #:%graph-nodes-table #:graph-weakly-connected-p)
(:export #:getattrs #:getattr #:remattr)
(:export #:defsimplifier)
(:export #:Attribute #:defnode #:debug/render-defined-nodes #:debug/attrs-by-module #:node-build-documentation-by-class #:verify-args #:dump-into-list))
2 changes: 1 addition & 1 deletion source/ajit/attrs.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ for(int idx=upfrom, below, by)
(defnode (:Render :ENDFOR) () "
RenderGraph:
```
} / idx
} // idx
```"
:slots ((idx)))

Expand Down
18 changes: 9 additions & 9 deletions source/ajit/backends/clang.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ Compiled with: ~a"
(defmethod %render-expr ((lang Clang) (op (eql :Aref)) lhs rhs z)
(assert (null z))
(assert (and lhs rhs))
(let ((ref (render-isl-aref rhs :genid #'(lambda (x) (nth x *access*)))))
(if (string= ref "")
(let ((ref (render-aref lang rhs :genid #'(lambda (x) (nth x *access*)))))
(if (string= "0" ref)
(if (args-p lhs)
(format nil "(*~(~a~)~a)" lhs (unroll-suffix rhs *suffix*))
(format nil "~(~a~)~a" lhs (unroll-suffix rhs *suffix*)))
Expand All @@ -186,7 +186,7 @@ Compiled with: ~a"
(assert (buffer-p (expr-y lhs)))
(assert (null z))
(let ((strides (map 'list #'(lambda (x) (render-expr lang x)) rhs)))
(format nil "(~a)" (render-isl-aref (expr-y lhs) :genid #'(lambda (x) (intern (or (nth x *access*) (car *access*)))) :strides strides))))
(format nil "~a" (render-aref lang (expr-y lhs) :genid #'(lambda (x) (intern (or (nth x *access*) (car *access*)))) :strides strides))))

(defmethod %render-expr ((lang Clang) (op (eql :NOT)) lhs rhs z)
(assert (and lhs (null rhs) (null z)))
Expand Down Expand Up @@ -295,16 +295,16 @@ Compiled with: ~a"
(dotimes (i (* 2 indent)) (princ " " out))
(format out ,designator ,@args)
(format out "~%"))))
(labels ((render-aref (id type)
(let ((ref (render-isl-aref type :genid #'(lambda (x) (nth x access)))))
(if (string= ref "")
(labels ((%render-aref (id type)
(let ((ref (render-aref lang type :genid #'(lambda (x) (nth x access)))))
(if (string= ref "0")
(if (args-p id)
(format nil "(*~(~a~)~a)" id (unroll-suffix type *suffix*))
(format nil "~(~a~)~a" id (unroll-suffix type *suffix*)))
(format nil "~(~a~)[~(~a~)]" id ref)))))
(loop with *access* = access
for node in (graph-nodes graph)
for type = (read-type-relay node) do
for type = (read-type-relay node) do
(ecase (node-type node)
(:ALLOCATE
(line "~(~a~) ~(~a~)~a;"
Expand All @@ -328,11 +328,11 @@ Compiled with: ~a"
(if (car (getattr node :declare-type))
(format nil "~a " (->cdtype (buffer-dtype ct)))
"")
(render-aref c ct) (render-aref a at) (render-aref b bt)))))
(%render-aref c ct) (%render-aref a at) (%render-aref b bt)))))
(:EXPR
(multiple-value-bind (at) (apply #'values (relay-writes type))
(line "~(~a~)~(~a~) = ~(~a~);"
(if (car (getattr node :declare-type))
(format nil "~a " (->cdtype (buffer-dtype at)))
"")
(render-aref (car (node-writes node)) at) (render-expr lang (getattr node :EXPR)))))))))))
(%render-aref (car (node-writes node)) at) (render-expr lang (getattr node :EXPR)))))))))))
2 changes: 2 additions & 0 deletions source/ajit/caten.ajit.asd
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
(:file "type-relay")
(:file "polyhedral")
(:file "renderer")
(:file "scheduled-items")
(:file "group")
(:file "scheduler")
(:file "isl-objects")
(:file "isl-ast-helpers")
Expand Down
188 changes: 188 additions & 0 deletions source/ajit/group.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
(in-package :caten/ajit)

;; ~~ From AVM Into Polyhedral Model Compilation ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
;; polyhedral compilation to determine the parallelization strategy
;; If we do; compile from avm into ISL, optimizng
;; This is the toplevel of all optimization stuff
;; Set no-writes=t to skip O(N) search algorithm.
(defstruct (Group
(:constructor make-group (nodes realize-on-vm &key (no-writes nil)
&aux
(args (when (null no-writes) (nodes-depends-on nodes)))
(shapes (when (null no-writes) (nodes-gather-args nodes))))))
(graph (apply #'make-graph nodes) :type graph)
(sched nil :type list)
(realize-on-vm realize-on-vm :type boolean)
(polyhedron nil :type (or null Polyhedral))
(render-graph nil :type (or null Graph))
(across-time-deps nil :type list)
(args args :type list)
(shapes shapes :type list)
(writes (when (null no-writes) (nodes-output-ids nodes)) :type list)
(id (gensym)))

(defmethod max-dimension-in-group ((group group))
(apply
#'max
(or
(loop for node in (graph-nodes (group-graph group))
append
(loop for var in `(,@(relay-reads (read-type-relay node)) ,@(relay-writes (read-type-relay node)))
if var
collect
(buffer-nrank var)))
(list 0))))

(defun relocate-independent-allocations! (graph)
"
X A+B
| |
Y Alloc <- If A+B is an independant operatin, Alloc and its subgraph
\ / can be relocated into the top of graph.
Z"
(declare (type graph graph))
(let ((alloc-candidates
(loop for node in (graph-nodes graph)
if (eql (node-type node) :Allocate)
collect node)))
(labels ((subgraph (alloc)
(loop for r in (node-reads alloc)
if (symbolp r)
append (get-subgraph r graph)))
(alloc-p (node alloc subgraph)
(or (find (node-id node) subgraph :key #'node-id)
(and
(eql (node-type node) :Allocate)
(eql (node-id node) (node-id alloc)))))
(relocate (alloc subgraph)
(setf (graph-nodes graph)
(append
subgraph
(list alloc)
(loop for node in (graph-nodes graph)
unless (alloc-p node alloc subgraph)
collect node))))
(isolated-p (alloc subgraph)
(when subgraph
(loop with writes = (apply #'append (map 'list #'node-writes subgraph))
for node in (graph-nodes graph)
unless (alloc-p node alloc subgraph)
do (when (intersection (node-reads node) writes) (return-from isolated-p nil))))
t))
(loop for alloc in alloc-candidates
for subgraph = (subgraph alloc)
if (isolated-p alloc subgraph)
do (relocate alloc subgraph)))))

(defun relocate-independent-loop-bound-computation! (graph)
"Applies the same relocation as relocate-independent-allocation! against views, simplifying the scheduling for dynamic-shaped kernels."
(declare (type graph graph))
(let ((view-candidates
(loop for node in (graph-nodes graph)
if (eql (node-type node) :View)
collect node)))
(labels ((subgraph (view)
(loop for r in (node-reads view)
if (symbolp r)
append (get-subgraph r graph)))
(view-p (node view subgraph)
(or (find (node-id node) subgraph :key #'node-id)
(and
(eql (node-type node) :View)
(eql (node-id node) (node-id view)))))
(relocate (view subgraph)
(setf (graph-nodes graph)
(append
subgraph
(list view)
(loop for node in (graph-nodes graph)
unless (view-p node view subgraph)
collect node))))
(isolated-p (view subgraph)
(when subgraph
(loop with writes = (apply #'append (map 'list #'node-writes subgraph))
for node in (graph-nodes graph)
unless (view-p node view subgraph)
do (when (intersection (node-reads node) writes) (return-from isolated-p nil))))
t))
(loop for view in view-candidates
for subgraph = (subgraph view)
if (isolated-p view subgraph)
do (relocate view subgraph)))))

(defun recursive-split-into-subgroups (group)
(declare (type group group))
(let ((graph (group-graph group))
(stashed-path)
(seen))
(labels ((finalize-group (group)
;; Infers group-writes
(make-group (graph-nodes (group-graph group)) (group-realize-on-vm group)))
(force-realize-on-vm (node)
(or
(eql (node-type node) :pause/backward)
(eql (node-type node) :Allocate)
(and (eql (node-type node) :LOAD)
(symbolp (getattr node :value))
(= 0 (buffer-nrank (car (relay-writes (read-type-relay node))))))))
(explore (id)
(declare (type symbol id))
(let ((node (id->value graph id)))
(when (and node (null (find (node-id node) seen :key #'node-id)))
;; dynamic shapes are stashed and excluded from the graph, or exists in the toplevel?
(push node seen)
(if (force-realize-on-vm node)
(progn
(push node stashed-path)
nil)
(make-group
(append
(loop for read in (node-reads node)
for parent = (when (symbolp read) (explore read))
if parent
append (graph-nodes (group-graph parent)))
(list node))
nil
:no-writes t)))))
(restart-from-stashed-node (node)
(list
(make-group
(list node)
t
:no-writes t)
(make-group
(loop for read in (node-reads node)
for parent = (when (symbolp read) (explore read))
if parent
append (graph-nodes (group-graph parent)))
nil
:no-writes t))))
(let ((new-groups (map 'list #'explore (group-writes group))))
;; TODO: Remove duplicated LOAD! they are in stashed-path
(loop while stashed-path
do (setf new-groups (append (restart-from-stashed-node (pop stashed-path)) new-groups)))
(loop for g in new-groups
;; Empty group can be removed
if (and g (graph-nodes (group-graph g))) collect (finalize-group g))))))

(defun split-into-subgroups (graph)
"Graphs are first breaked into subgroups only after:
- Tensor is shaped by a tensor
- :PAUSE/BACKWARD
Input: graph (AVM Graph)"
(declare (type graph graph))
(let ((groups))
(labels ((force-realize-on-vm (node) (or (eql (node-type node) :pause/backward))))
(apply
#'append
(map
'list
#'recursive-split-into-subgroups
`(,@(loop for node in (graph-nodes graph)
if (force-realize-on-vm node)
collect (make-group (nreverse groups) nil)
and collect (make-group (list node) t)
and do (setf groups nil)
else
do (push node groups))
,(make-group (nreverse groups) nil)))))))
2 changes: 1 addition & 1 deletion source/ajit/helpers.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ should be used instead"
(declare (type function function) (type hash-table hash-table))
(let ((keys (hash-table-keys hash-table)))
(assert (every #'numberp keys))
(loop for key in (sort keys #'>)
(loop for key in (sort keys #'<)
do (funcall function key (gethash key hash-table)))))

(defun render-list (list) (apply #'concatenate 'string (butlast (loop for n in list append (list (format nil "~a" n) ", ")))))
Expand Down
8 changes: 4 additions & 4 deletions source/ajit/isl-ast-helpers.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
(if (expr-z expr)
(format stream "~(~a~)(~(~a~), ~(~a~), ~(~a~))" (expr-op expr) (expr-x expr) (expr-y expr) (expr-z expr))
(if (expr-y expr)
(format stream "~(~a~)(~(~a~), ~(~a~))" (expr-op expr) (expr-x expr) (expr-y expr))
(format stream "~(~a~)(~(~a~))" (expr-op expr) (expr-x expr))))))
(format stream "~(~a~)(~(~a~), ~(~a~))" (expr-op expr) (expr-x expr) (and (not (buffer-p (expr-y expr))) (expr-y expr)))
(format stream "~(~a~)(~(~a~)~a)" (expr-op expr) (expr-x expr) (if (and (eql (expr-op expr) :Const) (numberp (expr-x expr))) ":num" ""))))))

(defstruct (ASTFor
(:constructor make-for (idx from to by body execute-once)))
Expand Down Expand Up @@ -110,12 +110,12 @@
(let* ((id (isl::%isl-ast-expr-id-get-id ast))
(name (cffi:foreign-string-to-lisp (isl::%isl-id-get-name id))))
(declare (type string name))
(make-expr :Const name)))
(make-const name nil)))
(:ast-expr-int
(let* ((id (isl::%isl-ast-expr-int-get-val ast))
(num (isl::%isl-val-get-d id)))
(declare (type number num))
(let ((num (round num))) (make-expr :Const num))))
(let ((num (round num))) (make-const num nil))))
(:ast-expr-op
(let* ((n-arg (isl::%isl-ast-expr-get-op-n-arg ast))
(args (loop for nth upfrom 0 below n-arg
Expand Down
Loading
Loading