From 0ca1c0560cee8fd31ed02025fc24bf7fa11b3d50 Mon Sep 17 00:00:00 2001 From: Ben Adams Date: Sun, 17 Sep 2017 05:28:05 +0100 Subject: [PATCH] Improve Dictionary CQ --- .../System/Collections/Generic/Dictionary.cs | 952 +++++++++--------- .../src/System/Collections/Hashtable.cs | 1 + 2 files changed, 501 insertions(+), 452 deletions(-) diff --git a/src/mscorlib/src/System/Collections/Generic/Dictionary.cs b/src/mscorlib/src/System/Collections/Generic/Dictionary.cs index 761f775905a7..322d077a07d6 100644 --- a/src/mscorlib/src/System/Collections/Generic/Dictionary.cs +++ b/src/mscorlib/src/System/Collections/Generic/Dictionary.cs @@ -30,23 +30,30 @@ namespace System.Collections.Generic /// /// Used internally to control behavior of insertion into a . /// - internal enum InsertionBehavior : byte - { - /// - /// The default insertion behavior. - /// - None = 0, - - /// - /// Specifies that an existing entry with the same key should be overwritten if encountered. - /// - OverwriteExisting = 1, - - /// - /// Specifies that if an existing entry with the same key is encountered, an exception should be thrown. - /// - ThrowOnExisting = 2 - } + internal interface IInsertionBehavior { } + + /// + /// The default insertion behavior. + /// + internal struct RejectIfExisting : IInsertionBehavior { } + + /// + /// Specifies that an existing entry with the same key should be overwritten if encountered. + /// + internal struct OverwriteExisting : IInsertionBehavior { } + + /// + /// Specifies that if an existing entry with the same key is encountered, an exception should be thrown. + /// + internal struct ThrowOnExisting : IInsertionBehavior { } + + internal interface IComparerType { } + internal struct CustomComparer : IComparerType { } + internal struct DefaultComparer : IComparerType { } + + internal interface IResizeBehavior { } + internal struct GenerateNewHashcodes : IResizeBehavior { } + internal struct KeepHashcodes : IResizeBehavior { } [DebuggerTypeProxy(typeof(IDictionaryDebugView<,>))] [DebuggerDisplay("Count = {Count}")] @@ -62,15 +69,19 @@ private struct Entry public TValue value; // Value of entry } - private int[] buckets; - private Entry[] entries; - private int count; - private int version; - private int freeList; - private int freeCount; - private IEqualityComparer comparer; - private KeyCollection keys; - private ValueCollection values; + private static Entry s_nullEntry; + + private int[] _buckets; + private Entry[] _entries; + private int _count; + private int _version; + private int _freeList; + private int _freeCount; + + private IEqualityComparer _customComparer; + + private KeyCollection _keys; + private ValueCollection _values; private Object _syncRoot; // constants for serialization @@ -89,18 +100,24 @@ public Dictionary(int capacity, IEqualityComparer comparer) { if (capacity < 0) ThrowHelper.ThrowArgumentOutOfRangeException(ExceptionArgument.capacity); if (capacity > 0) Initialize(capacity); - this.comparer = comparer ?? EqualityComparer.Default; + _customComparer = comparer; - if (this.comparer == EqualityComparer.Default) + // String has a more nuanced comparer as its GetHashCode is randomised for security + if (typeof(TKey) == typeof(string)) { - this.comparer = (IEqualityComparer)NonRandomizedStringEqualityComparer.Default; + // If TKey is a string, we move off the default comparer to a non-randomized comparer + // Later if collisions become too high we will move back onto the default randomized comparer + if (comparer == null || ReferenceEquals(comparer, EqualityComparer.Default)) + { + _customComparer = (IEqualityComparer)NonRandomizedStringEqualityComparer.Default; + } } } public Dictionary(IDictionary dictionary) : this(dictionary, null) { } public Dictionary(IDictionary dictionary, IEqualityComparer comparer) : - this(dictionary != null ? dictionary.Count : 0, comparer) + this(dictionary?.Count ?? 0, comparer) { if (dictionary == null) { @@ -114,13 +131,14 @@ public Dictionary(IDictionary dictionary, IEqualityComparer if (dictionary.GetType() == typeof(Dictionary)) { Dictionary d = (Dictionary)dictionary; - int count = d.count; - Entry[] entries = d.entries; + int count = d._count; + Entry[] entries = d._entries; for (int i = 0; i < count; i++) { - if (entries[i].hashCode >= 0) + ref Entry entry = ref entries[i]; + if (entry.hashCode >= 0) { - Add(entries[i].key, entries[i].value); + Add(entry.key, entry.value); } } return; @@ -158,92 +176,56 @@ protected Dictionary(SerializationInfo info, StreamingContext context) HashHelpers.SerializationInfoTable.Add(this, info); } - public IEqualityComparer Comparer - { - get - { - return comparer; - } - } + public IEqualityComparer Comparer => _customComparer ?? EqualityComparer.Default; - public int Count - { - get { return count - freeCount; } - } + public int Count => _count - _freeCount; public KeyCollection Keys { get { - if (keys == null) keys = new KeyCollection(this); - return keys; + return _keys ?? (_keys = new KeyCollection(this)); } } - ICollection IDictionary.Keys - { - get - { - if (keys == null) keys = new KeyCollection(this); - return keys; - } - } + ICollection IDictionary.Keys => _keys ?? (_keys = new KeyCollection(this)); - IEnumerable IReadOnlyDictionary.Keys - { - get - { - if (keys == null) keys = new KeyCollection(this); - return keys; - } - } + IEnumerable IReadOnlyDictionary.Keys => _keys ?? (_keys = new KeyCollection(this)); public ValueCollection Values { get { - if (values == null) values = new ValueCollection(this); - return values; + return _values ?? (_values = new ValueCollection(this)); } } - ICollection IDictionary.Values - { - get - { - if (values == null) values = new ValueCollection(this); - return values; - } - } + ICollection IDictionary.Values => _values ?? (_values = new ValueCollection(this)); - IEnumerable IReadOnlyDictionary.Values - { - get - { - if (values == null) values = new ValueCollection(this); - return values; - } - } + IEnumerable IReadOnlyDictionary.Values => _values ?? (_values = new ValueCollection(this)); public TValue this[TKey key] { get { - int i = FindEntry(key); - if (i >= 0) return entries[i].value; + ref Entry entry = ref FindEntry(key, out bool found); + if (found) + { + return entry.value; + } ThrowHelper.ThrowKeyNotFoundException(); return default(TValue); } set { - bool modified = TryInsert(key, value, InsertionBehavior.OverwriteExisting); + bool modified = TryInsert(key, value); Debug.Assert(modified); } } public void Add(TKey key, TValue value) { - bool modified = TryInsert(key, value, InsertionBehavior.ThrowOnExisting); + bool modified = TryInsert(key, value); Debug.Assert(modified); // If there was an existing key and the Add failed, an exception will already have been thrown. } @@ -254,18 +236,14 @@ void ICollection>.Add(KeyValuePair keyV bool ICollection>.Contains(KeyValuePair keyValuePair) { - int i = FindEntry(keyValuePair.Key); - if (i >= 0 && EqualityComparer.Default.Equals(entries[i].value, keyValuePair.Value)) - { - return true; - } - return false; + ref Entry entry = ref FindEntry(keyValuePair.Key, out bool found); + return found && EqualityComparer.Default.Equals(entry.value, keyValuePair.Value); } bool ICollection>.Remove(KeyValuePair keyValuePair) { - int i = FindEntry(keyValuePair.Key); - if (i >= 0 && EqualityComparer.Default.Equals(entries[i].value, keyValuePair.Value)) + ref Entry entry = ref FindEntry(keyValuePair.Key, out bool found); + if (found && EqualityComparer.Default.Equals(entry.value, keyValuePair.Value)) { Remove(keyValuePair.Key); return true; @@ -275,29 +253,38 @@ bool ICollection>.Remove(KeyValuePair k public void Clear() { + int count = _count; if (count > 0) { - for (int i = 0; i < buckets.Length; i++) buckets[i] = -1; - Array.Clear(entries, 0, count); - freeList = -1; - count = 0; - freeCount = 0; - version++; + int[] buckets = _buckets; + for (int i = 0; i < buckets.Length; i++) + { + buckets[i] = -1; + } + Array.Clear(_entries, 0, count); + _freeList = -1; + _count = 0; + _freeCount = 0; + _version++; } } public bool ContainsKey(TKey key) { - return FindEntry(key) >= 0; + FindEntry(key, out bool found); + return found; } public bool ContainsValue(TValue value) { + Entry[] entries = _entries; + int count = _count; if (value == null) { for (int i = 0; i < count; i++) { - if (entries[i].hashCode >= 0 && entries[i].value == null) return true; + ref Entry entry = ref entries[i]; + if (entry.hashCode >= 0 && entry.value == null) return true; } } else @@ -305,7 +292,8 @@ public bool ContainsValue(TValue value) EqualityComparer c = EqualityComparer.Default; for (int i = 0; i < count; i++) { - if (entries[i].hashCode >= 0 && c.Equals(entries[i].value, value)) return true; + ref Entry entry = ref entries[i]; + if (entry.hashCode >= 0 && c.Equals(entry.value, value)) return true; } } return false; @@ -328,26 +316,21 @@ private void CopyTo(KeyValuePair[] array, int index) ThrowHelper.ThrowArgumentException(ExceptionResource.Arg_ArrayPlusOffTooSmall); } - int count = this.count; - Entry[] entries = this.entries; + int count = this._count; + Entry[] entries = this._entries; for (int i = 0; i < count; i++) { - if (entries[i].hashCode >= 0) + ref Entry entry = ref entries[i]; + if (entry.hashCode >= 0) { - array[index++] = new KeyValuePair(entries[i].key, entries[i].value); + array[index++] = new KeyValuePair(entry.key, entry.value); } } } - public Enumerator GetEnumerator() - { - return new Enumerator(this, Enumerator.KeyValuePair); - } + public Enumerator GetEnumerator() => new Enumerator(this, Enumerator.KeyValuePair); - IEnumerator> IEnumerable>.GetEnumerator() - { - return new Enumerator(this, Enumerator.KeyValuePair); - } + IEnumerator> IEnumerable>.GetEnumerator() => new Enumerator(this, Enumerator.KeyValuePair); public virtual void GetObjectData(SerializationInfo info, StreamingContext context) { @@ -355,10 +338,10 @@ public virtual void GetObjectData(SerializationInfo info, StreamingContext conte { ThrowHelper.ThrowArgumentNullException(ExceptionArgument.info); } - info.AddValue(VersionName, version); - info.AddValue(ComparerName, comparer, typeof(IEqualityComparer)); - info.AddValue(HashSizeName, buckets == null ? 0 : buckets.Length); //This is the length of the bucket array. - if (buckets != null) + info.AddValue(VersionName, _version); + info.AddValue(ComparerName, _customComparer ?? EqualityComparer.Default, typeof(IEqualityComparer)); + info.AddValue(HashSizeName, _buckets == null ? 0 : _buckets.Length); //This is the length of the bucket array. + if (_buckets != null) { KeyValuePair[] array = new KeyValuePair[Count]; CopyTo(array, 0); @@ -366,98 +349,227 @@ public virtual void GetObjectData(SerializationInfo info, StreamingContext conte } } - private int FindEntry(TKey key) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private ref Entry FindEntry(TKey key, out bool found) { if (key == null) { ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); } + IEqualityComparer comparer = _customComparer; + if (comparer == null) + { + return ref FindEntry(key, out found, null); + } + else + { + return ref FindEntry(key, out found, comparer); + } + } + + private ref Entry FindEntry(TKey key, out bool found, IEqualityComparer customComparer) where TComparer : struct, IComparerType + { + Debug.Assert(typeof(TComparer) == typeof(DefaultComparer) || typeof(TComparer) == typeof(CustomComparer)); + Debug.Assert(key != null); + + found = true; + int[] buckets = _buckets; if (buckets != null) { - int hashCode = comparer.GetHashCode(key) & 0x7FFFFFFF; - for (int i = buckets[hashCode % buckets.Length]; i >= 0; i = entries[i].next) + int hashCode = 0; + if (typeof(TComparer) == typeof(DefaultComparer)) + { + // Keys are never null + hashCode = key.GetHashCode() & 0x7FFFFFFF; + } + else if (typeof(TComparer) == typeof(CustomComparer)) { - if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key)) return i; + hashCode = customComparer.GetHashCode(key) & 0x7FFFFFFF; + } + + Entry[] entries = _entries; + int i = buckets[HashHelpers.FindBucket((uint)hashCode, (uint)buckets.Length)]; + while (i >= 0) + { + ref Entry entry = ref entries[i]; + if (entry.hashCode == hashCode) + { + if (typeof(TComparer) == typeof(DefaultComparer)) + { + if (EqualityComparer.Default.Equals(entry.key, key)) + { + return ref entry; + } + } + else if (typeof(TComparer) == typeof(CustomComparer)) + { + if (customComparer.Equals(entry.key, key)) + { + return ref entry; + } + } + } + + i = entry.next; } } - return -1; + + found = false; + return ref NotFound; + } + + private ref Entry NotFound + { + [MethodImpl(MethodImplOptions.NoInlining)] + get => ref s_nullEntry; } private void Initialize(int capacity) { int size = HashHelpers.GetPrime(capacity); - buckets = new int[size]; - for (int i = 0; i < buckets.Length; i++) buckets[i] = -1; - entries = new Entry[size]; - freeList = -1; + int[] buckets = new int[size]; + for (int i = 0; i < buckets.Length; i++) + { + buckets[i] = -1; + } + _entries = new Entry[size]; + _freeList = -1; + + _buckets = buckets; } - private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private bool TryInsert(TKey key, TValue value) where TInsertionBehavior : struct, IInsertionBehavior { + Debug.Assert(typeof(TInsertionBehavior) == typeof(RejectIfExisting) || typeof(TInsertionBehavior) == typeof(OverwriteExisting) || typeof(TInsertionBehavior) == typeof(ThrowOnExisting)); + if (key == null) { ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); } - if (buckets == null) Initialize(0); - int hashCode = comparer.GetHashCode(key) & 0x7FFFFFFF; - int targetBucket = hashCode % buckets.Length; + IEqualityComparer customComparer = _customComparer; + if (customComparer == null) + { + return TryInsert(key, value, null); + } + else + { + return TryInsert(key, value, customComparer); + } + } + + private bool TryInsert(TKey key, TValue value, IEqualityComparer customComparer) where TInsertionBehavior : struct, IInsertionBehavior where TComparer : struct, IComparerType + { + Debug.Assert(typeof(TComparer) == typeof(DefaultComparer) || typeof(TComparer) == typeof(CustomComparer)); + Debug.Assert(typeof(TInsertionBehavior) == typeof(RejectIfExisting) || typeof(TInsertionBehavior) == typeof(OverwriteExisting) || typeof(TInsertionBehavior) == typeof(ThrowOnExisting)); + Debug.Assert(key != null); + + if (_buckets == null) + { + Initialize(0); + } + int[] buckets = _buckets; + // Keys are never null + int hashCode = 0; + if (typeof(TComparer) == typeof(DefaultComparer)) + { + // Keys are never null + hashCode = key.GetHashCode() & 0x7FFFFFFF; + } + else if (typeof(TComparer) == typeof(CustomComparer)) + { + hashCode = customComparer.GetHashCode(key) & 0x7FFFFFFF; + } + + uint targetBucket = HashHelpers.FindBucket((uint)hashCode, (uint)buckets.Length); + + // Count collisions to see if we need to move to randomized hashing for string keys int collisionCount = 0; + Entry[] entries = _entries; - for (int i = buckets[targetBucket]; i >= 0; i = entries[i].next) + int i = buckets[targetBucket]; + while (i >= 0) { - if (entries[i].hashCode == hashCode && comparer.Equals(entries[i].key, key)) + ref Entry candidateEntry = ref entries[i]; + if (candidateEntry.hashCode == hashCode) { - if (behavior == InsertionBehavior.OverwriteExisting) + bool keysEqual = false; + if (typeof(TComparer) == typeof(DefaultComparer)) { - entries[i].value = value; - version++; - return true; + keysEqual = EqualityComparer.Default.Equals(candidateEntry.key, key); } - - if (behavior == InsertionBehavior.ThrowOnExisting) + else if (typeof(TComparer) == typeof(CustomComparer)) { - ThrowHelper.ThrowAddingDuplicateWithKeyArgumentException(key); + keysEqual = customComparer.Equals(candidateEntry.key, key); } - return false; + if (keysEqual) + { + if (typeof(TInsertionBehavior) == typeof(OverwriteExisting)) + { + candidateEntry.value = value; + _version++; + return true; + } + else if (typeof(TInsertionBehavior) == typeof(RejectIfExisting)) + { + return false; + } + else if (typeof(TInsertionBehavior) == typeof(ThrowOnExisting)) + { + ThrowHelper.ThrowAddingDuplicateWithKeyArgumentException(key); + } + } } - collisionCount++; + i = candidateEntry.next; + if (typeof(TComparer) == typeof(CustomComparer)) + { + collisionCount++; + } } + int index; - if (freeCount > 0) - { - index = freeList; - freeList = entries[index].next; - freeCount--; - } - else + if (_freeCount == 0) { + int count = _count; if (count == entries.Length) { - Resize(); - targetBucket = hashCode % buckets.Length; + Resize(HashHelpers.ExpandPrime(count)); + // Update local cached items + buckets = _buckets; + entries = _entries; + targetBucket = HashHelpers.FindBucket((uint)hashCode, (uint)buckets.Length); } index = count; - count++; + _count = count + 1; + } + else + { + index = _freeList; + _freeList = entries[index].next; + _freeCount--; } - entries[index].hashCode = hashCode; - entries[index].next = buckets[targetBucket]; - entries[index].key = key; - entries[index].value = value; + ref Entry entry = ref entries[index]; + entry.hashCode = hashCode; + entry.next = buckets[targetBucket]; buckets[targetBucket] = index; - version++; - - // If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing - // i.e. EqualityComparer.Default. + entry.key = key; + entry.value = value; + _version++; - if (collisionCount > HashHelpers.HashCollisionThreshold && comparer == NonRandomizedStringEqualityComparer.Default) + if (typeof(TComparer) == typeof(CustomComparer)) { - comparer = (IEqualityComparer)EqualityComparer.Default; - Resize(entries.Length, true); + // If we hit the collision threshold we'll need to switch to the comparer which is using randomized string hashing + // i.e. EqualityComparer.Default. + if (collisionCount > HashHelpers.HashCollisionThreshold && ReferenceEquals(_customComparer, NonRandomizedStringEqualityComparer.Default)) + { + _customComparer = null; // Use default comparer + Resize(_entries.Length); + } } return true; @@ -479,14 +591,18 @@ public virtual void OnDeserialization(Object sender) int realVersion = siInfo.GetInt32(VersionName); int hashsize = siInfo.GetInt32(HashSizeName); - comparer = (IEqualityComparer)siInfo.GetValue(ComparerName, typeof(IEqualityComparer)); + _customComparer = (IEqualityComparer)siInfo.GetValue(ComparerName, typeof(IEqualityComparer)); if (hashsize != 0) { - buckets = new int[hashsize]; - for (int i = 0; i < buckets.Length; i++) buckets[i] = -1; - entries = new Entry[hashsize]; - freeList = -1; + int[] buckets = new int[hashsize]; + for (int i = 0; i < buckets.Length; i++) + { + buckets[i] = -1; + } + _buckets = buckets; + _entries = new Entry[hashsize]; + _freeList = -1; KeyValuePair[] array = (KeyValuePair[]) siInfo.GetValue(KeyValuePairsName, typeof(KeyValuePair[])); @@ -507,51 +623,67 @@ public virtual void OnDeserialization(Object sender) } else { - buckets = null; + _buckets = null; } - version = realVersion; + _version = realVersion; HashHelpers.SerializationInfoTable.Remove(this); } - private void Resize() + private void Resize(int newSize) where TResizeBehavior : struct, IResizeBehavior { - Resize(HashHelpers.ExpandPrime(count), false); - } + Debug.Assert(typeof(TResizeBehavior) == typeof(KeepHashcodes) || typeof(TResizeBehavior) == typeof(GenerateNewHashcodes)); + // Should only be rehashing when switching from custom NonRandomised string to default randomised + Debug.Assert(typeof(TResizeBehavior) == typeof(KeepHashcodes) || _customComparer == null); + Debug.Assert(newSize >= _entries.Length); - private void Resize(int newSize, bool forceNewHashCodes) - { - Debug.Assert(newSize >= entries.Length); int[] newBuckets = new int[newSize]; - for (int i = 0; i < newBuckets.Length; i++) newBuckets[i] = -1; + for (int i = 0; i < newBuckets.Length; i++) + { + newBuckets[i] = -1; + } + + int count = _count; Entry[] newEntries = new Entry[newSize]; - Array.Copy(entries, 0, newEntries, 0, count); - if (forceNewHashCodes) + Array.Copy(_entries, 0, newEntries, 0, count); + + // If the Jit eliminates bounds checks for a loop not limited by Length + // if the variable has been pre-confirmed + // add a check that (uint)count < (uint)newEntries.Length + + for (int i = 0; i < count; i++) { - for (int i = 0; i < count; i++) + if (typeof(TResizeBehavior) == typeof(GenerateNewHashcodes)) { - if (newEntries[i].hashCode != -1) + ref Entry entry = ref newEntries[i]; + int hashCode = entry.hashCode; + if (hashCode >= 0) { - newEntries[i].hashCode = (comparer.GetHashCode(newEntries[i].key) & 0x7FFFFFFF); + uint targetBucket = HashHelpers.FindBucket((uint)hashCode, (uint)newBuckets.Length); + // Keys are never null + hashCode = entry.key.GetHashCode() & 0x7FFFFFFF; + entry.hashCode = hashCode; + entry.next = newBuckets[targetBucket]; + newBuckets[targetBucket] = i; } } - } - for (int i = 0; i < count; i++) - { - if (newEntries[i].hashCode >= 0) + else { - int bucket = newEntries[i].hashCode % newSize; - newEntries[i].next = newBuckets[bucket]; - newBuckets[bucket] = i; + int hashCode = newEntries[i].hashCode; + if (hashCode >= 0) + { + uint targetBucket = HashHelpers.FindBucket((uint)hashCode, (uint)newBuckets.Length); + newEntries[i].next = newBuckets[targetBucket]; + newBuckets[targetBucket] = i; + } } } - buckets = newBuckets; - entries = newEntries; + + _buckets = newBuckets; + _entries = newEntries; } - // The overload Remove(TKey key, out TValue value) is a copy of this method with one additional - // statement to copy the value for entry being removed into the output parameter. - // Code has been intentionally duplicated for performance reasons. + [MethodImpl(MethodImplOptions.AggressiveInlining)] public bool Remove(TKey key) { if (key == null) @@ -559,53 +691,30 @@ public bool Remove(TKey key) ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); } - if (buckets != null) + IEqualityComparer customComparer = _customComparer; + bool success; + // Compiler doesn't support ref ternary yet https://github.com/dotnet/roslyn/issues/17797 + if (customComparer == null) { - int hashCode = comparer.GetHashCode(key) & 0x7FFFFFFF; - int bucket = hashCode % buckets.Length; - int last = -1; - int i = buckets[bucket]; - while (i >= 0) + ref Entry entry = ref Remove(key, out success, null); + if (success && RuntimeHelpers.IsReferenceOrContainsReferences()) { - ref Entry entry = ref entries[i]; - - if (entry.hashCode == hashCode && comparer.Equals(entry.key, key)) - { - if (last < 0) - { - buckets[bucket] = entry.next; - } - else - { - entries[last].next = entry.next; - } - entry.hashCode = -1; - entry.next = freeList; - - if (RuntimeHelpers.IsReferenceOrContainsReferences()) - { - entry.key = default(TKey); - } - if (RuntimeHelpers.IsReferenceOrContainsReferences()) - { - entry.value = default(TValue); - } - freeList = i; - freeCount++; - version++; - return true; - } - - last = i; - i = entry.next; + entry.value = default(TValue); } } - return false; + else + { + ref Entry entry = ref Remove(key, out success, customComparer); + if (success && RuntimeHelpers.IsReferenceOrContainsReferences()) + { + entry.value = default(TValue); + } + } + + return success; } - // This overload is a copy of the overload Remove(TKey key) with one additional - // statement to copy the value for entry being removed into the output parameter. - // Code has been intentionally duplicated for performance reasons. + [MethodImpl(MethodImplOptions.AggressiveInlining)] public bool Remove(TKey key, out TValue value) { if (key == null) @@ -613,76 +722,119 @@ public bool Remove(TKey key, out TValue value) ThrowHelper.ThrowArgumentNullException(ExceptionArgument.key); } + IEqualityComparer customComparer = _customComparer; + bool success; + // Compiler doesn't support ref ternary yet https://github.com/dotnet/roslyn/issues/17797 + if (customComparer == null) + { + ref Entry entry = ref Remove(key, out success, null); + value = entry.value; + if (success && RuntimeHelpers.IsReferenceOrContainsReferences()) + { + entry.value = default(TValue); + } + } + else + { + ref Entry entry = ref Remove(key, out success, customComparer); + value = entry.value; + if (success && RuntimeHelpers.IsReferenceOrContainsReferences()) + { + entry.value = default(TValue); + } + } + + return success; + } + + private ref Entry Remove(TKey key, out bool success, IEqualityComparer customComparer) where TComparer : struct, IComparerType + { + Debug.Assert(typeof(TComparer) == typeof(DefaultComparer) || typeof(TComparer) == typeof(CustomComparer)); + Debug.Assert(key != null); + + int[] buckets = _buckets; if (buckets != null) { - int hashCode = comparer.GetHashCode(key) & 0x7FFFFFFF; - int bucket = hashCode % buckets.Length; + int hashCode = 0; + if (typeof(TComparer) == typeof(DefaultComparer)) + { + // Keys are never null + hashCode = key.GetHashCode() & 0x7FFFFFFF; + } + else if (typeof(TComparer) == typeof(CustomComparer)) + { + hashCode = customComparer.GetHashCode(key) & 0x7FFFFFFF; + } + int last = -1; - int i = buckets[bucket]; + ref int bucket = ref buckets[HashHelpers.FindBucket((uint)hashCode, (uint)buckets.Length)]; + + Entry[] entries = _entries; + int i = bucket; while (i >= 0) { ref Entry entry = ref entries[i]; - - if (entry.hashCode == hashCode && comparer.Equals(entry.key, key)) + if (entry.hashCode == hashCode) { - if (last < 0) + bool keysEqual = false; + if (typeof(TComparer) == typeof(DefaultComparer)) { - buckets[bucket] = entry.next; + keysEqual = EqualityComparer.Default.Equals(entry.key, key); } - else + else if (typeof(TComparer) == typeof(CustomComparer)) { - entries[last].next = entry.next; + keysEqual = customComparer.Equals(entry.key, key); } - value = entry.value; - - entry.hashCode = -1; - entry.next = freeList; - - if (RuntimeHelpers.IsReferenceOrContainsReferences()) - { - entry.key = default(TKey); - } - if (RuntimeHelpers.IsReferenceOrContainsReferences()) + if (keysEqual) { - entry.value = default(TValue); + if (last < 0) + { + bucket = entry.next; + } + else + { + entries[last].next = entry.next; + } + + entry.hashCode = -1; + entry.next = _freeList; + + if (RuntimeHelpers.IsReferenceOrContainsReferences()) + { + entry.key = default(TKey); + } + + _freeList = i; + _freeCount++; + _version++; + success = true; + + return ref entry; } - freeList = i; - freeCount++; - version++; - return true; } last = i; i = entry.next; } } - value = default(TValue); - return false; + + success = false; + return ref NotFound; } public bool TryGetValue(TKey key, out TValue value) { - int i = FindEntry(key); - if (i >= 0) - { - value = entries[i].value; - return true; - } - value = default(TValue); - return false; + ref Entry entry = ref FindEntry(key, out bool found); + value = found ? entry.value : default(TValue); + return found; } - public bool TryAdd(TKey key, TValue value) => TryInsert(key, value, InsertionBehavior.None); - bool ICollection>.IsReadOnly - { - get { return false; } - } + public bool TryAdd(TKey key, TValue value) => TryInsert(key, value); - void ICollection>.CopyTo(KeyValuePair[] array, int index) - { - CopyTo(array, index); - } + bool ICollection>.IsReadOnly => false; + + void ICollection>.CopyTo(KeyValuePair[] array, int index) => CopyTo(array, index); void ICollection.CopyTo(Array array, int index) { @@ -719,8 +871,8 @@ void ICollection.CopyTo(Array array, int index) else if (array is DictionaryEntry[]) { DictionaryEntry[] dictEntryArray = array as DictionaryEntry[]; - Entry[] entries = this.entries; - for (int i = 0; i < count; i++) + Entry[] entries = this._entries; + for (int i = 0; i < _count; i++) { if (entries[i].hashCode >= 0) { @@ -738,8 +890,8 @@ void ICollection.CopyTo(Array array, int index) try { - int count = this.count; - Entry[] entries = this.entries; + int count = this._count; + Entry[] entries = this._entries; for (int i = 0; i < count; i++) { if (entries[i].hashCode >= 0) @@ -755,15 +907,9 @@ void ICollection.CopyTo(Array array, int index) } } - IEnumerator IEnumerable.GetEnumerator() - { - return new Enumerator(this, Enumerator.KeyValuePair); - } + IEnumerator IEnumerable.GetEnumerator() => new Enumerator(this, Enumerator.KeyValuePair); - bool ICollection.IsSynchronized - { - get { return false; } - } + bool ICollection.IsSynchronized => false; object ICollection.SyncRoot { @@ -777,25 +923,13 @@ object ICollection.SyncRoot } } - bool IDictionary.IsFixedSize - { - get { return false; } - } + bool IDictionary.IsFixedSize => false; - bool IDictionary.IsReadOnly - { - get { return false; } - } + bool IDictionary.IsReadOnly => false; - ICollection IDictionary.Keys - { - get { return (ICollection)Keys; } - } + ICollection IDictionary.Keys => (ICollection)Keys; - ICollection IDictionary.Values - { - get { return (ICollection)Values; } - } + ICollection IDictionary.Values => (ICollection)Values; object IDictionary.this[object key] { @@ -803,10 +937,10 @@ object IDictionary.this[object key] { if (IsCompatibleKey(key)) { - int i = FindEntry((TKey)key); - if (i >= 0) + ref Entry entry = ref FindEntry((TKey)key, out bool found); + if (found) { - return entries[i].value; + return entry.value; } } return null; @@ -874,20 +1008,9 @@ void IDictionary.Add(object key, object value) } } - bool IDictionary.Contains(object key) - { - if (IsCompatibleKey(key)) - { - return ContainsKey((TKey)key); - } + bool IDictionary.Contains(object key) => IsCompatibleKey(key) && ContainsKey((TKey)key); - return false; - } - - IDictionaryEnumerator IDictionary.GetEnumerator() - { - return new Enumerator(this, Enumerator.DictEntry); - } + IDictionaryEnumerator IDictionary.GetEnumerator() => new Enumerator(this, Enumerator.DictEntry); void IDictionary.Remove(object key) { @@ -912,7 +1035,7 @@ public struct Enumerator : IEnumerator>, internal Enumerator(Dictionary dictionary, int getEnumeratorRetType) { this.dictionary = dictionary; - version = dictionary.version; + version = dictionary._version; index = 0; this.getEnumeratorRetType = getEnumeratorRetType; current = new KeyValuePair(); @@ -920,16 +1043,16 @@ internal Enumerator(Dictionary dictionary, int getEnumeratorRetTyp public bool MoveNext() { - if (version != dictionary.version) + if (version != dictionary._version) { ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumFailedVersion(); } // Use unsigned comparison since we set index to dictionary.count+1 when the enumeration ends. // dictionary.count+1 could be negative if dictionary.count is Int32.MaxValue - while ((uint)index < (uint)dictionary.count) + while ((uint)index < (uint)dictionary._count) { - ref Entry entry = ref dictionary.entries[index++]; + ref Entry entry = ref dictionary._entries[index++]; if (entry.hashCode >= 0) { @@ -938,15 +1061,12 @@ public bool MoveNext() } } - index = dictionary.count + 1; + index = dictionary._count + 1; current = new KeyValuePair(); return false; } - public KeyValuePair Current - { - get { return current; } - } + public KeyValuePair Current => current; public void Dispose() { @@ -956,7 +1076,7 @@ object IEnumerator.Current { get { - if (index == 0 || (index == dictionary.count + 1)) + if (index == 0 || (index == dictionary._count + 1)) { ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumOpCantHappen(); } @@ -974,7 +1094,7 @@ object IEnumerator.Current void IEnumerator.Reset() { - if (version != dictionary.version) + if (version != dictionary._version) { ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumFailedVersion(); } @@ -987,7 +1107,7 @@ DictionaryEntry IDictionaryEnumerator.Entry { get { - if (index == 0 || (index == dictionary.count + 1)) + if (index == 0 || (index == dictionary._count + 1)) { ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumOpCantHappen(); } @@ -1000,7 +1120,7 @@ object IDictionaryEnumerator.Key { get { - if (index == 0 || (index == dictionary.count + 1)) + if (index == 0 || (index == dictionary._count + 1)) { ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumOpCantHappen(); } @@ -1013,7 +1133,7 @@ object IDictionaryEnumerator.Value { get { - if (index == 0 || (index == dictionary.count + 1)) + if (index == 0 || (index == dictionary._count + 1)) { ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumOpCantHappen(); } @@ -1038,10 +1158,7 @@ public KeyCollection(Dictionary dictionary) this.dictionary = dictionary; } - public Enumerator GetEnumerator() - { - return new Enumerator(dictionary); - } + public Enumerator GetEnumerator() => new Enumerator(dictionary); public void CopyTo(TKey[] array, int index) { @@ -1060,38 +1177,23 @@ public void CopyTo(TKey[] array, int index) ThrowHelper.ThrowArgumentException(ExceptionResource.Arg_ArrayPlusOffTooSmall); } - int count = dictionary.count; - Entry[] entries = dictionary.entries; + int count = dictionary._count; + Entry[] entries = dictionary._entries; for (int i = 0; i < count; i++) { if (entries[i].hashCode >= 0) array[index++] = entries[i].key; } } - public int Count - { - get { return dictionary.Count; } - } + public int Count => dictionary.Count; - bool ICollection.IsReadOnly - { - get { return true; } - } + bool ICollection.IsReadOnly => true; - void ICollection.Add(TKey item) - { - ThrowHelper.ThrowNotSupportedException(ExceptionResource.NotSupported_KeyCollectionSet); - } + void ICollection.Add(TKey item) => ThrowHelper.ThrowNotSupportedException(ExceptionResource.NotSupported_KeyCollectionSet); - void ICollection.Clear() - { - ThrowHelper.ThrowNotSupportedException(ExceptionResource.NotSupported_KeyCollectionSet); - } + void ICollection.Clear() => ThrowHelper.ThrowNotSupportedException(ExceptionResource.NotSupported_KeyCollectionSet); - bool ICollection.Contains(TKey item) - { - return dictionary.ContainsKey(item); - } + bool ICollection.Contains(TKey item) => dictionary.ContainsKey(item); bool ICollection.Remove(TKey item) { @@ -1099,15 +1201,9 @@ bool ICollection.Remove(TKey item) return false; } - IEnumerator IEnumerable.GetEnumerator() - { - return new Enumerator(dictionary); - } + IEnumerator IEnumerable.GetEnumerator() => new Enumerator(dictionary); - IEnumerator IEnumerable.GetEnumerator() - { - return new Enumerator(dictionary); - } + IEnumerator IEnumerable.GetEnumerator() => new Enumerator(dictionary); void ICollection.CopyTo(Array array, int index) { @@ -1149,8 +1245,8 @@ void ICollection.CopyTo(Array array, int index) ThrowHelper.ThrowArgumentException_Argument_InvalidArrayType(); } - int count = dictionary.count; - Entry[] entries = dictionary.entries; + int count = dictionary._count; + Entry[] entries = dictionary._entries; try { for (int i = 0; i < count; i++) @@ -1165,15 +1261,9 @@ void ICollection.CopyTo(Array array, int index) } } - bool ICollection.IsSynchronized - { - get { return false; } - } + bool ICollection.IsSynchronized => false; - Object ICollection.SyncRoot - { - get { return ((ICollection)dictionary).SyncRoot; } - } + Object ICollection.SyncRoot => ((ICollection)dictionary).SyncRoot; public struct Enumerator : IEnumerator, System.Collections.IEnumerator { @@ -1185,7 +1275,7 @@ public struct Enumerator : IEnumerator, System.Collections.IEnumerator internal Enumerator(Dictionary dictionary) { this.dictionary = dictionary; - version = dictionary.version; + version = dictionary._version; index = 0; currentKey = default(TKey); } @@ -1196,14 +1286,14 @@ public void Dispose() public bool MoveNext() { - if (version != dictionary.version) + if (version != dictionary._version) { ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumFailedVersion(); } - while ((uint)index < (uint)dictionary.count) + while ((uint)index < (uint)dictionary._count) { - ref Entry entry = ref dictionary.entries[index++]; + ref Entry entry = ref dictionary._entries[index++]; if (entry.hashCode >= 0) { @@ -1212,24 +1302,18 @@ public bool MoveNext() } } - index = dictionary.count + 1; + index = dictionary._count + 1; currentKey = default(TKey); return false; } - public TKey Current - { - get - { - return currentKey; - } - } + public TKey Current => currentKey; Object System.Collections.IEnumerator.Current { get { - if (index == 0 || (index == dictionary.count + 1)) + if (index == 0 || (index == dictionary._count + 1)) { ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumOpCantHappen(); } @@ -1240,7 +1324,7 @@ Object System.Collections.IEnumerator.Current void System.Collections.IEnumerator.Reset() { - if (version != dictionary.version) + if (version != dictionary._version) { ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumFailedVersion(); } @@ -1266,10 +1350,7 @@ public ValueCollection(Dictionary dictionary) this.dictionary = dictionary; } - public Enumerator GetEnumerator() - { - return new Enumerator(dictionary); - } + public Enumerator GetEnumerator() => new Enumerator(dictionary); public void CopyTo(TValue[] array, int index) { @@ -1288,28 +1369,19 @@ public void CopyTo(TValue[] array, int index) ThrowHelper.ThrowArgumentException(ExceptionResource.Arg_ArrayPlusOffTooSmall); } - int count = dictionary.count; - Entry[] entries = dictionary.entries; + int count = dictionary._count; + Entry[] entries = dictionary._entries; for (int i = 0; i < count; i++) { if (entries[i].hashCode >= 0) array[index++] = entries[i].value; } } - public int Count - { - get { return dictionary.Count; } - } + public int Count => dictionary.Count; - bool ICollection.IsReadOnly - { - get { return true; } - } + bool ICollection.IsReadOnly => true; - void ICollection.Add(TValue item) - { - ThrowHelper.ThrowNotSupportedException(ExceptionResource.NotSupported_ValueCollectionSet); - } + void ICollection.Add(TValue item) => ThrowHelper.ThrowNotSupportedException(ExceptionResource.NotSupported_ValueCollectionSet); bool ICollection.Remove(TValue item) { @@ -1317,25 +1389,13 @@ bool ICollection.Remove(TValue item) return false; } - void ICollection.Clear() - { - ThrowHelper.ThrowNotSupportedException(ExceptionResource.NotSupported_ValueCollectionSet); - } + void ICollection.Clear() => ThrowHelper.ThrowNotSupportedException(ExceptionResource.NotSupported_ValueCollectionSet); - bool ICollection.Contains(TValue item) - { - return dictionary.ContainsValue(item); - } + bool ICollection.Contains(TValue item) => dictionary.ContainsValue(item); - IEnumerator IEnumerable.GetEnumerator() - { - return new Enumerator(dictionary); - } + IEnumerator IEnumerable.GetEnumerator() => new Enumerator(dictionary); - IEnumerator IEnumerable.GetEnumerator() - { - return new Enumerator(dictionary); - } + IEnumerator IEnumerable.GetEnumerator() => new Enumerator(dictionary); void ICollection.CopyTo(Array array, int index) { @@ -1375,8 +1435,8 @@ void ICollection.CopyTo(Array array, int index) ThrowHelper.ThrowArgumentException_Argument_InvalidArrayType(); } - int count = dictionary.count; - Entry[] entries = dictionary.entries; + int count = dictionary._count; + Entry[] entries = dictionary._entries; try { for (int i = 0; i < count; i++) @@ -1391,15 +1451,9 @@ void ICollection.CopyTo(Array array, int index) } } - bool ICollection.IsSynchronized - { - get { return false; } - } + bool ICollection.IsSynchronized => false; - Object ICollection.SyncRoot - { - get { return ((ICollection)dictionary).SyncRoot; } - } + Object ICollection.SyncRoot => ((ICollection)dictionary).SyncRoot; public struct Enumerator : IEnumerator, System.Collections.IEnumerator { @@ -1411,7 +1465,7 @@ public struct Enumerator : IEnumerator, System.Collections.IEnumerator internal Enumerator(Dictionary dictionary) { this.dictionary = dictionary; - version = dictionary.version; + version = dictionary._version; index = 0; currentValue = default(TValue); } @@ -1422,14 +1476,14 @@ public void Dispose() public bool MoveNext() { - if (version != dictionary.version) + if (version != dictionary._version) { ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumFailedVersion(); } - while ((uint)index < (uint)dictionary.count) + while ((uint)index < (uint)dictionary._count) { - ref Entry entry = ref dictionary.entries[index++]; + ref Entry entry = ref dictionary._entries[index++]; if (entry.hashCode >= 0) { @@ -1437,24 +1491,18 @@ public bool MoveNext() return true; } } - index = dictionary.count + 1; + index = dictionary._count + 1; currentValue = default(TValue); return false; } - public TValue Current - { - get - { - return currentValue; - } - } + public TValue Current => currentValue; Object System.Collections.IEnumerator.Current { get { - if (index == 0 || (index == dictionary.count + 1)) + if (index == 0 || (index == dictionary._count + 1)) { ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumOpCantHappen(); } @@ -1465,7 +1513,7 @@ Object System.Collections.IEnumerator.Current void System.Collections.IEnumerator.Reset() { - if (version != dictionary.version) + if (version != dictionary._version) { ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumFailedVersion(); } diff --git a/src/mscorlib/src/System/Collections/Hashtable.cs b/src/mscorlib/src/System/Collections/Hashtable.cs index 0550030e7ca6..87a7141686c5 100644 --- a/src/mscorlib/src/System/Collections/Hashtable.cs +++ b/src/mscorlib/src/System/Collections/Hashtable.cs @@ -1472,6 +1472,7 @@ public static int ExpandPrime(int oldSize) return GetPrime(newSize); } + public static uint FindBucket(uint hashCode, uint limit) => hashCode % limit; // This is the maximum prime smaller than Array.MaxArrayLength public const int MaxPrimeArrayLength = 0x7FEFFFFD;