Skip to content

Commit

Permalink
Hack to enclude schedule string map (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and tmoreau89 committed Dec 1, 2018
1 parent e63defe commit 8d5203f
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 3 deletions.
2 changes: 2 additions & 0 deletions vta/include/vta/hw_spec.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ extern "C" {
#define VTA_OUT_WIDTH (1 << VTA_LOG_OUT_WIDTH)
/*! Accumulator data type width */
#define VTA_ACC_WIDTH (1 << VTA_LOG_ACC_WIDTH)
/*! Accumulator truncation bits */
#define VTA_ACC_TRUC_BITS 24
/*! log2 of ALU data type width */
#define VTA_LOG_ALU_WIDTH (VTA_LOG_ACC_WIDTH - 1)
/*! ALU data type width */
Expand Down
8 changes: 8 additions & 0 deletions vta/python/vta/top/vta_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
['batch', 'height', 'width', 'in_filter', 'out_filter',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])

_SCHEDULE_STR_MAP = {}


def find_schedules(layer, vt_only=False, best_only=False):
""" Returns a schedule for a given a layer.
Expand Down Expand Up @@ -415,6 +418,11 @@ def _traverse(op):
else:
pad_data = None
wrkld = _get_workload(data, pad_data, kernel, output)

if wrkld in _SCHEDULE_STR_MAP and planStr is None:
planStr = _SCHEDULE_STR_MAP[wrkld]
logging.info("Apply pre-cached schedule for %s->%s", str(wrkld) , planStr)

if planStr:
matchObj = re.match( r'b(\d+)_oc(\d+)_ic(\d+)_h(\d+)_w(\d+)_oct(\d+)_ht(\d+)', planStr)
b_factor = int(matchObj.group(1))
Expand Down
15 changes: 13 additions & 2 deletions vta/src/sim/sim_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,14 @@ class Device {
}
}

int32_t IntTrunc(int32_t value, int32_t bits) {
if (bits >= 32) return value;
int leftbits = (32 - bits);
value = value & ((1 << bits) -1);
value = (value << leftbits) >> leftbits;
return value;
}

void RunGEMM(const VTAGemInsn* op) {
if (!op->reset_reg) {
prof_->gemm_counter += op->iter_out * op->iter_in * (op->uop_end - op->uop_bgn);
Expand Down Expand Up @@ -452,6 +460,7 @@ class Device {
sum +=
inp.GetSigned(i * VTA_BLOCK_IN + k) *
wgt.GetSigned(j * VTA_BLOCK_IN + k);
sum = IntTrunc(sum, VTA_ACC_TRUC_BITS);
}
acc.SetSigned(acc_offset, sum);
}
Expand Down Expand Up @@ -540,11 +549,13 @@ class Device {
BitPacker<VTA_ACC_WIDTH> dst(acc_.BeginPtr(dst_index));
BitPacker<VTA_ACC_WIDTH> src(acc_.BeginPtr(src_index));
for (int k = 0; k < VTA_BLOCK_OUT; ++k) {
int32_t value;
if (use_imm) {
dst.SetSigned(k, func(dst.GetSigned(k), op->imm));
value = func(dst.GetSigned(k), op->imm);
} else {
dst.SetSigned(k, func(dst.GetSigned(k), src.GetSigned(k)));
value = func(dst.GetSigned(k), src.GetSigned(k));
}
dst.SetSigned(k, IntTrunc(value, VTA_ACC_TRUC_BITS));
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion vta/tests/python/integration/test_benchmark_topi_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,9 @@ def get_ref_data():
return a_np, w_np, b_np

def verify(s, check_correctness):
mod = vta.build(s, [data, kernel_arg, bias, coeff, res], "ext_dev",
mod = vta.build(s,
[data, kernel_arg, bias, coeff, res],
"ext_dev",
env.target_host, name="conv2d")
temp = util.tempdir()

Expand Down

0 comments on commit 8d5203f

Please sign in to comment.