diff --git a/Microsoft.Toolkit.HighPerformance/Buffers/StringPool.cs b/Microsoft.Toolkit.HighPerformance/Buffers/StringPool.cs new file mode 100644 index 00000000000..d2a98ae8c94 --- /dev/null +++ b/Microsoft.Toolkit.HighPerformance/Buffers/StringPool.cs @@ -0,0 +1,826 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +#if NETCOREAPP3_1 +using System.Numerics; +#endif +using System.Runtime.CompilerServices; +using System.Text; +using Microsoft.Toolkit.HighPerformance.Extensions; +#if !NETSTANDARD1_4 +using Microsoft.Toolkit.HighPerformance.Helpers; +#endif + +#nullable enable + +namespace Microsoft.Toolkit.HighPerformance.Buffers +{ + /// + /// A configurable pool for instances. This can be used to minimize allocations + /// when creating multiple instances from buffers of values. + /// The method provides a best-effort alternative to just creating + /// a new instance every time, in order to minimize the number of duplicated instances. + /// The type will internally manage a highly efficient priority queue for the + /// cached instances, so that when the full capacity is reached, the least frequently + /// used values will be automatically discarded to leave room for new values to cache. + /// + public sealed class StringPool + { + /// + /// The size used by default by the parameterless constructor. + /// + private const int DefaultSize = 2048; + + /// + /// The minimum size for instances. + /// + private const int MinimumSize = 32; + + /// + /// The current array of instances in use. + /// + private readonly FixedSizePriorityMap[] maps; + + /// + /// The total number of maps in use. + /// + private readonly int numberOfMaps; + + /// + /// Initializes a new instance of the class. + /// + public StringPool() + : this(DefaultSize) + { + } + + /// + /// Initializes a new instance of the class. + /// + /// The minimum size for the pool to create. + public StringPool(int minimumSize) + { + if (minimumSize <= 0) + { + ThrowArgumentOutOfRangeException(); + } + + // Set the minimum size + minimumSize = Math.Max(minimumSize, MinimumSize); + + // Calculates the rounded up factors for a specific size/factor pair + static void FindFactors(int size, int factor, out int x, out int y) + { + double + a = Math.Sqrt((double)size / factor), + b = factor * a; + + x = RoundUpPowerOfTwo((int)a); + y = RoundUpPowerOfTwo((int)b); + } + + // We want to find two powers of 2 factors that produce a number + // that is at least equal to the requested size. In order to find the + // combination producing the optimal factors (with the product being as + // close as possible to the requested size), we test a number of ratios + // that we consider acceptable, and pick the best results produced. + // The ratio between maps influences the number of objects being allocated, + // as well as the multithreading performance when locking on maps. + // We still want to contraint this number to avoid situations where we + // have a way too high number of maps compared to total size. + FindFactors(minimumSize, 2, out int x2, out int y2); + FindFactors(minimumSize, 3, out int x3, out int y3); + FindFactors(minimumSize, 4, out int x4, out int y4); + + int + p2 = x2 * y2, + p3 = x3 * y3, + p4 = x4 * y4; + + if (p3 < p2) + { + p2 = p3; + x2 = x3; + y2 = y3; + } + + if (p4 < p2) + { + p2 = p4; + x2 = x4; + y2 = y4; + } + + Span span = this.maps = new FixedSizePriorityMap[x2]; + + // We preallocate the maps in advance, since each bucket only contains the + // array field, which is not preinitialized, so the allocations are minimal. + // This lets us lock on each individual maps when retrieving a string instance. + foreach (ref FixedSizePriorityMap map in span) + { + map = new FixedSizePriorityMap(y2); + } + + this.numberOfMaps = x2; + + Size = p2; + } + + /// + /// Rounds up an value to a power of 2. + /// + /// The input value to round up. + /// The smallest power of two greater than or equal to . + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int RoundUpPowerOfTwo(int x) + { +#if NETCOREAPP3_1 + return 1 << (32 - BitOperations.LeadingZeroCount((uint)(x - 1))); +#else + x--; + x |= x >> 1; + x |= x >> 2; + x |= x >> 4; + x |= x >> 8; + x |= x >> 16; + x++; + + return x; +#endif + } + + /// + /// Gets the shared instance. + /// + /// + /// The shared pool provides a singleton, reusable instance that + /// can be accessed directly, and that pools instances for the entire + /// process. Since is thread-safe, the shared instance can be used + /// concurrently by multiple threads without the need for manual synchronization. + /// + public static StringPool Shared { get; } = new StringPool(); + + /// + /// Gets the total number of that can be stored in the current instance. + /// + public int Size { get; } + + /// + /// Stores a instance in the internal cache. + /// + /// The input instance to cache. + public void Add(string value) + { + if (value.Length == 0) + { + return; + } + + int + hashcode = GetHashCode(value.AsSpan()), + bucketIndex = hashcode & (this.numberOfMaps - 1); + + ref FixedSizePriorityMap map = ref this.maps.DangerousGetReferenceAt(bucketIndex); + + lock (map.SyncRoot) + { + map.Add(value, hashcode); + } + } + + /// + /// Gets a cached instance matching the input content, or stores the input one. + /// + /// The input instance with the contents to use. + /// A instance with the contents of , cached if possible. + public string GetOrAdd(string value) + { + if (value.Length == 0) + { + return string.Empty; + } + + int + hashcode = GetHashCode(value.AsSpan()), + bucketIndex = hashcode & (this.numberOfMaps - 1); + + ref FixedSizePriorityMap map = ref this.maps.DangerousGetReferenceAt(bucketIndex); + + lock (map.SyncRoot) + { + return map.GetOrAdd(value, hashcode); + } + } + + /// + /// Gets a cached instance matching the input content, or creates a new one. + /// + /// The input with the contents to use. + /// A instance with the contents of , cached if possible. + public string GetOrAdd(ReadOnlySpan span) + { + if (span.IsEmpty) + { + return string.Empty; + } + + int + hashcode = GetHashCode(span), + bucketIndex = hashcode & (this.numberOfMaps - 1); + + ref FixedSizePriorityMap map = ref this.maps.DangerousGetReferenceAt(bucketIndex); + + lock (map.SyncRoot) + { + return map.GetOrAdd(span, hashcode); + } + } + + /// + /// Gets a cached instance matching the input content (converted to Unicode), or creates a new one. + /// + /// The input with the contents to use, in a specified encoding. + /// The instance to use to decode the contents of . + /// A instance with the contents of , cached if possible. + public unsafe string GetOrAdd(ReadOnlySpan span, Encoding encoding) + { + if (span.IsEmpty) + { + return string.Empty; + } + + int maxLength = encoding.GetMaxCharCount(span.Length); + + using SpanOwner buffer = SpanOwner.Allocate(maxLength); + + fixed (byte* source = span) + fixed (char* destination = &buffer.DangerousGetReference()) + { + int effectiveLength = encoding.GetChars(source, span.Length, destination, maxLength); + + return GetOrAdd(new ReadOnlySpan(destination, effectiveLength)); + } + } + + /// + /// Tries to get a cached instance matching the input content, if present. + /// + /// The input with the contents to use. + /// The resulting cached instance, if present + /// Whether or not the target instance was found. + public bool TryGet(ReadOnlySpan span, [NotNullWhen(true)] out string? value) + { + if (span.IsEmpty) + { + value = string.Empty; + + return true; + } + + int + hashcode = GetHashCode(span), + bucketIndex = hashcode & (this.numberOfMaps - 1); + + ref FixedSizePriorityMap map = ref this.maps.DangerousGetReferenceAt(bucketIndex); + + lock (map.SyncRoot) + { + return map.TryGet(span, hashcode, out value); + } + } + + /// + /// Resets the current instance and its associated maps. + /// + public void Reset() + { + foreach (ref FixedSizePriorityMap map in this.maps.AsSpan()) + { + lock (map.SyncRoot) + { + map.Reset(); + } + } + } + + /// + /// A configurable map containing a group of cached instances. + /// + /// + /// Instances of this type are stored in an array within and they are + /// always accessed by reference - essentially as if this type had been a class. The type is + /// also private, so there's no risk for users to directly access it and accidentally copy an + /// instance, which would lead to bugs due to values becoming out of sync with the internal state + /// (that is, because instances would be copied by value, so primitive fields would not be shared). + /// The reason why we're using a struct here is to remove an indirection level and improve cache + /// locality when accessing individual buckets from the methods in the type. + /// + private struct FixedSizePriorityMap + { + /// + /// The index representing the end of a given list. + /// + private const int EndOfList = -1; + + /// + /// The array of 1-based indices for the items stored in . + /// + private readonly int[] buckets; + + /// + /// The array of currently cached entries (ie. the lists for each hash group). + /// + private readonly MapEntry[] mapEntries; + + /// + /// The array of priority values associated to each item stored in . + /// + private readonly HeapEntry[] heapEntries; + + /// + /// The current number of items stored in the map. + /// + private int count; + + /// + /// The current incremental timestamp for the items stored in . + /// + private uint timestamp; + + /// + /// A type representing a map entry, ie. a node in a given list. + /// + private struct MapEntry + { + /// + /// The precomputed hashcode for . + /// + public int HashCode; + + /// + /// The instance cached in this entry. + /// + public string? Value; + + /// + /// The 0-based index for the next node in the current list. + /// + public int NextIndex; + + /// + /// The 0-based index for the heap entry corresponding to the current node. + /// + public int HeapIndex; + } + + /// + /// A type representing a heap entry, used to associate priority to each item. + /// + private struct HeapEntry + { + /// + /// The timestamp for the current entry (ie. the priority for the item). + /// + public uint Timestamp; + + /// + /// The 0-based index for the map entry corresponding to the current item. + /// + public int MapIndex; + } + + /// + /// Initializes a new instance of the struct. + /// + /// The fixed capacity of the current map. + public FixedSizePriorityMap(int capacity) + { + this.buckets = new int[capacity]; + this.mapEntries = new MapEntry[capacity]; + this.heapEntries = new HeapEntry[capacity]; + this.count = 0; + this.timestamp = 0; + } + + /// + /// Gets an that can be used to synchronize access to the current instance. + /// + public object SyncRoot + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + get => this.buckets; + } + + /// + /// Implements for the current instance. + /// + /// The input instance to cache. + /// The precomputed hashcode for . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public unsafe void Add(string value, int hashcode) + { + ref string target = ref TryGet(value.AsSpan(), hashcode); + + if (Unsafe.AreSame(ref target, ref Unsafe.AsRef(null))) + { + Insert(value, hashcode); + } + else + { + target = value; + } + } + + /// + /// Implements for the current instance. + /// + /// The input instance with the contents to use. + /// The precomputed hashcode for . + /// A instance with the contents of . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public unsafe string GetOrAdd(string value, int hashcode) + { + ref string result = ref TryGet(value.AsSpan(), hashcode); + + if (!Unsafe.AreSame(ref result, ref Unsafe.AsRef(null))) + { + return result; + } + + Insert(value, hashcode); + + return value; + } + + /// + /// Implements for the current instance. + /// + /// The input with the contents to use. + /// The precomputed hashcode for . + /// A instance with the contents of , cached if possible. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public unsafe string GetOrAdd(ReadOnlySpan span, int hashcode) + { + ref string result = ref TryGet(span, hashcode); + + if (!Unsafe.AreSame(ref result, ref Unsafe.AsRef(null))) + { + return result; + } + + string value = span.ToString(); + + Insert(value, hashcode); + + return value; + } + + /// + /// Implements for the current instance. + /// + /// The input with the contents to use. + /// The precomputed hashcode for . + /// The resulting cached instance, if present + /// Whether or not the target instance was found. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public unsafe bool TryGet(ReadOnlySpan span, int hashcode, [NotNullWhen(true)] out string? value) + { + ref string result = ref TryGet(span, hashcode); + + if (!Unsafe.AreSame(ref result, ref Unsafe.AsRef(null))) + { + value = result; + + return true; + } + + value = null; + + return false; + } + + /// + /// Resets the current instance and throws away all the cached values. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Reset() + { + this.buckets.AsSpan().Clear(); + this.mapEntries.AsSpan().Clear(); + this.heapEntries.AsSpan().Clear(); + this.count = 0; + this.timestamp = 0; + } + + /// + /// Tries to get a target instance, if it exists, and returns a reference to it. + /// + /// The input with the contents to use. + /// The precomputed hashcode for . + /// A reference to the slot where the target instance could be. + [MethodImpl(MethodImplOptions.NoInlining)] + private unsafe ref string TryGet(ReadOnlySpan span, int hashcode) + { + ref MapEntry mapEntriesRef = ref this.mapEntries.DangerousGetReference(); + ref MapEntry entry = ref Unsafe.AsRef(null); + int + length = this.buckets.Length, + bucketIndex = hashcode & (length - 1); + + for (int i = this.buckets.DangerousGetReferenceAt(bucketIndex) - 1; + (uint)i < (uint)length; + i = entry.NextIndex) + { + entry = ref Unsafe.Add(ref mapEntriesRef, (IntPtr)(void*)(uint)i); + + if (entry.HashCode == hashcode && + entry.Value!.AsSpan().SequenceEqual(span)) + { + UpdateTimestamp(ref entry.HeapIndex); + + return ref entry.Value!; + } + } + + return ref Unsafe.AsRef(null); + } + + /// + /// Inserts a new instance in the current map, freeing up a space if needed. + /// + /// The new instance to store. + /// The precomputed hashcode for . + [MethodImpl(MethodImplOptions.NoInlining)] + private unsafe void Insert(string value, int hashcode) + { + ref int bucketsRef = ref this.buckets.DangerousGetReference(); + ref MapEntry mapEntriesRef = ref this.mapEntries.DangerousGetReference(); + ref HeapEntry heapEntriesRef = ref this.heapEntries.DangerousGetReference(); + int entryIndex, heapIndex; + + // If the current map is full, first get the oldest value, which is + // always the first item in the heap. Then, free up a slot by + // removing that, and insert the new value in that empty location. + if (this.count == this.mapEntries.Length) + { + entryIndex = heapEntriesRef.MapIndex; + heapIndex = 0; + + ref MapEntry removedEntry = ref Unsafe.Add(ref mapEntriesRef, (IntPtr)(void*)(uint)entryIndex); + + // The removal logic can be extremely optimized in this case, as we + // can retrieve the precomputed hashcode for the target entry by doing + // a lookup on the target map node, and we can also skip all the comparisons + // while traversing the target chain since we know in advance the index of + // the target node which will contain the item to remove from the map. + Remove(removedEntry.HashCode, entryIndex); + } + else + { + // If the free list is not empty, get that map node and update the field + entryIndex = this.count; + heapIndex = this.count; + } + + int bucketIndex = hashcode & (this.buckets.Length - 1); + ref int targetBucket = ref Unsafe.Add(ref bucketsRef, (IntPtr)(void*)(uint)bucketIndex); + ref MapEntry targetMapEntry = ref Unsafe.Add(ref mapEntriesRef, (IntPtr)(void*)(uint)entryIndex); + ref HeapEntry targetHeapEntry = ref Unsafe.Add(ref heapEntriesRef, (IntPtr)(void*)(uint)heapIndex); + + // Assign the values in the new map entry + targetMapEntry.HashCode = hashcode; + targetMapEntry.Value = value; + targetMapEntry.NextIndex = targetBucket - 1; + targetMapEntry.HeapIndex = heapIndex; + + // Update the bucket slot and the current count + targetBucket = entryIndex + 1; + this.count++; + + // Link the heap node with the current entry + targetHeapEntry.MapIndex = entryIndex; + + // Update the timestamp and balance the heap again + UpdateTimestamp(ref targetMapEntry.HeapIndex); + } + + /// + /// Removes a specified instance from the map to free up one slot. + /// + /// The precomputed hashcode of the instance to remove. + /// The index of the target map node to remove. + /// The input instance needs to already exist in the map. + [MethodImpl(MethodImplOptions.NoInlining)] + private unsafe void Remove(int hashcode, int mapIndex) + { + ref MapEntry mapEntriesRef = ref this.mapEntries.DangerousGetReference(); + int + bucketIndex = hashcode & (this.buckets.Length - 1), + entryIndex = this.buckets.DangerousGetReferenceAt(bucketIndex) - 1, + lastIndex = EndOfList; + + // We can just have an undefined loop, as the input + // value we're looking for is guaranteed to be present + while (true) + { + ref MapEntry candidate = ref Unsafe.Add(ref mapEntriesRef, (IntPtr)(void*)(uint)entryIndex); + + // Check the current value for a match + if (entryIndex == mapIndex) + { + // If this was not the first list node, update the parent as well + if (lastIndex != EndOfList) + { + ref MapEntry lastEntry = ref Unsafe.Add(ref mapEntriesRef, (IntPtr)(void*)(uint)lastIndex); + + lastEntry.NextIndex = candidate.NextIndex; + } + else + { + // Otherwise, update the target index from the bucket slot + this.buckets.DangerousGetReferenceAt(bucketIndex) = candidate.NextIndex + 1; + } + + this.count--; + + return; + } + + // Move to the following node in the current list + lastIndex = entryIndex; + entryIndex = candidate.NextIndex; + } + } + + /// + /// Updates the timestamp of a heap node at the specified index (which is then synced back). + /// + /// The index of the target heap node to update. + [MethodImpl(MethodImplOptions.NoInlining)] + private unsafe void UpdateTimestamp(ref int heapIndex) + { + int + currentIndex = heapIndex, + count = this.count; + ref MapEntry mapEntriesRef = ref this.mapEntries.DangerousGetReference(); + ref HeapEntry heapEntriesRef = ref this.heapEntries.DangerousGetReference(); + ref HeapEntry root = ref Unsafe.Add(ref heapEntriesRef, (IntPtr)(void*)(uint)currentIndex); + uint timestamp = this.timestamp; + + // Check if incrementing the current timestamp for the heap node to update + // would result in an overflow. If that happened, we could end up violating + // the min-heap property (the value of each node has to always be <= than that + // of its child nodes), eg. if we were updating a node that was not the root. + // In that scenario, we could end up with a node somewhere along the heap with + // a value lower than that of its parent node (as the timestamp would be 0). + // To guard against this, we just check the current timestamp value, and if + // the maximum value has been reached, we reinitialize the entire heap. This + // is done in a non-inlined call, so we don't increase the codegen size in this + // method. The reinitialization simply traverses the heap in breadth-first order + // (ie. level by level), and assigns incrementing timestamps to all nodes starting + // from 0. The value of the current timestamp is then just set to the current size. + if (timestamp == uint.MaxValue) + { + // We use a goto here as this path is very rarely taken. Doing so + // causes the generated asm to contain a forward jump to the fallback + // path if this branch is taken, whereas the normal execution path will + // not need to execute any jumps at all. This is done to reduce the overhead + // introduced by this check in all the invocations where this point is not reached. + goto Fallback; + } + + Downheap: + + // Assign a new timestamp to the target heap node. We use a + // local incremental timestamp instead of using the system timer + // as this greatly reduces the overhead and the time spent in system calls. + // The uint type provides a large enough range and it's unlikely users would ever + // exhaust it anyway (especially considering each map has a separate counter). + root.Timestamp = this.timestamp = timestamp + 1; + + // Once the timestamp is updated (which will cause the heap to become + // unbalanced), start a sift down loop to balance the heap again. + while (true) + { + // The heap is 0-based (so that the array length can remain the same + // as the power of 2 value used for the other arrays in this type). + // This means that children of each node are at positions: + // - left: (2 * n) + 1 + // - right: (2 * n) + 2 + ref HeapEntry minimum = ref root; + int + left = (currentIndex * 2) + 1, + right = (currentIndex * 2) + 2, + targetIndex = currentIndex; + + // Check and update the left child, if necessary + if (left < count) + { + ref HeapEntry child = ref Unsafe.Add(ref heapEntriesRef, (IntPtr)(void*)(uint)left); + + if (child.Timestamp < minimum.Timestamp) + { + minimum = ref child; + targetIndex = left; + } + } + + // Same check as above for the right child + if (right < count) + { + ref HeapEntry child = ref Unsafe.Add(ref heapEntriesRef, (IntPtr)(void*)(uint)right); + + if (child.Timestamp < minimum.Timestamp) + { + minimum = ref child; + targetIndex = right; + } + } + + // If no swap is pending, we can just stop here. + // Before returning, we update the target index as well. + if (Unsafe.AreSame(ref root, ref minimum)) + { + heapIndex = targetIndex; + + return; + } + + // Update the indices in the respective map entries (accounting for the swap) + Unsafe.Add(ref mapEntriesRef, (IntPtr)(void*)(uint)root.MapIndex).HeapIndex = targetIndex; + Unsafe.Add(ref mapEntriesRef, (IntPtr)(void*)(uint)minimum.MapIndex).HeapIndex = currentIndex; + + currentIndex = targetIndex; + + // Swap the parent and child (so that the minimum value bubbles up) + HeapEntry temp = root; + + root = minimum; + minimum = temp; + + // Update the reference to the root node + root = ref Unsafe.Add(ref heapEntriesRef, (IntPtr)(void*)(uint)currentIndex); + } + + Fallback: + + UpdateAllTimestamps(); + + // After having updated all the timestamps, if the heap contains N items, the + // node in the bottom right corner will have a value of N - 1. Since the timestamp + // is incremented by 1 before starting the downheap execution, here we simply + // update the local timestamp to be N - 1, so that the code above will set the + // timestamp of the node currently being updated to exactly N. + timestamp = (uint)(count - 1); + + goto Downheap; + } + + /// + /// Updates the timestamp of all the current heap nodes in incrementing order. + /// The heap is always guaranteed to be complete binary tree, so when it contains + /// a given number of nodes, those are all contiguous from the start of the array. + /// + [MethodImpl(MethodImplOptions.NoInlining)] + private unsafe void UpdateAllTimestamps() + { + int count = this.count; + ref HeapEntry heapEntriesRef = ref this.heapEntries.DangerousGetReference(); + + for (int i = 0; i < count; i++) + { + Unsafe.Add(ref heapEntriesRef, (IntPtr)(void*)(uint)i).Timestamp = (uint)i; + } + } + } + + /// + /// Gets the (positive) hashcode for a given instance. + /// + /// The input instance. + /// The hashcode for . + [Pure] + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int GetHashCode(ReadOnlySpan span) + { +#if NETSTANDARD1_4 + return span.GetDjb2HashCode(); +#else + return HashCode.Combine(span); +#endif + } + + /// + /// Throws an when the requested size exceeds the capacity. + /// + private static void ThrowArgumentOutOfRangeException() + { + throw new ArgumentOutOfRangeException("minimumSize", "The requested size must be greater than 0"); + } + } +} diff --git a/UnitTests/UnitTests.HighPerformance.Shared/Buffers/Test_StringPool.cs b/UnitTests/UnitTests.HighPerformance.Shared/Buffers/Test_StringPool.cs new file mode 100644 index 00000000000..b679609d40a --- /dev/null +++ b/UnitTests/UnitTests.HighPerformance.Shared/Buffers/Test_StringPool.cs @@ -0,0 +1,344 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using System.Text; +using Microsoft.Toolkit.HighPerformance.Buffers; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +#nullable enable + +namespace UnitTests.HighPerformance.Buffers +{ + [TestClass] + public class Test_StringPool + { + [TestCategory("StringPool")] + [TestMethod] + [DataRow(44, 4, 16, 64)] + [DataRow(76, 8, 16, 128)] + [DataRow(128, 8, 16, 128)] + [DataRow(179, 8, 32, 256)] + [DataRow(366, 16, 32, 512)] + [DataRow(512, 16, 32, 512)] + [DataRow(890, 16, 64, 1024)] + [DataRow(1280, 32, 64, 2048)] + [DataRow(2445, 32, 128, 4096)] + [DataRow(5000, 64, 128, 8192)] + [DataRow(8000, 64, 128, 8192)] + [DataRow(12442, 64, 256, 16384)] + [DataRow(234000, 256, 1024, 262144)] + public void Test_StringPool_Cctor_Ok(int minimumSize, int x, int y, int size) + { + var pool = new StringPool(minimumSize); + + Assert.AreEqual(size, pool.Size); + + Array maps = (Array)typeof(StringPool).GetField("maps", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(pool); + + Assert.AreEqual(x, maps.Length); + + Type bucketType = Type.GetType("Microsoft.Toolkit.HighPerformance.Buffers.StringPool+FixedSizePriorityMap, Microsoft.Toolkit.HighPerformance"); + + int[] buckets = (int[])bucketType.GetField("buckets", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(maps.GetValue(0)); + + Assert.AreEqual(y, buckets.Length); + } + + [TestCategory("StringPool")] + [TestMethod] + [DataRow(0)] + [DataRow(-3248234)] + [DataRow(int.MinValue)] + public void Test_StringPool_Cctor_Fail(int size) + { + try + { + var pool = new StringPool(size); + + Assert.Fail(); + } + catch (ArgumentOutOfRangeException e) + { + var cctor = typeof(StringPool).GetConstructor(new[] { typeof(int) }); + + Assert.AreEqual(cctor.GetParameters()[0].Name, e.ParamName); + } + } + + [TestCategory("StringPool")] + [TestMethod] + public void Test_StringPool_Add_Empty() + { + StringPool.Shared.Add(string.Empty); + + bool found = StringPool.Shared.TryGet(ReadOnlySpan.Empty, out string? text); + + Assert.IsTrue(found); + Assert.AreSame(string.Empty, text); + } + + [TestCategory("StringPool")] + [TestMethod] + public void Test_StringPool_Add_Single() + { + var pool = new StringPool(); + + string hello = nameof(hello); + + Assert.IsFalse(pool.TryGet(hello.AsSpan(), out _)); + + pool.Add(hello); + + Assert.IsTrue(pool.TryGet(hello.AsSpan(), out string? hello2)); + + Assert.AreSame(hello, hello2); + } + + [TestCategory("StringPool")] + [TestMethod] + public void Test_StringPool_Add_Misc() + { + var pool = new StringPool(); + + string + hello = nameof(hello), + helloworld = nameof(helloworld), + windowsCommunityToolkit = nameof(windowsCommunityToolkit); + + Assert.IsFalse(pool.TryGet(hello.AsSpan(), out _)); + Assert.IsFalse(pool.TryGet(helloworld.AsSpan(), out _)); + Assert.IsFalse(pool.TryGet(windowsCommunityToolkit.AsSpan(), out _)); + + pool.Add(hello); + pool.Add(helloworld); + pool.Add(windowsCommunityToolkit); + + Assert.IsTrue(pool.TryGet(hello.AsSpan(), out string? hello2)); + Assert.IsTrue(pool.TryGet(helloworld.AsSpan(), out string? world2)); + Assert.IsTrue(pool.TryGet(windowsCommunityToolkit.AsSpan(), out string? windowsCommunityToolkit2)); + + Assert.AreSame(hello, hello2); + Assert.AreSame(helloworld, world2); + Assert.AreSame(windowsCommunityToolkit, windowsCommunityToolkit2); + } + + [TestCategory("StringPool")] + [TestMethod] + public void Test_StringPool_Add_Overwrite() + { + var pool = new StringPool(); + + var today = DateTime.Today; + + var text1 = ToStringNoInlining(today); + + pool.Add(text1); + + Assert.IsTrue(pool.TryGet(text1.AsSpan(), out string? result)); + + Assert.AreSame(text1, result); + + var text2 = ToStringNoInlining(today); + + pool.Add(text2); + + Assert.IsTrue(pool.TryGet(text2.AsSpan(), out result)); + + Assert.AreNotSame(text1, result); + Assert.AreSame(text2, result); + } + + // Separate method just to ensure the JIT can't optimize things away + // and make the test fail because different string instances are interned + [MethodImpl(MethodImplOptions.NoInlining)] + private static string ToStringNoInlining(object obj) + { + return obj.ToString(); + } + + [TestCategory("StringPool")] + [TestMethod] + public void Test_StringPool_GetOrAdd_String_Empty() + { + string empty = StringPool.Shared.GetOrAdd(string.Empty); + + Assert.AreSame(string.Empty, empty); + } + + [TestCategory("StringPool")] + [TestMethod] + public void Test_StringPool_GetOrAdd_String_Misc() + { + var pool = new StringPool(); + + string helloworld = nameof(helloworld); + + string cached = pool.GetOrAdd(helloworld); + + Assert.AreSame(helloworld, cached); + + Span span = stackalloc char[helloworld.Length]; + + helloworld.AsSpan().CopyTo(span); + + string helloworld2 = span.ToString(); + + cached = pool.GetOrAdd(helloworld2); + + Assert.AreSame(helloworld, cached); + + cached = pool.GetOrAdd(new string(helloworld.ToCharArray())); + + Assert.AreSame(helloworld, cached); + } + + [TestCategory("StringPool")] + [TestMethod] + public void Test_StringPool_GetOrAdd_ReadOnlySpan_Empty() + { + string empty = StringPool.Shared.GetOrAdd(ReadOnlySpan.Empty); + + Assert.AreSame(string.Empty, empty); + } + + [TestCategory("StringPool")] + [TestMethod] + public void Test_StringPool_GetOrAdd_ReadOnlySpan_Misc() + { + var pool = new StringPool(); + + string + hello = pool.GetOrAdd(nameof(hello).AsSpan()), + helloworld = pool.GetOrAdd(nameof(helloworld).AsSpan()), + windowsCommunityToolkit = pool.GetOrAdd(nameof(windowsCommunityToolkit).AsSpan()); + + Assert.AreEqual(nameof(hello), hello); + Assert.AreEqual(nameof(helloworld), helloworld); + Assert.AreEqual(nameof(windowsCommunityToolkit), windowsCommunityToolkit); + + Assert.AreSame(hello, pool.GetOrAdd(hello.AsSpan())); + Assert.AreSame(helloworld, pool.GetOrAdd(helloworld.AsSpan())); + Assert.AreSame(windowsCommunityToolkit, pool.GetOrAdd(windowsCommunityToolkit.AsSpan())); + + pool.Reset(); + + Assert.AreEqual(nameof(hello), hello); + Assert.AreEqual(nameof(helloworld), helloworld); + Assert.AreEqual(nameof(windowsCommunityToolkit), windowsCommunityToolkit); + + Assert.AreNotSame(hello, pool.GetOrAdd(hello.AsSpan())); + Assert.AreNotSame(helloworld, pool.GetOrAdd(helloworld.AsSpan())); + Assert.AreNotSame(windowsCommunityToolkit, pool.GetOrAdd(windowsCommunityToolkit.AsSpan())); + } + + [TestCategory("StringPool")] + [TestMethod] + public void Test_StringPool_GetOrAdd_Encoding_Empty() + { + string empty = StringPool.Shared.GetOrAdd(ReadOnlySpan.Empty, Encoding.ASCII); + + Assert.AreSame(string.Empty, empty); + } + + [TestCategory("StringPool")] + [TestMethod] + public void Test_StringPool_GetOrAdd_Encoding_Misc() + { + var pool = new StringPool(); + + string helloworld = nameof(helloworld); + + pool.Add(helloworld); + + Span span = Encoding.UTF8.GetBytes(nameof(helloworld)); + + string helloworld2 = pool.GetOrAdd(span, Encoding.UTF8); + + Assert.AreSame(helloworld, helloworld2); + + string windowsCommunityToolkit = nameof(windowsCommunityToolkit); + + Span span2 = Encoding.ASCII.GetBytes(windowsCommunityToolkit); + + string + windowsCommunityToolkit2 = pool.GetOrAdd(span2, Encoding.ASCII), + windowsCommunityToolkit3 = pool.GetOrAdd(windowsCommunityToolkit); + + Assert.AreSame(windowsCommunityToolkit2, windowsCommunityToolkit3); + } + + [TestCategory("StringPool")] + [TestMethod] + public void Test_StringPool_GetOrAdd_Overflow() + { + var pool = new StringPool(32); + + // Fill the pool + for (int i = 0; i < 4096; i++) + { + _ = pool.GetOrAdd(i.ToString()); + } + + // Get the buckets + Array maps = (Array)typeof(StringPool).GetField("maps", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(pool); + + Type bucketType = Type.GetType("Microsoft.Toolkit.HighPerformance.Buffers.StringPool+FixedSizePriorityMap, Microsoft.Toolkit.HighPerformance"); + FieldInfo timestampInfo = bucketType.GetField("timestamp", BindingFlags.Instance | BindingFlags.NonPublic); + + // Force the timestamp to be the maximum value, or the test would take too long + for (int i = 0; i < maps.LongLength; i++) + { + object map = maps.GetValue(i); + + timestampInfo.SetValue(map, uint.MaxValue); + + maps.SetValue(map, i); + } + + // Force an overflow + string text = "Hello world"; + + _ = pool.GetOrAdd(text); + + Type heapEntryType = Type.GetType("Microsoft.Toolkit.HighPerformance.Buffers.StringPool+FixedSizePriorityMap+HeapEntry, Microsoft.Toolkit.HighPerformance"); + + foreach (var map in maps) + { + // Get the heap for each bucket + Array heapEntries = (Array)bucketType.GetField("heapEntries", BindingFlags.Instance | BindingFlags.NonPublic).GetValue(map); + FieldInfo fieldInfo = heapEntryType.GetField("Timestamp"); + + // Extract the array with the timestamps in the heap nodes + uint[] array = heapEntries.Cast().Select(entry => (uint)fieldInfo.GetValue(entry)).ToArray(); + + static bool IsMinHeap(uint[] array) + { + for (int i = 0; i < array.Length; i++) + { + int + left = (i * 2) + 1, + right = (i * 2) + 2; + + if ((left < array.Length && + array[left] <= array[i]) || + (right < array.Length && + array[right] <= array[i])) + { + return false; + } + } + + return true; + } + + // Verify that the current heap is indeed valid after the overflow + Assert.IsTrue(IsMinHeap(array)); + } + } + } +} diff --git a/UnitTests/UnitTests.HighPerformance.Shared/UnitTests.HighPerformance.Shared.projitems b/UnitTests/UnitTests.HighPerformance.Shared/UnitTests.HighPerformance.Shared.projitems index ef69154b89f..cb1c3a5a85f 100644 --- a/UnitTests/UnitTests.HighPerformance.Shared/UnitTests.HighPerformance.Shared.projitems +++ b/UnitTests/UnitTests.HighPerformance.Shared/UnitTests.HighPerformance.Shared.projitems @@ -13,6 +13,7 @@ +