From d32f026eababa0427fdbabe5e16f0bb4e56569a0 Mon Sep 17 00:00:00 2001 From: Garret Catron Date: Fri, 2 Aug 2019 16:57:20 -0700 Subject: [PATCH] Ported PyTorchLoaderTest to EE2 (#3372) Summary: Ported PyTorchLoaderTest to EE2 Documentation: Progress on https://github.com/pytorch/glow/issues/3239 Pull Request resolved: https://github.com/pytorch/glow/pull/3372 Test Plan: Verify CI passes Differential Revision: D16627018 Pulled By: gcatron fbshipit-source-id: d62f1d17f526c646c00169ab0685ca4d0a4db321 --- torch_glow/src/CMakeLists.txt | 2 +- torch_glow/src/CachingGraphRunner.cpp | 6 ++---- torch_glow/src/CachingGraphRunner.h | 4 ++-- torch_glow/src/PyTorchLoaderTest.cpp | 2 +- 4 files changed, 6 insertions(+), 8 deletions(-) diff --git a/torch_glow/src/CMakeLists.txt b/torch_glow/src/CMakeLists.txt index e4b003ecb9..2846081c15 100644 --- a/torch_glow/src/CMakeLists.txt +++ b/torch_glow/src/CMakeLists.txt @@ -24,7 +24,7 @@ target_link_libraries(_torch_glow torch c10 Support - ExecutionEngine + ExecutionEngine2 Graph Importer Support diff --git a/torch_glow/src/CachingGraphRunner.cpp b/torch_glow/src/CachingGraphRunner.cpp index decf6b8690..c5dbeeaf5b 100644 --- a/torch_glow/src/CachingGraphRunner.cpp +++ b/torch_glow/src/CachingGraphRunner.cpp @@ -32,10 +32,8 @@ llvm::Error CachingGraphRunner::runGraph(const torch::jit::Node *node, const char *const functionName = "PyTorchFunction"; glow::Function *f = nullptr; + executionEngine_.setBackendName(executionEngine_.getBackendName()); auto &mod = executionEngine_.getModule(); - if ((f = mod.getFunction(functionName))) { - mod.eraseFunction(f); - } f = mod.createFunction(functionName); std::vector inputPlaceholders; std::vector outputPlaceholders; @@ -44,7 +42,7 @@ llvm::Error CachingGraphRunner::runGraph(const torch::jit::Node *node, *f, *graph, inputs, inputPlaceholders, outputPlaceholders)); glow::CompilationContext cctx; - executionEngine_.compile(f, cctx); + executionEngine_.compile(cctx); glow::PlaceholderBindings bindings; for (size_t i = 0; i < inputs.size(); ++i) { diff --git a/torch_glow/src/CachingGraphRunner.h b/torch_glow/src/CachingGraphRunner.h index b22917fed9..4a385189a8 100644 --- a/torch_glow/src/CachingGraphRunner.h +++ b/torch_glow/src/CachingGraphRunner.h @@ -20,13 +20,13 @@ #include #include -#include "glow/ExecutionEngine/ExecutionEngine.h" +#include "glow/ExecutionEngine/ExecutionEngine2.h" /// Responsible for maintaining a mapping from PyTorch subgraphs and their /// unique input types to compiled Glow Functions. class CachingGraphRunner { /// Glow ExecutionEngine. - glow::ExecutionEngine executionEngine_; + glow::ExecutionEngine2 executionEngine_; public: CachingGraphRunner() = default; diff --git a/torch_glow/src/PyTorchLoaderTest.cpp b/torch_glow/src/PyTorchLoaderTest.cpp index 58b1d81f03..2e2c722500 100644 --- a/torch_glow/src/PyTorchLoaderTest.cpp +++ b/torch_glow/src/PyTorchLoaderTest.cpp @@ -16,7 +16,7 @@ #include "PyTorchFileLoader.h" #include "PyTorchModelLoader.h" -#include "glow/ExecutionEngine/ExecutionEngine.h" +#include "glow/ExecutionEngine/ExecutionEngine2.h" #include "gtest/gtest.h" #include #include