diff --git a/internal/lsp/cache/maps.go b/internal/lsp/cache/maps.go new file mode 100644 index 00000000000..70f8039bdac --- /dev/null +++ b/internal/lsp/cache/maps.go @@ -0,0 +1,112 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cache + +import ( + "golang.org/x/tools/internal/persistent" + "golang.org/x/tools/internal/span" +) + +// TODO(euroelessar): Use generics once support for go1.17 is dropped. + +type goFilesMap struct { + impl *persistent.Map +} + +func newGoFilesMap() *goFilesMap { + return &goFilesMap{ + impl: persistent.NewMap(func(a, b interface{}) bool { + return parseKeyLess(a.(parseKey), b.(parseKey)) + }), + } +} + +func parseKeyLess(a, b parseKey) bool { + if a.mode != b.mode { + return a.mode < b.mode + } + if a.file.Hash != b.file.Hash { + return a.file.Hash.Less(b.file.Hash) + } + return a.file.URI < b.file.URI +} + +func (m *goFilesMap) Clone() *goFilesMap { + return &goFilesMap{ + impl: m.impl.Clone(), + } +} + +func (m *goFilesMap) Destroy() { + m.impl.Destroy() +} + +func (m *goFilesMap) Load(key parseKey) (*parseGoHandle, bool) { + value, ok := m.impl.Load(key) + if !ok { + return nil, false + } + return value.(*parseGoHandle), true +} + +func (m *goFilesMap) Range(do func(key parseKey, value *parseGoHandle)) { + m.impl.Range(func(key, value interface{}) { + do(key.(parseKey), value.(*parseGoHandle)) + }) +} + +func (m *goFilesMap) Store(key parseKey, value *parseGoHandle, release func()) { + m.impl.Store(key, value, func(key, value interface{}) { + release() + }) +} + +func (m *goFilesMap) Delete(key parseKey) { + m.impl.Delete(key) +} + +type parseKeysByURIMap struct { + impl *persistent.Map +} + +func newParseKeysByURIMap() *parseKeysByURIMap { + return &parseKeysByURIMap{ + impl: persistent.NewMap(func(a, b interface{}) bool { + return a.(span.URI) < b.(span.URI) + }), + } +} + +func (m *parseKeysByURIMap) Clone() *parseKeysByURIMap { + return &parseKeysByURIMap{ + impl: m.impl.Clone(), + } +} + +func (m *parseKeysByURIMap) Destroy() { + m.impl.Destroy() +} + +func (m *parseKeysByURIMap) Load(key span.URI) ([]parseKey, bool) { + value, ok := m.impl.Load(key) + if !ok { + return nil, false + } + return value.([]parseKey), true +} + +func (m *parseKeysByURIMap) Range(do func(key span.URI, value []parseKey)) { + m.impl.Range(func(key, value interface{}) { + do(key.(span.URI), value.([]parseKey)) + }) +} + +func (m *parseKeysByURIMap) Store(key span.URI, value []parseKey) { + m.impl.Store(key, value, nil) +} + +func (m *parseKeysByURIMap) Delete(key span.URI) { + m.impl.Delete(key) +} diff --git a/internal/lsp/cache/parse.go b/internal/lsp/cache/parse.go index ab55743ccf0..074724861e9 100644 --- a/internal/lsp/cache/parse.go +++ b/internal/lsp/cache/parse.go @@ -58,7 +58,7 @@ func (s *snapshot) parseGoHandle(ctx context.Context, fh source.FileHandle, mode if pgh := s.getGoFile(key); pgh != nil { return pgh } - parseHandle := s.generation.Bind(key, func(ctx context.Context, arg memoize.Arg) interface{} { + parseHandle, release := s.generation.NewHandle(key, func(ctx context.Context, arg memoize.Arg) interface{} { snapshot := arg.(*snapshot) return parseGo(ctx, snapshot.FileSet(), fh, mode) }, nil) @@ -68,7 +68,7 @@ func (s *snapshot) parseGoHandle(ctx context.Context, fh source.FileHandle, mode file: fh, mode: mode, } - return s.addGoFile(key, pgh) + return s.addGoFile(key, pgh, release) } func (pgh *parseGoHandle) String() string { diff --git a/internal/lsp/cache/session.go b/internal/lsp/cache/session.go index 286d8f12c46..7d5f2e859d2 100644 --- a/internal/lsp/cache/session.go +++ b/internal/lsp/cache/session.go @@ -234,7 +234,8 @@ func (s *Session) createView(ctx context.Context, name string, folder span.URI, packages: make(map[packageKey]*packageHandle), meta: &metadataGraph{}, files: make(map[span.URI]source.VersionedFileHandle), - goFiles: newGoFileMap(), + goFiles: newGoFilesMap(), + parseKeysByURI: newParseKeysByURIMap(), symbols: make(map[span.URI]*symbolHandle), actions: make(map[actionKey]*actionHandle), workspacePackages: make(map[PackageID]PackagePath), diff --git a/internal/lsp/cache/snapshot.go b/internal/lsp/cache/snapshot.go index 32681735b28..e575c4f0be5 100644 --- a/internal/lsp/cache/snapshot.go +++ b/internal/lsp/cache/snapshot.go @@ -77,7 +77,8 @@ type snapshot struct { files map[span.URI]source.VersionedFileHandle // goFiles maps a parseKey to its parseGoHandle. - goFiles *goFileMap + goFiles *goFilesMap + parseKeysByURI *parseKeysByURIMap // TODO(rfindley): consider merging this with files to reduce burden on clone. symbols map[span.URI]*symbolHandle @@ -133,6 +134,12 @@ type actionKey struct { analyzer *analysis.Analyzer } +func (s *snapshot) Destroy(destroyedBy string) { + s.generation.Destroy(destroyedBy) + s.goFiles.Destroy() + s.parseKeysByURI.Destroy() +} + func (s *snapshot) ID() uint64 { return s.id } @@ -665,17 +672,23 @@ func (s *snapshot) transitiveReverseDependencies(id PackageID, ids map[PackageID func (s *snapshot) getGoFile(key parseKey) *parseGoHandle { s.mu.Lock() defer s.mu.Unlock() - return s.goFiles.get(key) + if result, ok := s.goFiles.Load(key); ok { + return result + } + return nil } -func (s *snapshot) addGoFile(key parseKey, pgh *parseGoHandle) *parseGoHandle { +func (s *snapshot) addGoFile(key parseKey, pgh *parseGoHandle, release func()) *parseGoHandle { s.mu.Lock() defer s.mu.Unlock() - - if prev := s.goFiles.get(key); prev != nil { - return prev - } - s.goFiles.set(key, pgh) + if result, ok := s.goFiles.Load(key); ok { + release() + return result + } + s.goFiles.Store(key, pgh, release) + keys, _ := s.parseKeysByURI.Load(key.file.URI) + keys = append([]parseKey{key}, keys...) + s.parseKeysByURI.Store(key.file.URI, keys) return pgh } @@ -1661,6 +1674,9 @@ func (ac *unappliedChanges) GetFile(ctx context.Context, uri span.URI) (source.F } func (s *snapshot) clone(ctx, bgCtx context.Context, changes map[span.URI]*fileChange, forceReloadMetadata bool) *snapshot { + ctx, done := event.Start(ctx, "snapshot.clone") + defer done() + var vendorChanged bool newWorkspace, workspaceChanged, workspaceReload := s.workspace.invalidate(ctx, changes, &unappliedChanges{ originalSnapshot: s, @@ -1684,7 +1700,8 @@ func (s *snapshot) clone(ctx, bgCtx context.Context, changes map[span.URI]*fileC packages: make(map[packageKey]*packageHandle, len(s.packages)), actions: make(map[actionKey]*actionHandle, len(s.actions)), files: make(map[span.URI]source.VersionedFileHandle, len(s.files)), - goFiles: s.goFiles.clone(), + goFiles: s.goFiles.Clone(), + parseKeysByURI: s.parseKeysByURI.Clone(), symbols: make(map[span.URI]*symbolHandle, len(s.symbols)), workspacePackages: make(map[PackageID]PackagePath, len(s.workspacePackages)), unloadableFiles: make(map[span.URI]struct{}, len(s.unloadableFiles)), @@ -1729,27 +1746,14 @@ func (s *snapshot) clone(ctx, bgCtx context.Context, changes map[span.URI]*fileC result.parseWorkHandles[k] = v } - // Copy the handles of all Go source files. - // There may be tens of thousands of files, - // but changes are typically few, so we - // use a striped map optimized for this case - // and visit its stripes in parallel. - var ( - toDeleteMu sync.Mutex - toDelete []parseKey - ) - s.goFiles.forEachConcurrent(func(k parseKey, v *parseGoHandle) { - if changes[k.file.URI] == nil { - // no change (common case) - newGen.Inherit(v.handle) - } else { - toDeleteMu.Lock() - toDelete = append(toDelete, k) - toDeleteMu.Unlock() + for uri := range changes { + keys, ok := result.parseKeysByURI.Load(uri) + if ok { + for _, key := range keys { + result.goFiles.Delete(key) + } + result.parseKeysByURI.Delete(uri) } - }) - for _, k := range toDelete { - result.goFiles.delete(k) } // Copy all of the go.mod-related handles. They may be invalidated later, @@ -2206,7 +2210,7 @@ func metadataChanges(ctx context.Context, lockedSnapshot *snapshot, oldFH, newFH // lockedSnapshot must be locked. func peekOrParse(ctx context.Context, lockedSnapshot *snapshot, fh source.FileHandle, mode source.ParseMode) (*source.ParsedGoFile, error) { key := parseKey{file: fh.FileIdentity(), mode: mode} - if pgh := lockedSnapshot.goFiles.get(key); pgh != nil { + if pgh, ok := lockedSnapshot.goFiles.Load(key); ok { cached := pgh.handle.Cached(lockedSnapshot.generation) if cached != nil { cached := cached.(*parseGoData) @@ -2494,89 +2498,3 @@ func readGoSum(dst map[module.Version][]string, file string, data []byte) error } return nil } - -// -- goFileMap -- - -// A goFileMap is conceptually a map[parseKey]*parseGoHandle, -// optimized for cloning all or nearly all entries. -type goFileMap struct { - // The map is represented as a map of 256 stripes, one per - // distinct value of the top 8 bits of key.file.Hash. - // Each stripe has an associated boolean indicating whether it - // is shared, and thus immutable, and thus must be copied before any update. - // (The bits could be packed but it hasn't been worth it yet.) - stripes [256]map[parseKey]*parseGoHandle - exclusive [256]bool // exclusive[i] means stripe[i] is not shared and may be safely mutated -} - -// newGoFileMap returns a new empty goFileMap. -func newGoFileMap() *goFileMap { - return new(goFileMap) // all stripes are shared (non-exclusive) nil maps -} - -// clone returns a copy of m. -// For concurrency, it counts as an update to m. -func (m *goFileMap) clone() *goFileMap { - m.exclusive = [256]bool{} // original and copy are now nonexclusive - copy := *m - return © -} - -// get returns the value for key k. -func (m *goFileMap) get(k parseKey) *parseGoHandle { - return m.stripes[m.hash(k)][k] -} - -// set updates the value for key k to v. -func (m *goFileMap) set(k parseKey, v *parseGoHandle) { - m.unshare(k)[k] = v -} - -// delete deletes the value for key k, if any. -func (m *goFileMap) delete(k parseKey) { - // TODO(adonovan): opt?: skip unshare if k isn't present. - delete(m.unshare(k), k) -} - -// forEachConcurrent calls f for each entry in the map. -// Calls may be concurrent. -// f must not modify m. -func (m *goFileMap) forEachConcurrent(f func(parseKey, *parseGoHandle)) { - // Visit stripes in parallel chunks. - const p = 16 // concurrency level - var wg sync.WaitGroup - wg.Add(p) - for i := 0; i < p; i++ { - chunk := m.stripes[i*p : (i+1)*p] - go func() { - for _, stripe := range chunk { - for k, v := range stripe { - f(k, v) - } - } - wg.Done() - }() - } - wg.Wait() -} - -// -- internal-- - -// hash returns 8 bits from the key's file digest. -func (*goFileMap) hash(k parseKey) byte { return k.file.Hash[0] } - -// unshare makes k's stripe exclusive, allocating a copy if needed, and returns it. -func (m *goFileMap) unshare(k parseKey) map[parseKey]*parseGoHandle { - i := m.hash(k) - if !m.exclusive[i] { - m.exclusive[i] = true - - // Copy the map. - copy := make(map[parseKey]*parseGoHandle, len(m.stripes[i])) - for k, v := range m.stripes[i] { - copy[k] = v - } - m.stripes[i] = copy - } - return m.stripes[i] -} diff --git a/internal/lsp/cache/view.go b/internal/lsp/cache/view.go index 0ed9883451b..d84e7ea2249 100644 --- a/internal/lsp/cache/view.go +++ b/internal/lsp/cache/view.go @@ -524,7 +524,7 @@ func (v *View) shutdown(ctx context.Context) { v.mu.Unlock() v.snapshotMu.Lock() if v.snapshot != nil { - go v.snapshot.generation.Destroy("View.shutdown") + go v.snapshot.Destroy("View.shutdown") v.snapshot = nil } v.snapshotMu.Unlock() @@ -721,7 +721,7 @@ func (v *View) invalidateContent(ctx context.Context, changes map[span.URI]*file oldSnapshot := v.snapshot v.snapshot = oldSnapshot.clone(ctx, v.baseCtx, changes, forceReloadMetadata) - go oldSnapshot.generation.Destroy("View.invalidateContent") + go oldSnapshot.Destroy("View.invalidateContent") return v.snapshot, v.snapshot.generation.Acquire() } diff --git a/internal/lsp/source/view.go b/internal/lsp/source/view.go index 0d8d661d60c..73e1b7f89ed 100644 --- a/internal/lsp/source/view.go +++ b/internal/lsp/source/view.go @@ -551,6 +551,11 @@ func (h Hash) String() string { return fmt.Sprintf("%64x", [sha256.Size]byte(h)) } +// Less returns true if the given hash is less than the other. +func (h Hash) Less(other Hash) bool { + return bytes.Compare(h[:], other[:]) < 0 +} + // FileIdentity uniquely identifies a file at a version from a FileSystem. type FileIdentity struct { URI span.URI diff --git a/internal/memoize/memoize.go b/internal/memoize/memoize.go index 480b87f5ce9..5c1e5d5d126 100644 --- a/internal/memoize/memoize.go +++ b/internal/memoize/memoize.go @@ -83,19 +83,10 @@ func (g *Generation) Destroy(destroyedBy string) { g.store.mu.Lock() defer g.store.mu.Unlock() - for k, e := range g.store.handles { - e.mu.Lock() - if _, ok := e.generations[g]; ok { - delete(e.generations, g) // delete even if it's dead, in case of dangling references to the entry. - if len(e.generations) == 0 { - delete(g.store.handles, k) - e.state = stateDestroyed - if e.cleanup != nil && e.value != nil { - e.cleanup(e.value) - } - } + for _, e := range g.store.handles { + if e.trackGenerations { + e.decrementRef(g, g.store) } - e.mu.Unlock() } delete(g.store.generations, g) } @@ -161,6 +152,9 @@ type Handle struct { // cleanup, if non-nil, is used to perform any necessary clean-up on values // produced by function. cleanup func(interface{}) + + trackGenerations bool + refCounter int32 } // Bind returns a handle for the given key and function. @@ -173,7 +167,31 @@ type Handle struct { // // If cleanup is non-nil, it will be called on any non-nil values produced by // function when they are no longer referenced. +// +// It is responsibility of the caller to call Inherit on the handler whenever +// it should still be accessible by a next generation. func (g *Generation) Bind(key interface{}, function Function, cleanup func(interface{})) *Handle { + return g.newHandle(key, function, cleanup, true) +} + +// NewHandle returns a handle for the given key and function with similar +// properties and behavior as Bind. +// +// As in opposite to Bind it returns a release callback which has to be called +// once this reference to handle is not needed anymore. +func (g *Generation) NewHandle(key interface{}, function Function, cleanup func(interface{})) (*Handle, func()) { + handle := g.newHandle(key, function, cleanup, false) + store := g.store + release := func() { + store.mu.Lock() + defer store.mu.Unlock() + + handle.decrementRef(nil, store) + } + return handle, release +} + +func (g *Generation) newHandle(key interface{}, function Function, cleanup func(interface{}), trackGenerations bool) *Handle { // panic early if the function is nil // it would panic later anyway, but in a way that was much harder to debug if function == nil { @@ -186,20 +204,19 @@ func (g *Generation) Bind(key interface{}, function Function, cleanup func(inter defer g.store.mu.Unlock() h, ok := g.store.handles[key] if !ok { - h := &Handle{ - key: key, - function: function, - generations: map[*Generation]struct{}{g: {}}, - cleanup: cleanup, + h = &Handle{ + key: key, + function: function, + cleanup: cleanup, + trackGenerations: trackGenerations, + } + if trackGenerations { + h.generations = make(map[*Generation]struct{}, 1) } g.store.handles[key] = h - return h - } - h.mu.Lock() - defer h.mu.Unlock() - if _, ok := h.generations[g]; !ok { - h.generations[g] = struct{}{} } + + h.incrementRef(g) return h } @@ -240,13 +257,68 @@ func (g *Generation) Inherit(h *Handle) { if atomic.LoadUint32(&g.destroyed) != 0 { panic("inherit on generation " + g.name + " destroyed by " + g.destroyedBy) } + if !h.trackGenerations { + panic("called Inherit on handle not created by Generation.Bind") + } + + h.incrementRef(g) +} +func (h *Handle) destroy(store *Store) { + h.state = stateDestroyed + if h.cleanup != nil && h.value != nil { + h.cleanup(h.value) + } + delete(store.handles, h.key) +} + +func (h *Handle) incrementRef(g *Generation) { h.mu.Lock() + defer h.mu.Unlock() + if h.state == stateDestroyed { panic(fmt.Sprintf("inheriting destroyed handle %#v (type %T) into generation %v", h.key, h.key, g.name)) } - h.generations[g] = struct{}{} - h.mu.Unlock() + + if h.trackGenerations { + h.generations[g] = struct{}{} + } else { + h.refCounter++ + } +} + +func (h *Handle) decrementRef(g *Generation, store *Store) { + h.mu.Lock() + defer h.mu.Unlock() + + if h.trackGenerations { + if g == nil { + panic("passed nil generation to Handle.decrementRef") + } + if _, ok := h.generations[g]; ok { + delete(h.generations, g) // delete even if it's dead, in case of dangling references to the entry. + if len(h.generations) == 0 { + h.destroy(store) + } + } + } else { + if g != nil { + panic(fmt.Sprintf("passed non-generation to Handle.decrementRef: %v", g)) + } + h.refCounter-- + if h.refCounter == 0 { + h.destroy(store) + } + } +} + +func (h *Handle) hasRefLocked(g *Generation) bool { + if !h.trackGenerations { + return true + } + + _, ok := h.generations[g] + return ok } // Cached returns the value associated with a handle. @@ -256,7 +328,7 @@ func (g *Generation) Inherit(h *Handle) { func (h *Handle) Cached(g *Generation) interface{} { h.mu.Lock() defer h.mu.Unlock() - if _, ok := h.generations[g]; !ok { + if !h.hasRefLocked(g) { return nil } if h.state == stateCompleted { @@ -277,7 +349,7 @@ func (h *Handle) Get(ctx context.Context, g *Generation, arg Arg) (interface{}, return nil, ctx.Err() } h.mu.Lock() - if _, ok := h.generations[g]; !ok { + if !h.hasRefLocked(g) { h.mu.Unlock() err := fmt.Errorf("reading key %#v: generation %v is not known", h.key, g.name) diff --git a/internal/memoize/memoize_test.go b/internal/memoize/memoize_test.go index ee0fd23ea1d..6c52c3f8c6d 100644 --- a/internal/memoize/memoize_test.go +++ b/internal/memoize/memoize_test.go @@ -106,3 +106,58 @@ func TestCleanup(t *testing.T) { t.Error("after destroying g2, v2 is not cleaned up") } } + +func TestHandleRefCounting(t *testing.T) { + s := &memoize.Store{} + g1 := s.Generation("g1") + v1 := false + v2 := false + cleanup := func(v interface{}) { + *(v.(*bool)) = true + } + h1, release1 := g1.NewHandle("key1", func(context.Context, memoize.Arg) interface{} { + return &v1 + }, nil) + h2, release2 := g1.NewHandle("key2", func(context.Context, memoize.Arg) interface{} { + return &v2 + }, cleanup) + expectGet(t, h1, g1, &v1) + expectGet(t, h2, g1, &v2) + + g2 := s.Generation("g2") + expectGet(t, h1, g2, &v1) + g1.Destroy("by test") + expectGet(t, h2, g2, &v2) + + h2Copy, release2Copy := g2.NewHandle("key2", func(context.Context, memoize.Arg) interface{} { + return &v1 + }, nil) + if h2 != h2Copy { + t.Error("NewHandle returned a new value while old is not destroyed yet") + } + expectGet(t, h2Copy, g2, &v2) + g2.Destroy("by test") + + release2() + if got, want := v2, false; got != want { + t.Error("after destroying first v2 ref, v2 is cleaned up") + } + release2Copy() + if got, want := v2, true; got != want { + t.Error("after destroying second v2 ref, v2 is not cleaned up") + } + if got, want := v1, false; got != want { + t.Error("after destroying v2, v1 is cleaned up") + } + release1() + + g3 := s.Generation("g3") + h2Copy, release2Copy = g3.NewHandle("key2", func(context.Context, memoize.Arg) interface{} { + return &v2 + }, cleanup) + if h2 == h2Copy { + t.Error("NewHandle returned previously destroyed value") + } + release2Copy() + g3.Destroy("by test") +} diff --git a/internal/persistent/map.go b/internal/persistent/map.go new file mode 100644 index 00000000000..56716f0a33e --- /dev/null +++ b/internal/persistent/map.go @@ -0,0 +1,257 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The persistent package defines various persistent data structures; +// that is, data structures that can be efficiently copied and modified +// in sublinear time. +package persistent + +import ( + "math/rand" + "sync/atomic" +) + +// Map is an associative mapping from keys to values, both represented as +// interface{}. Key comparison and iteration order is defined by a +// client-provided function that implements a strict weak order. +// +// Maps can be Cloned in constant time. +// +// Values are reference counted, and a client-supplied release function +// is called when a value is no longer referenced by a map or any clone. +// +// Internally the implementation is based on a randomized persistent treap: +// https://en.wikipedia.org/wiki/Treap. +type Map struct { + less func(a, b interface{}) bool + root *mapNode +} + +type mapNode struct { + key interface{} + value *refValue + weight uint64 + refCount int32 + left, right *mapNode +} + +type refValue struct { + refCount int32 + value interface{} + release func(key, value interface{}) +} + +func newNodeWithRef(key, value interface{}, release func(key, value interface{})) *mapNode { + return &mapNode{ + key: key, + value: &refValue{ + value: value, + release: release, + refCount: 1, + }, + refCount: 1, + weight: rand.Uint64(), + } +} + +func (node *mapNode) shallowCloneWithRef() *mapNode { + atomic.AddInt32(&node.value.refCount, 1) + return &mapNode{ + key: node.key, + value: node.value, + weight: node.weight, + refCount: 1, + } +} + +func (node *mapNode) incref() *mapNode { + if node != nil { + atomic.AddInt32(&node.refCount, 1) + } + return node +} + +func (node *mapNode) decref() { + if node == nil { + return + } + if atomic.AddInt32(&node.refCount, -1) == 0 { + if atomic.AddInt32(&node.value.refCount, -1) == 0 { + if node.value.release != nil { + node.value.release(node.key, node.value.value) + } + node.value.value = nil + node.value.release = nil + } + node.left.decref() + node.right.decref() + } +} + +// NewMap returns a new map whose keys are ordered by the given comparison +// function (a strict weak order). It is the responsibility of the caller to +// Destroy it at later time. +func NewMap(less func(a, b interface{}) bool) *Map { + return &Map{ + less: less, + } +} + +// Clone returns a copy of the given map. It is a responsibility of the caller +// to Destroy it at later time. +func (pm *Map) Clone() *Map { + return &Map{ + less: pm.less, + root: pm.root.incref(), + } +} + +// Destroy the persistent map. +// +// After Destroy, the Map should not be used again. +func (pm *Map) Destroy() { + pm.root.decref() + pm.root = nil +} + +// Range calls f sequentially in ascending key order for all entries in the map. +func (pm *Map) Range(f func(key, value interface{})) { + pm.root.forEach(f) +} + +func (node *mapNode) forEach(f func(key, value interface{})) { + if node == nil { + return + } + node.left.forEach(f) + f(node.key, node.value.value) + node.right.forEach(f) +} + +// Load returns the value stored in the map for a key, or nil if no entry is +// present. The ok result indicates whether an entry was found in the map. +func (pm *Map) Load(key interface{}) (interface{}, bool) { + node := pm.root + for node != nil { + if pm.less(key, node.key) { + node = node.left + } else if pm.less(node.key, key) { + node = node.right + } else { + return node.value.value, true + } + } + return nil, false +} + +// Store sets the value for a key. +// If release is non-nil, it will be called with entry's key and value once the +// key is no longer contained in the map or any clone. +func (pm *Map) Store(key, value interface{}, release func(key, value interface{})) { + first := pm.root + second := newNodeWithRef(key, value, release) + pm.root = union(first, second, pm.less, true) + first.decref() + second.decref() +} + +// union returns a new tree which is a union of first and second one. +// If overwrite is set to true, second one would override a value for any duplicate keys. +// +// union(left:-0, right:-0) (result:+1) +// Union borrows both subtrees without affecting their refcount and returns a +// new reference that the caller is expected to dispose of. +func union(first, second *mapNode, less func(a, b interface{}) bool, overwrite bool) *mapNode { + if first == nil { + return second.incref() + } + if second == nil { + return first.incref() + } + + if first.weight < second.weight { + second, first, overwrite = first, second, !overwrite + } + + left, mid, right := split(second, first.key, less) + var result *mapNode + if overwrite && mid != nil { + result = mid.shallowCloneWithRef() + } else { + result = first.shallowCloneWithRef() + } + result.weight = first.weight + result.left = union(first.left, left, less, overwrite) + result.right = union(first.right, right, less, overwrite) + left.decref() + mid.decref() + right.decref() + return result +} + +// split the tree midway by the key into three different ones. +// Return three new trees: left with all nodes with smaller than key, mid with +// the node matching the key, right with all nodes larger than key. +// If there are no nodes in one of trees, return nil instead of it. +// +// split(n:-0) (left:+1, mid:+1, right:+1) +// Split borrows n without affecting its refcount, and returns three +// new references that that caller is expected to dispose of. +func split(n *mapNode, key interface{}, less func(a, b interface{}) bool) (left, mid, right *mapNode) { + if n == nil { + return nil, nil, nil + } + + if less(n.key, key) { + left, mid, right := split(n.right, key, less) + newN := n.shallowCloneWithRef() + newN.left = n.left.incref() + newN.right = left + return newN, mid, right + } else if less(key, n.key) { + left, mid, right := split(n.left, key, less) + newN := n.shallowCloneWithRef() + newN.left = right + newN.right = n.right.incref() + return left, mid, newN + } + mid = n.shallowCloneWithRef() + return n.left.incref(), mid, n.right.incref() +} + +// Delete deletes the value for a key. +func (pm *Map) Delete(key interface{}) { + root := pm.root + left, mid, right := split(root, key, pm.less) + pm.root = merge(left, right) + left.decref() + mid.decref() + right.decref() + root.decref() +} + +// merge two trees while preserving the weight invariant. +// All nodes in left must have smaller keys than any node in right. +// +// merge(left:-0, right:-0) (result:+1) +// Merge borrows its arguments without affecting their refcount +// and returns a new reference that the caller is expected to dispose of. +func merge(left, right *mapNode) *mapNode { + switch { + case left == nil: + return right.incref() + case right == nil: + return left.incref() + case left.weight > right.weight: + root := left.shallowCloneWithRef() + root.left = left.left.incref() + root.right = merge(left.right, right) + return root + default: + root := right.shallowCloneWithRef() + root.left = merge(left, right.left) + root.right = right.right.incref() + return root + } +} diff --git a/internal/persistent/map_test.go b/internal/persistent/map_test.go new file mode 100644 index 00000000000..9585956100b --- /dev/null +++ b/internal/persistent/map_test.go @@ -0,0 +1,316 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package persistent + +import ( + "fmt" + "math/rand" + "reflect" + "sync/atomic" + "testing" +) + +type mapEntry struct { + key int + value int +} + +type validatedMap struct { + impl *Map + expected map[int]int + deleted map[mapEntry]struct{} + seen map[mapEntry]struct{} +} + +func TestSimpleMap(t *testing.T) { + deletedEntries := make(map[mapEntry]struct{}) + seenEntries := make(map[mapEntry]struct{}) + + m1 := &validatedMap{ + impl: NewMap(func(a, b interface{}) bool { + return a.(int) < b.(int) + }), + expected: make(map[int]int), + deleted: deletedEntries, + seen: seenEntries, + } + + m3 := m1.clone() + validateRef(t, m1, m3) + m3.insert(t, 8, 8) + validateRef(t, m1, m3) + m3.destroy() + + assertSameMap(t, deletedEntries, map[mapEntry]struct{}{ + {key: 8, value: 8}: {}, + }) + + validateRef(t, m1) + m1.insert(t, 1, 1) + validateRef(t, m1) + m1.insert(t, 2, 2) + validateRef(t, m1) + m1.insert(t, 3, 3) + validateRef(t, m1) + m1.remove(t, 2) + validateRef(t, m1) + m1.insert(t, 6, 6) + validateRef(t, m1) + + assertSameMap(t, deletedEntries, map[mapEntry]struct{}{ + {key: 2, value: 2}: {}, + {key: 8, value: 8}: {}, + }) + + m2 := m1.clone() + validateRef(t, m1, m2) + m1.insert(t, 6, 60) + validateRef(t, m1, m2) + m1.remove(t, 1) + validateRef(t, m1, m2) + + for i := 10; i < 14; i++ { + m1.insert(t, i, i) + validateRef(t, m1, m2) + } + + m1.insert(t, 10, 100) + validateRef(t, m1, m2) + + m1.remove(t, 12) + validateRef(t, m1, m2) + + m2.insert(t, 4, 4) + validateRef(t, m1, m2) + m2.insert(t, 5, 5) + validateRef(t, m1, m2) + + m1.destroy() + + assertSameMap(t, deletedEntries, map[mapEntry]struct{}{ + {key: 2, value: 2}: {}, + {key: 6, value: 60}: {}, + {key: 8, value: 8}: {}, + {key: 10, value: 10}: {}, + {key: 10, value: 100}: {}, + {key: 11, value: 11}: {}, + {key: 12, value: 12}: {}, + {key: 13, value: 13}: {}, + }) + + m2.insert(t, 7, 7) + validateRef(t, m2) + + m2.destroy() + + assertSameMap(t, seenEntries, deletedEntries) +} + +func TestRandomMap(t *testing.T) { + deletedEntries := make(map[mapEntry]struct{}) + seenEntries := make(map[mapEntry]struct{}) + + m := &validatedMap{ + impl: NewMap(func(a, b interface{}) bool { + return a.(int) < b.(int) + }), + expected: make(map[int]int), + deleted: deletedEntries, + seen: seenEntries, + } + + keys := make([]int, 0, 1000) + for i := 0; i < 1000; i++ { + key := rand.Int() + m.insert(t, key, key) + keys = append(keys, key) + + if i%10 == 1 { + index := rand.Intn(len(keys)) + last := len(keys) - 1 + key = keys[index] + keys[index], keys[last] = keys[last], keys[index] + keys = keys[:last] + + m.remove(t, key) + } + } + + m.destroy() + assertSameMap(t, seenEntries, deletedEntries) +} + +func (vm *validatedMap) onDelete(t *testing.T, key, value int) { + entry := mapEntry{key: key, value: value} + if _, ok := vm.deleted[entry]; ok { + t.Fatalf("tried to delete entry twice, key: %d, value: %d", key, value) + } + vm.deleted[entry] = struct{}{} +} + +func validateRef(t *testing.T, maps ...*validatedMap) { + t.Helper() + + actualCountByEntry := make(map[mapEntry]int32) + nodesByEntry := make(map[mapEntry]map[*mapNode]struct{}) + expectedCountByEntry := make(map[mapEntry]int32) + for i, m := range maps { + dfsRef(m.impl.root, actualCountByEntry, nodesByEntry) + dumpMap(t, fmt.Sprintf("%d:", i), m.impl.root) + } + for entry, nodes := range nodesByEntry { + expectedCountByEntry[entry] = int32(len(nodes)) + } + assertSameMap(t, expectedCountByEntry, actualCountByEntry) +} + +func dfsRef(node *mapNode, countByEntry map[mapEntry]int32, nodesByEntry map[mapEntry]map[*mapNode]struct{}) { + if node == nil { + return + } + + entry := mapEntry{key: node.key.(int), value: node.value.value.(int)} + countByEntry[entry] = atomic.LoadInt32(&node.value.refCount) + + nodes, ok := nodesByEntry[entry] + if !ok { + nodes = make(map[*mapNode]struct{}) + nodesByEntry[entry] = nodes + } + nodes[node] = struct{}{} + + dfsRef(node.left, countByEntry, nodesByEntry) + dfsRef(node.right, countByEntry, nodesByEntry) +} + +func dumpMap(t *testing.T, prefix string, n *mapNode) { + if n == nil { + t.Logf("%s nil", prefix) + return + } + t.Logf("%s {key: %v, value: %v (ref: %v), ref: %v, weight: %v}", prefix, n.key, n.value.value, n.value.refCount, n.refCount, n.weight) + dumpMap(t, prefix+"l", n.left) + dumpMap(t, prefix+"r", n.right) +} + +func (vm *validatedMap) validate(t *testing.T) { + t.Helper() + + validateNode(t, vm.impl.root, vm.impl.less) + + for key, value := range vm.expected { + entry := mapEntry{key: key, value: value} + if _, ok := vm.deleted[entry]; ok { + t.Fatalf("entry is deleted prematurely, key: %d, value: %d", key, value) + } + } + + actualMap := make(map[int]int, len(vm.expected)) + vm.impl.Range(func(key, value interface{}) { + if other, ok := actualMap[key.(int)]; ok { + t.Fatalf("key is present twice, key: %d, first value: %d, second value: %d", key, value, other) + } + actualMap[key.(int)] = value.(int) + }) + + assertSameMap(t, actualMap, vm.expected) +} + +func validateNode(t *testing.T, node *mapNode, less func(a, b interface{}) bool) { + if node == nil { + return + } + + if node.left != nil { + if less(node.key, node.left.key) { + t.Fatalf("left child has larger key: %v vs %v", node.left.key, node.key) + } + if node.left.weight > node.weight { + t.Fatalf("left child has larger weight: %v vs %v", node.left.weight, node.weight) + } + } + + if node.right != nil { + if less(node.right.key, node.key) { + t.Fatalf("right child has smaller key: %v vs %v", node.right.key, node.key) + } + if node.right.weight > node.weight { + t.Fatalf("right child has larger weight: %v vs %v", node.right.weight, node.weight) + } + } + + validateNode(t, node.left, less) + validateNode(t, node.right, less) +} + +func (vm *validatedMap) insert(t *testing.T, key, value int) { + vm.seen[mapEntry{key: key, value: value}] = struct{}{} + vm.impl.Store(key, value, func(deletedKey, deletedValue interface{}) { + if deletedKey != key || deletedValue != value { + t.Fatalf("unexpected passed in deleted entry: %v/%v, expected: %v/%v", deletedKey, deletedValue, key, value) + } + vm.onDelete(t, key, value) + }) + vm.expected[key] = value + vm.validate(t) + + loadValue, ok := vm.impl.Load(key) + if !ok || loadValue != value { + t.Fatalf("unexpected load result after insertion, key: %v, expected: %v, got: %v (%v)", key, value, loadValue, ok) + } +} + +func (vm *validatedMap) remove(t *testing.T, key int) { + vm.impl.Delete(key) + delete(vm.expected, key) + vm.validate(t) + + loadValue, ok := vm.impl.Load(key) + if ok { + t.Fatalf("unexpected load result after removal, key: %v, got: %v", key, loadValue) + } +} + +func (vm *validatedMap) clone() *validatedMap { + expected := make(map[int]int, len(vm.expected)) + for key, value := range vm.expected { + expected[key] = value + } + + return &validatedMap{ + impl: vm.impl.Clone(), + expected: expected, + deleted: vm.deleted, + seen: vm.seen, + } +} + +func (vm *validatedMap) destroy() { + vm.impl.Destroy() +} + +func assertSameMap(t *testing.T, map1, map2 interface{}) { + t.Helper() + + if !reflect.DeepEqual(map1, map2) { + t.Fatalf("different maps:\n%v\nvs\n%v", map1, map2) + } +} + +func isSameMap(map1, map2 reflect.Value) bool { + if map1.Len() != map2.Len() { + return false + } + iter := map1.MapRange() + for iter.Next() { + key := iter.Key() + value1 := iter.Value() + value2 := map2.MapIndex(key) + if value2.IsZero() || !reflect.DeepEqual(value1.Interface(), value2.Interface()) { + return false + } + } + return true +}