Skip to content

Commit

Permalink
[pytorch][PR] Add ability for a mobile::Module to save as flatbuffer (p…
Browse files Browse the repository at this point in the history
…ytorch#70201)

Summary:
Pull Request resolved: pytorch#70201

Included functions:
save_mobile_module -> saves a mobile::Module to flatbuffer
load_mobile_module_from_file -> loads a flatbuffer into mobile::Module
parse_mobile_module -> parses from bytes or deserialized flatbuffer module object

Compared to previous attempts, this diff only adds flatbuffer to cmake target and leaves fbcode/xplat ones unchanged.

Test Plan: unittest

Reviewed By: malfet, gmagogsfm

Differential Revision: D33239362

fbshipit-source-id: b9ca36b83d6af2d78cc50b9eb9e2a6fa7fce0763
  • Loading branch information
qihqi authored and facebook-github-bot committed Jan 13, 2022
1 parent 7a93d8b commit 1bc3571
Show file tree
Hide file tree
Showing 20 changed files with 5,132 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
- name: Ensure canonical include
if: always()
run: |
(! git --no-pager grep -In $'#include "' -- ./c10 ./aten ./torch/csrc ':(exclude)aten/src/ATen/native/quantized/cpu/qnnpack/**' || (echo "The above lines have include with quotes; please convert them to #include <xxxx>"; false))
(! git --no-pager grep -In $'#include "' -- ./c10 ./aten ./torch/csrc ':(exclude)aten/src/ATen/native/quantized/cpu/qnnpack/**' ':(exclude)torch/csrc/jit/serialization/mobile_bytecode_generated.h'|| (echo "The above lines have include with quotes; please convert them to #include <xxxx>"; false))
- name: Ensure no versionless Python shebangs
if: always()
run: |
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,6 @@
[submodule "third_party/breakpad"]
path = third_party/breakpad
url = https://github.com/driazati/breakpad.git
[submodule "third_party/flatbuffers"]
path = third_party/flatbuffers
url = https://github.com/google/flatbuffers.git
3 changes: 3 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1692,6 +1692,7 @@ cc_library(
":aten_headers",
":caffe2_headers",
"//c10:headers",
"@com_github_google_flatbuffers//:flatbuffers",
"@local_config_python//:python_headers",
"@onnx",
],
Expand Down Expand Up @@ -1725,6 +1726,8 @@ cc_library(
],
)) + libtorch_core_sources + libtorch_distributed_sources + torch_cpp_srcs + libtorch_extra_sources + jit_core_sources + [
":cpp_generated_code",
"torch/csrc/jit/serialization/flatbuffer_serializer.cpp",
"torch/csrc/jit/mobile/flatbuffer_loader.cpp"
],
copts = TORCH_COPTS,
defines = [
Expand Down
5 changes: 5 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,8 @@ new_local_repository(
build_file = "@//third_party:cudnn.BUILD",
path = "/usr/",
)

local_repository(
name = "com_github_google_flatbuffers",
path = "third_party/flatbuffers",
)
5 changes: 5 additions & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/mobile/interpreter.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/model_compatibility.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/module.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/flatbuffer_loader.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_operators.cpp
Expand Down Expand Up @@ -595,6 +596,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/serialization/export.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/export_bytecode.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/export_module.cpp
${TORCH_SRC_DIR}/csrc/jit/serialization/flatbuffer_serializer.cpp
${TORCH_SRC_DIR}/csrc/jit/codegen/fuser/cpu/fused_kernel.cpp
${TORCH_SRC_DIR}/csrc/jit/api/module_save.cpp
${TORCH_SRC_DIR}/csrc/utils/byte_order.cpp
Expand Down Expand Up @@ -1645,6 +1647,9 @@ if(APPLE AND USE_PYTORCH_METAL)
endif()
endif()


target_link_libraries(torch_cpu PRIVATE flatbuffers)

# Note [Global dependencies]
# Some libraries (e.g. OpenMPI) like to dlopen plugins after they're initialized,
# and they assume that all of their symbols will be available in the global namespace.
Expand Down
3 changes: 3 additions & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1996,3 +1996,6 @@ if(USE_KINETO)
message(STATUS "Configured Kineto")
endif()
endif()

# Include google/FlatBuffers
include(${CMAKE_CURRENT_LIST_DIR}/FlatBuffers.cmake)
10 changes: 10 additions & 0 deletions cmake/FlatBuffers.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
set(FlatBuffers_Include ${PROJECT_SOURCE_DIR}/third_party/flatbuffers/include)
file(GLOB FlatBuffers_Library_SRCS
${FlatBuffers_Include}/flatbuffers/*.h
)
add_library(flatbuffers INTERFACE)
target_sources(
flatbuffers
INTERFACE ${FlatBuffers_Library_SRCS}
)
target_include_directories(flatbuffers INTERFACE ${FlatBuffers_Include})
5 changes: 5 additions & 0 deletions test/cpp/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ set(JIT_TEST_SRCS
${JIT_TEST_ROOT}/test_script_profile.cpp
${JIT_TEST_ROOT}/test_shape_analysis.cpp
${JIT_TEST_ROOT}/test_jit_logging_levels.cpp
${JIT_TEST_ROOT}/test_flatbuffer.cpp
)

if(USE_CUDA)
Expand All @@ -101,6 +102,10 @@ add_executable(test_jit
${JIT_TEST_SRCS}
)

target_link_libraries(
test_jit PRIVATE flatbuffers)


# TODO temporary until we can delete the old gtest polyfills.
target_compile_definitions(test_jit PRIVATE USE_GTEST)

Expand Down
Loading

0 comments on commit 1bc3571

Please sign in to comment.