ecere/com/Containers/CustomAVLTree: Fixed key class confusion caused by a47aec0a06c2c...
[sdk] / ecere / src / com / containers / CustomAVLTree.ec
1 namespace com;
2
3 import "Container"
4
5 default:
6
7 extern int __ecereVMethodID_class_OnCompare;
8 extern int __ecereVMethodID_class_OnCopy;
9 extern int __ecereVMethodID_class_OnFree;
10 private:
11
12 enum AddSide : int { compare = 0, left = -1, right = 1};
13
14 public class AVLNode<class T> : IteratorPointer
15 {
16    class_fixed
17
18    thisclass parent, left, right;
19    int depth;
20 public:
21    T key;
22
23    property thisclass prev
24    {
25       get
26       {
27          if(left)
28             return left.maximum;
29          while(this)
30          {
31             if(parent && this == parent.right)
32                return parent;
33             else
34                this = parent;
35          }
36          return this;
37       }
38    }
39
40    property thisclass next
41    {
42       get
43       {
44          thisclass right = this.right;
45          if(right)
46             return right.minimum;
47          while(this)
48          {
49             thisclass parent = this.parent;
50             if(parent && this == parent.left)
51                return parent;
52             else
53                this = parent;
54          }
55          return null;
56       }
57    }
58
59    property thisclass minimum
60    {
61       get { while(left) this = left; return this; }
62    }
63
64    property thisclass maximum
65    {
66       get { while(right) this = right; return this; }
67    }
68
69    property int count
70    {
71       get { return 1 + (left ? left.count : 0) + (right ? right.count : 0); }
72    }
73    property int depthProp
74    {
75       get
76       {
77          int leftDepth = left ? (left.depthProp+1) : 0;
78          int rightDepth = right ? (right.depthProp+1) : 0;
79          return Max(leftDepth, rightDepth);
80       }
81    }
82 private:
83
84    void Free()
85    {
86         if (left) left.Free();
87         if (right) right.Free();
88       delete this;
89    }
90
91    bool Add(Class Tclass, thisclass node, AddSide addSide)
92    {
93       ClassType t;
94       int (* onCompare)(void *, void *, void *);
95       uint offset = 0;
96       bool reference = false;
97       byte * a;
98
99       if(!Tclass)
100          Tclass = class(uint64);
101       t = Tclass.type;
102       onCompare = (void *)Tclass._vTbl[__ecereVMethodID_class_OnCompare];
103       if((t == systemClass && !Tclass.byValueSystemClass) || t == bitClass || t == enumClass || t == unitClass || t == structClass)
104       {
105          reference = true;
106          offset = __ENDIAN_PAD((t == structClass) ? sizeof(void *) : Tclass.typeSize);
107       }
108       a = reference ? ((byte *)&node.key + offset) : ((byte *)(uintptr)node.key);
109
110       while(true)
111       {
112          int result;
113          if(addSide)
114             result = addSide;
115          else
116          {
117             byte * b = reference ? ((byte *)&key + offset) : (byte *)(uintptr)key;
118             result = onCompare(Tclass, a, b);
119          }
120          if(!result)
121             return false;
122          else if(result > 0)
123          {
124             if(right)
125                this = right;
126             else
127             {
128                         right = node;
129                break;
130                 }
131         }
132          else
133          {
134             if(left)
135                this = left;
136             else
137             {
138                left = node;
139                break;
140             }
141         }
142       }
143       node.parent = this;
144       node.depth = 0;
145       {
146          AVLNode<T> n;
147          for(n = this; n; n = n.parent)
148          {
149             int newDepth = Max(n.left ? (n.left.depth+1) : 0, n.right ? (n.right.depth+1) : 0);
150             if(newDepth == n.depth)
151                break;
152             n.depth = newDepth;
153          }
154       }
155       return true;
156    }
157
158    public thisclass Find(Class Tclass, const T key)
159    {
160       byte * a;
161       bool reference = false;
162       uint offset = 0;
163       ClassType t = Tclass.type;
164       int (* onCompare)(void *, void *, void *) = (void *)Tclass._vTbl[__ecereVMethodID_class_OnCompare];
165
166       reference = (t == systemClass && !Tclass.byValueSystemClass) || t == bitClass || t == enumClass || t == unitClass;
167       offset = __ENDIAN_PAD(Tclass.typeSize);
168       a = reference ? ((byte *)&(uint64)key) + offset : (byte *)(uintptr)key;
169       if(t == structClass)
170       {
171          reference = true;
172          offset = __ENDIAN_PAD(sizeof(void *));
173       }
174
175       while(this)
176       {
177          // *** NEED COMPARISON OPERATOR SUPPORT HERE INVOKING OnCompare, AS WELL AS TYPE INFORMATION PASSED ***
178          byte * b = reference ? ((byte *)&this.key) + offset : (byte *)(uintptr)this.key;
179          int result = onCompare(Tclass, a, b);
180          if(result < 0)
181             this = left;
182          else if(result > 0)
183             this = right;
184          else
185             break;
186       }
187       return this;
188    }
189
190    thisclass FindEx(Class Tclass, const T key, AVLNode */*thisclass **/ addTo, AddSide * addSide)
191    {
192       byte * a;
193       bool reference = false;
194       uint offset = 0;
195       ClassType t = Tclass.type;
196       int (* onCompare)(void *, void *, void *) = (void *)Tclass._vTbl[__ecereVMethodID_class_OnCompare];
197       bool isInt64 = onCompare == (void *)class(int64).OnCompare;
198
199       reference = (t == systemClass && !Tclass.byValueSystemClass) || t == bitClass || t == enumClass || t == unitClass;
200       offset = __ENDIAN_PAD(Tclass.typeSize);
201       a = reference ? ((byte *)&(uint64)key) + offset : (byte *)(uintptr)key;
202       if(t == structClass)
203       {
204          reference = true;
205          offset = __ENDIAN_PAD(sizeof(void *));
206       }
207
208       if(Tclass == class(uint))
209       {
210          uint ia = *(uint *)a;
211          while(this)
212          {
213             uint ib = *(uint *)(reference ? ((byte *)&this.key) + offset : (byte *)(uintptr)this.key);
214             int result = ia > ib ? 1 : ia < ib ? -1 : 0;
215             if(result)
216             {
217                thisclass node = result < 0 ? left : right;
218                if(!node)
219                {
220                   if(addTo) *addTo = this;
221                   if(addSide) *addSide = (AddSide)result;
222                }
223                this = node;
224             }
225             else
226                break;
227          }
228       }
229       else
230       {
231          int64 a64;
232          if(isInt64)
233             a64 = *(int64 *)a;
234          while(this)
235          {
236             byte * b = reference ? ((byte *)&this.key) + offset : (byte *)(uintptr)this.key;
237             int result;
238             if(isInt64)
239             {
240                int64 b64 = *(int64 *)b;
241                     if(a64 > b64) result = 1;
242                else if(a64 < b64) result = -1;
243                else result = 0;
244             }
245             else
246                result = onCompare(Tclass, a, b);
247             if(result)
248             {
249                thisclass node = result < 0 ? left : right;
250                if(!node)
251                {
252                   if(addTo) *addTo = this;
253                   if(addSide) *addSide = (AddSide)result;
254                }
255                this = node;
256             }
257             else
258                break;
259          }
260       }
261       return this;
262    }
263
264    thisclass FindAll(const T key)
265    {
266       AVLNode<T> result = null;
267       // *** FIND ALL COMPARES KEY FOR EQUALITY, NOT USING OnCompare ***
268       if(this.key == key) result = this;
269       if(!result && left) result = left.FindAll(key);
270       if(!result && right) result = right.FindAll(key);
271       return result;
272    }
273
274    void RemoveSwap(thisclass swap)
275    {
276       if(swap.left)
277       {
278          swap.left.parent = swap.parent;
279          if(swap == swap.parent.left)
280             swap.parent.left = swap.left;
281          else if(swap == swap.parent.right)
282             swap.parent.right = swap.left;
283          swap.left = null;
284       }
285       if(swap.right)
286       {
287          swap.right.parent = swap.parent;
288          if(swap == swap.parent.left)
289             swap.parent.left = swap.right;
290          else if(swap == swap.parent.right)
291             swap.parent.right = swap.right;
292          swap.right = null;
293       }
294
295       if(swap == swap.parent.left) swap.parent.left = null;
296       else if(swap == swap.parent.right) swap.parent.right = null;
297
298       {
299          AVLNode<T> n;
300          for(n = swap.parent; n; n = n.parent)
301          {
302             int newDepth = Max(n.left ? (n.left.depth+1) : 0, n.right ? (n.right.depth+1) : 0);
303             if(newDepth == n.depth)
304                break;
305             n.depth = newDepth;
306             if(n == this) break;
307          }
308       }
309
310       swap.left = left;
311       if(left)
312          left.parent = swap;
313
314        swap.right = right;
315        if (right)
316             right.parent = swap;
317
318       swap.parent = parent;
319       left = null;
320       right = null;
321       if(parent)
322       {
323          if(this == parent.left) parent.left = swap;
324          else if(this == parent.right) parent.right = swap;
325       }
326    }
327
328    thisclass RemoveSwapLeft()
329    {
330       thisclass swap = left ? left.maximum : right;
331       thisclass swapParent = null;
332       if(swap) { swapParent = swap.parent; RemoveSwap(swap); }
333       if(parent)
334       {
335          if(this == parent.left) parent.left = null;
336          else if(this == parent.right) parent.right = null;
337       }
338       {
339          AVLNode<T> n;
340          for(n = swap ? swap : parent; n; n = n.parent)
341          {
342             int newDepth = Max(n.left ? (n.left.depth+1) : 0, n.right ? (n.right.depth+1) : 0);
343             if(newDepth == n.depth && n != swap)
344                break;
345             n.depth = newDepth;
346          }
347       }
348       if(swapParent && swapParent != this)
349          return swapParent.Rebalance();
350       else if(swap)
351          return swap.Rebalance();
352       else if(parent)
353          return parent.Rebalance();
354       else
355          return null;
356    }
357
358    thisclass RemoveSwapRight()
359    {
360       thisclass result;
361       thisclass swap = right ? right.minimum : left;
362       thisclass swapParent = null;
363       if(swap) { swapParent = swap.parent; RemoveSwap(swap); }
364       if(parent)
365       {
366          if(this == parent.left) parent.left = null;
367          else if(this == parent.right) parent.right = null;
368       }
369       {
370          AVLNode<T> n;
371          for(n = swap ? swap : parent; n; n = n.parent)
372          {
373             int newDepth = Max(n.left ? (n.left.depth+1) : 0, n.right ? (n.right.depth+1) : 0);
374
375             if(newDepth == n.depth && n != swap)
376                break;
377             n.depth = newDepth;
378          }
379       }
380       if(swapParent && swapParent != this)
381          result = swapParent.Rebalance();
382       else if(swap)
383          result = swap.Rebalance();
384       else if(parent)
385          result = parent.Rebalance();
386       else
387          result = null;
388       return result;
389    }
390
391    property int balanceFactor
392    {
393       get
394       {
395          int leftDepth = left ? (left.depth+1) : 0;
396          int rightDepth = right ? (right.depth+1) : 0;
397          return rightDepth - leftDepth;
398       }
399    }
400
401    thisclass Rebalance()
402    {
403       while(true)
404       {
405          int factor = balanceFactor;
406          if (factor < -1)
407          {
408             if(left.balanceFactor == 1)
409                DoubleRotateRight();
410             else
411                SingleRotateRight();
412          }
413          else if (factor > 1)
414          {
415             if(right.balanceFactor == -1)
416                DoubleRotateLeft();
417             else
418                SingleRotateLeft();
419          }
420          if(parent)
421             this = parent;
422          else
423             return this;
424       }
425    }
426
427    void SingleRotateRight()
428    {
429       if(parent)
430       {
431          if(this == parent.left)
432             parent.left = left;
433          else if(this == parent.right)
434             parent.right = left;
435       }
436       left.parent = parent;
437       parent = left;
438       left = parent.right;
439       if(left) left.parent = this;
440       parent.right = this;
441
442       depth = Max(left ? (left.depth+1) : 0, right ? (right.depth+1) : 0);
443       parent.depth = Max(parent.left ? (parent.left.depth+1) : 0, parent.right ? (parent.right.depth+1) : 0);
444       {
445          AVLNode<T> n;
446          for(n = parent.parent; n; n = n.parent)
447          {
448             int newDepth = Max(n.left ? (n.left.depth+1) : 0, n.right ? (n.right.depth+1) : 0);
449             if(newDepth == n.depth)
450                break;
451             n.depth = newDepth;
452          }
453       }
454    }
455
456    void SingleRotateLeft()
457    {
458       if(parent)
459       {
460          if(this == parent.right)
461             parent.right = right;
462          else if(this == parent.left)
463             parent.left = right;
464       }
465       right.parent = parent;
466       parent = right;
467       right = parent.left;
468       if(right) right.parent = this;
469       parent.left = this;
470
471       depth = Max(left ? (left.depth+1) : 0, right ? (right.depth+1) : 0);
472       parent.depth = Max(parent.left ? (parent.left.depth+1) : 0, parent.right ? (parent.right.depth+1) : 0);
473       {
474          AVLNode<T> n;
475          for(n = parent.parent; n; n = n.parent)
476          {
477             int newDepth = Max(n.left ? (n.left.depth+1) : 0, n.right ? (n.right.depth+1) : 0);
478             if(newDepth == n.depth)
479                break;
480             n.depth = newDepth;
481          }
482       }
483    }
484
485    void DoubleRotateRight()
486    {
487       left.SingleRotateLeft();
488       SingleRotateRight();
489    }
490
491    void DoubleRotateLeft()
492    {
493       right.SingleRotateRight();
494       SingleRotateLeft();
495    }
496 }
497
498 public class CustomAVLTree<class BT:AVLNode, class KT = uint64> : Container<BT, I = KT>
499 {
500    class_fixed
501
502 public:
503    BT root;
504    int count;
505
506    IteratorPointer GetFirst() { return (IteratorPointer) (root ? root.minimum : null); }
507    IteratorPointer GetLast()  { return (IteratorPointer) (root ? root.maximum : null); }
508    IteratorPointer GetPrev(IteratorPointer node) { return ((BT)node).prev; }
509    IteratorPointer GetNext(IteratorPointer node) { return ((BT)node).next; }
510    BT GetData(IteratorPointer node) { return (BT)node; }
511    bool SetData(IteratorPointer node, BT data)
512    {
513       // Not supported for CustomAVLTree
514       return false;
515    }
516
517    IteratorPointer Add(BT node)
518    {
519       if(!root)
520          root = node;
521       else
522       {
523          Class btClass = class(BT);
524          Class Tclass = btClass.templateArgs[0].dataTypeClass;
525          if(!Tclass)
526          {
527             Tclass = btClass.templateArgs[0].dataTypeClass =
528                eSystem_FindClass(__thisModule.application, btClass.templateArgs[0].dataTypeString);
529          }
530          if(root.Add(Tclass, node, 0))
531             root = node.Rebalance();
532          else
533             return null;
534       }
535       count++;
536       return (IteratorPointer)node;
537    }
538
539    private IteratorPointer AddEx(BT node, BT addNode, AddSide addSide)
540    {
541       if(!root)
542          root = node;
543       else
544       {
545          Class Tclass = class(BT).templateArgs[0].dataTypeClass;
546          if(!Tclass)
547          {
548             Tclass = class(BT).templateArgs[0].dataTypeClass =
549                eSystem_FindClass(__thisModule.application, class(BT).templateArgs[0].dataTypeString);
550          }
551          if(addNode.Add(Tclass, node, addSide))
552             root = node.Rebalance();
553          else
554             return null;
555       }
556       count++;
557       return (IteratorPointer)node;
558    }
559
560    void Remove(IteratorPointer node)
561    {
562       BT parent = ((BT)node).parent;
563
564       if(parent || root == (BT)node)
565       {
566          root = ((BT)node).RemoveSwapRight();
567          count--;
568          ((BT)node).parent = null;
569       }
570    }
571
572    void Delete(IteratorPointer _item)
573    {
574       BT item = (BT)_item;
575       // THIS SHOULDN'T BE CALLING THE VIRTUAL FUNCTION
576       CustomAVLTree::Remove(_item);
577       FreeKey((BT)item);
578       delete item;
579    }
580
581    void FreeKey(BT item)
582    {
583       if(class(BT).type == structClass)
584       {
585          // TODO: Make this easier...
586          Class Tclass = class(BT);
587          ((void (*)(void *, void *))(void *)Tclass._vTbl[__ecereVMethodID_class_OnFree])(Tclass, (((byte *)&item.key) + __ENDIAN_PAD(sizeof(void *))));
588       }
589       else
590       {
591          // TOFIX: delete key; // This indexes the wrong templateArg (BT instead of KT)
592          KT k = item.key;
593          delete k;
594          item.key = (KT)0;
595       }
596    }
597
598    void Free()
599    {
600       BT item;
601       item = root;
602       while(item)
603       {
604          if(item.left)
605          {
606             BT left = item.left;
607             item.left = null;
608             item = left;
609          }
610          else if(item.right)
611          {
612             BT right = item.right;
613             item.right = null;
614             item = right;
615          }
616          else
617          {
618             BT parent = item.parent;
619             FreeKey((BT)item);
620             delete item;
621             item = parent;
622          }
623       }
624       root = null;
625       count = 0;
626    }
627
628    IteratorPointer Find(BT value)
629    {
630       return (IteratorPointer)value;
631    }
632
633    BT GetAtPosition(const KT pos, bool create, bool * justAdded)
634    {
635       // TODO: FindEx / AddEx & create nodes if create is true?
636       return root ? root.Find(class(KT), pos) : null;
637    }
638 }