Skip to content

Commit

Permalink
[REFACTOR] Compile to ILA Asm (apache#11)
Browse files Browse the repository at this point in the history
* change to uint8_t in ila runtime

* [ refactor ] pre-compile to ILA asm

* [ impl ] jit

* [ add ] mlp model

* [ add ] quantized model in PT

* [ fix ] run quantized

* [ refactor ] AoT compiler

* hmm? I dont remember I touched this file

* [ fix ] naming issue and resolve some warnings

* [ fix ] turn off size check

* [ refactor ] cast according to annotation

* [ fix ] dtype
  • Loading branch information
AD1024 authored and gussmith23 committed Dec 29, 2021
1 parent d035d2a commit 359a5c3
Show file tree
Hide file tree
Showing 13 changed files with 26,363 additions and 202 deletions.
25,533 changes: 25,533 additions & 0 deletions include/tvm/support/json.hpp

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions python/tvm/relay/op/contrib/ilavta.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _func_wrapper(attrs, *args):
# _register_external_op_helper("nn.batch_matmul")
_register_external_op_helper("nn.bias_add")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu")
# _register_external_op_helper("nn.relu")


def make_pattern_conv2d():
Expand Down Expand Up @@ -57,10 +57,10 @@ def make_pattern_relu():

@register_pattern_table("ilavta")
def pattern_table():
conv2d_pat = ("ilavta.conv2d", make_pattern_conv2d())
# conv2d_pat = ("ilavta.conv2d", make_pattern_conv2d())
matmul_pat = ("ilavta.batch_matmul", make_pattern_batch_matmul())
dense_pat = ("ilavta.dense", make_pattern_dense())
bias_add_pat = ("ilavta.bias_add", make_pattern_bias_add())
relu_pat = ("ilavta.relu", make_pattern_relu())
ilavta_patterns = [conv2d_pat, matmul_pat, dense_pat, bias_add_pat, relu_pat]
ilavta_patterns = [matmul_pat, dense_pat, bias_add_pat, relu_pat]
return ilavta_patterns
36 changes: 36 additions & 0 deletions src/relay/backend/contrib/ilavta/ilavta_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#include <iostream>
#include <numeric>
#include <sstream>
#include <set>

#include "ilavta_codegen_utils.h"
#include "../../utils.h"

#include "../../../../runtime/contrib/json/json_node.h"
Expand All @@ -32,6 +34,7 @@ class ILAVTAJSONSerializer : public backend::contrib::JSONSerializer {
std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* cn) override {
Expr expr = GetRef<Expr>(cn);
std::string name;
std::string filename;

if (const auto* op_node = cn->op.as<OpNode>()) {
name = op_node->name;
Expand All @@ -45,10 +48,36 @@ class ILAVTAJSONSerializer : public backend::contrib::JSONSerializer {
}
if (name == "ilavta.dense") {
LOG(INFO) << "ilavta.dense pattern";
auto input_shape = GetShape(cn->args[0]->checked_type());
auto weight_shape = GetShape(cn->args[1]->checked_type());
int batch = input_shape[0];
int n_inp_cols = input_shape[1];
int n_wgt_rows = weight_shape[0];
int info[] = {batch, n_inp_cols, n_wgt_rows};
filename = GetCompiledFilename("dense", info, 3);
if (this->compiled_func.find(filename) == this->compiled_func.end()) {
filename = CompileGEMM(batch, n_inp_cols, n_wgt_rows, "./prog_frag/" + filename);
}
} else if (name == "ilavta.bias_add") {
LOG(INFO) << "ilavta.bias_add pattern";
auto input_shape = GetShape(cn->args[0]->checked_type());
int batch = input_shape[0];
int n_feat = input_shape[1];
int info[] = {batch, n_feat};
filename = GetCompiledFilename("bias_add", info, 2);
if (this->compiled_func.find(filename) == this->compiled_func.end()) {
filename = CompilBiasAdd(batch, n_feat, "./prog_frag/" + filename);
}
} else if (name == "ilavta.relu") {
LOG(INFO) << "ilavta.relu pattern";
auto input_shape = GetShape(cn->args[0]->checked_type());
int batch = input_shape[0];
int n_feat = input_shape[1];
int info[] = {batch, n_feat};
filename = GetCompiledFilename("relu", info, 2);
if (this->compiled_func.find(filename) == this->compiled_func.end()) {
filename = CompileRelu(batch, n_feat, "./prog_frag/" + filename);
}
}
} else {
LOG(FATAL) << "ILAVTA runtime does not support calls to "
Expand All @@ -64,8 +93,15 @@ class ILAVTAJSONSerializer : public backend::contrib::JSONSerializer {
auto node = std::make_shared<JSONGraphNode>(name, /* name_ */
"kernel", /* op_type_ */
inputs, 1 /* num_outputs_ */);
std::vector<std::string> vec;
std::vector<dmlc::any> compiler_attr;
vec.push_back(filename);
compiler_attr.emplace_back(vec);
node->SetAttr("asm_file", compiler_attr);
return AddNode(node, GetRef<Expr>(cn));
}
private:
std::set<std::string> compiled_func;

}; // class ILAVTAJSONSerializer

Expand Down
208 changes: 208 additions & 0 deletions src/relay/backend/contrib/ilavta/ilavta_codegen_utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#include <iomanip>
#include "ilavta_codegen_utils.h"

namespace tvm {
namespace relay {
namespace contrib {

using namespace nlohmann;
using addr_byte_pairs = std::vector<std::pair<vta_phy_addr_t, uint8_t>>;

json byte_pairs_to_json(const addr_byte_pairs& byte_pairs) {
std::vector<json> pair_list;

for (const auto& pair : byte_pairs) {
std::stringstream addr_stream;
addr_stream << "0x" << std::setfill('0') << std::setw(sizeof(vta_phy_addr_t)*2)
<< std::hex << pair.first;

std::stringstream byte_stream;
// casting to uint32_t because uint8_t's are treated as char literals, not ints
byte_stream << "0x" << std::setfill('0') << std::setw(2)
<< std::hex << static_cast<uint32_t>(pair.second);

pair_list.push_back({
{"addr", addr_stream.str()},
{"value", byte_stream.str()}
});
}

return pair_list;
}

json getGEMMAsm(int uop_bgn, int uop_end) {
return {
{"name", "gemm"},
{"reset_f", 0},
{"uop_bgn", uop_bgn},
{"uop_end", uop_end},
{"iter_o", 1},
{"iter_i", 1},
{"dst_fo", 0},
{"dst_fi", 0},
{"src_fo", 0},
{"src_fi", 0},
{"dst_fo", 0},
{"dst_fi", 0},
{"wgt_fo", 0},
{"wgt_fi", 0}
};
}

json get2DLoadStoreAsm(int opcode, int mem_type, int sram_id, int dram_id, int y_size, int x_size) {
std::string cmd_type;
switch (opcode) {
case VTA_OPCODE_LOAD:
cmd_type = "load_";
break;
case VTA_OPCODE_STORE:
cmd_type = "store_";
break;
default:
fprintf(stderr, "Unknown load / store: %d", opcode);
exit(-1);
}
switch (mem_type) {
case VTA_MEM_ID_INP:
cmd_type += "inp";
break;
case VTA_MEM_ID_WGT:
cmd_type += "wgt";
break;
case VTA_MEM_ID_UOP:
cmd_type += "uop";
break;
case VTA_MEM_ID_ACC:
cmd_type += "bias";
break;
case VTA_MEM_ID_OUT:
cmd_type += "acc";
break;
}
if (cmd_type == "load_uop") {
return {
{"name", cmd_type},
{"sram_id", sram_id},
{"dram_id", dram_id},
{"x_size", x_size}
};
} else if (cmd_type == "load_wgt" || cmd_type == "load_bias" || opcode == VTA_OPCODE_STORE){
return {
{"name", cmd_type},
{"sram_id", sram_id},
{"dram_id", dram_id},
{"y_size", y_size},
{"x_size", x_size},
{"x_stride", 1}
};
} else if (cmd_type == "load_inp") {
return {
{"name", cmd_type},
{"sram_id", sram_id},
{"dram_id", dram_id},
{"y_size", y_size},
{"x_size", x_size},
{"x_stride", 1},
{"y_pad0", 0},
{"x_pad0", 0},
{"y_pad1", 0},
{"x_pad1", 0}
};
} else {
fprintf(stderr, "Command %s not supported by ASM", cmd_type.c_str());
exit(-1);
}
}

json getAluAsm(int alu_opcode, int uop_bgn, int uop_end, bool use_imm, int imm) {
int asm_opcode = -1;
std::string op_name = "";
switch (alu_opcode) {
case VTA_ALU_OPCODE_MIN: asm_opcode = 0; op_name = "min"; break;
case VTA_ALU_OPCODE_MAX: asm_opcode = 1; op_name = "max"; break;
case VTA_ALU_OPCODE_ADD: asm_opcode = 2; op_name = "add"; break;
case VTA_ALU_OPCODE_SHR: asm_opcode = 3; op_name = "shr"; break;
default:
fprintf(stderr, "ALU Opcode %d is not valid", alu_opcode);
exit(-1);
}
return {
{"name", "alu_" + op_name},
{"reset_f", 0},
{"uop_bgn", uop_bgn},
{"uop_end", uop_end},
{"iter_o", 1},
{"iter_i", 1},
{"dst_fo", 0},
{"dst_fi", 0},
{"src_fo", 0},
{"src_fi", 0},
{"alu_op", asm_opcode},
{"use_imm", use_imm},
{"imm", imm}
};
}

std::string write_to_file(const std::string& filename, const json& data) {
std::ofstream out_file(filename + "_prog_frag.json");
out_file << std::setw(4) << data << "\n";
return filename + "_prog_frag.json";
}

std::string CompileGEMM(int batch, size_t n_inp_cols, size_t n_wgt_rows, std::string filename) {
size_t in_dim = n_inp_cols % VTA_BLOCK_IN != 0 ? n_inp_cols / VTA_BLOCK_IN + 1 : n_inp_cols / VTA_BLOCK_IN;
size_t out_dim = n_wgt_rows % VTA_BLOCK_OUT != 0 ? n_wgt_rows / VTA_BLOCK_OUT + 1 : n_wgt_rows / VTA_BLOCK_OUT;
size_t uop_size = batch * in_dim * out_dim;
json prog_frag = {};
prog_frag["asm"] = json::array({});
auto& prog = prog_frag["asm"];
prog.push_back(get2DLoadStoreAsm(VTA_OPCODE_LOAD, VTA_MEM_ID_UOP, 0, 0, 1, uop_size));
prog.push_back(get2DLoadStoreAsm(VTA_OPCODE_LOAD, VTA_MEM_ID_WGT, 0, 0, out_dim * in_dim, 1));
prog.push_back(get2DLoadStoreAsm(VTA_OPCODE_LOAD, VTA_MEM_ID_INP, 0, 0, batch * in_dim, 1));
prog.push_back(getGEMMAsm(0, uop_size));
prog.push_back(get2DLoadStoreAsm(VTA_OPCODE_STORE, VTA_MEM_ID_OUT, 0, 0, batch * out_dim, 1));
return write_to_file(filename, prog_frag);
}

std::string CompilBiasAdd(int batch, size_t n_feat, std::string filename) {
size_t in_dim = n_feat % VTA_BLOCK_IN != 0 ? n_feat / VTA_BLOCK_IN + 1 : n_feat / VTA_BLOCK_IN;
size_t uop_size = batch * in_dim;
json prog_frag = {
{"asm", json::array({})}
};
auto& prog = prog_frag["asm"];
prog.push_back(get2DLoadStoreAsm(VTA_OPCODE_LOAD, VTA_MEM_ID_UOP, 0, 0, 1, uop_size));
prog.push_back(get2DLoadStoreAsm(VTA_OPCODE_LOAD, VTA_MEM_ID_ACC, 0, 0, batch * in_dim, 1));
prog.push_back(get2DLoadStoreAsm(VTA_OPCODE_LOAD, VTA_MEM_ID_ACC, batch * in_dim, batch * in_dim, in_dim, 1));
prog.push_back(getAluAsm(VTA_ALU_OPCODE_ADD, 0, uop_size, 0, 0));
prog.push_back(get2DLoadStoreAsm(VTA_OPCODE_STORE, VTA_MEM_ID_OUT, 0, 0, batch * in_dim, 1));
return write_to_file(filename, prog_frag);
}

std::string CompileRelu(int batch, size_t n_feat, std::string filename) {
size_t in_dim = n_feat % VTA_BLOCK_IN != 0 ? n_feat / VTA_BLOCK_IN + 1 : n_feat / VTA_BLOCK_IN;
size_t uop_size = batch * in_dim;
json prog_frag = {
{"asm", json::array({})}
};
auto& prog = prog_frag["asm"];
prog.push_back(get2DLoadStoreAsm(VTA_OPCODE_LOAD, VTA_MEM_ID_UOP, 0, 0, 1, uop_size));
prog.push_back(get2DLoadStoreAsm(VTA_OPCODE_LOAD, VTA_MEM_ID_ACC, 0, 0, batch * in_dim, 1));
prog.push_back(getAluAsm(VTA_ALU_OPCODE_MAX, 0, uop_size, 1, 0));
prog.push_back(get2DLoadStoreAsm(VTA_OPCODE_STORE, VTA_MEM_ID_OUT, 0, 0, batch * in_dim, 1));
return write_to_file(filename, prog_frag);
}

std::string GetCompiledFilename(const std::string op_name, const int* input_info, const int num_info) {
std::stringstream ss;
ss << op_name + "_";
for (int i = 0; i < num_info; ++i) {
ss << input_info[i] << "_";
}
return ss.str();
}

}
}
}

22 changes: 22 additions & 0 deletions src/relay/backend/contrib/ilavta/ilavta_codegen_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef ILAVTA_CODEGEN_UTILS_H__
#define ILAVTA_CODEGEN_UTILS_H__
#include <vta/hw_spec.h>
#include <vta/driver.h>
#include <tvm/support/json.hpp>
#include <fstream>
#include <sstream>

namespace tvm {
namespace relay {
namespace contrib {

std::string CompileGEMM(int batch, size_t n_inp_cols, size_t n_wgt_rows, std::string filename);
std::string CompilBiasAdd(int batch, size_t n_feat, std::string filename);
std::string CompileRelu(int batch, size_t n_feat, std::string filename);
std::string GetCompiledFilename(const std::string op_name, const int* input_info, const int num_info);

}
}
}

#endif // ILAVTA_CODEGEN_UTILS_H__
Loading

0 comments on commit 359a5c3

Please sign in to comment.