Skip to content

Commit

Permalink
update tutorial and tophub
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored and tmoreau89 committed Jan 2, 2019
1 parent 8841df9 commit d370a98
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
15 changes: 9 additions & 6 deletions vta/python/vta/top/arm_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,20 @@
from topi.nn import conv2d, conv2d_alter_layout
from topi import generic

@conv2d.register(["vtacpu", "vta"])
@conv2d.register(["vta"])
def compute(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
target = tvm.target.current_target()
with tvm.target.arm_cpu(model=target.model):
return conv2d(*args, **kwargs)

@generic.schedule_conv2d_nchw.register(["vtacpu", "vta"])
@generic.schedule_conv2d_nchw.register(["vta"])
def schedule(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
target = tvm.target.current_target()
with tvm.target.arm_cpu(model=target.model):
return generic.schedule_conv2d_nchw(*args, **kwargs)

@conv2d_alter_layout.register(["vtacpu", "vta"])
@conv2d_alter_layout.register(["vta"])
def alter(*args, **kwargs):
with tvm.target.arm_cpu("vtacpu"):
target = tvm.target.current_target()
with tvm.target.arm_cpu(model=target.model):
return conv2d_alter_layout(*args, **kwargs)
25 changes: 11 additions & 14 deletions vta/tutorials/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,18 @@ def classify(m, image):
# Helper function to compile the NNVM graph
# Takes in a path to a graph file, params file, and device target
# Returns the NNVM graph object, a compiled library object, and the params dict
def generate_graph(graph_fn, params_fn, device="vta"):
def generate_graph(graph_fn, params_fn, target):
# Measure build start time
build_start = time.time()

# Derive the TVM target
target = tvm.target.create("llvm -device={}".format(device))

# Derive the LLVM compiler flags
# When targetting the Pynq, cross-compile to ARMv7 ISA
if env.TARGET == "sim":
target_host = "llvm"
elif env.TARGET == "pynq":
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
target_host = "llvm -target=armv7-none-linux-gnueabihf"
elif env.TARGET == 'ultra96':
target_host = "llvm -target=aarch64-linux-gnu"

# Load the ResNet-18 graph and parameters
sym = nnvm.graph.load_json(open(graph_fn).read())
Expand Down Expand Up @@ -153,10 +152,6 @@ def generate_graph(graph_fn, params_fn, device="vta"):
# Read in ImageNet Categories
synset = eval(open(os.path.join(data_dir, categ_fn)).read())

# Download pre-tuned op parameters of conv2d for ARM CPU used in VTA
autotvm.tophub.check_backend('vta')


######################################################################
# Setup the Pynq Board's RPC Server
# ---------------------------------
Expand Down Expand Up @@ -198,17 +193,19 @@ def generate_graph(graph_fn, params_fn, device="vta"):
# ------------------------
# Build the ResNet graph runtime, and configure the parameters.

# Set ``device=vtacpu`` to run inference on the CPU
# or ``device=vta`` to run inference on the FPGA.
device = "vta"
# target = tvm.target.create('llvm -device=arm_cpu -model=pynq') # run arm cpu on pynq
# target = tvm.target.create('llvm -device=arm_cpu -model=ultra96') # run arm cpu on ultra96
# target = tvm.target.create('llvm -device=vta -model=pynq') # run vta on pynq
# target = tvm.target.create('llvm -device=vta -model=ultra96') # run vta on ultra96
target = tvm.target.create('llvm -device=vta -model=pynq')

# Device context
ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0)
ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0)

# Build the graph runtime
graph, lib, params = generate_graph(os.path.join(data_dir, graph_fn),
os.path.join(data_dir, params_fn),
device)
target)
m = graph_runtime.create(graph, lib, ctx)

# Set the parameters
Expand Down

0 comments on commit d370a98

Please sign in to comment.