diff --git a/Ryujinx.Common/Collections/TreeDictionary.cs b/Ryujinx.Common/Collections/TreeDictionary.cs new file mode 100644 index 00000000..a44f650c --- /dev/null +++ b/Ryujinx.Common/Collections/TreeDictionary.cs @@ -0,0 +1,987 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; + +namespace Ryujinx.Common.Collections +{ + /// + /// Dictionary that provides the ability for O(logN) Lookups for keys that exist in the Dictionary, and O(logN) lookups for keys immediately greater than or less than a specified key. + /// + /// Key + /// Value + public class TreeDictionary : IDictionary where K : IComparable + { + private const bool Black = true; + private const bool Red = false; + private Node _root = null; + private int _count = 0; + public TreeDictionary() { } + + #region Public Methods + + /// + /// Returns the value of the node whose key is , or the default value if no such node exists. + /// + /// Key of the node value to get + /// Value associated w/ + /// is null + public V Get(K key) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + + Node node = GetNode(key); + + if (node == null) + { + return default; + } + + return node.Value; + } + + /// + /// Adds a new node into the tree whose key is key and value is . + ///

+ /// Note: Adding the same key multiple times will cause the value for that key to be overwritten. + ///
+ /// Key of the node to add + /// Value of the node to add + /// or are null + public void Add(K key, V value) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + if (null == value) + { + throw new ArgumentNullException(nameof(value)); + } + + Insert(key, value); + } + + /// + /// Removes the node whose key is from the tree. + /// + /// Key of the node to remove + /// is null + public void Remove(K key) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + if (Delete(key) != null) + { + _count--; + } + } + + /// + /// Returns the value whose key is equal to or immediately less than . + /// + /// Key for which to find the floor value of + /// Key of node immediately less than + /// is null + public K Floor(K key) + { + Node node = FloorNode(key); + if (node != null) + { + return node.Key; + } + return default; + } + + /// + /// Returns the node whose key is equal to or immediately greater than . + /// + /// Key for which to find the ceiling node of + /// Key of node immediately greater than + /// is null + public K Ceiling(K key) + { + Node node = CeilingNode(key); + if (node != null) + { + return node.Key; + } + return default; + } + + /// + /// Finds the value whose key is immediately greater than . + /// + /// Key to find the successor of + /// Value + public K SuccessorOf(K key) + { + Node node = GetNode(key); + if (node != null) + { + Node successor = SuccessorOf(node); + + return successor != null ? successor.Key : default; + } + return default; + } + + /// + /// Finds the value whose key is immediately less than . + /// + /// Key to find the predecessor of + /// Value + public K PredecessorOf(K key) + { + Node node = GetNode(key); + if (node != null) + { + Node predecessor = PredecessorOf(node); + + return predecessor != null ? predecessor.Key : default; + } + return default; + } + + /// + /// Adds all the nodes in the dictionary as key/value pairs into . + ///

+ /// The key/value pairs will be added in Level Order. + ///
+ /// List to add the tree pairs into + public List> AsLevelOrderList() + { + List> list = new List>(); + + Queue> nodes = new Queue>(); + + if (this._root != null) + { + nodes.Enqueue(this._root); + } + while (nodes.Count > 0) + { + Node node = nodes.Dequeue(); + list.Add(new KeyValuePair(node.Key, node.Value)); + if (node.Left != null) + { + nodes.Enqueue(node.Left); + } + if (node.Right != null) + { + nodes.Enqueue(node.Right); + } + } + return list; + } + + /// + /// Adds all the nodes in the dictionary into . + ///

+ /// The nodes will be added in Sorted by Key Order. + ///
+ public List> AsList() + { + List> list = new List>(); + + Queue> nodes = new Queue>(); + + if (this._root != null) + { + nodes.Enqueue(this._root); + } + while (nodes.Count > 0) + { + Node node = nodes.Dequeue(); + list.Add(new KeyValuePair(node.Key, node.Value)); + if (node.Left != null) + { + nodes.Enqueue(node.Left); + } + if (node.Right != null) + { + nodes.Enqueue(node.Right); + } + } + + return list; + } + #endregion + #region Private Methods (BST) + + /// + /// Retrieve the node reference whose key is , or null if no such node exists. + /// + /// Key of the node to get + /// Node reference in the tree + /// is null + private Node GetNode(K key) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + + Node node = _root; + while (node != null) + { + int cmp = key.CompareTo(node.Key); + if (cmp < 0) + { + node = node.Left; + } + else if (cmp > 0) + { + node = node.Right; + } + else + { + return node; + } + } + return null; + } + + /// + /// Inserts a new node into the tree whose key is and value is . + ///

+ /// Adding the same key multiple times will overwrite the previous value. + ///
+ /// Key of the node to insert + /// Value of the node to insert + private void Insert(K key, V value) + { + Node newNode = BSTInsert(key, value); + RestoreBalanceAfterInsertion(newNode); + } + + /// + /// Insertion Mechanism for a Binary Search Tree (BST). + ///

+ /// Iterates the tree starting from the root and inserts a new node where all children in the left subtree are less than , and all children in the right subtree are greater than . + ///

+ /// Note: If a node whose key is already exists, it's value will be overwritten. + ///
+ /// Key of the node to insert + /// Value of the node to insert + /// The inserted Node + private Node BSTInsert(K key, V value) + { + Node parent = null; + Node node = _root; + + while (node != null) + { + parent = node; + int cmp = key.CompareTo(node.Key); + if (cmp < 0) + { + node = node.Left; + } + else if (cmp > 0) + { + node = node.Right; + } + else + { + node.Value = value; + return node; + } + } + Node newNode = new Node(key, value, parent); + if (newNode.Parent == null) + { + _root = newNode; + } + else if (key.CompareTo(parent.Key) < 0) + { + parent.Left = newNode; + } + else + { + parent.Right = newNode; + } + _count++; + return newNode; + } + + /// + /// Removes from the dictionary, if it exists. + /// + /// Key of the node to delete + /// The deleted Node + private Node Delete(K key) + { + // O(1) Retrieval + Node nodeToDelete = GetNode(key); + + if (nodeToDelete == null) return null; + + Node replacementNode; + + if (LeftOf(nodeToDelete) == null || RightOf(nodeToDelete) == null) + { + replacementNode = nodeToDelete; + } + else + { + replacementNode = PredecessorOf(nodeToDelete); + } + + Node tmp = LeftOf(replacementNode) ?? RightOf(replacementNode); + + if (tmp != null) + { + tmp.Parent = ParentOf(replacementNode); + } + + if (ParentOf(replacementNode) == null) + { + _root = tmp; + } + + else if (replacementNode == LeftOf(ParentOf(replacementNode))) + { + ParentOf(replacementNode).Left = tmp; + } + else + { + ParentOf(replacementNode).Right = tmp; + } + + if (replacementNode != nodeToDelete) + { + nodeToDelete.Key = replacementNode.Key; + nodeToDelete.Value = replacementNode.Value; + } + + if (tmp != null && ColorOf(replacementNode) == Black) + { + RestoreBalanceAfterRemoval(tmp); + } + + return replacementNode; + } + + /// + /// Returns the node with the largest key where is considered the root node. + /// + /// Root Node + /// Node with the maximum key in the tree of + /// is null + private static Node Maximum(Node node) + { + if (node == null) + { + throw new ArgumentNullException(nameof(node)); + } + Node tmp = node; + while (tmp.Right != null) + { + tmp = tmp.Right; + } + + return tmp; + } + + /// + /// Returns the node with the smallest key where is considered the root node. + /// + /// Root Node + /// Node with the minimum key in the tree of + /// is null + private static Node Minimum(Node node) + { + if (node == null) + { + throw new ArgumentNullException(nameof(node)); + } + Node tmp = node; + while (tmp.Left != null) + { + tmp = tmp.Left; + } + + return tmp; + } + + /// + /// Returns the node whose key immediately less than or equal to . + /// + /// Key for which to find the floor node of + /// Node whose key is immediately less than or equal to , or null if no such node is found. + /// is null + private Node FloorNode(K key) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + Node tmp = _root; + + while (tmp != null) + { + int cmp = key.CompareTo(tmp.Key); + if (cmp > 0) + { + if (tmp.Right != null) + { + tmp = tmp.Right; + } + else + { + return tmp; + } + } + else if (cmp < 0) + { + if (tmp.Left != null) + { + tmp = tmp.Left; + } + else + { + Node parent = tmp.Parent; + Node ptr = tmp; + while (parent != null && ptr == parent.Left) + { + ptr = parent; + parent = parent.Parent; + } + return parent; + } + } + else + { + return tmp; + } + } + return null; + } + + /// + /// Returns the node whose key is immediately greater than or equal to than . + /// + /// Key for which to find the ceiling node of + /// Node whose key is immediately greater than or equal to , or null if no such node is found. + /// is null + private Node CeilingNode(K key) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + Node tmp = _root; + + while (tmp != null) + { + int cmp = key.CompareTo(tmp.Key); + if (cmp < 0) + { + if (tmp.Left != null) + { + tmp = tmp.Left; + } + else + { + return tmp; + } + } + else if (cmp > 0) + { + if (tmp.Right != null) + { + tmp = tmp.Right; + } + else + { + Node parent = tmp.Parent; + Node ptr = tmp; + while (parent != null && ptr == parent.Right) + { + ptr = parent; + parent = parent.Parent; + } + return parent; + } + } + else + { + return tmp; + } + } + return null; + } + + /// + /// Finds the node with the key immediately greater than .Key. + /// + /// Node to find the successor of + /// Successor of + private static Node SuccessorOf(Node node) + { + if (node.Right != null) + { + return Minimum(node.Right); + } + Node parent = node.Parent; + while (parent != null && node == parent.Right) + { + node = parent; + parent = parent.Parent; + } + return parent; + } + + /// + /// Finds the node whose key immediately less than .Key. + /// + /// Node to find the predecessor of + /// Predecessor of + private static Node PredecessorOf(Node node) + { + if (node.Left != null) + { + return Maximum(node.Left); + } + Node parent = node.Parent; + while (parent != null && node == parent.Left) + { + node = parent; + parent = parent.Parent; + } + return parent; + } + #endregion + #region Private Methods (RBL) + + private void RestoreBalanceAfterRemoval(Node balanceNode) + { + Node ptr = balanceNode; + + while (ptr != _root && ColorOf(ptr) == Black) + { + if (ptr == LeftOf(ParentOf(ptr))) + { + Node sibling = RightOf(ParentOf(ptr)); + + if (ColorOf(sibling) == Red) + { + SetColor(sibling, Black); + SetColor(ParentOf(ptr), Red); + RotateLeft(ParentOf(ptr)); + sibling = RightOf(ParentOf(ptr)); + } + if (ColorOf(LeftOf(sibling)) == Black && ColorOf(RightOf(sibling)) == Black) + { + SetColor(sibling, Red); + ptr = ParentOf(ptr); + } + else + { + if (ColorOf(RightOf(sibling)) == Black) + { + SetColor(LeftOf(sibling), Black); + SetColor(sibling, Red); + RotateRight(sibling); + sibling = RightOf(ParentOf(ptr)); + } + SetColor(sibling, ColorOf(ParentOf(ptr))); + SetColor(ParentOf(ptr), Black); + SetColor(RightOf(sibling), Black); + RotateLeft(ParentOf(ptr)); + ptr = _root; + } + } + else + { + Node sibling = LeftOf(ParentOf(ptr)); + + if (ColorOf(sibling) == Red) + { + SetColor(sibling, Black); + SetColor(ParentOf(ptr), Red); + RotateRight(ParentOf(ptr)); + sibling = LeftOf(ParentOf(ptr)); + } + if (ColorOf(RightOf(sibling)) == Black && ColorOf(LeftOf(sibling)) == Black) + { + SetColor(sibling, Red); + ptr = ParentOf(ptr); + } + else + { + if (ColorOf(LeftOf(sibling)) == Black) + { + SetColor(RightOf(sibling), Black); + SetColor(sibling, Red); + RotateLeft(sibling); + sibling = LeftOf(ParentOf(ptr)); + } + SetColor(sibling, ColorOf(ParentOf(ptr))); + SetColor(ParentOf(ptr), Black); + SetColor(LeftOf(sibling), Black); + RotateRight(ParentOf(ptr)); + ptr = _root; + } + } + } + SetColor(ptr, Black); + } + + private void RestoreBalanceAfterInsertion(Node balanceNode) + { + SetColor(balanceNode, Red); + while (balanceNode != null && balanceNode != _root && ColorOf(ParentOf(balanceNode)) == Red) + { + if (ParentOf(balanceNode) == LeftOf(ParentOf(ParentOf(balanceNode)))) + { + Node sibling = RightOf(ParentOf(ParentOf(balanceNode))); + + if (ColorOf(sibling) == Red) + { + SetColor(ParentOf(balanceNode), Black); + SetColor(sibling, Black); + SetColor(ParentOf(ParentOf(balanceNode)), Red); + balanceNode = ParentOf(ParentOf(balanceNode)); + } + else + { + if (balanceNode == RightOf(ParentOf(balanceNode))) + { + balanceNode = ParentOf(balanceNode); + RotateLeft(balanceNode); + } + SetColor(ParentOf(balanceNode), Black); + SetColor(ParentOf(ParentOf(balanceNode)), Red); + RotateRight(ParentOf(ParentOf(balanceNode))); + } + } + else + { + Node sibling = LeftOf(ParentOf(ParentOf(balanceNode))); + + if (ColorOf(sibling) == Red) + { + SetColor(ParentOf(balanceNode), Black); + SetColor(sibling, Black); + SetColor(ParentOf(ParentOf(balanceNode)), Red); + balanceNode = ParentOf(ParentOf(balanceNode)); + } + else + { + if (balanceNode == LeftOf(ParentOf(balanceNode))) + { + balanceNode = ParentOf(balanceNode); + RotateRight(balanceNode); + } + SetColor(ParentOf(balanceNode), Black); + SetColor(ParentOf(ParentOf(balanceNode)), Red); + RotateLeft(ParentOf(ParentOf(balanceNode))); + } + } + } + SetColor(_root, Black); + } + + private void RotateLeft(Node node) + { + if (node != null) + { + Node right = RightOf(node); + node.Right = LeftOf(right); + if (LeftOf(right) != null) + { + LeftOf(right).Parent = node; + } + right.Parent = ParentOf(node); + if (ParentOf(node) == null) + { + _root = right; + } + else if (node == LeftOf(ParentOf(node))) + { + ParentOf(node).Left = right; + } + else + { + ParentOf(node).Right = right; + } + right.Left = node; + node.Parent = right; + } + } + + private void RotateRight(Node node) + { + if (node != null) + { + Node left = LeftOf(node); + node.Left = RightOf(left); + if (RightOf(left) != null) + { + RightOf(left).Parent = node; + } + left.Parent = node.Parent; + if (ParentOf(node) == null) + { + _root = left; + } + else if (node == RightOf(ParentOf(node))) + { + ParentOf(node).Right = left; + } + else + { + ParentOf(node).Left = left; + } + left.Right = node; + node.Parent = left; + } + } + #endregion + + #region Safety-Methods + + // These methods save memory by allowing us to forego sentinel nil nodes, as well as serve as protection against nullpointerexceptions. + + /// + /// Returns the color of , or Black if it is null. + /// + /// Node + /// The boolean color of , or black if null + private static bool ColorOf(Node node) + { + return node == null || node.Color; + } + + /// + /// Sets the color of node to . + ///

+ /// This method does nothing if is null. + ///
+ /// Node to set the color of + /// Color (Boolean) + private static void SetColor(Node node, bool color) + { + if (node != null) + { + node.Color = color; + } + } + + /// + /// This method returns the left node of , or null if is null. + /// + /// Node to retrieve the left child from + /// Left child of + private static Node LeftOf(Node node) + { + return node?.Left; + } + + /// + /// This method returns the right node of , or null if is null. + /// + /// Node to retrieve the right child from + /// Right child of + private static Node RightOf(Node node) + { + return node?.Right; + } + + /// + /// Returns the parent node of , or null if is null. + /// + /// Node to retrieve the parent from + /// Parent of + private static Node ParentOf(Node node) + { + return node?.Parent; + } + #endregion + + #region Interface Implementations + + // Method descriptions are not provided as they are already included as part of the interface. + public bool ContainsKey(K key) + { + if (key == null) + { + throw new ArgumentNullException(nameof(key)); + } + return GetNode(key) != null; + } + + bool IDictionary.Remove(K key) + { + int count = _count; + Remove(key); + return count > _count; + } + + public bool TryGetValue(K key, [MaybeNullWhen(false)] out V value) + { + if (null == key) + { + throw new ArgumentNullException(nameof(key)); + } + Node node = GetNode(key); + value = node != null ? node.Value : default; + return node != null; + } + + public void Add(KeyValuePair item) + { + if (item.Key == null) + { + throw new ArgumentNullException(nameof(item.Key)); + } + + Add(item.Key, item.Value); + } + + public void Clear() + { + _root = null; + _count = 0; + } + + public bool Contains(KeyValuePair item) + { + if (item.Key == null) + { + return false; + } + + Node node = GetNode(item.Key); + if (node != null) + { + return node.Key.Equals(item.Key) && node.Value.Equals(item.Value); + } + return false; + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + if (arrayIndex < 0 || array.Length - arrayIndex < this.Count) + { + throw new ArgumentOutOfRangeException(nameof(arrayIndex)); + } + + SortedList list = GetKeyValues(); + + int offset = 0; + + for (int i = arrayIndex; i < array.Length && offset < list.Count; i++) + { + array[i] = new KeyValuePair(list.Keys[i], list.Values[i]); + offset++; + } + } + + public bool Remove(KeyValuePair item) + { + Node node = GetNode(item.Key); + + if (node == null) + { + return false; + } + + if (node.Value.Equals(item.Value)) + { + int count = _count; + Remove(item.Key); + return count > _count; + } + + return false; + } + + public IEnumerator> GetEnumerator() + { + return GetKeyValues().GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetKeyValues().GetEnumerator(); + } + + public int Count => _count; + + public ICollection Keys => GetKeyValues().Keys; + + public ICollection Values => GetKeyValues().Values; + + public bool IsReadOnly => false; + + public V this[K key] + { + get => Get(key); + set => Add(key, value); + } + + #endregion + #region Private Interface Helper Methods + + /// + /// Returns a sorted list of all the node keys / values in the tree. + /// + /// List of node keys + private SortedList GetKeyValues() + { + SortedList set = new SortedList(); + Queue> queue = new Queue>(); + if (_root != null) + { + queue.Enqueue(_root); + } + + while (queue.Count > 0) + { + Node node = queue.Dequeue(); + set.Add(node.Key, node.Value); + if (null != node.Left) + { + queue.Enqueue(node.Left); + } + if (null != node.Right) + { + queue.Enqueue(node.Right); + } + } + + return set; + } + #endregion + } + + /// + /// Represents a node in the TreeDictionary which contains a key and value of generic type K and V, respectively. + /// + /// Key of the node + /// Value of the node + internal class Node + { + internal bool Color = true; + internal Node Left = null; + internal Node Right = null; + internal Node Parent = null; + internal K Key; + internal V Value; + + public Node(K key, V value, Node parent) + { + this.Key = key; + this.Value = value; + this.Parent = parent; + } + } +} diff --git a/Ryujinx.Graphics.Gpu/Image/TextureManager.cs b/Ryujinx.Graphics.Gpu/Image/TextureManager.cs index 6b11a671..993218ce 100644 --- a/Ryujinx.Graphics.Gpu/Image/TextureManager.cs +++ b/Ryujinx.Graphics.Gpu/Image/TextureManager.cs @@ -472,7 +472,7 @@ namespace Ryujinx.Graphics.Gpu.Image { ulong address = _context.MemoryManager.Translate(copyTexture.Address.Pack()); - if (address == MemoryManager.BadAddress) + if (address == MemoryManager.PteUnmapped) { return null; } @@ -533,7 +533,7 @@ namespace Ryujinx.Graphics.Gpu.Image { ulong address = _context.MemoryManager.Translate(colorState.Address.Pack()); - if (address == MemoryManager.BadAddress) + if (address == MemoryManager.PteUnmapped) { return null; } @@ -618,7 +618,7 @@ namespace Ryujinx.Graphics.Gpu.Image { ulong address = _context.MemoryManager.Translate(dsState.Address.Pack()); - if (address == MemoryManager.BadAddress) + if (address == MemoryManager.PteUnmapped) { return null; } @@ -983,7 +983,7 @@ namespace Ryujinx.Graphics.Gpu.Image { ulong address = _context.MemoryManager.Translate(cbp.DstAddress.Pack()); - if (address == MemoryManager.BadAddress) + if (address == MemoryManager.PteUnmapped) { return null; } diff --git a/Ryujinx.Graphics.Gpu/Image/TexturePool.cs b/Ryujinx.Graphics.Gpu/Image/TexturePool.cs index 53d810b9..9c7e849b 100644 --- a/Ryujinx.Graphics.Gpu/Image/TexturePool.cs +++ b/Ryujinx.Graphics.Gpu/Image/TexturePool.cs @@ -54,7 +54,7 @@ namespace Ryujinx.Graphics.Gpu.Image // Bad address. We can't add a texture with a invalid address // to the cache. - if (info.Address == MemoryManager.BadAddress) + if (info.Address == MemoryManager.PteUnmapped) { return null; } diff --git a/Ryujinx.Graphics.Gpu/Memory/BufferManager.cs b/Ryujinx.Graphics.Gpu/Memory/BufferManager.cs index 568133ca..1d48b38c 100644 --- a/Ryujinx.Graphics.Gpu/Memory/BufferManager.cs +++ b/Ryujinx.Graphics.Gpu/Memory/BufferManager.cs @@ -401,7 +401,7 @@ namespace Ryujinx.Graphics.Gpu.Memory ulong address = _context.MemoryManager.Translate(gpuVa); - if (address == MemoryManager.BadAddress) + if (address == MemoryManager.PteUnmapped) { return 0; } diff --git a/Ryujinx.Graphics.Gpu/Memory/MemoryManager.cs b/Ryujinx.Graphics.Gpu/Memory/MemoryManager.cs index 91575e20..2990fb52 100644 --- a/Ryujinx.Graphics.Gpu/Memory/MemoryManager.cs +++ b/Ryujinx.Graphics.Gpu/Memory/MemoryManager.cs @@ -10,10 +10,6 @@ namespace Ryujinx.Graphics.Gpu.Memory /// public class MemoryManager { - private const ulong AddressSpaceSize = 1UL << 40; - - public const ulong BadAddress = ulong.MaxValue; - private const int PtLvl0Bits = 14; private const int PtLvl1Bits = 14; public const int PtPageBits = 12; @@ -29,8 +25,7 @@ namespace Ryujinx.Graphics.Gpu.Memory private const int PtLvl0Bit = PtPageBits + PtLvl1Bits; private const int PtLvl1Bit = PtPageBits; - private const ulong PteUnmapped = 0xffffffff_ffffffff; - private const ulong PteReserved = 0xffffffff_fffffffe; + public const ulong PteUnmapped = 0xffffffff_ffffffff; private readonly ulong[][] _pageTable; @@ -136,116 +131,6 @@ namespace Ryujinx.Graphics.Gpu.Memory return va; } - /// - /// Maps a given range of pages to an allocated GPU virtual address. - /// The memory is automatically allocated by the memory manager. - /// - /// CPU virtual address to map into - /// Size in bytes of the mapping - /// Required alignment of the GPU virtual address in bytes - /// GPU virtual address where the range was mapped, or an all ones mask in case of failure - public ulong MapAllocate(ulong pa, ulong size, ulong alignment) - { - lock (_pageTable) - { - ulong va = GetFreePosition(size, alignment); - - if (va != PteUnmapped) - { - for (ulong offset = 0; offset < size; offset += PageSize) - { - SetPte(va + offset, pa + offset); - } - } - - return va; - } - } - - /// - /// Maps a given range of pages to an allocated GPU virtual address. - /// The memory is automatically allocated by the memory manager. - /// This also ensures that the mapping is always done in the first 4GB of GPU address space. - /// - /// CPU virtual address to map into - /// Size in bytes of the mapping - /// GPU virtual address where the range was mapped, or an all ones mask in case of failure - public ulong MapLow(ulong pa, ulong size) - { - lock (_pageTable) - { - ulong va = GetFreePosition(size, 1, PageSize); - - if (va != PteUnmapped && va <= uint.MaxValue && (va + size) <= uint.MaxValue) - { - for (ulong offset = 0; offset < size; offset += PageSize) - { - SetPte(va + offset, pa + offset); - } - } - else - { - va = PteUnmapped; - } - - return va; - } - } - - /// - /// Reserves memory at a fixed GPU memory location. - /// This prevents the reserved region from being used for memory allocation for map. - /// - /// GPU virtual address to reserve - /// Size in bytes of the reservation - /// GPU virtual address of the reservation, or an all ones mask in case of failure - public ulong ReserveFixed(ulong va, ulong size) - { - lock (_pageTable) - { - MemoryUnmapped?.Invoke(this, new UnmapEventArgs(va, size)); - - for (ulong offset = 0; offset < size; offset += PageSize) - { - if (IsPageInUse(va + offset)) - { - return PteUnmapped; - } - } - - for (ulong offset = 0; offset < size; offset += PageSize) - { - SetPte(va + offset, PteReserved); - } - } - - return va; - } - - /// - /// Reserves memory at any GPU memory location. - /// - /// Size in bytes of the reservation - /// Reservation address alignment in bytes - /// GPU virtual address of the reservation, or an all ones mask in case of failure - public ulong Reserve(ulong size, ulong alignment) - { - lock (_pageTable) - { - ulong address = GetFreePosition(size, alignment); - - if (address != PteUnmapped) - { - for (ulong offset = 0; offset < size; offset += PageSize) - { - SetPte(address + offset, PteReserved); - } - } - - return address; - } - } - /// /// Frees memory that was previously allocated by a map or reserved. /// @@ -265,55 +150,6 @@ namespace Ryujinx.Graphics.Gpu.Memory } } - /// - /// Gets the address of an unused (free) region of the specified size. - /// - /// Size of the region in bytes - /// Required alignment of the region address in bytes - /// Start address of the search on the address space - /// GPU virtual address of the allocation, or an all ones mask in case of failure - private ulong GetFreePosition(ulong size, ulong alignment = 1, ulong start = 1UL << 32) - { - // Note: Address 0 is not considered valid by the driver, - // when 0 is returned it's considered a mapping error. - ulong address = start; - ulong freeSize = 0; - - if (alignment == 0) - { - alignment = 1; - } - - alignment = (alignment + PageMask) & ~PageMask; - - while (address + freeSize < AddressSpaceSize) - { - if (!IsPageInUse(address + freeSize)) - { - freeSize += PageSize; - - if (freeSize >= size) - { - return address; - } - } - else - { - address += freeSize + PageSize; - freeSize = 0; - - ulong remainder = address % alignment; - - if (remainder != 0) - { - address = (address - remainder) + alignment; - } - } - } - - return PteUnmapped; - } - /// /// Checks if a given page is mapped. /// @@ -333,7 +169,7 @@ namespace Ryujinx.Graphics.Gpu.Memory { ulong baseAddress = GetPte(gpuVa); - if (baseAddress == PteUnmapped || baseAddress == PteReserved) + if (baseAddress == PteUnmapped) { return PteUnmapped; } @@ -341,29 +177,6 @@ namespace Ryujinx.Graphics.Gpu.Memory return baseAddress + (gpuVa & PageMask); } - /// - /// Checks if a given memory page is mapped or reserved. - /// - /// GPU virtual address of the page - /// True if the page is mapped or reserved, false otherwise - private bool IsPageInUse(ulong gpuVa) - { - if (gpuVa >> PtLvl0Bits + PtLvl1Bits + PtPageBits != 0) - { - return false; - } - - ulong l0 = (gpuVa >> PtLvl0Bit) & PtLvl0Mask; - ulong l1 = (gpuVa >> PtLvl1Bit) & PtLvl1Mask; - - if (_pageTable[l0] == null) - { - return false; - } - - return _pageTable[l0][l1] != PteUnmapped; - } - /// /// Gets the Page Table entry for a given GPU virtual address. /// diff --git a/Ryujinx.HLE/HOS/Services/Nv/NvDrvServices/NvHostAsGpu/NvHostAsGpuDeviceFile.cs b/Ryujinx.HLE/HOS/Services/Nv/NvDrvServices/NvHostAsGpu/NvHostAsGpuDeviceFile.cs index 6c49fd5c..fb973229 100644 --- a/Ryujinx.HLE/HOS/Services/Nv/NvDrvServices/NvHostAsGpu/NvHostAsGpuDeviceFile.cs +++ b/Ryujinx.HLE/HOS/Services/Nv/NvDrvServices/NvHostAsGpu/NvHostAsGpuDeviceFile.cs @@ -1,4 +1,5 @@ -using Ryujinx.Common.Logging; +using Ryujinx.Common.Collections; +using Ryujinx.Common.Logging; using Ryujinx.Graphics.Gpu.Memory; using Ryujinx.HLE.HOS.Kernel.Process; using Ryujinx.HLE.HOS.Services.Nv.NvDrvServices.NvHostAsGpu.Types; @@ -12,8 +13,12 @@ namespace Ryujinx.HLE.HOS.Services.Nv.NvDrvServices.NvHostAsGpu class NvHostAsGpuDeviceFile : NvDeviceFile { private static ConcurrentDictionary _addressSpaceContextRegistry = new ConcurrentDictionary(); + private NvMemoryAllocator _memoryAllocator; - public NvHostAsGpuDeviceFile(ServiceCtx context, IVirtualMemoryManager memory, long owner) : base(context, owner) { } + public NvHostAsGpuDeviceFile(ServiceCtx context, IVirtualMemoryManager memory, long owner) : base(context, owner) + { + _memoryAllocator = context.Device.MemoryAllocator; + } public override NvInternalResult Ioctl(NvIoctl command, Span arguments) { @@ -92,11 +97,30 @@ namespace Ryujinx.HLE.HOS.Services.Nv.NvDrvServices.NvHostAsGpu // the Offset field holds the alignment size instead. if ((arguments.Flags & AddressSpaceFlags.FixedOffset) != 0) { - arguments.Offset = (long)addressSpaceContext.Gmm.ReserveFixed((ulong)arguments.Offset, size); + bool regionInUse = _memoryAllocator.IsRegionInUse((ulong)arguments.Offset, size, out ulong freeAddressStartPosition); + ulong address; + + if (!regionInUse) + { + _memoryAllocator.AllocateRange((ulong)arguments.Offset, size, freeAddressStartPosition); + address = freeAddressStartPosition; + } + else + { + address = NvMemoryAllocator.PteUnmapped; + } + + arguments.Offset = (long)address; } else { - arguments.Offset = (long)addressSpaceContext.Gmm.Reserve((ulong)size, (ulong)arguments.Offset); + ulong address = _memoryAllocator.GetFreeAddress((ulong)size, out ulong freeAddressStartPosition, (ulong)arguments.Offset); + if (address != NvMemoryAllocator.PteUnmapped) + { + _memoryAllocator.AllocateRange(address, (ulong)size, freeAddressStartPosition); + } + + arguments.Offset = unchecked((long)address); } if (arguments.Offset < 0) @@ -128,6 +152,7 @@ namespace Ryujinx.HLE.HOS.Services.Nv.NvDrvServices.NvHostAsGpu if (addressSpaceContext.RemoveReservation(arguments.Offset)) { + _memoryAllocator.DeallocateRange((ulong)arguments.Offset, size); addressSpaceContext.Gmm.Free((ulong)arguments.Offset, size); } else @@ -152,6 +177,7 @@ namespace Ryujinx.HLE.HOS.Services.Nv.NvDrvServices.NvHostAsGpu { if (size != 0) { + _memoryAllocator.DeallocateRange((ulong)arguments.Offset, (ulong)size); addressSpaceContext.Gmm.Free((ulong)arguments.Offset, (ulong)size); } } @@ -252,7 +278,12 @@ namespace Ryujinx.HLE.HOS.Services.Nv.NvDrvServices.NvHostAsGpu } else { - arguments.Offset = (long)addressSpaceContext.Gmm.MapAllocate((ulong)physicalAddress, (ulong)size, pageSize); + ulong va = _memoryAllocator.GetFreeAddress((ulong)size, out ulong freeAddressStartPosition, (ulong) pageSize); + if (va != NvMemoryAllocator.PteUnmapped) + { + _memoryAllocator.AllocateRange(va, (ulong)size, freeAddressStartPosition); + } + arguments.Offset = (long)addressSpaceContext.Gmm.Map((ulong)physicalAddress, va, (ulong)size); } if (arguments.Offset < 0) diff --git a/Ryujinx.HLE/HOS/Services/Nv/NvDrvServices/NvHostChannel/NvHostChannelDeviceFile.cs b/Ryujinx.HLE/HOS/Services/Nv/NvDrvServices/NvHostChannel/NvHostChannelDeviceFile.cs index d675ffc7..ca20aab5 100644 --- a/Ryujinx.HLE/HOS/Services/Nv/NvDrvServices/NvHostChannel/NvHostChannelDeviceFile.cs +++ b/Ryujinx.HLE/HOS/Services/Nv/NvDrvServices/NvHostChannel/NvHostChannelDeviceFile.cs @@ -1,4 +1,5 @@ -using Ryujinx.Common.Logging; +using Ryujinx.Common.Collections; +using Ryujinx.Common.Logging; using Ryujinx.Graphics.Gpu.Memory; using Ryujinx.HLE.HOS.Services.Nv.NvDrvServices.NvHostAsGpu; using Ryujinx.HLE.HOS.Services.Nv.NvDrvServices.NvHostChannel.Types; @@ -23,6 +24,7 @@ namespace Ryujinx.HLE.HOS.Services.Nv.NvDrvServices.NvHostChannel private readonly Switch _device; private readonly IVirtualMemoryManager _memory; + private NvMemoryAllocator _memoryAllocator; public enum ResourcePolicy { @@ -45,6 +47,7 @@ namespace Ryujinx.HLE.HOS.Services.Nv.NvDrvServices.NvHostChannel _timeout = 3000; _submitTimeout = 0; _timeslice = 0; + _memoryAllocator = _device.MemoryAllocator; ChannelSyncpoints = new uint[MaxModuleSyncpoint]; @@ -245,7 +248,17 @@ namespace Ryujinx.HLE.HOS.Services.Nv.NvDrvServices.NvHostChannel { if (map.DmaMapAddress == 0) { - map.DmaMapAddress = (long)gmm.MapLow((ulong)map.Address, (uint)map.Size); + ulong va = _memoryAllocator.GetFreeAddress((ulong) map.Size, out ulong freeAddressStartPosition, 1, MemoryManager.PageSize); + + if (va != NvMemoryAllocator.PteUnmapped && va <= uint.MaxValue && (va + (uint)map.Size) <= uint.MaxValue) + { + _memoryAllocator.AllocateRange(va, (uint)map.Size, freeAddressStartPosition); + map.DmaMapAddress = (long)gmm.Map((ulong)map.Address, va, (uint)map.Size); + } + else + { + map.DmaMapAddress = unchecked((long)NvMemoryAllocator.PteUnmapped); + } } commandBufferEntry.MapAddress = (int)map.DmaMapAddress; diff --git a/Ryujinx.HLE/HOS/Services/Nv/NvMemoryAllocator.cs b/Ryujinx.HLE/HOS/Services/Nv/NvMemoryAllocator.cs new file mode 100644 index 00000000..46626608 --- /dev/null +++ b/Ryujinx.HLE/HOS/Services/Nv/NvMemoryAllocator.cs @@ -0,0 +1,282 @@ +using Ryujinx.Common.Collections; +using System.Collections.Generic; +using Ryujinx.Common; +using System; +using Ryujinx.Graphics.Gpu.Memory; + +namespace Ryujinx.HLE.HOS.Services.Nv.NvDrvServices +{ + class NvMemoryAllocator + { + private const ulong AddressSpaceSize = 1UL << 40; + + private const ulong DefaultStart = 1UL << 32; + private const ulong InvalidAddress = 0; + + private const ulong PageSize = MemoryManager.PageSize; + private const ulong PageMask = MemoryManager.PageMask; + + public const ulong PteUnmapped = MemoryManager.PteUnmapped; + + // Key --> Start Address of Region + // Value --> End Address of Region + private readonly TreeDictionary _tree = new TreeDictionary(); + + private readonly Dictionary> _dictionary = new Dictionary>(); + private readonly LinkedList _list = new LinkedList(); + + public NvMemoryAllocator() + { + _tree.Add(PageSize, PageSize + AddressSpaceSize); + LinkedListNode node = _list.AddFirst(PageSize); + _dictionary[PageSize] = node; + } + + /// + /// Marks a range of memory as consumed by removing it from the tree. + /// This function will split memory regions if there is available space. + /// + /// Virtual address at which to allocate + /// Size of the allocation in bytes + /// Reference to the address of memory where the allocation can take place + #region Memory Allocation + public void AllocateRange(ulong va, ulong size, ulong referenceAddress = InvalidAddress) + { + lock (_tree) + { + if (referenceAddress != InvalidAddress) + { + ulong endAddress = va + size; + ulong referenceEndAddress = _tree.Get(referenceAddress); + if (va >= referenceAddress) + { + // Need Left Node + if (va > referenceAddress) + { + ulong leftEndAddress = va; + + // Overwrite existing block with its new smaller range. + _tree.Add(referenceAddress, leftEndAddress); + } + else + { + // We need to get rid of the large chunk. + _tree.Remove(referenceAddress); + } + + ulong rightSize = referenceEndAddress - endAddress; + // If leftover space, create a right node. + if (rightSize > 0) + { + _tree.Add(endAddress, referenceEndAddress); + + LinkedListNode node = _list.AddAfter(_dictionary[referenceAddress], endAddress); + _dictionary[endAddress] = node; + } + + if (va == referenceAddress) + { + _list.Remove(_dictionary[referenceAddress]); + _dictionary.Remove(referenceAddress); + } + } + } + } + } + + /// + /// Marks a range of memory as free by adding it to the tree. + /// This function will automatically compact the tree when it determines there are multiple ranges of free memory adjacent to each other. + /// + /// Virtual address at which to deallocate + /// Size of the allocation in bytes + public void DeallocateRange(ulong va, ulong size) + { + lock (_tree) + { + ulong freeAddressStartPosition = _tree.Floor(va); + if (freeAddressStartPosition != InvalidAddress) + { + LinkedListNode node = _dictionary[freeAddressStartPosition]; + ulong targetPrevAddress = _dictionary[freeAddressStartPosition].Previous != null ? _dictionary[_dictionary[freeAddressStartPosition].Previous.Value].Value : InvalidAddress; + ulong targetNextAddress = _dictionary[freeAddressStartPosition].Next != null ? _dictionary[_dictionary[freeAddressStartPosition].Next.Value].Value : InvalidAddress; + ulong expandedStart = va; + ulong expandedEnd = va + size; + + while (targetPrevAddress != InvalidAddress) + { + ulong prevAddress = targetPrevAddress; + ulong prevEndAddress = _tree.Get(targetPrevAddress); + if (prevEndAddress >= expandedStart) + { + expandedStart = targetPrevAddress; + LinkedListNode prevPtr = _dictionary[prevAddress]; + if (prevPtr.Previous != null) + { + targetPrevAddress = prevPtr.Previous.Value; + } + else + { + targetPrevAddress = InvalidAddress; + } + node = node.Previous; + _tree.Remove(prevAddress); + _list.Remove(_dictionary[prevAddress]); + _dictionary.Remove(prevAddress); + } + else + { + break; + } + } + + while (targetNextAddress != InvalidAddress) + { + ulong nextAddress = targetNextAddress; + ulong nextEndAddress = _tree.Get(targetNextAddress); + if (nextAddress <= expandedEnd) + { + expandedEnd = Math.Max(expandedEnd, nextEndAddress); + LinkedListNode nextPtr = _dictionary[nextAddress]; + if (nextPtr.Next != null) + { + targetNextAddress = nextPtr.Next.Value; + } + else + { + targetNextAddress = InvalidAddress; + } + _tree.Remove(nextAddress); + _list.Remove(_dictionary[nextAddress]); + _dictionary.Remove(nextAddress); + } + else + { + break; + } + } + _tree.Add(expandedStart, expandedEnd); + LinkedListNode nodePtr = _list.AddAfter(node, expandedStart); + _dictionary[expandedStart] = nodePtr; + } + } + } + + /// + /// Gets the address of an unused (free) region of the specified size. + /// + /// Size of the region in bytes + /// Position at which memory can be allocated + /// Required alignment of the region address in bytes + /// Start address of the search on the address space + /// GPU virtual address of the allocation, or an all ones mask in case of failure + public ulong GetFreeAddress(ulong size, out ulong freeAddressStartPosition, ulong alignment = 1, ulong start = DefaultStart) + { + // Note: Address 0 is not considered valid by the driver, + // when 0 is returned it's considered a mapping error. + lock (_tree) + { + ulong address = start; + + if (alignment == 0) + { + alignment = 1; + } + + alignment = (alignment + PageMask) & ~PageMask; + if (address < AddressSpaceSize) + { + bool completedFirstPass = false; + ulong targetAddress; + if(start == DefaultStart) + { + targetAddress = _list.Last.Value; + } + else + { + targetAddress = _tree.Floor(address); + if(targetAddress == InvalidAddress) + { + targetAddress = _tree.Ceiling(address); + } + } + while (address < AddressSpaceSize) + { + if (targetAddress != InvalidAddress) + { + if (address >= targetAddress) + { + if (address + size <= _tree.Get(targetAddress)) + { + freeAddressStartPosition = targetAddress; + return address; + } + else + { + LinkedListNode nextPtr = _dictionary[targetAddress]; + if (nextPtr.Next != null) + { + targetAddress = nextPtr.Next.Value; + } + else + { + if (completedFirstPass) + { + break; + } + else + { + completedFirstPass = true; + address = start; + targetAddress = _tree.Floor(address); + } + } + } + } + else + { + address += PageSize * (targetAddress / PageSize - (address / PageSize)); + + ulong remainder = address % alignment; + + if (remainder != 0) + { + address = (address - remainder) + alignment; + } + } + } + else + { + break; + } + } + } + freeAddressStartPosition = InvalidAddress; + } + + return PteUnmapped; + } + + /// + /// Checks if a given memory region is mapped or reserved. + /// + /// GPU virtual address of the page + /// Size of the allocation in bytes + /// Nearest lower address that memory can be allocated + /// True if the page is mapped or reserved, false otherwise + public bool IsRegionInUse(ulong gpuVa, ulong size, out ulong freeAddressStartPosition) + { + lock (_tree) + { + ulong floorAddress = _tree.Floor(gpuVa); + freeAddressStartPosition = floorAddress; + if (floorAddress != InvalidAddress) + { + return !(gpuVa >= floorAddress && ((gpuVa + size) < _tree.Get(floorAddress))); + } + } + return true; + } + #endregion + } +} diff --git a/Ryujinx.HLE/Switch.cs b/Ryujinx.HLE/Switch.cs index d54c64e1..6014ccff 100644 --- a/Ryujinx.HLE/Switch.cs +++ b/Ryujinx.HLE/Switch.cs @@ -12,6 +12,7 @@ using Ryujinx.HLE.HOS; using Ryujinx.HLE.HOS.Services; using Ryujinx.HLE.HOS.Services.Apm; using Ryujinx.HLE.HOS.Services.Hid; +using Ryujinx.HLE.HOS.Services.Nv.NvDrvServices; using Ryujinx.HLE.HOS.SystemState; using Ryujinx.Memory; using System; @@ -26,6 +27,8 @@ namespace Ryujinx.HLE public GpuContext Gpu { get; private set; } + internal NvMemoryAllocator MemoryAllocator { get; private set; } + internal Host1xDevice Host1x { get; } public VirtualFileSystem FileSystem { get; private set; } @@ -69,6 +72,8 @@ namespace Ryujinx.HLE Gpu = new GpuContext(renderer); + MemoryAllocator = new NvMemoryAllocator(); + Host1x = new Host1xDevice(Gpu.Synchronization); var nvdec = new NvdecDevice(Gpu.MemoryManager); var vic = new VicDevice(Gpu.MemoryManager); diff --git a/Ryujinx.Tests/TreeDictionaryTests.cs b/Ryujinx.Tests/TreeDictionaryTests.cs new file mode 100644 index 00000000..610c2f6e --- /dev/null +++ b/Ryujinx.Tests/TreeDictionaryTests.cs @@ -0,0 +1,244 @@ +using NUnit.Framework; +using Ryujinx.Common.Collections; +using System; +using System.Collections.Generic; + +namespace Ryujinx.Tests.Collections +{ + class TreeDictionaryTests + { + [Test] + public void EnsureAddIntegrity() + { + TreeDictionary dictionary = new TreeDictionary(); + + Assert.AreEqual(dictionary.Count, 0); + + dictionary.Add(2, 7); + dictionary.Add(1, 4); + dictionary.Add(10, 2); + dictionary.Add(4, 1); + dictionary.Add(3, 2); + dictionary.Add(11, 2); + dictionary.Add(5, 2); + + Assert.AreEqual(dictionary.Count, 7); + + List> list = dictionary.AsLevelOrderList(); + + /* + * Tree Should Look as Follows After Rotations + * + * 2 + * 1 4 + * 3 10 + * 5 11 + * + */ + + Assert.AreEqual(list.Count, dictionary.Count); + Assert.AreEqual(list[0].Key, 2); + Assert.AreEqual(list[1].Key, 1); + Assert.AreEqual(list[2].Key, 4); + Assert.AreEqual(list[3].Key, 3); + Assert.AreEqual(list[4].Key, 10); + Assert.AreEqual(list[5].Key, 5); + Assert.AreEqual(list[6].Key, 11); + } + + [Test] + public void EnsureRemoveIntegrity() + { + TreeDictionary dictionary = new TreeDictionary(); + + Assert.AreEqual(dictionary.Count, 0); + + dictionary.Add(2, 7); + dictionary.Add(1, 4); + dictionary.Add(10, 2); + dictionary.Add(4, 1); + dictionary.Add(3, 2); + dictionary.Add(11, 2); + dictionary.Add(5, 2); + dictionary.Add(7, 2); + dictionary.Add(9, 2); + dictionary.Add(8, 2); + dictionary.Add(13, 2); + dictionary.Add(24, 2); + dictionary.Add(6, 2); + Assert.AreEqual(dictionary.Count, 13); + + List> list = dictionary.AsLevelOrderList(); + + /* + * Tree Should Look as Follows After Rotations + * + * 4 + * 2 10 + * 1 3 7 13 + * 5 9 11 24 + * 6 8 + */ + + foreach (KeyValuePair node in list) + { + Console.WriteLine($"{node.Key} -> {node.Value}"); + } + Assert.AreEqual(list.Count, dictionary.Count); + Assert.AreEqual(list[0].Key, 4); + Assert.AreEqual(list[1].Key, 2); + Assert.AreEqual(list[2].Key, 10); + Assert.AreEqual(list[3].Key, 1); + Assert.AreEqual(list[4].Key, 3); + Assert.AreEqual(list[5].Key, 7); + Assert.AreEqual(list[6].Key, 13); + Assert.AreEqual(list[7].Key, 5); + Assert.AreEqual(list[8].Key, 9); + Assert.AreEqual(list[9].Key, 11); + Assert.AreEqual(list[10].Key, 24); + Assert.AreEqual(list[11].Key, 6); + Assert.AreEqual(list[12].Key, 8); + + list.Clear(); + + dictionary.Remove(7); + + /* + * Tree Should Look as Follows After Removal + * + * 4 + * 2 10 + * 1 3 6 13 + * 5 9 11 24 + * 8 + */ + + list = dictionary.AsLevelOrderList(); + foreach (KeyValuePair node in list) + { + Console.WriteLine($"{node.Key} -> {node.Value}"); + } + Assert.AreEqual(list[0].Key, 4); + Assert.AreEqual(list[1].Key, 2); + Assert.AreEqual(list[2].Key, 10); + Assert.AreEqual(list[3].Key, 1); + Assert.AreEqual(list[4].Key, 3); + Assert.AreEqual(list[5].Key, 6); + Assert.AreEqual(list[6].Key, 13); + Assert.AreEqual(list[7].Key, 5); + Assert.AreEqual(list[8].Key, 9); + Assert.AreEqual(list[9].Key, 11); + Assert.AreEqual(list[10].Key, 24); + Assert.AreEqual(list[11].Key, 8); + + list.Clear(); + + dictionary.Remove(10); + + list = dictionary.AsLevelOrderList(); + /* + * Tree Should Look as Follows After Removal + * + * 4 + * 2 9 + * 1 3 6 13 + * 5 8 11 24 + * + */ + foreach (KeyValuePair node in list) + { + Console.WriteLine($"{node.Key} -> {node.Value}"); + } + Assert.AreEqual(list[0].Key, 4); + Assert.AreEqual(list[1].Key, 2); + Assert.AreEqual(list[2].Key, 9); + Assert.AreEqual(list[3].Key, 1); + Assert.AreEqual(list[4].Key, 3); + Assert.AreEqual(list[5].Key, 6); + Assert.AreEqual(list[6].Key, 13); + Assert.AreEqual(list[7].Key, 5); + Assert.AreEqual(list[8].Key, 8); + Assert.AreEqual(list[9].Key, 11); + Assert.AreEqual(list[10].Key, 24); + } + + [Test] + public void EnsureOverwriteIntegrity() + { + TreeDictionary dictionary = new TreeDictionary(); + + Assert.AreEqual(dictionary.Count, 0); + + dictionary.Add(2, 7); + dictionary.Add(1, 4); + dictionary.Add(10, 2); + dictionary.Add(4, 1); + dictionary.Add(3, 2); + dictionary.Add(11, 2); + dictionary.Add(5, 2); + dictionary.Add(7, 2); + dictionary.Add(9, 2); + dictionary.Add(8, 2); + dictionary.Add(13, 2); + dictionary.Add(24, 2); + dictionary.Add(6, 2); + Assert.AreEqual(dictionary.Count, 13); + + List> list = dictionary.AsLevelOrderList(); + + foreach (KeyValuePair node in list) + { + Console.WriteLine($"{node.Key} -> {node.Value}"); + } + + /* + * Tree Should Look as Follows After Rotations + * + * 4 + * 2 10 + * 1 3 7 13 + * 5 9 11 24 + * 6 8 + */ + + Assert.AreEqual(list.Count, dictionary.Count); + Assert.AreEqual(list[0].Key, 4); + Assert.AreEqual(list[1].Key, 2); + Assert.AreEqual(list[2].Key, 10); + Assert.AreEqual(list[3].Key, 1); + Assert.AreEqual(list[4].Key, 3); + Assert.AreEqual(list[5].Key, 7); + Assert.AreEqual(list[6].Key, 13); + Assert.AreEqual(list[7].Key, 5); + Assert.AreEqual(list[8].Key, 9); + Assert.AreEqual(list[9].Key, 11); + Assert.AreEqual(list[10].Key, 24); + Assert.AreEqual(list[11].Key, 6); + Assert.AreEqual(list[12].Key, 8); + + Assert.AreEqual(list[4].Value, 2); + + dictionary.Add(3, 4); + + list = dictionary.AsLevelOrderList(); + + Assert.AreEqual(list[4].Value, 4); + + + // Assure that none of the nodes locations have been modified. + Assert.AreEqual(list[0].Key, 4); + Assert.AreEqual(list[1].Key, 2); + Assert.AreEqual(list[2].Key, 10); + Assert.AreEqual(list[3].Key, 1); + Assert.AreEqual(list[4].Key, 3); + Assert.AreEqual(list[5].Key, 7); + Assert.AreEqual(list[6].Key, 13); + Assert.AreEqual(list[7].Key, 5); + Assert.AreEqual(list[8].Key, 9); + Assert.AreEqual(list[9].Key, 11); + Assert.AreEqual(list[10].Key, 24); + Assert.AreEqual(list[11].Key, 6); + Assert.AreEqual(list[12].Key, 8); + } + } +}