From 75d6cbe605879c8b859c11428eabc8b3cdf36b45 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Thu, 24 Mar 2022 16:24:32 -0700 Subject: [PATCH] [4/5]Testing jit module in flatbuffer in Python. (#74387) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74387 Make temporary python bindings for flatbuffer to test ScriptModule save / load. (Note: this ignores all push blocking failures!) Test Plan: unittest Reviewed By: iseeyuan Differential Revision: D34968080 fbshipit-source-id: d23b16abda6e4b7ecf6b1198ed6e00908a3db903 (cherry picked from commit 5cbbc390c5f54146a1c469106ab4a6286c754325) --- setup.py | 10 + test/jit/test_save_load.py | 404 ++++++++++++++++++ test/test_jit.py | 2 +- tools/linter/clang_tidy/__main__.py | 3 + torch/CMakeLists.txt | 6 + torch/_C_flatbuffer/__init__.pyi | 10 + torch/csrc/init_flatbuffer_module.cpp | 96 +++++ torch/csrc/jit/mobile/flatbuffer_loader.cpp | 49 ++- torch/csrc/jit/python/script_init.cpp | 1 + .../serialization/flatbuffer_serializer.cpp | 40 +- torch/csrc/stub_with_flatbuffer.c | 18 + torch/jit/__init__.py | 2 +- torch/jit/_serialization.py | 71 +++ 13 files changed, 679 insertions(+), 33 deletions(-) create mode 100644 torch/_C_flatbuffer/__init__.pyi create mode 100644 torch/csrc/init_flatbuffer_module.cpp create mode 100644 torch/csrc/stub_with_flatbuffer.c diff --git a/setup.py b/setup.py index 410edfefa7e72b..9f99a73e8d1da8 100644 --- a/setup.py +++ b/setup.py @@ -824,7 +824,16 @@ def make_relative_rpath_args(path): include_dirs=[], library_dirs=library_dirs, extra_link_args=extra_link_args + main_link_args + make_relative_rpath_args('lib')) + C_flatbuffer = Extension("torch._C_flatbuffer", + libraries=main_libraries, + sources=["torch/csrc/stub_with_flatbuffer.c"], + language='c', + extra_compile_args=main_compile_args + extra_compile_args, + include_dirs=[], + library_dirs=library_dirs, + extra_link_args=extra_link_args + main_link_args + make_relative_rpath_args('lib')) extensions.append(C) + extensions.append(C_flatbuffer) if not IS_WINDOWS: DL = Extension("torch._dl", @@ -932,6 +941,7 @@ def print_box(msg): 'bin/*', 'test/*', '_C/*.pyi', + '_C_flatbuffer/*.pyi', 'cuda/*.pyi', 'optim/*.pyi', 'autograd/*.pyi', diff --git a/test/jit/test_save_load.py b/test/jit/test_save_load.py index 445caf2f40e18c..45e65e5e8184db 100644 --- a/test/jit/test_save_load.py +++ b/test/jit/test_save_load.py @@ -528,3 +528,407 @@ def forward(self, x): self.assertFalse(m_loaded_params['bar.bias'].is_meta) self.assertTrue(m_buffers['buffer'].is_meta) self.assertTrue(m_loaded_buffers['buffer'].is_meta) + + +class TestSaveLoadFlatbuffer(JitTestCase): + def test_different_modules(self): + """ + Exercise the situation where we have the same qualified name + in two different CompilationUnits on save/load. + """ + class Foo(torch.nn.Module): + def __init__(self): + super(Foo, self).__init__() + self.foo = torch.nn.Linear(2, 2) + self.bar = torch.nn.Linear(2, 2) + + def forward(self, x): + x = self.foo(x) + x = self.bar(x) + return x + + first_script_module = torch.jit.script(Foo()) + first_saved_module = io.BytesIO() + torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module) + first_saved_module.seek(0) + + clear_class_registry() + + class Foo(torch.nn.Module): + def __init__(self): + super(Foo, self).__init__() + self.foo = torch.nn.Linear(2, 2) + + def forward(self, x): + x = self.foo(x) + return x + + second_script_module = torch.jit.script(Foo()) + second_saved_module = io.BytesIO() + torch.jit.save_jit_module_to_flatbuffer(torch.jit.script(Foo()), second_saved_module) + second_saved_module.seek(0) + + clear_class_registry() + + self.assertEqual( + first_script_module._c.qualified_name, second_script_module._c.qualified_name + ) + + class ContainsBoth(torch.nn.Module): + def __init__(self): + super().__init__() + self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module)) + self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module)) + + def forward(self, x): + x = self.first(x) + x = self.second(x) + return x + + sm = torch.jit.script(ContainsBoth()) + contains_both = io.BytesIO() + torch.jit.save_jit_module_to_flatbuffer(sm, contains_both) + contains_both.seek(0) + sm = torch.jit.jit_module_from_flatbuffer(contains_both) + + def test_different_functions(self): + """ + Exercise the situation where we have the same qualified name + in two different CompilationUnits on save/load. + """ + def lol(x): + return x + + class Foo(torch.nn.Module): + def forward(self, x): + return lol(x) + + first_script_module = torch.jit.script(Foo()) + first_saved_module = io.BytesIO() + torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module) + first_saved_module.seek(0) + + clear_class_registry() + + def lol(x): # noqa: F811 + return "hello" + + class Foo(torch.nn.Module): + def forward(self, x): + return lol(x) + + second_script_module = torch.jit.script(Foo()) + second_saved_module = io.BytesIO() + torch.jit.save_jit_module_to_flatbuffer(torch.jit.script(Foo()), second_saved_module) + second_saved_module.seek(0) + + clear_class_registry() + + self.assertEqual( + first_script_module._c.qualified_name, second_script_module._c.qualified_name + ) + + class ContainsBoth(torch.nn.Module): + def __init__(self): + super().__init__() + self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module)) + self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module)) + + def forward(self, x): + x = self.first(x) + x = self.second(x) + return x + + sm = torch.jit.script(ContainsBoth()) + contains_both = io.BytesIO() + torch.jit.save_jit_module_to_flatbuffer(sm, contains_both) + contains_both.seek(0) + sm = torch.jit.jit_module_from_flatbuffer(contains_both) + + def test_different_interfaces(self): + """ + Exercise the situation where we have the same qualified name + in two different CompilationUnits on save/load. + """ + @torch.jit.interface + class MyInterface(object): + def bar(self, x: Tensor) -> Tensor: + pass + + @torch.jit.script + class ImplementInterface(object): + def __init__(self): + pass + + def bar(self, x): + return x + + class Foo(torch.nn.Module): + __annotations__ = {"interface": MyInterface} + + def __init__(self): + super().__init__() + self.interface = ImplementInterface() + + def forward(self, x): + return self.interface.bar(x) + + first_script_module = torch.jit.script(Foo()) + first_saved_module = io.BytesIO() + torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module) + first_saved_module.seek(0) + + clear_class_registry() + + @torch.jit.interface + class MyInterface(object): + def not_bar(self, x: Tensor) -> Tensor: + pass + + @torch.jit.script # noqa: F811 + class ImplementInterface(object): # noqa: F811 + def __init__(self): + pass + + def not_bar(self, x): + return x + + class Foo(torch.nn.Module): + __annotations__ = {"interface": MyInterface} + + def __init__(self): + super().__init__() + self.interface = ImplementInterface() + + def forward(self, x): + return self.interface.not_bar(x) + + second_script_module = torch.jit.script(Foo()) + second_saved_module = io.BytesIO() + torch.jit.save_jit_module_to_flatbuffer(torch.jit.script(Foo()), second_saved_module) + second_saved_module.seek(0) + + clear_class_registry() + + self.assertEqual( + first_script_module._c.qualified_name, second_script_module._c.qualified_name + ) + + class ContainsBoth(torch.nn.Module): + def __init__(self): + super().__init__() + self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module)) + self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module)) + + def forward(self, x): + x = self.first(x) + x = self.second(x) + return x + + sm = torch.jit.script(ContainsBoth()) + contains_both = io.BytesIO() + torch.jit.save_jit_module_to_flatbuffer(sm, contains_both) + contains_both.seek(0) + sm = torch.jit.jit_module_from_flatbuffer(contains_both) + + def test_many_collisions(self): + class MyCoolNamedTuple(NamedTuple): + a: int + + @torch.jit.interface + class MyInterface(object): + def bar(self, x: Tensor) -> Tensor: + pass + + @torch.jit.script + class ImplementInterface(object): + def __init__(self): + pass + + def bar(self, x): + return x + + def lol(x): + return x + + class Foo(torch.nn.Module): + interface: MyInterface + + def __init__(self): + super().__init__() + self.foo = torch.nn.Linear(2, 2) + self.bar = torch.nn.Linear(2, 2) + self.interface = ImplementInterface() + + def forward(self, x): + x = self.foo(x) + x = self.bar(x) + x = lol(x) + x = self.interface.bar(x) + + return x, MyCoolNamedTuple(a=5) + + + first_script_module = torch.jit.script(Foo()) + first_saved_module = io.BytesIO() + torch.jit.save_jit_module_to_flatbuffer(first_script_module, first_saved_module) + first_saved_module.seek(0) + + clear_class_registry() + + @torch.jit.interface + class MyInterface(object): + def not_bar(self, x: Tensor) -> Tensor: + pass + + @torch.jit.script # noqa: F811 + class ImplementInterface(object): # noqa: F811 + def __init__(self): + pass + + def not_bar(self, x): + return x + + def lol(x): # noqa: F811 + return "asdofij" + + class MyCoolNamedTuple(NamedTuple): # noqa: F811 + a: str + + class Foo(torch.nn.Module): + interface: MyInterface + + def __init__(self): + super().__init__() + self.foo = torch.nn.Linear(2, 2) + self.interface = ImplementInterface() + + def forward(self, x): + x = self.foo(x) + self.interface.not_bar(x) + x = lol(x) + return x, MyCoolNamedTuple(a="hello") + + second_script_module = torch.jit.script(Foo()) + second_saved_module = io.BytesIO() + torch.jit.save_jit_module_to_flatbuffer(second_script_module, second_saved_module) + second_saved_module.seek(0) + + clear_class_registry() + + self.assertEqual( + first_script_module._c.qualified_name, second_script_module._c.qualified_name + ) + + class ContainsBoth(torch.nn.Module): + def __init__(self): + super().__init__() + self.add_module("second", torch.jit.jit_module_from_flatbuffer(second_saved_module)) + self.add_module("first", torch.jit.jit_module_from_flatbuffer(first_saved_module)) + + def forward(self, x): + x, named_tuple_1 = self.first(x) + x, named_tuple_2 = self.second(x) + return len(x + named_tuple_2.a) + named_tuple_1.a + + sm = torch.jit.script(ContainsBoth()) + contains_both = io.BytesIO() + torch.jit.save_jit_module_to_flatbuffer(sm, contains_both) + contains_both.seek(0) + sm = torch.jit.jit_module_from_flatbuffer(contains_both) + + def test_save_load_using_pathlib(self): + class MyMod(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, a): + return 2 * a + + m = MyMod() + + # Save then load. + with TemporaryFileName() as fname: + path = pathlib.Path(fname) + torch.jit.save_jit_module_to_flatbuffer(m, path) + m2 = torch.jit.jit_module_from_flatbuffer(path) + + x = torch.tensor([1., 2., 3., 4.]) + self.assertTrue(torch.equal(m(x), m2(x))) + + def test_save_namedtuple_input_only(self): + """ + Even if a NamedTuple is only used as an input argument, saving and + loading should work correctly. + """ + global FooTuple # see [local resolution in python] + + class FooTuple(NamedTuple): + a: int + + class MyModule(torch.nn.Module): + def forward(self, x: FooTuple) -> torch.Tensor: + return torch.tensor(3) + + m_loaded = self.getExportImportCopy(torch.jit.script(MyModule())) + output = m_loaded(FooTuple(a=5)) + self.assertEqual(output, torch.tensor(3)) + + def test_save_namedtuple_output_only(self): + """ + Even if a NamedTuple is only used as an output argument, saving and + loading should work correctly. + """ + global FooTuple # see [local resolution in python] + + class FooTuple(NamedTuple): + a: int + + class MyModule(torch.nn.Module): + def forward(self) -> Optional[FooTuple]: + return None + + m_loaded = self.getExportImportCopy(torch.jit.script(MyModule())) + output = m_loaded() + self.assertEqual(output, None) + + def test_save_load_params_buffers_submodules(self): + """ + Check that parameters, buffers, and submodules are the same after loading. + """ + + class Submodule(torch.nn.Module): + def __init__(self): + super().__init__() + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.add_module("submodule_a", Submodule()) + self.register_parameter("parameter_a", torch.nn.Parameter(torch.randn(4))) + self.register_buffer("buffer", torch.randn(4)) + self.t = torch.rand(4) # not buffer + + self.parameter_b = torch.nn.Parameter(torch.randn(4)) + self.submodule_b = Submodule() + + m = TestModule() + m_loaded = self.getExportImportCopy(torch.jit.script(m)) + + # Check submodules. + self.assertEqual(len(list(m.named_modules())), len(list(m_loaded.named_modules()))) + for m_s, loaded_s in zip(m.named_modules(), m_loaded.named_modules()): + m_name, _ = m_s + loaded_name, _ = loaded_s + self.assertEqual(m_name, loaded_name) + + # Check parameters. + self.assertEqual(len(list(m.parameters())), len(list(m_loaded.parameters()))) + for m_p, loaded_p in zip(m.parameters(), m_loaded.parameters()): + self.assertEqual(m_p, loaded_p) + + # Check buffers. + self.assertEqual(len(list(m.named_buffers())), len(list(m_loaded.named_buffers()))) + for m_b, loaded_b in zip(m.named_buffers(), m_loaded.named_buffers()): + m_name, m_buffer = m_b + loaded_name, loaded_buffer = loaded_b + self.assertEqual(m_name, loaded_name) + self.assertEqual(m_buffer, loaded_buffer) diff --git a/test/test_jit.py b/test/test_jit.py index 620824a9b6f0c7..b5d2457f1919c3 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -30,7 +30,7 @@ from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing # noqa: F401 from jit.test_peephole import TestPeephole # noqa: F401 from jit.test_alias_analysis import TestAliasAnalysis # noqa: F401 -from jit.test_save_load import TestSaveLoad # noqa: F401 +from jit.test_save_load import TestSaveLoad, TestSaveLoadFlatbuffer # noqa: F401 from jit.test_save_load_for_op_version import TestSaveLoadForOpVersion # noqa: F401 from jit.test_module_containers import TestModuleContainers # noqa: F401 from jit.test_python_bindings import TestPythonBindings # noqa: F401 diff --git a/tools/linter/clang_tidy/__main__.py b/tools/linter/clang_tidy/__main__.py index fa6403a64bb664..c602f57ef3bd39 100644 --- a/tools/linter/clang_tidy/__main__.py +++ b/tools/linter/clang_tidy/__main__.py @@ -76,6 +76,9 @@ def clang_search_dirs() -> List[str]: "-torch/csrc/jit/serialization/export.cpp", "-torch/csrc/jit/serialization/import.cpp", "-torch/csrc/jit/serialization/import_legacy.cpp", + "-torch/csrc/jit/serialization/mobile_bytecode_generated.cpp", + "-torch/csrc/init_flatbuffer_module.cpp", + "-torch/csrc/stub_with_flatbuffer.c", "-torch/csrc/onnx/init.cpp", "-torch/csrc/cuda/nccl.*", "-torch/csrc/cuda/python_nccl.cpp", diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 723267336c3d7b..4dddf7b33d71bf 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -44,6 +44,9 @@ set(TORCH_PYTHON_SRCS ) append_filelist("libtorch_python_core_sources" TORCH_PYTHON_SRCS) +list(APPEND TORCH_PYTHON_SRCS + ${TORCH_SRC_DIR}/csrc/init_flatbuffer_module.cpp) + # NB: This has to match the condition under which the JIT test directory # is included (at the time of writing that's in caffe2/CMakeLists.txt). if(BUILD_TEST) @@ -389,6 +392,9 @@ set_source_files_properties( # Disable certain warnings for GCC-9.X if(CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 9.0.0)) set_source_files_properties(${TORCH_SRC_DIR}/csrc/Module.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") + set_source_files_properties( + ${TORCH_SRC_DIR}/csrc/init_flatbuffer_module.cpp + PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") set_source_files_properties(${TORCH_SRC_DIR}/csrc/autograd/python_variable.cpp PROPERTIES COMPILE_FLAGS "-Wno-cast-function-type") endif() diff --git a/torch/_C_flatbuffer/__init__.pyi b/torch/_C_flatbuffer/__init__.pyi new file mode 100644 index 00000000000000..3a2ff059b0ed9d --- /dev/null +++ b/torch/_C_flatbuffer/__init__.pyi @@ -0,0 +1,10 @@ +from torch._C import LiteScriptModule, ScriptModule + +def _load_mobile_module_from_file(filename: str): ... +def _load_mobile_module_from_bytes(bytes_: bytes): ... +def _load_jit_module_from_file(filename: str): ... +def _load_jit_module_from_bytes(bytes_: bytes): ... +def _save_mobile_module(m: LiteScriptModule, filename: str): ... +def _save_jit_module(m: ScriptModule, filename: str): ... +def _save_mobile_module_to_bytes(m: LiteScriptModule) -> bytes: ... +def _save_jit_module_to_bytes(m: ScriptModule) -> bytes: ... diff --git a/torch/csrc/init_flatbuffer_module.cpp b/torch/csrc/init_flatbuffer_module.cpp new file mode 100644 index 00000000000000..05db14efdf525d --- /dev/null +++ b/torch/csrc/init_flatbuffer_module.cpp @@ -0,0 +1,96 @@ +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include // NOLINT +#include +#include +#include +#include +#include + +namespace py = pybind11; + +static std::shared_ptr copyStr(const std::string& bytes) { + size_t size = (bytes.size() / FLATBUFFERS_MAX_ALIGNMENT + 1) * FLATBUFFERS_MAX_ALIGNMENT; +#ifdef _WIN32 + std::shared_ptr bytes_copy(static_cast(_aligned_malloc(size, FLATBUFFERS_MAX_ALIGNMENT)), _aligned_free); +#else + std::shared_ptr bytes_copy(static_cast(aligned_alloc(FLATBUFFERS_MAX_ALIGNMENT, size)), free); +#endif + memcpy(bytes_copy.get(), bytes.data(), bytes.size()); + return bytes_copy; +} + + + +extern "C" +#ifdef _WIN32 +__declspec(dllexport) +#endif +PyObject* initModuleFlatbuffer() { + using namespace torch::jit; + PyMethodDef m[] = {{nullptr, nullptr, 0, nullptr}}; // NOLINT + static struct PyModuleDef torchmodule = { + PyModuleDef_HEAD_INIT, + "torch._C_flatbuffer", + nullptr, + -1, + m, + }; // NOLINT + PyObject* module = PyModule_Create(&torchmodule); + auto pym = py::handle(module).cast(); + pym.def( + "_load_mobile_module_from_file", + [](const std::string& filename) { + return torch::jit::load_mobile_module_from_file(filename); + }); + pym.def( + "_load_mobile_module_from_bytes", + [](const std::string& bytes) { + auto bytes_copy = copyStr(bytes); + return torch::jit::parse_and_initialize_mobile_module(bytes_copy, bytes.size()); + }); + pym.def( + "_load_jit_module_from_file", + [](const std::string& filename) { + return torch::jit::load_jit_module_from_file(filename); + }); + pym.def( + "_load_jit_module_from_bytes", + [](const std::string& bytes) { + auto bytes_copy = copyStr(bytes); + return torch::jit::parse_and_initialize_jit_module(bytes_copy, bytes.size()); + }); + pym.def( + "_save_mobile_module", + [](const torch::jit::mobile::Module& module, const std::string& filename) { + return torch::jit::save_mobile_module(module, filename); + }); + pym.def( + "_save_jit_module", + [](const torch::jit::Module& module, const std::string& filename) { + return torch::jit::save_jit_module(module, filename); + }); + pym.def( + "_save_mobile_module_to_bytes", + [](const torch::jit::mobile::Module& module) { + auto detached_buffer = torch::jit::save_mobile_module_to_bytes(module); + return py::bytes(reinterpret_cast(detached_buffer.data()), detached_buffer.size()); + }); + pym.def( + "_save_jit_module_to_bytes", + [](const torch::jit::Module& module) { + auto detached_buffer = torch::jit::save_jit_module_to_bytes(module); + return py::bytes(reinterpret_cast(detached_buffer.data()), detached_buffer.size()); + }); + return module; +} diff --git a/torch/csrc/jit/mobile/flatbuffer_loader.cpp b/torch/csrc/jit/mobile/flatbuffer_loader.cpp index 94846467ed37ca..618913c1394583 100644 --- a/torch/csrc/jit/mobile/flatbuffer_loader.cpp +++ b/torch/csrc/jit/mobile/flatbuffer_loader.cpp @@ -254,29 +254,32 @@ std::unique_ptr FlatbufferLoader::parseFunction( function->set_register_size(method->register_size()); if (method->schema()) { - auto parseArgList = [this](const auto* args_fb) { - std::vector args; - for (const auto* arg_tb : *args_fb) { - IValue default_value = getIValue(arg_tb->default_value()); - TypePtr type_ptr = getOrCreateTypeAnnotations(arg_tb->type()); - auto arg = c10::Argument( - arg_tb->name()->str(), - std::move(type_ptr), - c10::nullopt /*N*/, - std::move(default_value)); - args.emplace_back(std::move(arg)); - } - return args; - }; - c10::FunctionSchema schema( - method->qn()->str(), - "" /*overload_name*/, - parseArgList(method->schema()->arguments()), - parseArgList(method->schema()->returns()), - false /*is_varargs*/, - false /*is_varret*/); - - function->setSchema(std::move(schema)); + try { + auto parseArgList = [this](const auto* args_fb) { + std::vector args; + for (const auto* arg_tb : *args_fb) { + IValue default_value = getIValue(arg_tb->default_value()); + TypePtr type_ptr = getOrCreateTypeAnnotations(arg_tb->type()); + auto arg = c10::Argument( + arg_tb->name()->str(), + std::move(type_ptr), + c10::nullopt /*N*/, + std::move(default_value)); + args.emplace_back(std::move(arg)); + } + return args; + }; + c10::FunctionSchema schema( + method->qn()->str(), + "" /*overload_name*/, + parseArgList(method->schema()->arguments()), + parseArgList(method->schema()->returns()), + false /*is_varargs*/, + false /*is_varret*/); + + function->setSchema(std::move(schema)); + } catch (const c10::Error& e) { + } } return function; } diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 35a81094cf9669..2b5f8bb3557eb8 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp index 4425e0813281e0..89ab78bdf9e446 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp @@ -33,6 +33,25 @@ namespace { // We will store IValue NONE in index 0 in flatbuffer. constexpr int kNoneIndex = 0; +static TypePtr realType(TypePtr type) { + if (auto dyn = type->castRaw()) { + return dyn->fallback(); + } else { + return type; + } +} + +auto print_type(const c10::Type& t) -> c10::optional { + auto namedType = t.cast(); + if (namedType && namedType->name()) { + return namedType->name().value().qualifiedName(); + } + if (auto dyn = t.castRaw()) { + return dyn->fallback()->annotation_str(); + } + return c10::nullopt; +} + class FlatbufferSerializer { public: FlatbufferSerializer() = default; @@ -157,21 +176,21 @@ flatbuffers::Offset FlatbufferSerializer:: return_vec.reserve(returns.size()); for (const auto& arg : args) { int index = storeIValueAndGetIndex(fbb, arg.default_value()); - TORCH_INTERNAL_ASSERT(arg.type()->kind() != c10::DynamicType::Kind); arg_vec.emplace_back(CreateArg( fbb, fbb.CreateSharedString(arg.name()), - fbb.CreateSharedString(arg.type()->annotation_str(type_printer)), + fbb.CreateSharedString( + realType(arg.type())->annotation_str(type_printer)), index)); } for (const auto& ret : returns) { int index = storeIValueAndGetIndex(fbb, ret.default_value()); - TORCH_INTERNAL_ASSERT(ret.type()->kind() != c10::DynamicType::Kind); return_vec.emplace_back(CreateArg( fbb, fbb.CreateSharedString(ret.name()), - fbb.CreateSharedString(ret.type()->annotation_str(type_printer)), + fbb.CreateSharedString( + realType(ret.type())->annotation_str(type_printer)), index)); } return CreateSchema( @@ -219,8 +238,7 @@ flatbuffers::Offset FlatbufferSerializer:: std::vector> type_offsets; for (const TypePtr& t : code.types_) { - auto type_str = t->annotation_str(); - TORCH_INTERNAL_ASSERT(t->kind() != c10::DynamicType::Kind); + auto type_str = realType(t)->annotation_str(); if (type_str.find(torch_prefix) == 0) { TORCH_CHECK( type_str.find(class_prefix) == 0, @@ -243,6 +261,9 @@ flatbuffers::Offset FlatbufferSerializer:: if (namedType && namedType->name()) { return namedType->name().value().qualifiedName(); } + if (auto dyn = t.castRaw()) { + return dyn->fallback()->annotation_str(); + } return c10::nullopt; }; @@ -398,7 +419,8 @@ flatbuffers::Offset FlatbufferSerializer::listToFB( return CreateList( fbb, fbb.CreateVector(items), - fbb.CreateSharedString(list.type()->annotation_str())); + fbb.CreateSharedString( + realType(list.type())->annotation_str(print_type))); } flatbuffers::Offset FlatbufferSerializer::dictToFB( @@ -415,11 +437,13 @@ flatbuffers::Offset FlatbufferSerializer::dictToFB( int value_index = storeIValueAndGetIndex(fbb, entry.value()); values.push_back(value_index); } + return CreateDict( fbb, fbb.CreateVector(keys), fbb.CreateVector(values), - fbb.CreateSharedString(ivalue.type()->annotation_str())); + fbb.CreateSharedString( + realType(ivalue.type())->annotation_str(print_type))); } flatbuffers::Offset FlatbufferSerializer:: diff --git a/torch/csrc/stub_with_flatbuffer.c b/torch/csrc/stub_with_flatbuffer.c new file mode 100644 index 00000000000000..6f7c159634e097 --- /dev/null +++ b/torch/csrc/stub_with_flatbuffer.c @@ -0,0 +1,18 @@ +#include // NOLINT + +#ifdef _WIN32 +__declspec(dllimport) +#endif +extern PyObject* initModuleFlatbuffer(void); + +#ifndef _WIN32 +#ifdef __cplusplus +extern "C" +#endif +__attribute__((visibility("default"))) PyObject* PyInit__C_flatbuffer(void); +#endif + +PyMODINIT_FUNC PyInit__C_flatbuffer(void) +{ + return initModuleFlatbuffer(); +} diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index d20032d350cbac..9c70b77df77d26 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -47,7 +47,7 @@ _get_trace_graph, ) from torch.jit._async import fork, wait -from torch.jit._serialization import save, load +from torch.jit._serialization import save, load, jit_module_from_flatbuffer, save_jit_module_to_flatbuffer from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph, set_fusion_strategy from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations from torch.jit._ir_utils import _InsertPoint diff --git a/torch/jit/_serialization.py b/torch/jit/_serialization.py index f2c32c3a19bcc8..3911cb411c5b2a 100644 --- a/torch/jit/_serialization.py +++ b/torch/jit/_serialization.py @@ -182,3 +182,74 @@ def validate_map_location(map_location=None): validate_cuda_device(map_location) return map_location + + +def jit_module_from_flatbuffer(f): + try: + import torch._C_flatbuffer as ff + except ImportError: + print("Please include //caffe2:_C_flatbuffer as dependency.") + raise + + if isinstance(f, string_classes): + if not os.path.exists(f): # type: ignore[type-var] + raise ValueError("The provided filename {} does not exist".format(f)) # type: ignore[str-bytes-safe] + if os.path.isdir(f): + raise ValueError("The provided filename {} is a directory".format(f)) # type: ignore[str-bytes-safe] + + if isinstance(f, str) or isinstance(f, pathlib.Path): + f = str(f) + return wrap_cpp_module(ff._load_jit_module_from_file(f)) + else: + return wrap_cpp_module(ff._load_jit_module_from_bytes(f.read())) + + +def save_jit_module_to_flatbuffer(m, f): + r""" + Save an offline version of this module for use in a separate process. The + saved module serializes all of the methods, submodules, parameters, and + attributes of this module. It can be loaded into the C++ API using + ``torch::jit::load_jit_module_from_file(filename)`` or into the Python API with + :func:`torch.jit.jit_module_from_flatbuffer`. + + To be able to save a module, it must not make any calls to native Python + functions. This means that all submodules must be subclasses of + :class:`ScriptModule` as well. + + .. DANGER:: + All modules, no matter their device, are always loaded onto the CPU + during loading. This is different from :func:`torch.load`'s semantics + and may change in the future. + + Args: + m: A :class:`ScriptModule` to save. + f: A string for file path + + + Example: + + .. testcode:: + + import torch + import io + + class MyModule(torch.nn.Module): + def forward(self, x): + return x + 10 + + m = torch.jit.script(MyModule()) + + # Save to file + torch.jit.save_jit_module_to_flatbuffer(m, 'scriptmodule.ff') + """ + try: + import torch._C_flatbuffer as ff + except ImportError: + print("Please include //caffe2:_C_flatbuffer as dependency.") + raise + if isinstance(f, str) or isinstance(f, pathlib.Path): + f = str(f) + ff._save_jit_module(m._c, f) + else: + s = ff._save_jit_module_to_bytes(m._c) + f.write(s)