From e48b39a2d2477b7ae08532e066cb7a6b1b087c1a Mon Sep 17 00:00:00 2001 From: kuangjux <18630816527@163.com> Date: Fri, 22 Nov 2024 16:20:29 +0800 Subject: [PATCH] Add pythonbind for allocate_var pass and test. --- thriller-bindings/src/graph.rs | 11 +++++++++-- thriller-bindings/tests/pass/allocate_var.py | 19 ++----------------- thriller-core/src/dataflow/mod.rs | 1 + .../src/dataflow/pass/allocate_var.rs | 17 ++++++++++++++++- thriller-core/src/dataflow/pass/mod.rs | 3 +++ thriller-core/src/lib.rs | 3 ++- 6 files changed, 33 insertions(+), 21 deletions(-) diff --git a/thriller-bindings/src/graph.rs b/thriller-bindings/src/graph.rs index bc76d21..9349b2b 100644 --- a/thriller-bindings/src/graph.rs +++ b/thriller-bindings/src/graph.rs @@ -2,8 +2,8 @@ use pyo3::prelude::*; use pyo3::types::PyList; use thriller_core::{ - AccessMap, Convert, DataType, Gemm, Task, ThrillerEdge, ThrillerGraph, ThrillerNode, - ThrillerNodeInner, + AccessMap, AllocateVar, Convert, DataType, Gemm, GraphPass, Task, ThrillerEdge, ThrillerGraph, + ThrillerNode, ThrillerNodeInner, }; use crate::buffer::PyBuffer; @@ -53,6 +53,13 @@ impl PyGraph { self.0.borrow_mut().connect(); } + fn allocate_var(&mut self) -> PyResult { + let mut graph = self.0.borrow_mut(); + let mut pass = AllocateVar::new(); + pass.run(&mut graph); + Ok(pass.code().clone()) + } + fn codegen(&self) -> PyResult { self.0 .borrow() diff --git a/thriller-bindings/tests/pass/allocate_var.py b/thriller-bindings/tests/pass/allocate_var.py index 69ea323..4ac0b34 100644 --- a/thriller-bindings/tests/pass/allocate_var.py +++ b/thriller-bindings/tests/pass/allocate_var.py @@ -1,7 +1,3 @@ -''' -Whole GEMM is an example of GEMM that utilizes all memory hierarchies -of NVIDIA GPU. -''' import context from pythriller import initialize_thriller_flow, Layout, Tensor, TensorType @@ -123,7 +119,6 @@ # Print codegen for Reg Graph. reg_code = RegGraph.codegen() - print(reg_code) # Build Block for Shared to Register. SharedToRegBlock = Block( @@ -131,7 +126,6 @@ # Print codegen for Shared to Register Block. shared_to_reg_code = SharedToRegBlock.codegen() - print(shared_to_reg_code) # Define BlockNode for SharedToRegBlock SharedBlockNode = Node.block(SharedToRegBlock) @@ -150,14 +144,5 @@ # Connect Shared Graph. SharedGraph.connect() - # Print codegen for Shared Graph. - shared_code = SharedGraph.codegen() - print(shared_code) - - # Build Block for Global to Shared. - GlobalToSharedBlock = Block( - [AttachedEdgeGA2SA, AttachedEdgeGB2SB], [AttachedEdgeSC2GC], SharedGraph, [LoopIterG2S]) - - # Print codegen for Global to Shared Block. - global_to_shared_code = GlobalToSharedBlock.codegen() - print(global_to_shared_code) + allocate_vars = SharedGraph.allocate_var() + print(allocate_vars) diff --git a/thriller-core/src/dataflow/mod.rs b/thriller-core/src/dataflow/mod.rs index e5c7703..2307d9e 100644 --- a/thriller-core/src/dataflow/mod.rs +++ b/thriller-core/src/dataflow/mod.rs @@ -8,3 +8,4 @@ pub use block::ThrillerBlock; pub use edge::{AttachedEdge, ThrillerEdge}; pub use graph::ThrillerGraph; pub use node::{ThrillerNode, ThrillerNodeInner}; +pub use pass::{AllocateVar, GraphPass}; diff --git a/thriller-core/src/dataflow/pass/allocate_var.rs b/thriller-core/src/dataflow/pass/allocate_var.rs index 6e2cc30..095f1b7 100644 --- a/thriller-core/src/dataflow/pass/allocate_var.rs +++ b/thriller-core/src/dataflow/pass/allocate_var.rs @@ -1,10 +1,25 @@ use super::GraphPass; use crate::{dataflow::ThrillerGraph, BufType, ThrillerNodeInner}; +/// AllocateVar pub struct AllocateVar { code: String, } +impl AllocateVar { + #[doc(hidden)] + pub fn new() -> Self { + Self { + code: String::new(), + } + } + + #[doc(hidden)] + pub fn code(&self) -> String { + self.code.clone() + } +} + impl GraphPass for AllocateVar { fn run(&mut self, graph: &mut ThrillerGraph) { // Transver the graph and allocate variables. @@ -27,7 +42,7 @@ impl GraphPass for AllocateVar { &BufType::RegTile | &BufType::RegVec => { self.code += - format!("Shared{} {};\n", buf.get_name(), buf.get_name()).as_str(); + format!("Reg{} {};\n", buf.get_name(), buf.get_name()).as_str(); } } } diff --git a/thriller-core/src/dataflow/pass/mod.rs b/thriller-core/src/dataflow/pass/mod.rs index c6fa882..9e2cad6 100644 --- a/thriller-core/src/dataflow/pass/mod.rs +++ b/thriller-core/src/dataflow/pass/mod.rs @@ -4,7 +4,10 @@ mod allocate_edge; mod allocate_var; mod gen_iterator; +pub use allocate_var::AllocateVar; + /// A trait for graph passes. pub trait GraphPass { + /// Run the pass on the graph. fn run(&mut self, graph: &mut ThrillerGraph); } diff --git a/thriller-core/src/lib.rs b/thriller-core/src/lib.rs index dff2f25..8f48d23 100644 --- a/thriller-core/src/lib.rs +++ b/thriller-core/src/lib.rs @@ -49,7 +49,8 @@ mod var; pub use access::{AccessMap, AccessMatrix, AccessOffset}; pub use buffer::{BufType, Buffer}; pub use dataflow::{ - AttachedEdge, ThrillerBlock, ThrillerEdge, ThrillerGraph, ThrillerNode, ThrillerNodeInner, + AllocateVar, AttachedEdge, GraphPass, ThrillerBlock, ThrillerEdge, ThrillerGraph, ThrillerNode, + ThrillerNodeInner, }; pub use dtype::DataType; pub use engine::{BlockLayout, BlockShape, ThrillerEngine};