-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(example): Add Back2Back GEMM example. (#23)
* Add b2b_gemm example. * chore: Add python b2b_gemm program. * fix access map codegen for load. * Add some attached edge. * Add cast task. * Add cast codegen in python example. * Add rD to sD storer codegen. * Delete unused comments. * Delete unused code.
- Loading branch information
Showing
14 changed files
with
392 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
''' | ||
Back-to-Back GEMM example. | ||
''' | ||
import context | ||
|
||
from pythriller import initialize_thriller_flow, Layout, Tensor, TensorType | ||
from pythriller import Graph, Node, Edge, AttachedEdge, IterationVar, AccessMap | ||
from pythriller import Block, DType | ||
|
||
if __name__ == '__main__': | ||
# Initialize the Thriller flow runtime. | ||
initialize_thriller_flow() | ||
|
||
# Define reg layout for A, B, C, D, acc. | ||
RegLayoutA = Layout.RowMajor | ||
RegLayoutB = Layout.RowMajor | ||
RegLayoutC = Layout.RowMajor | ||
RegLayoutD = Layout.RowMajor | ||
RegLayoutAcc = Layout.RowMajor | ||
|
||
# Define shared layout for A, B, C, D. | ||
SharedLayoutA = Layout.RowMajor | ||
SharedLayoutB = Layout.ColMajor | ||
SharedLayoutC = Layout.ColMajor | ||
SharedLayoutD = Layout.RowMajor | ||
|
||
# Define global layout for A, B, C, D. | ||
GlobalLayoutA = Layout.RowMajor | ||
GlobalLayoutB = Layout.ColMajor | ||
GlobalLayoutC = Layout.ColMajor | ||
GlobalLayoutD = Layout.RowMajor | ||
|
||
# Define Reg Dim for A, B, C, D, acc. | ||
RegDimA = [64, 64] | ||
RegDimB = [64, 64] | ||
RegDimC = [64, 64] | ||
RegDimD = [64, 64] | ||
RegDimAcc = [64, 64] | ||
|
||
# Define Shared Dim for A, B, C, D. | ||
SharedDimA = [64, 64] | ||
SharedDimB = [64, 64] | ||
SharedDimC = [64, 64] | ||
SharedDimD = [64, 64] | ||
|
||
# Define Global Dim for A, B, C, D. | ||
GlobalDimA = [256, 256] | ||
GlobalDimB = [256, 256] | ||
GlobalDimC = [256, 256] | ||
GlobalDimD = [256, 256] | ||
|
||
# Define Reg Tensor for A, B, C, D, acc. | ||
rA = Tensor("rA", RegDimA, RegLayoutA, TensorType.RegTile) | ||
rB = Tensor("rB", RegDimB, RegLayoutB, TensorType.RegTile) | ||
rC = Tensor("rC", RegDimC, RegLayoutC, TensorType.RegTile) | ||
rD = Tensor("rD", RegDimD, RegLayoutD, TensorType.RegTile) | ||
rAcc = Tensor("rAcc", RegDimAcc, RegLayoutAcc, TensorType.RegTile) | ||
rAccHalf = Tensor("rAccHalf", RegDimAcc, RegLayoutAcc, TensorType.RegTile) | ||
|
||
# Define Shared Tensor for A, B, C, D. | ||
sA = Tensor("sA", SharedDimA, SharedLayoutA, TensorType.SharedTile) | ||
sB = Tensor("sB", SharedDimB, SharedLayoutB, TensorType.SharedTile) | ||
sC = Tensor("sC", SharedDimC, SharedLayoutC, TensorType.SharedTile) | ||
sD = Tensor("sD", SharedDimD, SharedLayoutD, TensorType.SharedTile) | ||
|
||
# Define Global Tensor for A, B, C, D. | ||
gA = Tensor("gA", GlobalDimA, GlobalLayoutA, TensorType.GlobalTile) | ||
gB = Tensor("gB", GlobalDimB, GlobalLayoutB, TensorType.GlobalTile) | ||
gC = Tensor("gC", GlobalDimC, GlobalLayoutC, TensorType.GlobalTile) | ||
gD = Tensor("gD", GlobalDimD, GlobalLayoutD, TensorType.GlobalTile) | ||
|
||
# Define Reg Node for A, B, C, D, acc. | ||
NodeRA = Node.tensor(rA) | ||
NodeRB = Node.tensor(rB) | ||
NodeRC = Node.tensor(rC) | ||
NodeRD = Node.tensor(rD) | ||
NodeRAcc = Node.tensor(rAcc) | ||
NodeRAccHalf = Node.tensor(rAccHalf) | ||
# Define A, B, Acc to GEMM Node. | ||
RegABGemmCNode = Node.gemm(NodeRA, NodeRB, NodeRAcc) | ||
|
||
# Define Acc, C, D to GEMM Node. | ||
RegAccCGemmDNode = Node.gemm(NodeRAccHalf, NodeRC, NodeRD) | ||
|
||
# Build Cast Node for Acc -> Half | ||
AccCastNode = Node.cast(rAcc, rAccHalf, DType.F32, DType.Half) | ||
|
||
# Define Shared Node for A, B, C, D. | ||
NodeSA = Node.tensor(sA) | ||
NodeSB = Node.tensor(sB) | ||
NodeSC = Node.tensor(sC) | ||
NodeSD = Node.tensor(sD) | ||
|
||
# Define Global Node for A, B, C, D. | ||
NodeGA = Node.tensor(gA) | ||
NodeGB = Node.tensor(gB) | ||
NodeGC = Node.tensor(gC) | ||
NodeGD = Node.tensor(gD) | ||
|
||
# Define Edge for A, B, Acc, Gemm | ||
RegEdgeA = Edge(NodeRA, RegABGemmCNode) | ||
RegEdgeB = Edge(NodeRB, RegABGemmCNode) | ||
RegEdgeAcc = Edge(NodeRAcc, RegABGemmCNode) | ||
|
||
# Define Edge for NodeRAcc, AccCastNode, NodeRAccHalf | ||
EdgeAccCastIn = Edge(NodeRAcc, AccCastNode) | ||
EdgeAccCastOut = Edge(AccCastNode, NodeRAccHalf) | ||
|
||
# Define iteration variable for A, B, Acc, Gemm loop. | ||
# Iterate over the register tiles along the kTK dimension. | ||
IterVarI = IterationVar("i", (0, 1)) | ||
|
||
# Iterate over K. | ||
IterVarK = IterationVar("k", (0, 1)) | ||
|
||
# Iterator over N. | ||
IterVarN = IterationVar("n", (0, 4)) | ||
|
||
# Build AccessMap from sA, sB load into rA, rB. | ||
AccessMapSA2RA = AccessMap([0], [[[1]], [[0]]], [[0], [0]], [IterVarI]) | ||
AccessMapSB2RB = AccessMap([0], [[[1]], [[0]]], [[0], [0]], [IterVarI]) | ||
|
||
# Build AccessMap from sC load into rC. | ||
AccessMapSC2RC = AccessMap([0], [[], []], [[], []], []) | ||
|
||
# Build AccessMap from gA, gB load into sA, sB. | ||
AccessMapGA2SA = AccessMap([0], [[[1]], [[0]]], [[0], [0]], [IterVarK]) | ||
AccessMapGB2SB = AccessMap([0], [[[1, 0], [0, 1]], [[0, 0]]], [ | ||
[0, 0], [0]], [IterVarK, IterVarN]) | ||
|
||
# Build AccessMap from gC load into sC. | ||
AccessMapGC2SC = AccessMap([0], [[[1]], [[0]]], [[0], [0]], [IterVarN]) | ||
|
||
# Build AccessMap from rAcc store into sD. | ||
AccessMapRAcc2GD = AccessMap([0], [[[1]], [[0]]], [[0], [0]], []) | ||
|
||
# Build AccessMap from rD store into sD. | ||
AccessMapRD2SD = AccessMap([0], [[], []], [[], []], []) | ||
|
||
# Build AccessMap from sD store into gD. | ||
AccessMapSD2GD = AccessMap([0], [[], []], [[], []], []) | ||
|
||
# Build Attached Edge for load sA, sB into rA, rB. | ||
AttachedEdgeSA2RA = AttachedEdge(sA, rA, AccessMapSA2RA) | ||
AttachedEdgeSB2RB = AttachedEdge(sB, rB, AccessMapSB2RB) | ||
|
||
# Build Attached Edge for load sC into rC. | ||
AttachedEdgeSC2RC = AttachedEdge(sC, rC, AccessMapSC2RC) | ||
|
||
# Build Attached Edge for load gA, gB into sA, sB. | ||
AttachedEdgeGA2SA = AttachedEdge(gA, sA, AccessMapGA2SA) | ||
AttachedEdgeGB2SB = AttachedEdge(gB, sB, AccessMapGB2SB) | ||
|
||
# Build Attached Edge for load gC into sC. | ||
AttachedEdgeGC2SC = AttachedEdge(gC, sC, AccessMapGC2SC) | ||
|
||
# Build Attached Edge for store rD into sD. | ||
AttachedEdgeRD2SD = AttachedEdge(rD, sD, AccessMapRD2SD) | ||
|
||
# Build Attached Edge for store sD into gD. | ||
AttachedEdgeSD2GD = AttachedEdge(sD, gD, AccessMapSD2GD) | ||
|
||
# Build rA, rB, Acc, Gemm Graph. | ||
RegABGemmGraph = Graph() | ||
|
||
# Add Nodes to the Graph. | ||
RegABGemmGraph.add_nodes([NodeRA, NodeRB, NodeRAcc, RegABGemmCNode]) | ||
# Add Edges to the Graph. | ||
RegABGemmGraph.add_edges([RegEdgeA, RegEdgeB, RegEdgeAcc]) | ||
# Connect the Graph. | ||
RegABGemmGraph.connect() | ||
|
||
# Print codegen for Reg Graph. | ||
print(RegABGemmGraph.codegen()) | ||
|
||
# Build Block for adding attached edge sA, sB into rA, rB. | ||
BlockRegABGemm = Block( | ||
[AttachedEdgeSA2RA, AttachedEdgeSB2RB], [], RegABGemmGraph, [IterVarI]) | ||
# Print codegen for Block. | ||
print(BlockRegABGemm.codegen()) | ||
|
||
# Define Block Node for `BlockRegABGemm`. | ||
BlockRegABGemmNode = Node.block(BlockRegABGemm) | ||
|
||
# Build Graph for BlockRegABGemm. | ||
BlockRegABGemmGraph = Graph() | ||
# Add Nodes to the Graph. | ||
BlockRegABGemmGraph.add_nodes([BlockRegABGemmNode]) | ||
|
||
# Connect the Graph. | ||
BlockRegABGemmGraph.connect() | ||
# Print codegen for BlockRegABGemmGraph. | ||
print(BlockRegABGemmGraph.codegen()) | ||
|
||
# Build Block for adding attached edge gA, gB into sA, sB. | ||
BlockSharedABGemm = Block( | ||
[AttachedEdgeGA2SA, AttachedEdgeGB2SB], [], BlockRegABGemmGraph, [IterVarK]) | ||
# Print codegen for Block. | ||
print(BlockSharedABGemm.codegen()) | ||
|
||
# Build Node for `BlockSharedABGemm`. | ||
BlockSharedABGemmNode = Node.block(BlockSharedABGemm) | ||
|
||
# Build Graph for BlockSharedABGemm. | ||
BlockSharedABGemmGraph = Graph() | ||
# Add Nodes to the Graph. | ||
BlockSharedABGemmGraph.add_nodes( | ||
[BlockSharedABGemmNode, AccCastNode, RegAccCGemmDNode]) | ||
|
||
# Build Edge for connecting BlockSharedABGemmNode and RegAccCGemmDNode. | ||
EdgeBlockCast = Edge( | ||
BlockSharedABGemmNode, AccCastNode) | ||
|
||
EdgeCastGemm = Edge(AccCastNode, RegAccCGemmDNode) | ||
|
||
# Add Edges to the Graph. | ||
BlockSharedABGemmGraph.add_edges([EdgeBlockCast, EdgeCastGemm]) | ||
|
||
# Connect the Graph. | ||
BlockSharedABGemmGraph.connect() | ||
# Print codegen for BlockSharedABGemmGraph. | ||
print(BlockSharedABGemmGraph.codegen()) | ||
|
||
# Build Block for adding attached edge gC into sC. | ||
BlockSharedCGemm = Block( | ||
[AttachedEdgeGC2SC, AttachedEdgeSC2RC], [AttachedEdgeRD2SD, AttachedEdgeSD2GD], BlockSharedABGemmGraph, [IterVarN]) | ||
|
||
# Print codegen for Block. | ||
print(BlockSharedCGemm.codegen()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import os | ||
import sys | ||
|
||
sys.path.insert( | ||
0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
from .context import initialize_thriller_flow, Layout, TensorType | ||
from .context import Graph, Node, Edge, Gemm, AttachedEdge, Tensor | ||
from .context import Block, IterationVar, AccessMap | ||
from .context import Block, IterationVar, AccessMap, DType |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
use pyo3::prelude::*; | ||
|
||
#[pyclass(module = "dtype", name = "DType")] | ||
pub enum PyDType { | ||
F32, | ||
F64, | ||
Half, | ||
CutlassHalf, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.