source: trunk/src/lol/base/avl_tree.h @ 3847

Last change on this file since 3847 was 3847, checked in by guite, 7 years ago

map: more bug fixes (still not fully fixed…)

  • Property svn:eol-style set to LF
File size: 14.6 KB
Line 
1//
2//  Lol Engine
3//
4//  Copyright © 2010-2015 Sam Hocevar <sam@hocevar.net>
5//            © 2013-2015 Benjamin "Touky" Huet <huet.benjamin@gmail.com>
6//            © 2013-2015 Guillaume Bittoun <guillaume.bittoun@gmail.com>
7//
8//  This program is free software. It comes without any warranty, to
9//  the extent permitted by applicable law. You can redistribute it
10//  and/or modify it under the terms of the Do What the Fuck You Want
11//  to Public License, Version 2, as published by the WTFPL Task Force.
12//  See http://www.wtfpl.net/ for more details.
13//
14
15#pragma once
16
17namespace lol
18{
19
20#include <lol/base/all.h>
21
22template<typename K, typename V>
23class avl_tree
24{
25public:
26    avl_tree() :
27        m_root(nullptr),
28        m_count(0)
29    {
30    }
31
32    avl_tree(avl_tree const & other) :
33        m_root(nullptr),
34        m_count(0)
35    {
36        for (auto it : other)
37            insert(it.key, it.value);
38    }
39
40    avl_tree & operator=(avl_tree const & other)
41    {
42        if (&other != this)
43        {
44            clear();
45
46            for (auto it : other)
47                insert(it.key, it.value);
48        }
49
50        return *this;
51    }
52
53    ~avl_tree()
54    {
55        clear();
56    }
57
58    bool insert(K const & key, V const & value)
59    {
60        if (!m_root)
61        {
62            m_root = new tree_node(key, value, &m_root);
63            ++m_count;
64            return true;
65        }
66
67        if (m_root->insert(key, value))
68        {
69            ++m_count;
70            return true;
71        }
72
73        return false;
74    }
75
76    bool erase(K const & key)
77    {
78        if (!m_root)
79            return false;
80
81        if (m_root->erase(key))
82        {
83            --m_count;
84            return true;
85        }
86
87        return false;
88    }
89
90    bool exists(K const & key)
91    {
92        if (!m_root)
93            return false;
94
95        return m_root->exists(key);
96    }
97
98    void clear()
99    {
100        if (m_root)
101        {
102            tree_node * node = nullptr;
103            m_root->get_min(node);
104
105            while (node)
106            {
107                tree_node * next = node->get_next();
108                delete node;
109                node = next;
110            }
111        }
112
113        m_root = nullptr;
114        m_count = 0;
115    }
116
117    bool try_get(K const & key, V * & value_ptr) const
118    {
119        if (m_root)
120            return m_root->try_get(key, value_ptr);
121
122        return false;
123    }
124
125    bool try_get_min(K const * & key_ptr, V * & value_ptr) const
126    {
127        tree_node * min_node = nullptr;
128
129        if (m_root)
130        {
131            m_root->get_min(min_node);
132            key_ptr = &min_node->get_key();
133            value_ptr = &min_node->get_value();
134
135            return true;
136        }
137
138        return false;
139    }
140
141    bool try_get_max(K const * & key_ptr, V * & value_ptr) const
142    {
143        tree_node * max_node = nullptr;
144
145        if (m_root)
146        {
147            m_root->get_max(max_node);
148            key_ptr = &max_node->get_key();
149            value_ptr = &max_node->get_value();
150
151            return true;
152        }
153
154        return false;
155    }
156
157    class iterator;
158    class const_iterator;
159
160    iterator begin()
161    {
162        tree_node * node = nullptr;
163
164        if (m_root)
165            m_root->get_min(node);
166
167        return iterator(node);
168    }
169
170    const_iterator begin() const
171    {
172        tree_node * node = nullptr;
173
174        if (m_root)
175            m_root->get_min(node);
176
177        return const_iterator(node);
178    }
179
180    int count() const
181    {
182        return m_count;
183    }
184
185    iterator end()
186    {
187        return iterator(nullptr);
188    }
189
190    const_iterator end() const
191    {
192        return const_iterator(nullptr);
193    }
194
195protected:
196
197    class tree_node
198    {
199    public:
200        tree_node(K key, V value, tree_node ** parent_slot) :
201            m_key(key),
202            m_value(value),
203            m_parent_slot(parent_slot)
204        {
205            m_child[0] = m_child[1] = nullptr;
206            m_stairs[0] = m_stairs[1] = 0;
207            m_chain[0] = m_chain[1] = nullptr;
208        }
209
210        K const & get_key()
211        {
212            return m_key;
213        }
214
215        V & get_value()
216        {
217            return m_value;
218        }
219
220        /* Insert a value in tree and return true or update an existing value for
221         * the existing key and return false */
222        bool insert(K const & key, V const & value)
223        {
224            int i = -1 + (key < m_key) + 2 * (m_key < key);
225
226            bool created = false;
227
228            if (i < 0)
229                m_value = value;
230            else if (m_child[i])
231                created = m_child[i]->insert(key, value);
232            else
233            {
234                created = true;
235
236                m_child[i] = new tree_node(key, value, &m_child[i]);
237
238                m_child[i]->m_chain[i] = m_chain[i];
239                m_child[i]->m_chain[i ? 0 : 1] = this;
240
241                if (m_chain[i])
242                    m_chain[i]->m_chain[i ? 0 : 1] = m_child[i];
243                m_chain[i] = m_child[i];
244            }
245
246            if (created)
247            {
248                rebalance_if_needed();
249            }
250
251            return created;
252        }
253
254        /* Erase a value in tree and return true or return false */
255        bool erase(K const & key)
256        {
257            int i = -1 + (key < m_key) + 2 * (m_key < key);
258
259            bool erased = false;
260
261            if (i < 0)
262            {
263                erase_self();
264                delete this;
265                erased = true;
266            }
267            else if (m_child[i] && m_child[i]->erase(key))
268            {
269                rebalance_if_needed();
270                erased = true;
271            }
272
273            return erased;
274        }
275
276        bool try_get(K const & key, V * & value_ptr)
277        {
278            int i = -1 + (key < m_key) + 2 * (m_key < key);
279
280            if (i < 0)
281            {
282                value_ptr = &m_value;
283                return true;
284            }
285
286            if (m_child[i])
287                return m_child[i]->try_get(key, value_ptr);
288
289            return false;
290        }
291
292        bool exists(K const & key)
293        {
294            int i = -1 + (key < m_key) + 2 * (m_key < key);
295
296            if (i < 0)
297                return true;
298
299            if (m_child[i])
300                return m_child[i]->exists(key);
301
302            return false;
303        }
304
305        void get_min(tree_node * & min_node)
306        {
307            min_node = this;
308
309            while (min_node->m_child[0])
310                min_node = min_node->m_child[0];
311        }
312
313        void get_max(tree_node * & max_node) const
314        {
315            max_node = this;
316
317            while (max_node->m_child[1])
318                max_node = max_node->m_child[1];
319        }
320
321        int get_balance() const
322        {
323            return m_stairs[1] - m_stairs[0];
324        }
325
326        tree_node * get_previous() const
327        {
328            return m_chain[0];
329        }
330
331        tree_node * get_next() const
332        {
333            return m_chain[1];
334        }
335
336    protected:
337
338        void update_balance()
339        {
340            m_stairs[0] = m_child[0] ? (m_child[0]->m_stairs[0] > m_child[0]->m_stairs[1] ? m_child[0]->m_stairs[0] : m_child[0]->m_stairs[1]) + 1 : 0;
341            m_stairs[1] = m_child[1] ? (m_child[1]->m_stairs[0] > m_child[1]->m_stairs[1] ? m_child[1]->m_stairs[0] : m_child[1]->m_stairs[1]) + 1 : 0;
342        }
343
344        void rebalance_if_needed()
345        {
346            update_balance();
347
348            int i = -1 + (get_balance() == -2) + 2 * (get_balance() == 2);
349
350            if (i != -1)
351            {
352                tree_node * replacement = nullptr;
353
354
355                if (get_balance() / 2 + m_child[i]->get_balance() == 0)
356                {
357                    replacement = m_child[i]->m_child[1 - i];
358                    tree_node * save0 = replacement->m_child[i];
359                    tree_node * save1 = replacement->m_child[1 - i];
360
361                    replacement->m_parent_slot = this->m_parent_slot;
362                    *replacement->m_parent_slot = replacement;
363
364                    replacement->m_child[i] = m_child[i];
365                    m_child[i]->m_parent_slot = &replacement->m_child[i];
366
367                    replacement->m_child[1 - i] = this;
368                    this->m_parent_slot = &replacement->m_child[1 - i];
369
370                    replacement->m_child[i]->m_child[1 - i] = save0;
371                    if (save0)
372                        save0->m_parent_slot = &replacement->m_child[i]->m_child[1 - i];
373
374                    replacement->m_child[1 - i]->m_child[i] = save1;
375                    if (save1)
376                        save1->m_parent_slot = &replacement->m_child[1 - i]->m_child[i];
377                }
378                else
379                {
380                    replacement = m_child[i];
381                    tree_node * save = replacement->m_child[1 - i];
382
383                    replacement->m_parent_slot = this->m_parent_slot;
384                    *replacement->m_parent_slot = replacement;
385
386                    replacement->m_child[1 - i] = this;
387                    this->m_parent_slot = &replacement->m_child[1 - i];
388
389                    this->m_child[i] = save;
390                    if (save)
391                        save->m_parent_slot = &this->m_child[i];
392                }
393
394                replacement->m_child[0]->update_balance();
395                replacement->m_child[1]->update_balance();
396                replacement->update_balance();
397            }
398        }
399
400        void erase_self()
401        {
402            int i = (get_balance() == -1);
403
404            tree_node * replacement = m_child[1 - i];
405
406            if (replacement)
407            {
408                while (replacement->m_child[i])
409                    replacement = replacement->m_child[i];
410            }
411
412            if (replacement)
413            {
414                *replacement->m_parent_slot = replacement->m_child[1 - i];
415                if (replacement->m_child[1 - i])
416                    replacement->m_child[1 - i]->m_parent_slot = replacement->m_parent_slot;
417
418                replacement->m_parent_slot = m_parent_slot;
419                *replacement->m_parent_slot = replacement;
420
421                replacement->m_child[0] = m_child[0];
422                if (replacement->m_child[0])
423                    replacement->m_child[0]->m_parent_slot = &replacement->m_child[0];
424
425                replacement->m_child[1] = m_child[1];
426                if (replacement->m_child[1])
427                    replacement->m_child[1]->m_parent_slot = &replacement->m_child[1];
428
429                if (replacement->m_child[1-i])
430                    replacement->m_child[1-i]->deep_balance(replacement->m_key);
431
432                replacement->update_balance();
433            }
434            else
435            {
436                *m_parent_slot = m_child[i];
437                if (m_child[i])
438                    m_child[i]->m_parent_slot = m_parent_slot;
439
440                replacement = m_child[i];
441            }
442
443            replace_chain(replacement);
444        }
445
446        void deep_balance(K const & key)
447        {
448            int i = -1 + (key < m_key) + 2 * (m_key < key);
449
450            if (i != -1 && m_child[i])
451                m_child[i]->deep_balance(key);
452
453            update_balance();
454        }
455
456        void replace_chain(tree_node * replacement)
457        {
458            if (replacement)
459            {
460                if (replacement->m_chain[0])
461                    replacement->m_chain[0]->m_chain[1] = replacement->m_chain[1];
462
463                if (replacement->m_chain[1])
464                    replacement->m_chain[1]->m_chain[0] = replacement->m_chain[0];
465
466                replacement->m_chain[0] = m_chain[0];
467                replacement->m_chain[1] = m_chain[1];
468
469                if (replacement->m_chain[0])
470                    replacement->m_chain[0]->m_chain[1] = replacement;
471                if (replacement->m_chain[1])
472                    replacement->m_chain[1]->m_chain[0] = replacement;
473            }
474            else
475            {
476                if (m_chain[0])
477                    m_chain[0]->m_chain[1] = m_chain[1];
478                if (m_chain[1])
479                    m_chain[1]->m_chain[0] = m_chain[0];
480            }
481        }
482
483        K m_key;
484        V m_value;
485
486        tree_node *m_child[2];
487        int m_stairs[2];
488
489        tree_node ** m_parent_slot;
490
491        tree_node * m_chain[2]; // Linked list used to keep order between nodes
492    };
493
494public:
495
496    /* Iterators related */
497
498    struct output_value
499    {
500        output_value(K const & _key, V & _value) :
501            key(_key),
502            value(_value)
503        {
504        }
505
506        K const & key;
507        V & value;
508    };
509
510    class iterator
511    {
512    public:
513
514        iterator(tree_node * node) :
515            m_node(node)
516        {
517        }
518
519        iterator & operator++(int)
520        {
521            m_node = m_node->get_next();
522
523            return *this;
524        }
525
526        iterator & operator--(int)
527        {
528            m_node = m_node->get_previous();
529
530            return *this;
531        }
532
533        iterator operator++()
534        {
535            tree_node * ret = m_node;
536            m_node = m_node->get_next();
537
538            return iterator(ret);
539        }
540
541        iterator operator--()
542        {
543            tree_node * ret = m_node;
544            m_node = m_node->get_previous();
545
546            return iterator(ret);
547        }
548
549        output_value operator*()
550        {
551            return output_value(m_node->get_key(), m_node->get_value());
552        }
553
554        bool operator!=(iterator const & that) const
555        {
556            return m_node != that.m_node;
557        }
558
559    protected:
560
561        tree_node * m_node;
562    };
563
564    struct const_output_value
565    {
566        const_output_value(K const & _key, V const & _value) :
567            key(_key),
568            value(_value)
569        {
570        }
571
572        K const & key;
573        V const & value;
574    };
575
576    class const_iterator
577    {
578    public:
579
580        const_iterator(tree_node * node) :
581            m_node(node)
582        {
583        }
584
585        const_iterator & operator++(int)
586        {
587            m_node = m_node->get_next();
588
589            return *this;
590        }
591
592        const_iterator & operator--(int)
593        {
594            m_node = m_node->get_previous();
595
596            return *this;
597        }
598
599        const_iterator operator++()
600        {
601            tree_node * ret = m_node;
602            m_node = m_node->get_next();
603
604            return const_iterator(ret);
605        }
606
607        const_iterator operator--()
608        {
609            tree_node * ret = m_node;
610            m_node = m_node->get_previous();
611
612            return const_iterator(ret);
613        }
614
615        const_output_value operator*()
616        {
617            return const_output_value(m_node->get_key(), m_node->get_value());
618        }
619
620        bool operator!=(const_iterator const & that) const
621        {
622            return m_node != that.m_node;
623        }
624
625    protected:
626
627        tree_node * m_node;
628    };
629
630protected:
631
632    tree_node * m_root;
633
634    int m_count;
635};
636
637}
638
Note: See TracBrowser for help on using the repository browser.