Skip to content

Commit

Permalink
Added CallFunctionContext to pass a context.Context through to host c…
Browse files Browse the repository at this point in the history
…alls

Signed-off-by: Phil Kedy <[email protected]>
  • Loading branch information
pkedy committed Feb 3, 2022
1 parent 21310be commit 090acd4
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 14 deletions.
5 changes: 5 additions & 0 deletions wasm/engine.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package wasm

import "context"

// Engine is the interface implemented by interpreters.
type Engine interface {
// Call invokes a function instance f with given parameters.
// Returns the results from the function.
Call(f *FunctionInstance, params ...uint64) (results []uint64, err error)
// CallContext invokes a function instance f with given parameters.
// Returns the results from the function.
CallContext(ctx context.Context, f *FunctionInstance, params ...uint64) (results []uint64, err error)
// Compile compiles down the function instance.
Compile(f *FunctionInstance) error
}
27 changes: 18 additions & 9 deletions wasm/interpreter/interpreter.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package interpreter

import (
"context"
"encoding/binary"
"fmt"
"math"
Expand Down Expand Up @@ -426,6 +427,11 @@ func (it *interpreter) lowerIROps(f *wasm.FunctionInstance,

// Call implements an interpreted wasm.Engine.
func (it *interpreter) Call(f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) {
return it.CallContext(context.Background(), f, params...)
}

// Call implements an interpreted wasm.Engine.
func (it *interpreter) CallContext(ctx context.Context, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) {
prevFrameLen := len(it.frames)

// shouldRecover is true when a panic at the origin of callstack should be recovered
Expand Down Expand Up @@ -478,12 +484,12 @@ func (it *interpreter) Call(f *wasm.FunctionInstance, params ...uint64) (results
}

if g.hostFn != nil {
it.callHostFunc(g, params...)
it.callHostFunc(ctx, g, params...)
} else {
for _, param := range params {
it.push(param)
}
it.callNativeFunc(g)
it.callNativeFunc(ctx, g)
}
results = make([]uint64, len(f.FunctionType.Type.Results))
for i := range results {
Expand All @@ -492,7 +498,7 @@ func (it *interpreter) Call(f *wasm.FunctionInstance, params ...uint64) (results
return
}

func (it *interpreter) callHostFunc(f *interpreterFunction, _ ...uint64) {
func (it *interpreter) callHostFunc(ctx context.Context, f *interpreterFunction, _ ...uint64) {
tp := f.hostFn.Type()
in := make([]reflect.Value, tp.NumIn())
for i := len(in) - 1; i >= 1; i-- {
Expand All @@ -515,7 +521,10 @@ func (it *interpreter) callHostFunc(f *interpreterFunction, _ ...uint64) {
if len(it.frames) > 0 {
memory = it.frames[len(it.frames)-1].f.funcInstance.ModuleInstance.Memory
}
val.Set(reflect.ValueOf(&wasm.HostFunctionCallContext{Memory: memory}))
val.Set(reflect.ValueOf(&wasm.HostFunctionCallContext{
Context: ctx,
Memory: memory,
}))
in[0] = val

frame := &interpreterFrame{f: f}
Expand All @@ -535,7 +544,7 @@ func (it *interpreter) callHostFunc(f *interpreterFunction, _ ...uint64) {
it.popFrame()
}

func (it *interpreter) callNativeFunc(f *interpreterFunction) {
func (it *interpreter) callNativeFunc(ctx context.Context, f *interpreterFunction) {
frame := &interpreterFrame{f: f}
moduleInst := f.funcInstance.ModuleInstance
memoryInst := moduleInst.Memory
Expand Down Expand Up @@ -582,9 +591,9 @@ func (it *interpreter) callNativeFunc(f *interpreterFunction) {
case wazeroir.OperationKindCall:
{
if op.f.hostFn != nil {
it.callHostFunc(op.f, it.stack[len(it.stack)-len(op.f.funcInstance.FunctionType.Type.Params):]...)
it.callHostFunc(ctx, op.f, it.stack[len(it.stack)-len(op.f.funcInstance.FunctionType.Type.Params):]...)
} else {
it.callNativeFunc(op.f)
it.callNativeFunc(ctx, op.f)
}
frame.pc++
}
Expand All @@ -605,9 +614,9 @@ func (it *interpreter) callNativeFunc(f *interpreterFunction) {
target := it.functions[table.Table[offset].FunctionAddress]
// Call in.
if target.hostFn != nil {
it.callHostFunc(target, it.stack[len(it.stack)-len(target.funcInstance.FunctionType.Type.Params):]...)
it.callHostFunc(ctx, target, it.stack[len(it.stack)-len(target.funcInstance.FunctionType.Type.Params):]...)
} else {
it.callNativeFunc(target)
it.callNativeFunc(ctx, target)
}
frame.pc++
}
Expand Down
19 changes: 15 additions & 4 deletions wasm/jit/engine.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jit

import (
"context"
"encoding/hex"
"fmt"
"math"
Expand Down Expand Up @@ -332,6 +333,10 @@ func (c *callFrame) String() string {
}

func (e *engine) Call(f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) {
return e.CallContext(context.Background(), f, params...)
}

func (e *engine) CallContext(ctx context.Context, f *wasm.FunctionInstance, params ...uint64) (results []uint64, err error) {
// We ensure that this Call method never panics as
// this Call method is indirectly invoked by embedders via store.CallFunction,
// and we have to make sure that all the runtime errors, including the one happening inside
Expand Down Expand Up @@ -387,9 +392,12 @@ func (e *engine) Call(f *wasm.FunctionInstance, params ...uint64) (results []uin
}

if compiled.source.IsHostFunction() {
e.execHostFunction(compiled.source.HostFunction, &wasm.HostFunctionCallContext{Memory: f.ModuleInstance.Memory})
e.execHostFunction(compiled.source.HostFunction, &wasm.HostFunctionCallContext{
Context: ctx,
Memory: f.ModuleInstance.Memory,
})
} else {
e.execFunction(compiled)
e.execFunction(ctx, compiled)
}

// Note the top value is the tail of the results,
Expand Down Expand Up @@ -497,7 +505,7 @@ func (e *engine) execHostFunction(f *reflect.Value, ctx *wasm.HostFunctionCallCo
}
}

func (e *engine) execFunction(f *compiledFunction) {
func (e *engine) execFunction(ctx context.Context, f *compiledFunction) {
// We continuously execute functions until we reach the previous top frame
// to support recursive Wasm function executions.
e.globalContext.previousCallFrameStackPointer = e.globalContext.callFrameStackPointer
Expand Down Expand Up @@ -533,7 +541,10 @@ jitentry:
}
}
saved := e.globalContext.previousCallFrameStackPointer
e.execHostFunction(fn.source.HostFunction, &wasm.HostFunctionCallContext{Memory: callerCompiledFunction.source.ModuleInstance.Memory})
e.execHostFunction(fn.source.HostFunction, &wasm.HostFunctionCallContext{
Context: ctx,
Memory: callerCompiledFunction.source.ModuleInstance.Memory,
})
e.globalContext.previousCallFrameStackPointer = saved
goto jitentry
case jitCallStatusCodeCallBuiltInFunction:
Expand Down
8 changes: 7 additions & 1 deletion wasm/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package wasm

import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
Expand Down Expand Up @@ -314,6 +315,10 @@ func (s *Store) Instantiate(module *Module, name string) error {
}

func (s *Store) CallFunction(moduleName, funcName string, params ...uint64) (results []uint64, resultTypes []ValueType, err error) {
return s.CallFunctionContext(context.Background(), moduleName, funcName, params...)
}

func (s *Store) CallFunctionContext(ctx context.Context, moduleName, funcName string, params ...uint64) (results []uint64, resultTypes []ValueType, err error) {
var exp *ExportInstance
if exp, err = s.getExport(moduleName, funcName, ExportKindFunc); err != nil {
return
Expand All @@ -325,7 +330,7 @@ func (s *Store) CallFunction(moduleName, funcName string, params ...uint64) (res
return
}

results, err = s.engine.Call(f, params...)
results, err = s.engine.CallContext(ctx, f, params...)
resultTypes = f.FunctionType.Type.Results
return
}
Expand Down Expand Up @@ -844,6 +849,7 @@ func DecodeBlockType(types []*TypeInstance, r io.Reader) (*FunctionType, uint64,

// HostFunctionCallContext is the first argument of all host functions.
type HostFunctionCallContext struct {
context.Context
// Memory is the currently used memory instance at the time when the host function call is made.
Memory *MemoryInstance
// TODO: Add others if necessary.
Expand Down
5 changes: 5 additions & 0 deletions wasm/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package wasm

import (
"bytes"
"context"
"encoding/binary"
"math"
"reflect"
Expand Down Expand Up @@ -98,6 +99,10 @@ func (e *nopEngine) Call(_ *FunctionInstance, _ ...uint64) (results []uint64, er
return nil, nil
}

func (e *nopEngine) CallContext(_ context.Context, _ *FunctionInstance, _ ...uint64) (results []uint64, err error) {
return nil, nil
}

func (e *nopEngine) Compile(_ *FunctionInstance) error {
return nil
}
Expand Down

0 comments on commit 090acd4

Please sign in to comment.