Skip to content

Commit

Permalink
[HARDWARE] [AUTOTVM] Allowing AutoTVM to tune for different VTA desig…
Browse files Browse the repository at this point in the history
…ns and adding ZCU102 support (apache#19)

* compilation support for ZCU102

* dll caching to save on dynamic reconfiguration time

* introducing model signature to differentiate between VTA variants

* disable tophub to avoid falling back on invalid schedules

* reconfig when targeting an FPGA to reset the hardware

* typo fix

* addressing comments
  • Loading branch information
tmoreau89 committed Dec 1, 2018
1 parent e47976c commit 49d63dc
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 38 deletions.
4 changes: 4 additions & 0 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,10 @@ def run_through_rpc(measure_input, build_result,
try:
# upload built module
remote = request_remote(*remote_args)
if measure_input.target.device_name == 'vta':
from vta import program_fpga, reconfig_runtime
program_fpga(remote, None)
reconfig_runtime(remote)
remote.upload(build_result.filename)
func = remote.load_module(os.path.split(build_result.filename)[1])
ctx = remote.context(str(measure_input.target), 0)
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/autotvm/task/nnvm_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ..util import get_const_tuple
from .task import create, register
from .dispatcher import ApplyHistoryBest

logger = logging.getLogger('autotvm')

Expand Down Expand Up @@ -240,7 +241,8 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):

# run compiler to collect all TOPI calls during compilation
nnvm.compiler.engine.clear_cache()
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype)
with ApplyHistoryBest([]):
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype)
nnvm.compiler.engine.clear_cache()

logger.disabled = old_state
Expand Down
3 changes: 2 additions & 1 deletion vta/hardware/xilinx/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ $(BIT_PATH): $(IP_PATH)
mkdir -p $(HW_BUILD_PATH)
cd $(HW_BUILD_PATH) && \
$(VIVADO) -mode tcl -source $(SCRIPT_DIR)/ultra96.tcl \
-tclargs $(BUILD_DIR)/hls/$(CONF) $(VTA_HW_COMP_THREADS) $(VTA_CLOCK_FREQ) $(VTA_GEMM_II) \
-tclargs $(VTA_TARGET) $(BUILD_DIR)/hls/$(CONF) $(VTA_HW_COMP_THREADS) \
$(VTA_CLOCK_FREQ) $(VTA_GEMM_II) \
$(VTA_INP_WIDTH) $(VTA_WGT_WIDTH) $(VTA_OUT_WIDTH) \
$(VTA_BATCH) $(VTA_IN_BLOCK) $(VTA_OUT_BLOCK) \
$(VTA_INP_BUFF_SIZE) $(VTA_WGT_BUFF_SIZE) $(VTA_OUT_BUFF_SIZE)
Expand Down
10 changes: 9 additions & 1 deletion vta/hardware/xilinx/scripts/hls.tcl
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,21 @@ proc init_design {target per g_ii a_ii inp_width wgt_width out_width acc_width b
set_part {xc7z020clg484-1}
} elseif {$target=="ultra96"} {
set_part {xczu3eg-sbva484-1-e}
} elseif {$target=="zcu102"} {
set_part {xczu9eg-ffvb1156-2-e}
}

# Max bus width (supported by Vivado)
set max_width 1024

# Set axi width (TODO derive from top level config)
set axi_width 128
if {$target=="pynq"} {
set axi_width 64
} elseif {$target=="ultra96"} {
set axi_width 128
} elseif {$target=="zcu102"} {
set axi_width 128
}

# Set the clock frequency
create_clock -period $per -name default
Expand Down
42 changes: 24 additions & 18 deletions vta/hardware/xilinx/scripts/ultra96.tcl
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,24 @@ if { [string first $scripts_vivado_version $current_vivado_version] == -1 } {
}

# Parse argument list, derive the clock to utilize
if { [llength $argv] eq 13 } {
set ip_path [lindex $argv 0]
set num_threads [lindex $argv 1]
set clock_freq [lindex $argv 2]
set gemm_ii [lindex $argv 3]
set inp_width [expr 1 << [lindex $argv 4]]
set wgt_width [expr 1 << [lindex $argv 5]]
set out_width [expr 1 << [lindex $argv 6]]
set batch [expr 1 << [lindex $argv 7]]
set out_block [expr 1 << [lindex $argv 8]]
set in_block [expr 1 << [lindex $argv 9]]
set inp_mem_size [expr 1 << [lindex $argv 10]]
set wgt_mem_size [expr 1 << [lindex $argv 11]]
set out_mem_size [expr 1 << [lindex $argv 12]]
if { [llength $argv] eq 14 } {
set target [lindex $argv 0]
set ip_path [lindex $argv 1]
set num_threads [lindex $argv 2]
set clock_freq [lindex $argv 3]
set gemm_ii [lindex $argv 4]
set inp_width [expr 1 << [lindex $argv 5]]
set wgt_width [expr 1 << [lindex $argv 6]]
set out_width [expr 1 << [lindex $argv 7]]
set batch [expr 1 << [lindex $argv 8]]
set out_block [expr 1 << [lindex $argv 9]]
set in_block [expr 1 << [lindex $argv 10]]
set inp_mem_size [expr 1 << [lindex $argv 11]]
set wgt_mem_size [expr 1 << [lindex $argv 12]]
set out_mem_size [expr 1 << [lindex $argv 13]]
} else {
puts "Arg list incomplete: <path to ip dir> <num threads> <clock freq> <gemm ii> \
<inp width> <wgt_width> <out_width> <batch> <batch> <out_block> <in_block
puts "Arg list incomplete: <target> <path to ip dir> <num threads> <clock freq> \
<gemm ii> <inp width> <wgt_width> <out_width> <batch> <batch> <out_block> <in_block> \
<inp_mem_size> <wgt_mem_size> <out_mem_size>"
return 1
}
Expand Down Expand Up @@ -82,8 +83,13 @@ set compute_ip "${ip_path}/vta_compute/solution0/impl/ip/xilinx_com_hls_compute_
set store_ip "${ip_path}/vta_store/solution0/impl/ip/xilinx_com_hls_store_1_0.zip"

# Create custom project
create_project -force $proj_name $proj_path -part xczu3eg-sbva484-1-e
set_property BOARD_PART em.avnet.com:ultra96:part0:1.0 [current_project]
if { ${target} eq "ultra96" } {
create_project -force $proj_name $proj_path -part xczu3eg-sbva484-1-e
set_property BOARD_PART em.avnet.com:ultra96:part0:1.0 [current_project]
} elseif { ${target} eq "zcu102" } {
create_project -force $proj_name $proj_path -part xczu9eg-ffvb1156-2-e
set_property BOARD_PART xilinx.com:zcu102:part0:3.2 [current_project]
}

# Update IP repository with generated IP
file mkdir $ip_lib
Expand Down
17 changes: 16 additions & 1 deletion vta/python/vta/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __init__(self, cfg):
self._mock_env = None
self._dev_ctx = None
self._last_env = None
# derive bitstream name
# derive bitstream name
self.BITSTREAM = "{}/{}/{}x{}x{}_a{}w{}o{}_{}_{}_{}_{}_{}MHz_{}ns_gii{}".format(
self.HW_VER.replace('.', '_'),
self.TARGET,
Expand All @@ -171,6 +171,21 @@ def __init__(self, cfg):
if self.MUL_EN and self.ALU_EN:
self.BITSTREAM += "_mul"
self.BITSTREAM += ".bit"
# model - autoTVM signature that identifies VTA configuration.
# This is WIP: knobs that could influence the efficacy of the
# schedule have been left out for now.
self.MODEL = "{}-{}x{}x{}_a{}w{}o{}_{}_{}_{}_{}".format(
self.TARGET,
self.BATCH,
self.BLOCK_IN,
self.BLOCK_OUT,
self.INP_WIDTH,
self.WGT_WIDTH,
self.OUT_WIDTH,
self.LOG_UOP_BUFF_SIZE,
self.LOG_INP_BUFF_SIZE,
self.LOG_WGT_BUFF_SIZE,
self.LOG_ACC_BUFF_SIZE)

def __enter__(self):
self._last_env = Environment.current
Expand Down
27 changes: 19 additions & 8 deletions vta/python/vta/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ctypes
import json
import tvm
from shutil import copyfile
from tvm._ffi.base import c_str
from tvm import rpc
from tvm.contrib import cc
Expand Down Expand Up @@ -87,14 +88,24 @@ def reconfig_runtime(cfg_json):
if pkg.same_config(old_cfg):
logging.info("Skip reconfig_runtime due to same config.")
return
cflags = ["-O2", "-std=c++11"]
cflags += pkg.cflags
ldflags = pkg.ldflags
lib_name = dll_path
source = pkg.lib_source
logging.info("Rebuild runtime:\n output=%s,\n cflags=%s,\n source=%s,\n ldflags=%s",
dll_path, '\n\t'.join(cflags), '\n\t'.join(source), '\n\t'.join(ldflags))
cc.create_shared(lib_name, source, cflags + ldflags)
# check if a dll matching the configuration has been cached
dll_root, dll_ext = os.path.splitext(dll_path)
cached_dll_path = dll_root + '-' + pkg.signature + dll_ext
if os.path.isfile(cached_dll_path):
copyfile(cached_dll_path, dll_path)
logging.info("Swapping in cached dll: source=%s, destination=%s",
cached_dll_path, dll_path)
else:
cflags = ["-O2", "-std=c++11"]
cflags += pkg.cflags
ldflags = pkg.ldflags
lib_name = dll_path
source = pkg.lib_source
logging.info("Rebuild runtime:\n output=%s,\n cflags=%s,\n source=%s,\n ldflags=%s",
dll_path, '\n\t'.join(cflags), '\n\t'.join(source), '\n\t'.join(ldflags))
cc.create_shared(lib_name, source, cflags + ldflags)
copyfile(dll_path, cached_dll_path)
logging.info("Caching dll to: %s", cached_dll_path)
with open(cfg_path, "w") as outputfile:
outputfile.write(pkg.cfg_json)

Expand Down
15 changes: 15 additions & 0 deletions vta/python/vta/pkg_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,21 @@ def __init__(self, cfg, proj_root):
def cflags(self):
return self.include_path + self.macro_defs

@property
def signature(self):
return "{}-{}_{}_{}_{}_{}_{}_{}_{}_{}_{}".format(
self.cfg_dict["TARGET"],
self.cfg_dict["LOG_BATCH"],
self.cfg_dict["LOG_BLOCK_IN"],
self.cfg_dict["LOG_BLOCK_OUT"],
self.cfg_dict["LOG_INP_WIDTH"],
self.cfg_dict["LOG_WGT_WIDTH"],
self.cfg_dict["LOG_OUT_WIDTH"],
self.cfg_dict["LOG_UOP_BUFF_SIZE"],
self.cfg_dict["LOG_INP_BUFF_SIZE"],
self.cfg_dict["LOG_WGT_BUFF_SIZE"],
self.cfg_dict["LOG_ACC_BUFF_SIZE"])

@property
def cfg_json(self):
return json.dumps(self.cfg_dict, indent=2)
Expand Down
5 changes: 2 additions & 3 deletions vta/scripts/tune_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,11 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype):
return s, [data, kernel, bias, res]

if __name__ == '__main__':
model = env.TARGET
N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype = \
1, 64, 56, 56, 64, 3, 3, (1, 1), (1, 1), 'int8', 'int32'

task = autotvm.task.create(conv2d, args=(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype),
target=tvm.target.vta(model), target_host=env.target_host, template_key='direct')
target=tvm.target.vta(env.MODEL), target_host=env.target_host, template_key='direct')
print(task.config_space)

# logging config (for printing tuning log to the screen)
Expand All @@ -63,7 +62,7 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype):

measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(model, 'fleet', 9190, number=4, repeat=3, timeout=30,
runner=autotvm.RPCRunner(env.TARGET, 'fleet', 9190, number=4, repeat=3, timeout=30,
check_correctness=True))

tuner = autotvm.tuner.RandomTuner(task)
Expand Down
9 changes: 4 additions & 5 deletions vta/scripts/tune_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,17 @@ def tune_tasks(tasks,
os.remove(tmp_log_file)

if __name__ == '__main__':
device_key = env.TARGET

tuning_opt = {
'log_filename': 'resnet-18.log',
'log_filename': 'resnet-18-{}.log'.format(env.MODEL),

'tuner': 'random',
'n_trial': 1e9,
'early_stopping': None,

'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
runner=autotvm.RPCRunner(device_key, 'fleet', 9190,
runner=autotvm.RPCRunner(env.TARGET, 'fleet', 9190,
number=4, repeat=3, timeout=60,
check_correctness=True))
}
Expand All @@ -182,7 +181,7 @@ def tune_tasks(tasks,
register_vta_tuning_tasks()

print("Extract tasks...")
target = tvm.target.vta(device_key)
target = tvm.target.vta(env.MODEL)
target_host = env.target_host
tasks = extract_tasks(sym, params, target, target_host)

Expand All @@ -203,7 +202,7 @@ def tune_tasks(tasks,

# upload module to device
print("Upload...")
remote = autotvm.measure.request_remote(device_key, 'fleet', 9190, timeout=10000)
remote = autotvm.measure.request_remote(env.TARGET, 'fleet', 9190, timeout=10000)
remote.upload(tmp.relpath(filename))
rlib = remote.load_module(filename)

Expand Down

0 comments on commit 49d63dc

Please sign in to comment.