Mostly as a learning exercise, and partly as I thought it might be useful, I have written an implementation of AVL trees ( https://en.wikipedia.org/wiki/AVL_tree ).
An AVL tree is a height-balanced binary search tree, where if the height difference of the left and right sub-trees becomes more than one a "rotate" operation is applied to re-balance the tree, which keeps the tree nodes in the same order, but promoting a child node upwards, with the unbalanced node being demoted.
Implementing C# iterators was totally new to me, so that has been a learning exercise, and I'm not quite sure if my style on that is quite right.
There are two main classes, a key-only tree AvlTree, and a key/value dictionary AvlDict.
I tend to stick to C# version 2, for reasons I won't go into, and I do not generally use _ for private fields, please don't comment on that - all other comments and reviews welcome. The way I have implemented dictionary assignment is perhaps a little questionable, but I wanted to separate out the key/value pair logic from the AVL tree. Is there a better way to do that?
Here is the code, starting with a usage example:
using Collections = System.Collections;
using Generic = System.Collections.Generic;
using Console = System.Console; // For example usage.
class AvlExample
{
public static void Usage()
{
AvlTree<long> tree = new AvlTree<long>();
int testSize = 5 * 1000 * 1000;
// Insert keys 0, 10, 20 ... into the tree.
for ( int i = 0; i < testSize; i += 10 ) tree.Insert( i );
Console.WriteLine( "Should print 50,60..100" );
foreach ( long x in tree.Range( 50, 100 ) ) Console.WriteLine( x );
// Remove 4/5 of the keys to test Remove.
for ( int i = 0; i < testSize; i += 10 ) if ( i % 50 != 0 ) tree.Remove( i );
Console.WriteLine( "Should print 50,100..250" );
foreach ( long x in tree.Range( 50, 250 ) ) Console.WriteLine( x );
AvlDict<int,string> dict = new AvlDict<int,string>( "" );
dict[ 100 ] = "There";
dict[ 50 ] = "Hello";
dict[ 100 ] = "there";
foreach ( int i in dict ) Console.WriteLine( dict[ i ] );
}
}
class AvlTree<T> : Generic.IEnumerable<T>
{
public delegate int DComparer( T key1, T key2 );
public AvlTree() // Initialise with default compare.
{
Compare = Generic.Comparer<T>.Default.Compare;
}
public AvlTree( DComparer compare ) // Initialise with specific compare function.
{
Compare = compare;
}
public void Insert( T key )
// Insert key into the tree. If key is already in tree, Found is called.
{
bool heightIncreased;
Root = Insert( Root, key, out heightIncreased );
}
public bool Contains( T key )
{
return Lookup( key ) != null;
}
public void Remove( T key )
// Remove key from the tree. If key is not present, has no effect.
{
bool heightIncreased;
Root = Remove( Root, key, out heightIncreased );
}
public Generic.IEnumerator<T> GetEnumerator()
{
if ( Root != null ) foreach( T key in Root ) yield return key;
}
Collections.IEnumerator Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public Generic.IEnumerable<T> Range( T start, T end )
{
if ( Root != null ) foreach( T key in Root.Range( start, end, Compare ) ) yield return key;
}
protected virtual Node NewNode( T key )
{
// Called by Insert if key not found.
return new Node( key );
}
protected virtual void Found( Node x )
{
// Called by Insert when an existing key is found.
}
protected virtual void FreeNode( Node x )
{
// Called by Remove when a Node is removed.
}
protected Node Lookup( T key )
{
Node x = Root;
while ( x != null )
{
int cf = Compare( key, x.Key );
if ( cf < 0 ) x = x.Left;
else if ( cf > 0 ) x = x.Right;
else return x;
}
return null;
}
protected class Node : Generic.IEnumerable<T>
{
public Node Left, Right;
public readonly T Key;
public sbyte Balance;
public Node( T key )
{
Key = key;
}
public Generic.IEnumerator<T> GetEnumerator()
{
if ( Left != null ) foreach ( T key in Left ) yield return key;
yield return Key;
if ( Right != null ) foreach ( T key in Right ) yield return key;
}
Collections.IEnumerator Collections.IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
public Generic.IEnumerable<T> Range( T start, T end, DComparer compare )
{
int cstart = compare( start, Key );
int cend = compare( end, Key );
if ( cstart < 0 && Left != null )
{
foreach ( T key in Left.Range( start, end, compare ) ) yield return key;
}
if ( cstart <= 0 && cend >= 0 ) yield return Key;
if ( cend > 0 && Right != null )
{
foreach ( T key in Right.Range( start, end, compare ) ) yield return key;
}
}
} // Node
// Fields.
private readonly DComparer Compare;
private Node Root;
// Constant values for Node.Balance.
private const int LeftHigher = -1, Balanced = 0, RightHigher = 1;
// Private methods.
private Node Insert( Node x, T key, out bool heightIncreased )
{
if ( x == null )
{
x = NewNode( key );
heightIncreased = true;
}
else
{
int compare = Compare( key, x.Key );
if ( compare < 0 )
{
x.Left = Insert( x.Left, key, out heightIncreased );
if ( heightIncreased )
{
if ( x.Balance == Balanced )
{
x.Balance = LeftHigher;
}
else
{
heightIncreased = false;
if ( x.Balance == LeftHigher )
{
bool heightDecreased;
return RotateRight( x, out heightDecreased );
}
x.Balance = Balanced;
}
}
}
else if ( compare > 0 )
{
x.Right = Insert( x.Right, key, out heightIncreased );
if ( heightIncreased )
{
if ( x.Balance == Balanced )
{
x.Balance = RightHigher;
}
else
{
heightIncreased = false;
if ( x.Balance == RightHigher )
{
bool heightDecreased;
return RotateLeft( x, out heightDecreased );
}
x.Balance = Balanced;
}
}
}
else // compare == 0
{
Found( x );
heightIncreased = false;
}
}
return x;
}
private Node Remove( Node x, T key, out bool heightDecreased )
{
if ( x == null ) // key not found.
{
heightDecreased = false;
return x;
}
int compare = Compare( key, x.Key );
if ( compare == 0 )
{
Node deleted = x;
if ( x.Left == null )
{
heightDecreased = true;
x = x.Right;
}
else if ( x.Right == null )
{
heightDecreased = true;
x = x.Left;
}
else
{
// Remove the smallest element in the right sub-tree and substitute it for x.
Node right = RemoveLeast( deleted.Right, out x, out heightDecreased );
x.Left = deleted.Left;
x.Right = right;
x.Balance = deleted.Balance;
if ( heightDecreased )
{
if ( x.Balance == LeftHigher )
{
x = RotateRight( x, out heightDecreased );
}
else if ( x.Balance == RightHigher )
{
x.Balance = Balanced;
}
else
{
x.Balance = LeftHigher;
heightDecreased = false;
}
}
}
FreeNode( deleted );
}
else if ( compare < 0 )
{
x.Left = Remove( x.Left, key, out heightDecreased );
if ( heightDecreased )
{
if ( x.Balance == RightHigher )
{
return RotateLeft( x, out heightDecreased );
}
if ( x.Balance == LeftHigher )
{
x.Balance = Balanced;
}
else
{
x.Balance = RightHigher;
heightDecreased = false;
}
}
}
else
{
x.Right = Remove( x.Right, key, out heightDecreased );
if ( heightDecreased )
{
if ( x.Balance == LeftHigher )
{
return RotateRight( x, out heightDecreased );
}
if ( x.Balance == RightHigher )
{
x.Balance = Balanced;
}
else
{
x.Balance = LeftHigher;
heightDecreased = false;
}
}
}
return x;
}
private static Node RemoveLeast( Node x, out Node least, out bool heightDecreased )
{
if ( x.Left == null )
{
heightDecreased = true;
least = x;
return x.Right;
}
else
{
x.Left = RemoveLeast( x.Left, out least, out heightDecreased );
if ( heightDecreased )
{
if ( x.Balance == RightHigher )
{
return RotateLeft( x, out heightDecreased );
}
if ( x.Balance == LeftHigher )
{
x.Balance = Balanced;
}
else
{
x.Balance = RightHigher;
heightDecreased = false;
}
}
return x;
}
}
private static Node RotateRight( Node x, out bool heightDecreased )
{
// Left is 2 levels higher than Right.
heightDecreased = true;
Node z = x.Left;
Node y = z.Right;
if ( z.Balance != RightHigher ) // Single rotation.
{
z.Right = x;
x.Left = y;
if ( z.Balance == Balanced ) // Can only occur when deleting values.
{
x.Balance = LeftHigher;
z.Balance = RightHigher;
heightDecreased = false;
}
else // z.Balance = LeftHigher
{
x.Balance = Balanced;
z.Balance = Balanced;
}
return z;
}
else // Double rotation.
{
x.Left = y.Right;
z.Right = y.Left;
y.Right = x;
y.Left = z;
if ( y.Balance == LeftHigher )
{
x.Balance = RightHigher;
z.Balance = Balanced;
}
else if ( y.Balance == Balanced )
{
x.Balance = Balanced;
z.Balance = Balanced;
}
else // y.Balance == RightHigher
{
x.Balance = Balanced;
z.Balance = LeftHigher;
}
y.Balance = Balanced;
return y;
}
}
private static Node RotateLeft( Node x, out bool heightDecreased )
{
// Right is 2 levels higher than Left.
heightDecreased = true;
Node z = x.Right;
Node y = z.Left;
if ( z.Balance != LeftHigher ) // Single rotation.
{
z.Left = x;
x.Right = y;
if ( z.Balance == Balanced ) // Can only occur when deleting values.
{
x.Balance = RightHigher;
z.Balance = LeftHigher;
heightDecreased = false;
}
else // z.Balance = RightHigher
{
x.Balance = Balanced;
z.Balance = Balanced;
}
return z;
}
else // Double rotation
{
x.Right = y.Left;
z.Left = y.Right;
y.Left = x;
y.Right = z;
if ( y.Balance == RightHigher )
{
x.Balance = LeftHigher;
z.Balance = Balanced;
}
else if ( y.Balance == Balanced )
{
x.Balance = Balanced;
z.Balance = Balanced;
}
else // y.Balance == LeftHigher
{
x.Balance = Balanced;
z.Balance = RightHigher;
}
y.Balance = Balanced;
return y;
}
}
}
class AvlDict<TKey,TValue> : AvlTree<TKey>
{
public AvlDict( TValue def ) : base() { Default = def; }
public AvlDict( TValue def, DComparer compare ) : base( compare ) { Default = def; }
public TValue this [ TKey key ]
{
get
{
Pair p = (Pair) Lookup( key );
return p != null ? p.Value : Default;
}
set
{
Value = value;
Insert( key );
}
}
private readonly TValue Default;
private TValue Value;
private class Pair : AvlTree<TKey>.Node
{
public TValue Value;
public Pair( TKey key, TValue value ) : base( key ) { Value = value; }
}
protected override Node NewNode( TKey key )
{
return new Pair( key, Value );
}
protected override void Found( Node x )
{
Pair p = (Pair) x;
p.Value = Value;
}
}