Skip to content

Commit

Permalink
cmd/compile: add wasm stack optimization
Browse files Browse the repository at this point in the history
Go's SSA instructions only operate on registers. For example, an add
instruction would read two registers, do the addition and then write
to a register. WebAssembly's instructions, on the other hand, operate
on the stack. The add instruction first pops two values from the stack,
does the addition, then pushes the result to the stack. To fulfill
Go's semantics, one needs to map Go's single add instruction to
4 WebAssembly instructions:
- Push the value of local variable A to the stack
- Push the value of local variable B to the stack
- Do addition
- Write value from stack to local variable C

Now consider that B was set to the constant 42 before the addition:
- Push constant 42 to the stack
- Write value from stack to local variable B

This works, but is inefficient. Instead, the stack is used directly
by inlining instructions if possible. With inlining it becomes:
- Push the value of local variable A to the stack (add)
- Push constant 42 to the stack (constant)
- Do addition (add)
- Write value from stack to local variable C (add)

Note that the two SSA instructions can not be generated sequentially
anymore, because their WebAssembly instructions are interleaved.

Design doc: https://docs.google.com/document/d/131vjr4DH6JFnb-blm_uRdaC0_Nv3OUwjEY5qVCxCup4

Updates golang#18892

Change-Id: Ie35e1c0bebf4985fddda0d6330eb2066f9ad6dec
  • Loading branch information
neelance committed May 22, 2018
1 parent 54d7b46 commit ee6dd4a
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 36 deletions.
3 changes: 3 additions & 0 deletions src/cmd/compile/internal/gc/ssa.go
Original file line number Diff line number Diff line change
Expand Up @@ -4701,6 +4701,9 @@ type SSAGenState struct {
// within a single block sharing the same line number
// Used to move statement marks to the beginning of such runs.
lineRunStart *obj.Prog

// wasm: The number of values on the WebAssembly stack. This is only used as a safeguard.
WasmStackSize int
}

// Prog appends a new Prog.
Expand Down
71 changes: 69 additions & 2 deletions src/cmd/compile/internal/ssa/regalloc.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ import (
"cmd/compile/internal/types"
"cmd/internal/objabi"
"cmd/internal/src"
"cmd/internal/sys"
"fmt"
"unsafe"
)
Expand Down Expand Up @@ -372,6 +373,10 @@ func (s *regAllocState) assignReg(r register, v *Value, c *Value) {
// If there is no unused register, a Value will be kicked out of
// a register to make room.
func (s *regAllocState) allocReg(mask regMask, v *Value) register {
if v.OnWasmStack {
return noRegister
}

mask &= s.allocatable
mask &^= s.nospill
if mask == 0 {
Expand Down Expand Up @@ -411,6 +416,14 @@ func (s *regAllocState) allocReg(mask regMask, v *Value) register {
s.f.Fatalf("couldn't find register to spill")
}

if s.f.Config.ctxt.Arch.Arch == sys.ArchWasm {
// TODO(neelance): In theory this should never happen, because all wasm registers are equal.
// So if there is still a free register, the allocation should have picked that one in the first place insead of
// trying to kick some other value out. In practice, this case does happen and it breaks the stack optimization.
s.freeReg(r)
return r
}

// Try to move it around before kicking out, if there is a free register.
// We generate a Copy and record it. It will be deleted if never used.
v2 := s.regs[r].v
Expand Down Expand Up @@ -458,6 +471,16 @@ func (s *regAllocState) makeSpill(v *Value, b *Block) *Value {
// undone until the caller allows it by clearing nospill. Returns a
// *Value which is either v or a copy of v allocated to the chosen register.
func (s *regAllocState) allocValToReg(v *Value, mask regMask, nospill bool, pos src.XPos) *Value {
if s.f.Config.ctxt.Arch.Arch == sys.ArchWasm && v.rematerializeable() {
c := v.copyIntoWithXPos(s.curBlock, pos)
c.OnWasmStack = true
s.setOrig(c, v)
return c
}
if v.OnWasmStack {
return v
}

vi := &s.values[v.ID]
pos = pos.WithNotStmt()
// Check if v is already in a requested register.
Expand All @@ -472,8 +495,13 @@ func (s *regAllocState) allocValToReg(v *Value, mask regMask, nospill bool, pos
return s.regs[r].c
}

// Allocate a register.
r := s.allocReg(mask, v)
var r register
// If nospill is set, the value is used immedately, so it can live on the WebAssembly stack.
onWasmStack := nospill && s.f.Config.ctxt.Arch.Arch == sys.ArchWasm
if !onWasmStack {
// Allocate a register.
r = s.allocReg(mask, v)
}

// Allocate v to the new register.
var c *Value
Expand All @@ -495,6 +523,12 @@ func (s *regAllocState) allocValToReg(v *Value, mask regMask, nospill bool, pos
}
c = s.curBlock.NewValue1(pos, OpLoadReg, v.Type, spill)
}

if onWasmStack {
c.OnWasmStack = true
return c
}

s.setOrig(c, v)
s.assignReg(r, v, c)
if nospill {
Expand Down Expand Up @@ -656,6 +690,39 @@ func (s *regAllocState) init(f *Func) {
s.startRegs = make([][]startReg, f.NumBlocks())
s.spillLive = make([][]ID, f.NumBlocks())
s.sdom = f.sdom()

// wasm: Mark instructions that can be optimized to have their values only on the WebAssembly stack.
if f.Config.ctxt.Arch.Arch == sys.ArchWasm {
canLiveOnStack := f.newSparseSet(f.NumValues())
defer f.retSparseSet(canLiveOnStack)
for _, b := range f.Blocks {
// New block. Clear candidate set.
canLiveOnStack.clear()
if b.Control != nil && b.Control.Uses == 1 && !opcodeTable[b.Control.Op].generic {
canLiveOnStack.add(b.Control.ID)
}
// Walking backwards.
for i := len(b.Values) - 1; i >= 0; i-- {
v := b.Values[i]
if canLiveOnStack.contains(v.ID) {
v.OnWasmStack = true
} else {
// Value can not live on stack. Values are not allowed to be reordered, so clear candidate set.
canLiveOnStack.clear()
}
for _, arg := range v.Args {
// Value can live on the stack if:
// - it is only used once
// - it is used in the same basic block
// - it is not a "mem" value
// - it is a WebAssembly op
if arg.Uses == 1 && arg.Block == v.Block && !arg.Type.IsMemory() && !opcodeTable[arg.Op].generic {
canLiveOnStack.add(arg.ID)
}
}
}
}
}
}

// Adds a use record for id at distance dist from the start of the block.
Expand Down
2 changes: 1 addition & 1 deletion src/cmd/compile/internal/ssa/sizeof_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestSizeof(t *testing.T) {
_32bit uintptr // size on 32bit platforms
_64bit uintptr // size on 64bit platforms
}{
{Value{}, 68, 112},
{Value{}, 72, 112},
{Block{}, 152, 288},
{LocalSlot{}, 32, 48},
{valState{}, 28, 40},
Expand Down
2 changes: 1 addition & 1 deletion src/cmd/compile/internal/ssa/stackalloc.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (s *stackAllocState) init(f *Func, spillLive [][]ID) {
for _, b := range f.Blocks {
for _, v := range b.Values {
s.values[v.ID].typ = v.Type
s.values[v.ID].needSlot = !v.Type.IsMemory() && !v.Type.IsVoid() && !v.Type.IsFlags() && f.getHome(v.ID) == nil && !v.rematerializeable()
s.values[v.ID].needSlot = !v.Type.IsMemory() && !v.Type.IsVoid() && !v.Type.IsFlags() && f.getHome(v.ID) == nil && !v.rematerializeable() && !v.OnWasmStack
s.values[v.ID].isArg = v.Op == OpArg
if f.pass.debug > stackDebug && s.values[v.ID].needSlot {
fmt.Printf("%s needs a stack slot\n", v)
Expand Down
4 changes: 4 additions & 0 deletions src/cmd/compile/internal/ssa/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ type Value struct {
// Use count. Each appearance in Value.Args and Block.Control counts once.
Uses int32

// wasm: Value stays on the WebAssembly stack. This value will not get a "register" (WebAssembly variable)
// nor a slot on Go stack, and the generation of this value is delayed to its use time.
OnWasmStack bool

// Storage for the first three args
argstorage [3]*Value
}
Expand Down
90 changes: 58 additions & 32 deletions src/cmd/compile/internal/wasm/ssa.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func ssaGenBlock(s *gc.SSAGenState, b, next *ssa.Block) {
goToBlock(b.Succs[0].Block(), true)

case ssa.BlockIf:
getReg32(s, b.Control)
getValue32(s, b.Control)
s.Prog(wasm.AI32Eqz)
s.Prog(wasm.AIf)
goToBlock(b.Succs[1].Block(), false)
Expand Down Expand Up @@ -113,6 +113,10 @@ func ssaGenBlock(s *gc.SSAGenState, b, next *ssa.Block) {

// Entry point for the next block. Used by the JMP in goToBlock.
s.Prog(wasm.ARESUMEPOINT)

if s.WasmStackSize != 0 {
panic("wasm: bad stack")
}
}

func ssaGenValue(s *gc.SSAGenState, v *ssa.Value) {
Expand All @@ -124,33 +128,33 @@ func ssaGenValue(s *gc.SSAGenState, v *ssa.Value) {
s.Prog(wasm.ARESUMEPOINT)
}
if v.Op == ssa.OpWasmLoweredClosureCall {
getReg64(s, v.Args[1])
getValue64(s, v.Args[1])
setReg(s, wasm.REG_CTXT)
}
if sym, ok := v.Aux.(*obj.LSym); ok {
p := s.Prog(obj.ACALL)
p.To = obj.Addr{Type: obj.TYPE_MEM, Name: obj.NAME_EXTERN, Sym: sym}
} else {
getReg64(s, v.Args[0])
getValue64(s, v.Args[0])
p := s.Prog(obj.ACALL)
p.To = obj.Addr{Type: obj.TYPE_NONE}
}

case ssa.OpWasmLoweredMove:
getReg32(s, v.Args[0])
getReg32(s, v.Args[1])
getValue32(s, v.Args[0])
getValue32(s, v.Args[1])
i32Const(s, int32(v.AuxInt))
p := s.Prog(wasm.ACall)
p.To = obj.Addr{Type: obj.TYPE_MEM, Name: obj.NAME_EXTERN, Sym: gc.WasmMove}

case ssa.OpWasmLoweredZero:
getReg32(s, v.Args[0])
getValue32(s, v.Args[0])
i32Const(s, int32(v.AuxInt))
p := s.Prog(wasm.ACall)
p.To = obj.Addr{Type: obj.TYPE_MEM, Name: obj.NAME_EXTERN, Sym: gc.WasmZero}

case ssa.OpWasmLoweredNilCheck:
getReg64(s, v.Args[0])
getValue64(s, v.Args[0])
s.Prog(wasm.AI64Eqz)
s.Prog(wasm.AIf)
p := s.Prog(wasm.ACALLNORESUME)
Expand All @@ -161,14 +165,14 @@ func ssaGenValue(s *gc.SSAGenState, v *ssa.Value) {
}

case ssa.OpWasmLoweredWB:
getReg64(s, v.Args[0])
getReg64(s, v.Args[1])
getValue64(s, v.Args[0])
getValue64(s, v.Args[1])
p := s.Prog(wasm.ACALLNORESUME) // TODO(neelance): If possible, turn this into a simple wasm.ACall).
p.To = obj.Addr{Type: obj.TYPE_MEM, Name: obj.NAME_EXTERN, Sym: v.Aux.(*obj.LSym)}

case ssa.OpWasmI64Store8, ssa.OpWasmI64Store16, ssa.OpWasmI64Store32, ssa.OpWasmI64Store, ssa.OpWasmF32Store, ssa.OpWasmF64Store:
getReg32(s, v.Args[0])
getReg64(s, v.Args[1])
getValue32(s, v.Args[0])
getValue64(s, v.Args[1])
if v.Op == ssa.OpWasmF32Store {
s.Prog(wasm.AF32DemoteF64)
}
Expand All @@ -177,7 +181,7 @@ func ssaGenValue(s *gc.SSAGenState, v *ssa.Value) {

case ssa.OpStoreReg:
getReg(s, wasm.REG_SP)
getReg64(s, v.Args[0])
getValue64(s, v.Args[0])
if v.Type.Etype == types.TFLOAT32 {
s.Prog(wasm.AF32DemoteF64)
}
Expand All @@ -188,7 +192,16 @@ func ssaGenValue(s *gc.SSAGenState, v *ssa.Value) {
if v.Type.IsMemory() {
return
}
if v.OnWasmStack {
s.WasmStackSize++
// If a Value is marked OnWasmStack, we don't generate the value and store it to a register now.
// Instead, we delay the generation to when the value is used and then directly generate it on the WebAssembly stack.
return
}
ssaGenValueOnStack(s, v)
if s.WasmStackSize != 0 {
panic("wasm: bad stack")
}
setReg(s, v.Reg())
}
}
Expand Down Expand Up @@ -237,22 +250,22 @@ func ssaGenValueOnStack(s *gc.SSAGenState, v *ssa.Value) {
}

case ssa.OpWasmLoweredRound32F:
getReg64(s, v.Args[0])
getValue64(s, v.Args[0])
s.Prog(wasm.AF32DemoteF64)
s.Prog(wasm.AF64PromoteF32)

case ssa.OpWasmLoweredConvert:
getReg64(s, v.Args[0])
getValue64(s, v.Args[0])

case ssa.OpWasmSelect:
getReg64(s, v.Args[0])
getReg64(s, v.Args[1])
getReg64(s, v.Args[2])
getValue64(s, v.Args[0])
getValue64(s, v.Args[1])
getValue64(s, v.Args[2])
s.Prog(wasm.AI32WrapI64)
s.Prog(v.Op.Asm())

case ssa.OpWasmI64AddConst:
getReg64(s, v.Args[0])
getValue64(s, v.Args[0])
i64Const(s, v.AuxInt)
s.Prog(v.Op.Asm())

Expand All @@ -263,32 +276,32 @@ func ssaGenValueOnStack(s *gc.SSAGenState, v *ssa.Value) {
f64Const(s, v.AuxFloat())

case ssa.OpWasmI64Load8U, ssa.OpWasmI64Load8S, ssa.OpWasmI64Load16U, ssa.OpWasmI64Load16S, ssa.OpWasmI64Load32U, ssa.OpWasmI64Load32S, ssa.OpWasmI64Load, ssa.OpWasmF32Load, ssa.OpWasmF64Load:
getReg32(s, v.Args[0])
getValue32(s, v.Args[0])
p := s.Prog(v.Op.Asm())
p.From = obj.Addr{Type: obj.TYPE_CONST, Offset: v.AuxInt}
if v.Op == ssa.OpWasmF32Load {
s.Prog(wasm.AF64PromoteF32)
}

case ssa.OpWasmI64Eqz:
getReg64(s, v.Args[0])
getValue64(s, v.Args[0])
s.Prog(v.Op.Asm())
s.Prog(wasm.AI64ExtendUI32)

case ssa.OpWasmI64Eq, ssa.OpWasmI64Ne, ssa.OpWasmI64LtS, ssa.OpWasmI64LtU, ssa.OpWasmI64GtS, ssa.OpWasmI64GtU, ssa.OpWasmI64LeS, ssa.OpWasmI64LeU, ssa.OpWasmI64GeS, ssa.OpWasmI64GeU, ssa.OpWasmF64Eq, ssa.OpWasmF64Ne, ssa.OpWasmF64Lt, ssa.OpWasmF64Gt, ssa.OpWasmF64Le, ssa.OpWasmF64Ge:
getReg64(s, v.Args[0])
getReg64(s, v.Args[1])
getValue64(s, v.Args[0])
getValue64(s, v.Args[1])
s.Prog(v.Op.Asm())
s.Prog(wasm.AI64ExtendUI32)

case ssa.OpWasmI64Add, ssa.OpWasmI64Sub, ssa.OpWasmI64Mul, ssa.OpWasmI64DivU, ssa.OpWasmI64RemS, ssa.OpWasmI64RemU, ssa.OpWasmI64And, ssa.OpWasmI64Or, ssa.OpWasmI64Xor, ssa.OpWasmI64Shl, ssa.OpWasmI64ShrS, ssa.OpWasmI64ShrU, ssa.OpWasmF64Add, ssa.OpWasmF64Sub, ssa.OpWasmF64Mul, ssa.OpWasmF64Div:
getReg64(s, v.Args[0])
getReg64(s, v.Args[1])
getValue64(s, v.Args[0])
getValue64(s, v.Args[1])
s.Prog(v.Op.Asm())

case ssa.OpWasmI64DivS:
getReg64(s, v.Args[0])
getReg64(s, v.Args[1])
getValue64(s, v.Args[0])
getValue64(s, v.Args[1])
if v.Type.Size() == 8 {
// Division of int64 needs helper function wasmDiv to handle the MinInt64 / -1 case.
p := s.Prog(wasm.ACall)
Expand All @@ -298,17 +311,17 @@ func ssaGenValueOnStack(s *gc.SSAGenState, v *ssa.Value) {
s.Prog(wasm.AI64DivS)

case ssa.OpWasmI64TruncSF64:
getReg64(s, v.Args[0])
getValue64(s, v.Args[0])
p := s.Prog(wasm.ACall)
p.To = obj.Addr{Type: obj.TYPE_MEM, Name: obj.NAME_EXTERN, Sym: gc.WasmTruncS}

case ssa.OpWasmI64TruncUF64:
getReg64(s, v.Args[0])
getValue64(s, v.Args[0])
p := s.Prog(wasm.ACall)
p.To = obj.Addr{Type: obj.TYPE_MEM, Name: obj.NAME_EXTERN, Sym: gc.WasmTruncU}

case ssa.OpWasmF64Neg, ssa.OpWasmF64ConvertSI64, ssa.OpWasmF64ConvertUI64:
getReg64(s, v.Args[0])
getValue64(s, v.Args[0])
s.Prog(v.Op.Asm())

case ssa.OpLoadReg:
Expand All @@ -319,23 +332,36 @@ func ssaGenValueOnStack(s *gc.SSAGenState, v *ssa.Value) {
}

case ssa.OpCopy:
getReg64(s, v.Args[0])
getValue64(s, v.Args[0])

default:
v.Fatalf("unexpected op: %s", v.Op)

}
}

func getReg32(s *gc.SSAGenState, v *ssa.Value) {
func getValue32(s *gc.SSAGenState, v *ssa.Value) {
if v.OnWasmStack {
s.WasmStackSize--
ssaGenValueOnStack(s, v)
s.Prog(wasm.AI32WrapI64)
return
}

reg := v.Reg()
getReg(s, reg)
if reg != wasm.REG_SP {
s.Prog(wasm.AI32WrapI64)
}
}

func getReg64(s *gc.SSAGenState, v *ssa.Value) {
func getValue64(s *gc.SSAGenState, v *ssa.Value) {
if v.OnWasmStack {
s.WasmStackSize--
ssaGenValueOnStack(s, v)
return
}

reg := v.Reg()
getReg(s, reg)
if reg == wasm.REG_SP {
Expand Down

0 comments on commit ee6dd4a

Please sign in to comment.