Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ExecutionEngine] Create runctx and execute(ctx) functions #1907

Merged
merged 1 commit into from
Oct 24, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/char-rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ int main(int argc, char **argv) {
for (unsigned i = 0; i < generateChars; i++) {
// Generate a char:
updateInputPlaceholders(ctx, {X}, {&currCharInfer});
EE.run();
EE.run(ctx);

// Pick a char at random from the softmax distribution.
char c = getPredictedChar(*T, 0, numSteps - 1);
Expand Down
2 changes: 1 addition & 1 deletion examples/cifar10.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ void testCIFAR10() {
Tensor sample(ElemKind::FloatTy, {minibatchSize, 32, 32, 3});
sample.copyConsecutiveSlices(&images, minibatchSize * i);
updateInputPlaceholders(ctx, {A}, {&sample});
EE.run();
EE.run(ctx);

for (unsigned int iter = 0; iter < minibatchSize; iter++) {
auto T = result->getHandle<>().extractSlice(iter);
Expand Down
2 changes: 1 addition & 1 deletion examples/fr2en.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ void Model::translate(const std::vector<std::string> &batch) {
}

updateInputPlaceholders(ctx, {input_, seqLength_}, {&input, &seqLength});
EE_.run();
EE_.run(ctx);

auto OH = ctx.get(output_)->getHandle<int64_t>();
for (unsigned j = 0; j < batch.size(); j++) {
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ void testMNIST() {

for (int iter = numIterations; iter < numIterations + 10; iter++) {
inputTensor->copyConsecutiveSlices(&imageInputs, minibatchSize * iter);
EE.run();
EE.run(ctx);

for (unsigned i = 0; i < minibatchSize; i++) {
auto T = resultTensor->getHandle<>().extractSlice(i);
Expand Down
7 changes: 5 additions & 2 deletions include/glow/Backends/CompiledFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@

namespace glow {

class Context;

/// Interface for executing a compiled function.
class CompiledFunction {
public:
/// Dtor.
virtual ~CompiledFunction() = default;

/// Execute the network.
virtual void execute() = 0;
/// Execute the network and allocate Placeholder memory with given
/// \p ctx providing mapping between Placeholder and populated tensor.
virtual void execute(Context &ctx) = 0;

This comment was marked as off-topic.

This comment was marked as off-topic.

};

} // end namespace glow
Expand Down
4 changes: 2 additions & 2 deletions include/glow/ExecutionEngine/ExecutionEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ class ExecutionEngine final {
void save(CompilationMode mode, Function *F, llvm::StringRef outputDir,
llvm::StringRef networkName);

/// Runs a single execution of the function.
void run();
/// Context aware single execution of the function.
void run(Context &ctx);
};

//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion lib/Backends/CPU/CPUFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ CPUFunction::CPUFunction(std::unique_ptr<llvm::orc::GlowJIT> JIT, void *heap)

CPUFunction::~CPUFunction() { alignedFree(heap_); }

void CPUFunction::execute() {
void CPUFunction::execute(Context &ctx) {
auto sym = JIT_->findSymbol("jitmain");
assert(sym && "Unable to JIT the code!");
using JitFuncType = void (*)(void);
Expand Down
2 changes: 1 addition & 1 deletion lib/Backends/CPU/CPUFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class CPUFunction final : public CompiledFunction {
///@{
~CPUFunction() override;

void execute() override;
void execute(Context &ctx) override;
///@}
};

Expand Down
2 changes: 1 addition & 1 deletion lib/Backends/Interpreter/InterpreterFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ void InterpreterFunction::deleteTensor(const Value *v) {
tensors_.erase(it);
}

void InterpreterFunction::execute() {
void InterpreterFunction::execute(Context &ctx) {
// Do the forward pass.
#define DEF_VALUE(CLASS, NAME)
#define DEF_INSTR(CLASS, NAME) \
Expand Down
2 changes: 1 addition & 1 deletion lib/Backends/Interpreter/InterpreterFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class InterpreterFunction final : public CompiledFunction {
///@{
~InterpreterFunction() override;

void execute() override;
void execute(Context &ctx) override;
///@}

private:
Expand Down
3 changes: 1 addition & 2 deletions lib/Backends/OpenCL/OpenCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,8 +618,7 @@ static void topK(Tensor &outW, Tensor &indW, Tensor &inW, size_t k) {
}
}
}

void OpenCLFunction::execute() {
void OpenCLFunction::execute(Context &ctx) {
auto copiedToDeviceBytes = copyMutableWeightsToDevice();
(void)copiedToDeviceBytes;
DEBUG_GLOW(llvm::dbgs() << "Copied " << copiedToDeviceBytes
Expand Down
2 changes: 1 addition & 1 deletion lib/Backends/OpenCL/OpenCL.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class OpenCLFunction final : public CompiledFunction {
///@{
~OpenCLFunction() override;

void execute() override;
void execute(Context &ctx) override;
///@}

private:
Expand Down
6 changes: 3 additions & 3 deletions lib/ExecutionEngine/ExecutionEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ void glow::updateInputPlaceholdersByName(Context &ctx, Module *mod,
}
}

void ExecutionEngine::run() {
void ExecutionEngine::run(Context &ctx) {
assert(function_ && "No function has been compiled");
function_->execute();
function_->execute(ctx);
}

void glow::runBatch(ExecutionEngine &EE, Context &ctx, size_t iterations,
Expand Down Expand Up @@ -109,7 +109,7 @@ void glow::runBatch(ExecutionEngine &EE, Context &ctx, size_t iterations,
}

// Run the network.
EE.run();
EE.run(ctx);
sampleCounter += batchSize;
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Onnxifi/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ onnxStatus Graph::run() {
// Run inference.
auto &EE = backendPtr_->getEE();
updateInputPlaceholders(ctx_, phs, tensors);
EE.run();
EE.run(ctx_);

// Copy outputs to the addresses specified in the outputNodeToBuffer_.
for (auto outputVar : outputNodeToBuffer_) {
Expand Down
2 changes: 1 addition & 1 deletion tests/googletest
Submodule googletest updated 105 files
8 changes: 4 additions & 4 deletions tests/unittests/BackendCorrectnessTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ profileAndGetNodeQuantizationInfo(Context &ctx, ExecutionEngine &EE,
Function *profileF = glow::profileQuantization(ctx, origF);
EE.compile(CompilationMode::Infer, profileF, ctx);

EE.run();
EE.run(ctx);

return quantization::generateNodeQuantizationInfos(ctx, profileF);
}
Expand Down Expand Up @@ -92,8 +92,8 @@ compareAgainstInterpreter(BackendKind backendKind,
IEE.compile(CompilationMode::Infer, IF, ICtx);
BEE.compile(CompilationMode::Infer, BF, BCtx);

IEE.run();
BEE.run();
IEE.run(ICtx);
BEE.run(BCtx);

return IFT.second->isEqual(*BFT.second, allowedError);
}
Expand Down Expand Up @@ -309,7 +309,7 @@ TEST_P(CPUOnly, dataParallelStackingTest) {
}

MockCPUBackend backend;
backend.compileIR(std::move(M), ctx)->execute();
backend.compileIR(std::move(M), ctx)->execute(ctx);
auto H = outputTensor->getHandle();
EXPECT_EQ(H.at(0), 3);
EXPECT_EQ(H.at(1), 4);
Expand Down
12 changes: 6 additions & 6 deletions tests/unittests/BackendTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ TEST(Interpreter, profileQuantizationForANetwork) {
// TODO: Verify histogram itself, for now just verify min and max.
// Run inference first time and capture tensor stats.
updateInputPlaceholders(ctx, {A}, {&inputs});
EE.run();
EE.run(ctx);

QuantizationProfileNode *profile{nullptr};
// Find QPN for node A.
Expand All @@ -96,7 +96,7 @@ TEST(Interpreter, profileQuantizationForANetwork) {
// Run inference for the second time with new min and max.
inputs.getHandle() = {0.2f, 1.6f, 0.5f, 1.3f};
updateInputPlaceholders(ctx, {A}, {&inputs});
EE.run();
EE.run(ctx);
min = CI.raw(0);
max = CI.raw(1);
EXPECT_NEAR(0.2, min, 0.00001);
Expand Down Expand Up @@ -138,7 +138,7 @@ TEST_P(BackendTest, simpleInference) {
EE_.compile(CompilationMode::Infer, F, ctx);

updateInputPlaceholders(ctx, {input}, {&inputs});
EE_.run();
EE_.run(ctx);
}

/// Test that the DebugPrint instruction works correctly for the backend. Note
Expand All @@ -161,7 +161,7 @@ TEST_P(BackendTest, debugPrint) {
std::unique_ptr<BackendUsingGlowIR> backend(
static_cast<BackendUsingGlowIR *>(createBackend(GetParam())));
auto function = backend->compileIR(std::move(IR), ctx);
function->execute();
function->execute(ctx);
}

/// This test checks that we can compile a function without depending on the
Expand All @@ -186,7 +186,7 @@ TEST_P(BackendTest, decoupleCodegenFromGraph) {

// We can run the compiled code without having the graph representation
// around.
EE_.run();
EE_.run(ctx);

auto HX = saveTensor->getHandle();
EXPECT_NEAR(HX.at({0}), 1, 1E-5);
Expand All @@ -206,7 +206,7 @@ TEST_P(BackendTest, simplePlaceholderValue) {
auto *STensor = ctx.allocate(S->getPlaceholder());

EE_.compile(CompilationMode::Infer, F, ctx);
EE_.run();
EE_.run(ctx);
EXPECT_TRUE(STensor->isEqual(data));
}

Expand Down
Loading