Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zstd: x86 assembler implementation of sequenceDecs.executeSimple #531

Merged
merged 3 commits into from
Mar 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions zstd/_generate/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ func main() {
o.genDecodeSeqAsm("sequenceDecs_decode_amd64")
o.bmi2 = true
o.genDecodeSeqAsm("sequenceDecs_decode_bmi2")

exec := executeSimple{}
exec.generateProcedure("sequenceDecs_executeSimple_amd64")

Generate()
b, err := ioutil.ReadFile(out.Value.String())
if err != nil {
Expand Down Expand Up @@ -550,3 +554,200 @@ func (o options) adjustOffset(name string, moP, llP Mem, offsetB reg.GPVirtual)
Label(name + "_end")
return offset
}

type executeSimple struct{}

// copySize returns register size used to fast copy.
//
// See copyMemory()
func (e executeSimple) copySize() int {
return 16
}

func (e executeSimple) generateProcedure(name string) {
Package("github.com/klauspost/compress/zstd")
TEXT(name, 0, "func (ctx *executeAsmContext) bool")
Doc(name+" implements the main loop of sequenceDecs.decode in x86 asm", "")
Pragma("noescape")

seqsBase := GP64()
seqsLen := GP64()
seqIndex := GP64()
outBase := GP64()
outLen := GP64()
literals := GP64()
outPosition := GP64()
windowSize := GP64()

{
ctx := Dereference(Param("ctx"))
Load(ctx.Field("seqs").Len(), seqsLen)
TESTQ(seqsLen, seqsLen)
JZ(LabelRef("empty_seqs"))
Load(ctx.Field("seqs").Base(), seqsBase)
Load(ctx.Field("seqIndex"), seqIndex)
Load(ctx.Field("out").Base(), outBase)
Load(ctx.Field("out").Len(), outLen)
Load(ctx.Field("literals").Base(), literals)
Load(ctx.Field("outPosition"), outPosition)
Load(ctx.Field("windowSize"), windowSize)

tmp := GP64()
Comment("seqsBase += 24 * seqIndex")
LEAQ(Mem{Base: seqIndex, Index: seqIndex, Scale: 2}, tmp) // * 3
SHLQ(U8(3), tmp) // * 8
ADDQ(tmp, seqsBase)

Comment("outBase += outPosition")
ADDQ(outPosition, outBase)
}

Label("main_loop")

ml := GP64()
mo := GP64()
ll := GP64()

moPtr := Mem{Base: seqsBase, Disp: 2 * 8}
mlPtr := Mem{Base: seqsBase, Disp: 1 * 8}
llPtr := Mem{Base: seqsBase, Disp: 0 * 8}

MOVQ(mlPtr, ml)
MOVQ(llPtr, ll)

Comment("Copy literals")
Label("copy_literals")
{
TESTQ(ll, ll)
JZ(LabelRef("copy_match"))
e.copyMemory("1", literals, outBase, ll)

ADDQ(ll, literals)
ADDQ(ll, outBase)
ADDQ(ll, outPosition)
}

Comment("Copy match")
Label("copy_match")
{
TESTQ(ml, ml)
JZ(LabelRef("handle_loop"))

MOVQ(moPtr, mo)

Comment("Malformed input if seq.mo > t || seq.mo > s.windowSize)")
CMPQ(mo, outPosition)
JG(LabelRef("error_match_off_to_big"))
CMPQ(mo, windowSize)
JG(LabelRef("error_match_off_to_big"))

src := GP64()
MOVQ(outBase, src)
SUBQ(mo, src) // src = &s.out[t - mo]

// start := t - mo
// if ml <= t-start {
// // no overlap
// } else {
// // overlapping copy
// }
//
// Note: ml <= t - start
// ml <= t - (t - mo)
// ml <= mo
Comment("ml <= mo")
CMPQ(ml, mo)
JA(LabelRef("copy_overalapping_match"))

Comment("Copy non-overlapping match")
Label("copy_non_overalapping_match")
{
e.copyMemory("2", src, outBase, ml)
ADDQ(ml, outBase)
ADDQ(ml, outPosition)
JMP(LabelRef("handle_loop"))
}

Comment("Copy overlapping match")
Label("copy_overalapping_match")
{
e.copyOverlappedMemory("3", src, outBase, ml)
ADDQ(ml, outBase)
ADDQ(ml, outPosition)
}
}

Label("handle_loop")
ADDQ(U8(24), seqsBase) // seqs += sizeof(seqVals)
INCQ(seqIndex)
CMPQ(seqIndex, seqsLen)
JB(LabelRef("main_loop"))

ret, err := ReturnIndex(0).Resolve()
if err != nil {
panic(err)
}

returnValue := func(val int) {

Comment("Return value")
MOVB(U8(val), ret.Addr)

Comment("Update the context")
ctx := Dereference(Param("ctx"))
Store(seqIndex, ctx.Field("seqIndex"))
Store(outPosition, ctx.Field("outPosition"))

// compute litPosition
tmp := GP64()
Load(ctx.Field("literals").Base(), tmp)
SUBQ(tmp, literals) // litPosition := current - initial literals pointer
Store(literals, ctx.Field("litPosition"))
}
returnValue(1)
RET()

Label("error_match_off_to_big")
returnValue(0)
RET()

Label("empty_seqs")
Comment("Return value")
MOVB(U8(1), ret.Addr)
RET()
}

// copyMemory will copy memory in blocks of 16 or 32 bytes,
// overwriting up to 15 or 31 extra bytes.
func (e executeSimple) copyMemory(suffix string, src, dst, length reg.GPVirtual) {
label := "copy_" + suffix
ofs := GP64()
s := Mem{Base: src, Index: ofs, Scale: 1}
d := Mem{Base: dst, Index: ofs, Scale: 1}

XORQ(ofs, ofs)
Label(label)
t := XMM()
MOVUPS(s, t)
MOVUPS(t, d)
ADDQ(U8(e.copySize()), ofs)
CMPQ(ofs, length)
JB(LabelRef(label))
}

// copyOverlappedMemory will copy one byte at the time from src to dst.
func (e executeSimple) copyOverlappedMemory(suffix string, src, dst, length reg.GPVirtual) {
label := "copy_slow_" + suffix
ofs := GP64()
s := Mem{Base: src, Index: ofs, Scale: 1}
d := Mem{Base: dst, Index: ofs, Scale: 1}
t := GP64()

XORQ(ofs, ofs)
Label(label)
MOVB(s, t.As8())
MOVB(t.As8(), d)
INCQ(ofs)
CMPQ(ofs, length)
JB(LabelRef(label))
}
2 changes: 1 addition & 1 deletion zstd/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1151,7 +1151,7 @@ func testDecoderFileBad(t *testing.T, fn string, newDec func() (*Decoder, error)
if want == "" {
want = "<error>"
}
t.Error("Did not get expected error", want, "- got ", len(got), "bytes")
t.Error("Did not get expected error", want, "- got", len(got), "bytes")
return
}
if errMap[tt.Name] == "" {
Expand Down
4 changes: 4 additions & 0 deletions zstd/seqdec.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ func (s *sequenceDecs) initialize(br *bitReader, hist *history, out []byte) erro
// execute will execute the decoded sequence with the provided history.
// The sequence must be evaluated before being sent.
func (s *sequenceDecs) execute(seqs []seqVals, hist []byte) error {
if len(hist) == 0 && len(s.dict) == 0 {
return s.executeSimple(seqs)
}

// Ensure we have enough output size...
if len(s.out)+s.seqSize > cap(s.out) {
addBytes := s.seqSize + len(s.out)
Expand Down
67 changes: 67 additions & 0 deletions zstd/seqdec_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,70 @@ func (s *sequenceDecs) decode(seqs []seqVals) error {
}
return err
}

type executeAsmContext struct {
seqs []seqVals
seqIndex int
out []byte
literals []byte
outPosition int
litPosition int
windowSize int
}

// sequenceDecs_executeSimple_amd64 implements the main loop of sequenceDecs.executeSimple in x86 asm.
//
// Returns false if a match offset is too big.
//
// Please refer to seqdec_generic.go for the reference implementation.
//go:noescape
func sequenceDecs_executeSimple_amd64(ctx *executeAsmContext) bool

const overwriteSize = 16

// executeSimple handles cases when no history nor dictionary are used.
func (s *sequenceDecs) executeSimple(seqs []seqVals) error {
// Ensure we have enough output size...
if len(s.out)+s.seqSize+overwriteSize > cap(s.out) {
addBytes := s.seqSize + len(s.out) + overwriteSize
s.out = append(s.out, make([]byte, addBytes)...)
s.out = s.out[:len(s.out)-addBytes]
}

if debugDecoder {
printf("Execute %d seqs with literals: %d into %d bytes\n", len(seqs), len(s.literals), s.seqSize)
}

var t = len(s.out)
out := s.out[:t+s.seqSize]

ctx := executeAsmContext{
seqs: seqs,
seqIndex: 0,
out: out,
outPosition: t,
litPosition: 0,
literals: s.literals,
windowSize: s.windowSize,
}

ok := sequenceDecs_executeSimple_amd64(&ctx)
if !ok {
return fmt.Errorf("match offset (%d) bigger than current history (%d)",
seqs[ctx.seqIndex].mo, ctx.outPosition)
}
s.literals = s.literals[ctx.litPosition:]
t = ctx.outPosition

// Add final literals
copy(out[t:], s.literals)
if debugDecoder {
t += len(s.literals)
if t != len(out) {
panic(fmt.Errorf("length mismatch, want %d, got %d, ss: %d", len(out), t, s.seqSize))
}
}
s.out = out

return nil
}
Loading