Skip to content

Commit

Permalink
feat(example): Add Back2Back GEMM example. (#23)
Browse files Browse the repository at this point in the history
* Add b2b_gemm example.

* chore: Add python b2b_gemm program.

* fix access map codegen for load.

* Add some attached edge.

* Add cast task.

* Add cast codegen in python example.

* Add rD to sD storer codegen.

* Delete unused comments.

* Delete unused code.
  • Loading branch information
KuangjuX authored Sep 26, 2024
1 parent 83e6ddd commit 73b2313
Show file tree
Hide file tree
Showing 14 changed files with 392 additions and 23 deletions.
229 changes: 229 additions & 0 deletions thriller-bindings/examples/b2b_gemm/b2b_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
'''
Back-to-Back GEMM example.
'''
import context

from pythriller import initialize_thriller_flow, Layout, Tensor, TensorType
from pythriller import Graph, Node, Edge, AttachedEdge, IterationVar, AccessMap
from pythriller import Block, DType

if __name__ == '__main__':
# Initialize the Thriller flow runtime.
initialize_thriller_flow()

# Define reg layout for A, B, C, D, acc.
RegLayoutA = Layout.RowMajor
RegLayoutB = Layout.RowMajor
RegLayoutC = Layout.RowMajor
RegLayoutD = Layout.RowMajor
RegLayoutAcc = Layout.RowMajor

# Define shared layout for A, B, C, D.
SharedLayoutA = Layout.RowMajor
SharedLayoutB = Layout.ColMajor
SharedLayoutC = Layout.ColMajor
SharedLayoutD = Layout.RowMajor

# Define global layout for A, B, C, D.
GlobalLayoutA = Layout.RowMajor
GlobalLayoutB = Layout.ColMajor
GlobalLayoutC = Layout.ColMajor
GlobalLayoutD = Layout.RowMajor

# Define Reg Dim for A, B, C, D, acc.
RegDimA = [64, 64]
RegDimB = [64, 64]
RegDimC = [64, 64]
RegDimD = [64, 64]
RegDimAcc = [64, 64]

# Define Shared Dim for A, B, C, D.
SharedDimA = [64, 64]
SharedDimB = [64, 64]
SharedDimC = [64, 64]
SharedDimD = [64, 64]

# Define Global Dim for A, B, C, D.
GlobalDimA = [256, 256]
GlobalDimB = [256, 256]
GlobalDimC = [256, 256]
GlobalDimD = [256, 256]

# Define Reg Tensor for A, B, C, D, acc.
rA = Tensor("rA", RegDimA, RegLayoutA, TensorType.RegTile)
rB = Tensor("rB", RegDimB, RegLayoutB, TensorType.RegTile)
rC = Tensor("rC", RegDimC, RegLayoutC, TensorType.RegTile)
rD = Tensor("rD", RegDimD, RegLayoutD, TensorType.RegTile)
rAcc = Tensor("rAcc", RegDimAcc, RegLayoutAcc, TensorType.RegTile)
rAccHalf = Tensor("rAccHalf", RegDimAcc, RegLayoutAcc, TensorType.RegTile)

# Define Shared Tensor for A, B, C, D.
sA = Tensor("sA", SharedDimA, SharedLayoutA, TensorType.SharedTile)
sB = Tensor("sB", SharedDimB, SharedLayoutB, TensorType.SharedTile)
sC = Tensor("sC", SharedDimC, SharedLayoutC, TensorType.SharedTile)
sD = Tensor("sD", SharedDimD, SharedLayoutD, TensorType.SharedTile)

# Define Global Tensor for A, B, C, D.
gA = Tensor("gA", GlobalDimA, GlobalLayoutA, TensorType.GlobalTile)
gB = Tensor("gB", GlobalDimB, GlobalLayoutB, TensorType.GlobalTile)
gC = Tensor("gC", GlobalDimC, GlobalLayoutC, TensorType.GlobalTile)
gD = Tensor("gD", GlobalDimD, GlobalLayoutD, TensorType.GlobalTile)

# Define Reg Node for A, B, C, D, acc.
NodeRA = Node.tensor(rA)
NodeRB = Node.tensor(rB)
NodeRC = Node.tensor(rC)
NodeRD = Node.tensor(rD)
NodeRAcc = Node.tensor(rAcc)
NodeRAccHalf = Node.tensor(rAccHalf)
# Define A, B, Acc to GEMM Node.
RegABGemmCNode = Node.gemm(NodeRA, NodeRB, NodeRAcc)

# Define Acc, C, D to GEMM Node.
RegAccCGemmDNode = Node.gemm(NodeRAccHalf, NodeRC, NodeRD)

# Build Cast Node for Acc -> Half
AccCastNode = Node.cast(rAcc, rAccHalf, DType.F32, DType.Half)

# Define Shared Node for A, B, C, D.
NodeSA = Node.tensor(sA)
NodeSB = Node.tensor(sB)
NodeSC = Node.tensor(sC)
NodeSD = Node.tensor(sD)

# Define Global Node for A, B, C, D.
NodeGA = Node.tensor(gA)
NodeGB = Node.tensor(gB)
NodeGC = Node.tensor(gC)
NodeGD = Node.tensor(gD)

# Define Edge for A, B, Acc, Gemm
RegEdgeA = Edge(NodeRA, RegABGemmCNode)
RegEdgeB = Edge(NodeRB, RegABGemmCNode)
RegEdgeAcc = Edge(NodeRAcc, RegABGemmCNode)

# Define Edge for NodeRAcc, AccCastNode, NodeRAccHalf
EdgeAccCastIn = Edge(NodeRAcc, AccCastNode)
EdgeAccCastOut = Edge(AccCastNode, NodeRAccHalf)

# Define iteration variable for A, B, Acc, Gemm loop.
# Iterate over the register tiles along the kTK dimension.
IterVarI = IterationVar("i", (0, 1))

# Iterate over K.
IterVarK = IterationVar("k", (0, 1))

# Iterator over N.
IterVarN = IterationVar("n", (0, 4))

# Build AccessMap from sA, sB load into rA, rB.
AccessMapSA2RA = AccessMap([0], [[[1]], [[0]]], [[0], [0]], [IterVarI])
AccessMapSB2RB = AccessMap([0], [[[1]], [[0]]], [[0], [0]], [IterVarI])

# Build AccessMap from sC load into rC.
AccessMapSC2RC = AccessMap([0], [[], []], [[], []], [])

# Build AccessMap from gA, gB load into sA, sB.
AccessMapGA2SA = AccessMap([0], [[[1]], [[0]]], [[0], [0]], [IterVarK])
AccessMapGB2SB = AccessMap([0], [[[1, 0], [0, 1]], [[0, 0]]], [
[0, 0], [0]], [IterVarK, IterVarN])

# Build AccessMap from gC load into sC.
AccessMapGC2SC = AccessMap([0], [[[1]], [[0]]], [[0], [0]], [IterVarN])

# Build AccessMap from rAcc store into sD.
AccessMapRAcc2GD = AccessMap([0], [[[1]], [[0]]], [[0], [0]], [])

# Build AccessMap from rD store into sD.
AccessMapRD2SD = AccessMap([0], [[], []], [[], []], [])

# Build AccessMap from sD store into gD.
AccessMapSD2GD = AccessMap([0], [[], []], [[], []], [])

# Build Attached Edge for load sA, sB into rA, rB.
AttachedEdgeSA2RA = AttachedEdge(sA, rA, AccessMapSA2RA)
AttachedEdgeSB2RB = AttachedEdge(sB, rB, AccessMapSB2RB)

# Build Attached Edge for load sC into rC.
AttachedEdgeSC2RC = AttachedEdge(sC, rC, AccessMapSC2RC)

# Build Attached Edge for load gA, gB into sA, sB.
AttachedEdgeGA2SA = AttachedEdge(gA, sA, AccessMapGA2SA)
AttachedEdgeGB2SB = AttachedEdge(gB, sB, AccessMapGB2SB)

# Build Attached Edge for load gC into sC.
AttachedEdgeGC2SC = AttachedEdge(gC, sC, AccessMapGC2SC)

# Build Attached Edge for store rD into sD.
AttachedEdgeRD2SD = AttachedEdge(rD, sD, AccessMapRD2SD)

# Build Attached Edge for store sD into gD.
AttachedEdgeSD2GD = AttachedEdge(sD, gD, AccessMapSD2GD)

# Build rA, rB, Acc, Gemm Graph.
RegABGemmGraph = Graph()

# Add Nodes to the Graph.
RegABGemmGraph.add_nodes([NodeRA, NodeRB, NodeRAcc, RegABGemmCNode])
# Add Edges to the Graph.
RegABGemmGraph.add_edges([RegEdgeA, RegEdgeB, RegEdgeAcc])
# Connect the Graph.
RegABGemmGraph.connect()

# Print codegen for Reg Graph.
print(RegABGemmGraph.codegen())

# Build Block for adding attached edge sA, sB into rA, rB.
BlockRegABGemm = Block(
[AttachedEdgeSA2RA, AttachedEdgeSB2RB], [], RegABGemmGraph, [IterVarI])
# Print codegen for Block.
print(BlockRegABGemm.codegen())

# Define Block Node for `BlockRegABGemm`.
BlockRegABGemmNode = Node.block(BlockRegABGemm)

# Build Graph for BlockRegABGemm.
BlockRegABGemmGraph = Graph()
# Add Nodes to the Graph.
BlockRegABGemmGraph.add_nodes([BlockRegABGemmNode])

# Connect the Graph.
BlockRegABGemmGraph.connect()
# Print codegen for BlockRegABGemmGraph.
print(BlockRegABGemmGraph.codegen())

# Build Block for adding attached edge gA, gB into sA, sB.
BlockSharedABGemm = Block(
[AttachedEdgeGA2SA, AttachedEdgeGB2SB], [], BlockRegABGemmGraph, [IterVarK])
# Print codegen for Block.
print(BlockSharedABGemm.codegen())

# Build Node for `BlockSharedABGemm`.
BlockSharedABGemmNode = Node.block(BlockSharedABGemm)

# Build Graph for BlockSharedABGemm.
BlockSharedABGemmGraph = Graph()
# Add Nodes to the Graph.
BlockSharedABGemmGraph.add_nodes(
[BlockSharedABGemmNode, AccCastNode, RegAccCGemmDNode])

# Build Edge for connecting BlockSharedABGemmNode and RegAccCGemmDNode.
EdgeBlockCast = Edge(
BlockSharedABGemmNode, AccCastNode)

EdgeCastGemm = Edge(AccCastNode, RegAccCGemmDNode)

# Add Edges to the Graph.
BlockSharedABGemmGraph.add_edges([EdgeBlockCast, EdgeCastGemm])

# Connect the Graph.
BlockSharedABGemmGraph.connect()
# Print codegen for BlockSharedABGemmGraph.
print(BlockSharedABGemmGraph.codegen())

# Build Block for adding attached edge gC into sC.
BlockSharedCGemm = Block(
[AttachedEdgeGC2SC, AttachedEdgeSC2RC], [AttachedEdgeRD2SD, AttachedEdgeSD2GD], BlockSharedABGemmGraph, [IterVarN])

# Print codegen for Block.
print(BlockSharedCGemm.codegen())
5 changes: 5 additions & 0 deletions thriller-bindings/examples/b2b_gemm/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import os
import sys

sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')))
2 changes: 1 addition & 1 deletion thriller-bindings/pythriller/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .context import initialize_thriller_flow, Layout, TensorType
from .context import Graph, Node, Edge, Gemm, AttachedEdge, Tensor
from .context import Block, IterationVar, AccessMap
from .context import Block, IterationVar, AccessMap, DType
9 changes: 9 additions & 0 deletions thriller-bindings/src/dtype.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use pyo3::prelude::*;

#[pyclass(module = "dtype", name = "DType")]
pub enum PyDType {
F32,
F64,
Half,
CutlassHalf,
}
35 changes: 33 additions & 2 deletions thriller-bindings/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use pyo3::prelude::*;
use pyo3::types::PyList;

use thriller_core::{
AccessMap, Gemm, Task, ThrillerEdge, ThrillerGraph, ThrillerNode, ThrillerNodeInner,
AccessMap, Convert, DataType, Gemm, Task, ThrillerEdge, ThrillerGraph, ThrillerNode,
ThrillerNodeInner,
};

use crate::block::PyBlock;
use crate::buffer::PyBuffer;
use crate::{block::PyBlock, dtype::PyDType};

use std::{cell::RefCell, rc::Rc};

Expand Down Expand Up @@ -92,6 +93,36 @@ impl PyNode {
PyNode(Rc::new(RefCell::new(node)))
}

#[staticmethod]
fn cast(
src: PyRef<PyBuffer>,
dst: PyRef<PyBuffer>,
sdtype: PyRef<PyDType>,
ddtype: PyRef<PyDType>,
) -> Self {
let sdtype = match *sdtype {
PyDType::F32 => DataType::Float32,
PyDType::F64 => DataType::Float64,
PyDType::Half => DataType::Half,
PyDType::CutlassHalf => DataType::Cutlasshalf,
};

let ddtype = match *ddtype {
PyDType::F32 => DataType::Float32,
PyDType::F64 => DataType::Float64,
PyDType::Half => DataType::Half,
PyDType::CutlassHalf => DataType::Cutlasshalf,
};

let sbuf = Rc::clone(&src.0);
let dbuf = Rc::clone(&dst.0);

let cast = Convert::new(sbuf, dbuf, sdtype, ddtype);
let node = ThrillerNode::new(ThrillerNodeInner::Op(Box::new(cast)));

PyNode(Rc::new(RefCell::new(node)))
}

fn codegen(&self) -> PyResult<String> {
let node = self.0.borrow();
node.emit()
Expand Down
8 changes: 3 additions & 5 deletions thriller-bindings/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#![deny(missing_docs)]

use access::PyAccessMap;
use dtype::PyDType;
use pyo3::prelude::*;

use block::{PyAttachedEdge, PyBlock};
Expand All @@ -16,6 +17,7 @@ use var::PyIterationVar;
mod access;
mod block;
mod buffer;
mod dtype;
mod graph;
mod op;
mod var;
Expand All @@ -35,18 +37,14 @@ fn thriller_flow(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyBuffer>()?;
m.add_class::<PyLayout>()?;
m.add_class::<PyBufType>()?;

m.add_class::<PyDType>()?;
m.add_class::<PyGraph>()?;
m.add_class::<PyNode>()?;
m.add_class::<PyEdge>()?;

m.add_class::<PyGemm>()?;

m.add_class::<PyBlock>()?;
m.add_class::<PyAttachedEdge>()?;

m.add_class::<PyIterationVar>()?;

m.add_class::<PyAccessMap>()?;

Ok(())
Expand Down
8 changes: 6 additions & 2 deletions thriller-core/src/access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,13 @@ impl AccessMap {
// Emit the access row mulipled ivar.
for (cindex, access_col) in access_row.iter().enumerate() {
let ivar = &ivars[cindex];
// Emit the access row mulipled ivar.
if cindex != 0 {
code.push_str(" + ");
}
code.push_str(
format!(
"{access}*{ivar}",
"{access} * {ivar}",
access = *access_col,
ivar = ivar.get_name()
)
Expand All @@ -112,7 +116,7 @@ impl AccessMap {
}

if *offset != 0 {
code.push_str(format!("+{}", offset).as_str());
code.push_str(format!(" + {}", offset).as_str());
}

access.push(code);
Expand Down
12 changes: 8 additions & 4 deletions thriller-core/src/dataflow/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,24 +142,28 @@ impl ThrillerBlock {
(BufType::SharedTile, BufType::RegTile) => {
insert_copy_async = true;
code += format!(
"{indent}loader_tile_s2r_{sid}_to_{did}({sbuf_var}, {dbuf_var});\n",
"{indent}loader_tile_s2r_{sid}_to_{did}({sbuf_var}({src_access}), {dbuf_var}({target_access}));\n",
indent = indent,
sid = sbuf_id,
did = dbuf_id,
sbuf_var = sbuf_var,
dbuf_var = dbuf_var
src_access = source_access_code,
dbuf_var = dbuf_var,
target_access = target_access_code
)
.as_str();
}

(BufType::GlobalTile, BufType::SharedTile) => {
code += format!(
"{indent}loader_tile_g2s_{sid}_to_{did}({sbuf_var}, {dbuf_var});\n",
"{indent}loader_tile_g2s_{sid}_to_{did}({sbuf_var}({src_access}), {dbuf_var}({target_access}));\n",
indent = indent,
sid = sbuf_id,
did = dbuf_id,
sbuf_var = sbuf_var,
dbuf_var = dbuf_var
src_access = source_access_code,
dbuf_var = dbuf_var,
target_access = target_access_code
)
.as_str();
}
Expand Down
Loading

0 comments on commit 73b2313

Please sign in to comment.