Skip to content

Commit

Permalink
Add pythonbind for allocate_var pass and test.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Nov 22, 2024
1 parent 6b3e204 commit e48b39a
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 21 deletions.
11 changes: 9 additions & 2 deletions thriller-bindings/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use pyo3::prelude::*;
use pyo3::types::PyList;

use thriller_core::{
AccessMap, Convert, DataType, Gemm, Task, ThrillerEdge, ThrillerGraph, ThrillerNode,
ThrillerNodeInner,
AccessMap, AllocateVar, Convert, DataType, Gemm, GraphPass, Task, ThrillerEdge, ThrillerGraph,
ThrillerNode, ThrillerNodeInner,
};

use crate::buffer::PyBuffer;
Expand Down Expand Up @@ -53,6 +53,13 @@ impl PyGraph {
self.0.borrow_mut().connect();
}

fn allocate_var(&mut self) -> PyResult<String> {
let mut graph = self.0.borrow_mut();
let mut pass = AllocateVar::new();
pass.run(&mut graph);
Ok(pass.code().clone())
}

fn codegen(&self) -> PyResult<String> {
self.0
.borrow()
Expand Down
19 changes: 2 additions & 17 deletions thriller-bindings/tests/pass/allocate_var.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
'''
Whole GEMM is an example of GEMM that utilizes all memory hierarchies
of NVIDIA GPU.
'''
import context

from pythriller import initialize_thriller_flow, Layout, Tensor, TensorType
Expand Down Expand Up @@ -123,15 +119,13 @@

# Print codegen for Reg Graph.
reg_code = RegGraph.codegen()
print(reg_code)

# Build Block for Shared to Register.
SharedToRegBlock = Block(
[AttachedEdgeSA2RA, AttachedEdgeSB2RB], [AttachedEdgeSC2RC], RegGraph, [LoopIterS2R])

# Print codegen for Shared to Register Block.
shared_to_reg_code = SharedToRegBlock.codegen()
print(shared_to_reg_code)

# Define BlockNode for SharedToRegBlock
SharedBlockNode = Node.block(SharedToRegBlock)
Expand All @@ -150,14 +144,5 @@
# Connect Shared Graph.
SharedGraph.connect()

# Print codegen for Shared Graph.
shared_code = SharedGraph.codegen()
print(shared_code)

# Build Block for Global to Shared.
GlobalToSharedBlock = Block(
[AttachedEdgeGA2SA, AttachedEdgeGB2SB], [AttachedEdgeSC2GC], SharedGraph, [LoopIterG2S])

# Print codegen for Global to Shared Block.
global_to_shared_code = GlobalToSharedBlock.codegen()
print(global_to_shared_code)
allocate_vars = SharedGraph.allocate_var()
print(allocate_vars)
1 change: 1 addition & 0 deletions thriller-core/src/dataflow/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pub use block::ThrillerBlock;
pub use edge::{AttachedEdge, ThrillerEdge};
pub use graph::ThrillerGraph;
pub use node::{ThrillerNode, ThrillerNodeInner};
pub use pass::{AllocateVar, GraphPass};
17 changes: 16 additions & 1 deletion thriller-core/src/dataflow/pass/allocate_var.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
use super::GraphPass;
use crate::{dataflow::ThrillerGraph, BufType, ThrillerNodeInner};

/// AllocateVar
pub struct AllocateVar {
code: String,
}

impl AllocateVar {
#[doc(hidden)]
pub fn new() -> Self {
Self {
code: String::new(),
}
}

#[doc(hidden)]
pub fn code(&self) -> String {
self.code.clone()
}
}

impl GraphPass for AllocateVar {
fn run(&mut self, graph: &mut ThrillerGraph) {
// Transver the graph and allocate variables.
Expand All @@ -27,7 +42,7 @@ impl GraphPass for AllocateVar {

&BufType::RegTile | &BufType::RegVec => {
self.code +=
format!("Shared{} {};\n", buf.get_name(), buf.get_name()).as_str();
format!("Reg{} {};\n", buf.get_name(), buf.get_name()).as_str();
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions thriller-core/src/dataflow/pass/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ mod allocate_edge;
mod allocate_var;
mod gen_iterator;

pub use allocate_var::AllocateVar;

/// A trait for graph passes.
pub trait GraphPass {
/// Run the pass on the graph.
fn run(&mut self, graph: &mut ThrillerGraph);
}
3 changes: 2 additions & 1 deletion thriller-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ mod var;
pub use access::{AccessMap, AccessMatrix, AccessOffset};
pub use buffer::{BufType, Buffer};
pub use dataflow::{
AttachedEdge, ThrillerBlock, ThrillerEdge, ThrillerGraph, ThrillerNode, ThrillerNodeInner,
AllocateVar, AttachedEdge, GraphPass, ThrillerBlock, ThrillerEdge, ThrillerGraph, ThrillerNode,
ThrillerNodeInner,
};
pub use dtype::DataType;
pub use engine::{BlockLayout, BlockShape, ThrillerEngine};
Expand Down

0 comments on commit e48b39a

Please sign in to comment.