Skip to content

Commit

Permalink
Expose workspace size in tvmgen_default.h (#9510)
Browse files Browse the repository at this point in the history
This PR exposes the workspace size as a macro 
TVMGEN_DEFAULT_WORKSPACE_SIZE in tvmgen_default.h 
(or TVMGEN_<MODEL_NAME>_WORKSPACE_SIZE in 
tvmgen_<model_name>.h in the case that the model name is not default).

This functionality is useful for microTVM/AOT use cases 
where it's useful to know the workspace size at compile time.
  • Loading branch information
grant-arm authored Nov 22, 2021
1 parent 328d7c7 commit 8e1425d
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 32 deletions.
11 changes: 8 additions & 3 deletions python/tvm/micro/model_library_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,15 @@ class UnsupportedInModelLibraryFormatError(Exception):
"""Raised when export_model_library_format does not support the given Module tree."""


def generate_c_interface_header(module_name, inputs, outputs, devices, include_path):
def generate_c_interface_header(
module_name, inputs, outputs, devices, workspace_size, include_path
):
"""Generate C Interface header to be included in MLF"""
mangled_name = to_c_variable_style(prefix_generated_name(module_name))
metadata_header = os.path.join(include_path, f"{mangled_name}.h")

interface_c_create = tvm._ffi.get_global_func("runtime.InterfaceCCreate")
interface_c_module = interface_c_create(module_name, inputs, outputs, devices)
interface_c_module = interface_c_create(module_name, inputs, outputs, devices, workspace_size)

with open(metadata_header, "w") as header_file:
header_file.write(interface_c_module.get_source())
Expand Down Expand Up @@ -325,7 +327,10 @@ def _export_graph_model_library_format(
include_path.mkdir()
inputs, outputs = _get_inputs_and_outputs_from_module(mod)
devices = mod.get_devices()
generate_c_interface_header(mod.libmod_name, inputs, outputs, devices, include_path)
workspace_size = int(metadata["memory"]["functions"]["main"][0]["workspace_size_bytes"])
generate_c_interface_header(
mod.libmod_name, inputs, outputs, devices, workspace_size, include_path
)

parameters_dir = tempdir / "parameters"
parameters_dir.mkdir()
Expand Down
23 changes: 19 additions & 4 deletions src/target/source/interface_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ using namespace tvm::relay::backend;
class InterfaceCNode : public runtime::ModuleNode {
public:
InterfaceCNode(std::string module_name, Array<String> inputs, Array<String> outputs,
Array<String> devices)
: module_name_(module_name), inputs_(inputs), outputs_(outputs), devices_(devices) {}
Array<String> devices, int workspace_size)
: module_name_(module_name),
inputs_(inputs),
outputs_(outputs),
devices_(devices),
workspace_size_(workspace_size) {}
const char* type_key() const { return "h"; }

std::string GetSource(const std::string& format) final {
Expand All @@ -60,6 +64,7 @@ class InterfaceCNode : public runtime::ModuleNode {
}

EmitRunFunction(code);
EmitWorkspaceSize(code);
EmitLowerHeaderGuard(code);

return code.str();
Expand Down Expand Up @@ -140,15 +145,25 @@ class InterfaceCNode : public runtime::ModuleNode {
code_stream << ");\n";
}

void EmitWorkspaceSize(std::stringstream& code_stream) {
std::string workspace_size_name =
ToCConstantStyle(PrefixGeneratedName({module_name_, "WORKSPACE_SIZE"}));
code_stream << "/*!\n"
<< " * \\brief Workspace size for TVM module \"" << module_name_ << "\"\n"
<< " */\n"
<< "#define " << workspace_size_name << " " << workspace_size_ << "\n";
}

std::string module_name_;
Array<String> inputs_;
Array<String> outputs_;
Array<String> devices_;
int workspace_size_;
};

runtime::Module InterfaceCCreate(std::string module_name, Array<String> inputs,
Array<String> outputs, Array<String> devices) {
auto n = make_object<InterfaceCNode>(module_name, inputs, outputs, devices);
Array<String> outputs, Array<String> devices, int workspace_size) {
auto n = make_object<InterfaceCNode>(module_name, inputs, outputs, devices, workspace_size);
return runtime::Module(n);
}

Expand Down
51 changes: 34 additions & 17 deletions tests/cpp/target/source/interface_c_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace tvm {
namespace codegen {

runtime::Module InterfaceCCreate(std::string module_name, Array<String> inputs,
Array<String> outputs, Array<String> devices);
Array<String> outputs, Array<String> devices, int workspace_size);

namespace {

Expand All @@ -49,7 +49,8 @@ TEST(InterfaceAPI, ContainsHeaderGuards) {
<< "#endif\n\n"
<< "#endif // TVMGEN_ULTIMATE_CAT_SPOTTER_H_\n";

runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {});
runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, 0);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source, HasSubstr(upper_header_guard.str()));
Expand All @@ -69,7 +70,8 @@ TEST(InterfaceAPI, ContainsRunFunction) {
<< " struct tvmgen_ultimate_cat_spotter_outputs* outputs\n"
<< ");\n";

runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {});
runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, 0);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source, HasSubstr(run_function.str()));
Expand All @@ -91,7 +93,7 @@ TEST(InterfaceAPI, ContainsRunFunctionWithDevices) {
<< ");\n";

runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device"});
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device"}, 0);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source, HasSubstr(run_function.str()));
Expand All @@ -107,7 +109,8 @@ TEST(InterfaceAPI, ContainsInputStructSingle) {
<< " void* input;\n"
<< "};\n\n";

runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {});
runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, 0);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source, HasSubstr(input_struct.str()));
Expand All @@ -122,7 +125,7 @@ TEST(InterfaceAPI, ContainsInputStructMany) {
<< "};\n\n";

runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input1", "input2"}, {"output"}, {});
InterfaceCCreate("ultimate_cat_spotter", {"input1", "input2"}, {"output"}, {}, 0);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source, HasSubstr(input_struct.str()));
Expand All @@ -137,15 +140,15 @@ TEST(InterfaceAPI, ContainsInputStructSanitised) {
<< "};\n\n";

runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input+1", "input+2"}, {"output"}, {});
InterfaceCCreate("ultimate_cat_spotter", {"input+1", "input+2"}, {"output"}, {}, 0);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source, HasSubstr(input_struct.str()));
}

TEST(InterfaceAPI, ContainsInputStructClash) {
runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input+", "input-"}, {"output"}, {});
InterfaceCCreate("ultimate_cat_spotter", {"input+", "input-"}, {"output"}, {}, 0);
ASSERT_THROW(test_module->GetSource(), InternalError);
}

Expand All @@ -159,7 +162,8 @@ TEST(InterfaceAPI, ContainsOutputStructSingle) {
<< " void* output;\n"
<< "};\n\n";

runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {});
runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, 0);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source, HasSubstr(output_struct.str()));
Expand All @@ -174,7 +178,7 @@ TEST(InterfaceAPI, ContainsOutputStructMany) {
<< "};\n\n";

runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output1", "output2"}, {});
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output1", "output2"}, {}, 0);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source, HasSubstr(output_struct.str()));
Expand All @@ -189,15 +193,15 @@ TEST(InterfaceAPI, ContainsOutputStructSanitised) {
<< "};\n\n";

runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+1", "output-2"}, {});
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+1", "output-2"}, {}, 0);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source, HasSubstr(output_struct.str()));
}

TEST(InterfaceAPI, ContainsOutputStructClash) {
runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+", "output-"}, {});
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+", "output-"}, {}, 0);
ASSERT_THROW(test_module->GetSource(), InternalError);
}

Expand All @@ -210,7 +214,8 @@ TEST(InterfaceAPI, NoDeviceAPIStructIfNoDevices) {
<< "struct tvmgen_ultimate_cat_spotter_devices {\n"
<< "};\n\n";

runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {});
runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, 0);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source, Not(HasSubstr(device_struct.str())));
Expand All @@ -227,7 +232,7 @@ TEST(InterfaceAPI, ContainsDeviceStructSingle) {
<< "};\n\n";

runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device"});
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device"}, 0);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source, HasSubstr(device_struct.str()));
Expand All @@ -242,7 +247,7 @@ TEST(InterfaceAPI, ContainsDeviceStructMany) {
<< "};\n\n";

runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device1", "device2"});
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device1", "device2"}, 0);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source, HasSubstr(device_struct.str()));
Expand All @@ -257,18 +262,30 @@ TEST(InterfaceAPI, ContainsDeviceStructSanitised) {
<< "};\n\n";

runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device+1", "device+2"});
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device+1", "device+2"}, 0);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source, HasSubstr(device_struct.str()));
}

TEST(InterfaceAPI, ContainsDeviceStructClash) {
runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device+", "device-"});
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device+", "device-"}, 0);
ASSERT_THROW(test_module->GetSource(), InternalError);
}

TEST(InterfaceAPI, ContainsWorkspaceSize) {
runtime::Module test_module =
InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, 765432);
std::string header_source = test_module->GetSource();

ASSERT_THAT(header_source,
HasSubstr("* \\brief Workspace size for TVM module \"ultimate_cat_spotter\""));

ASSERT_THAT(header_source,
HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_WORKSPACE_SIZE 765432"));
}

} // namespace
} // namespace codegen
} // namespace tvm
4 changes: 2 additions & 2 deletions tests/micro/zephyr/test_zephyr_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_tflite(temp_dir, board, west_cmd, tvm_debug):
model_files_path = os.path.join(tar_temp_dir, "include")
os.mkdir(model_files_path)
header_path = generate_c_interface_header(
lowered.libmod_name, ["input_1"], ["output"], [], model_files_path
lowered.libmod_name, ["input_1"], ["output"], [], 0, model_files_path
)
tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir))

Expand Down Expand Up @@ -147,7 +147,7 @@ def test_qemu_make_fail(temp_dir, board, west_cmd, tvm_debug):
model_files_path = os.path.join(tar_temp_dir, "include")
os.mkdir(model_files_path)
header_path = generate_c_interface_header(
lowered.libmod_name, ["input_1"], ["output"], [], model_files_path
lowered.libmod_name, ["input_1"], ["output"], [], 0, model_files_path
)
tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir))
test_utils.create_header_file(
Expand Down
26 changes: 20 additions & 6 deletions tests/python/relay/aot/aot_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,17 @@ def emit_data_linkage(output_file, data_linkage):
)


def emit_main_prologue(main_file, custom_prologue, workspace_bytes, data_linkage):
def emit_main_prologue(
main_file, custom_prologue, workspace_bytes, data_linkage, compiled_models, interface_api
):
# Add TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES because of memory alignment.
main_file.write(
f"#define WORKSPACE_SIZE ({workspace_bytes} + TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES)\n"
)
workspace_define = f"#define WORKSPACE_SIZE ({workspace_bytes}"
if interface_api == "c":
for compiled_model in compiled_models:
model = compiled_model.model
workspace_define += f" + TVMGEN_{model.name.upper()}_WORKSPACE_SIZE"
workspace_define += " + TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES)\n"
main_file.write(workspace_define)
emit_data_linkage(main_file, data_linkage)
main_file.write("static uint8_t g_aot_memory[WORKSPACE_SIZE];\n")
main_file.write("tvm_workspace_t app_workspace;\n")
Expand Down Expand Up @@ -516,7 +522,14 @@ def create_main(
model = compiled_model.model
emit_main_data(main_file, model.inputs, model.outputs, model.name)

emit_main_prologue(main_file, custom_prologue, workspace_bytes, data_linkage)
emit_main_prologue(
main_file,
custom_prologue,
workspace_bytes,
data_linkage,
compiled_models,
interface_api,
)
emit_main_init_memory_manager(main_file)

if interface_api == "c":
Expand Down Expand Up @@ -679,7 +692,8 @@ def run_and_check(
t.extractall(base_path)

workspace_bytes += model.extra_memory_in_bytes
workspace_bytes += mlf_extract_workspace_size_bytes(tar_file)
if interface_api == "packed":
workspace_bytes += mlf_extract_workspace_size_bytes(tar_file)

for key in model.inputs:
sanitized_tensor_name = re.sub(r"\W", "_", key)
Expand Down

0 comments on commit 8e1425d

Please sign in to comment.