Skip to content

Commit

Permalink
[Relay][VM] Fix loading late bound consts when none exist (apache#10087)
Browse files Browse the repository at this point in the history
* Fix loading late bound consts when none exist

* Simplify comment

* -mFix skipping of verilator tests

* Skip loading late-bound consts if none present

* Remove semi-related fix to verilator test skipping

* Remove more test-skip fixing for pr hygiene

* No-op for ci

* No-op for ci
  • Loading branch information
michalpiszczek authored and ylc committed Feb 16, 2022
1 parent 1c8a407 commit c4f343d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,10 @@ void Executable::MoveLateBoundConstantsToFile(const std::string& path, size_t by
}

void Executable::LoadLateBoundConstantsFromStream(dmlc::Stream* stream) {
if (late_bound_constant_names.empty()) {
VLOG(1) << "Found no late-bound constants to load";
return;
}
ICHECK_EQ(late_bound_constant_names.size(), constants.size());
Map<String, NDArray> map = runtime::LoadParams(stream);
VLOG(1) << "loaded " << map.size() << " late-bound constants";
Expand Down
62 changes: 62 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,68 @@ def test_large_constants():
tvm.testing.assert_allclose(expected, actual.numpy())


def test_load_late_bound_consts_with_no_late_bound_consts():
"""Check that load_late_bound_consts handles a model with no late bound consts."""
target = tvm.target.Target("llvm")
dev = tvm.cpu()

const_data = np.random.rand(1).astype("float64")
x = relay.var("x", shape=(1,), dtype="float64")
const = relay.const(const_data, dtype="float64")

func = relay.Function([x], relay.op.add(x, const))
mod = tvm.IRModule.from_expr(func)

vm_exec = vm.compile(mod, target=target)

temp = utils.tempdir()
path_consts = temp.relpath("consts")
path_dso = temp.relpath("lib.so")

# Ensure const_data is below the byte threshold for a late-bound const.
byte_limit = len(const_data.tobytes()) + 1
vm_exec.move_late_bound_consts(path_consts, byte_limit=byte_limit)
vm_exec.mod.export_library(path_dso)

mod = runtime.load_module(path_dso)
mod["load_late_bound_consts"](path_consts)

x_data = np.random.rand(1).astype("float64")
loaded_vm = runtime.vm.VirtualMachine(mod, dev)
actual = loaded_vm.invoke("main", x_data)
expected = x_data + const_data
tvm.testing.assert_allclose(expected, actual.numpy())


def test_vm_save_and_load_without_designating_late_bound_consts():
"""Check that a VM can be saved and loaded without late-bound consts in play.
Specifically, this test ensures that the machinery behind late-bound const
loading does not assume the need to load late-bound consts (and cause an error)
when the user did not choose to designate any consts as such.
"""
target = tvm.target.Target("llvm")
dev = tvm.cpu()

const_data = np.random.rand(1).astype("float64")
x = relay.var("x", shape=(1,), dtype="float64")
const = relay.const(const_data, dtype="float64")

func = relay.Function([x], relay.op.add(x, const))
mod = tvm.IRModule.from_expr(func)

vm_exec = vm.compile(mod, target=target)

code, lib = vm_exec.save()
exe = runtime.vm.Executable.load_exec(code, lib)

x_data = np.random.rand(1).astype("float64")
loaded_vm = runtime.vm.VirtualMachine(exe, dev)
actual = loaded_vm.invoke("main", x_data)
expected = x_data + const_data
tvm.testing.assert_allclose(expected, actual.numpy())


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit c4f343d

Please sign in to comment.