Skip to content

Commit

Permalink
Ported PyTorchLoaderTest to EE2 (#3372)
Browse files Browse the repository at this point in the history
Summary:
Ported PyTorchLoaderTest to EE2

Documentation:

Progress on #3239
Pull Request resolved: #3372

Test Plan: Verify CI passes

Differential Revision: D16627018

Pulled By: gcatron

fbshipit-source-id: d62f1d17f526c646c00169ab0685ca4d0a4db321
  • Loading branch information
gcatron authored and facebook-github-bot committed Aug 3, 2019
1 parent c014ee2 commit d32f026
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion torch_glow/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ target_link_libraries(_torch_glow
torch
c10
Support
ExecutionEngine
ExecutionEngine2
Graph
Importer
Support
Expand Down
6 changes: 2 additions & 4 deletions torch_glow/src/CachingGraphRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ llvm::Error CachingGraphRunner::runGraph(const torch::jit::Node *node,
const char *const functionName = "PyTorchFunction";

glow::Function *f = nullptr;
executionEngine_.setBackendName(executionEngine_.getBackendName());
auto &mod = executionEngine_.getModule();
if ((f = mod.getFunction(functionName))) {
mod.eraseFunction(f);
}
f = mod.createFunction(functionName);
std::vector<glow::Placeholder *> inputPlaceholders;
std::vector<glow::Placeholder *> outputPlaceholders;
Expand All @@ -44,7 +42,7 @@ llvm::Error CachingGraphRunner::runGraph(const torch::jit::Node *node,
*f, *graph, inputs, inputPlaceholders, outputPlaceholders));

glow::CompilationContext cctx;
executionEngine_.compile(f, cctx);
executionEngine_.compile(cctx);

glow::PlaceholderBindings bindings;
for (size_t i = 0; i < inputs.size(); ++i) {
Expand Down
4 changes: 2 additions & 2 deletions torch_glow/src/CachingGraphRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/ir.h>

#include "glow/ExecutionEngine/ExecutionEngine.h"
#include "glow/ExecutionEngine/ExecutionEngine2.h"

/// Responsible for maintaining a mapping from PyTorch subgraphs and their
/// unique input types to compiled Glow Functions.
class CachingGraphRunner {
/// Glow ExecutionEngine.
glow::ExecutionEngine executionEngine_;
glow::ExecutionEngine2 executionEngine_;

public:
CachingGraphRunner() = default;
Expand Down
2 changes: 1 addition & 1 deletion torch_glow/src/PyTorchLoaderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

#include "PyTorchFileLoader.h"
#include "PyTorchModelLoader.h"
#include "glow/ExecutionEngine/ExecutionEngine.h"
#include "glow/ExecutionEngine/ExecutionEngine2.h"
#include "gtest/gtest.h"
#include <torch/csrc/jit/pass_manager.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
Expand Down

0 comments on commit d32f026

Please sign in to comment.