source: trunk/src/lol/base/avl_tree.h @ 3842

Last change on this file since 3842 was 3842, checked in by guite, 7 years ago

map: starting bug fix

  • Property svn:eol-style set to LF
File size: 14.3 KB
Line 
1//
2//  Lol Engine
3//
4//  Copyright © 2010-2015 Sam Hocevar <sam@hocevar.net>
5//            © 2013-2015 Benjamin "Touky" Huet <huet.benjamin@gmail.com>
6//            © 2013-2015 Guillaume Bittoun <guillaume.bittoun@gmail.com>
7//
8//  This program is free software. It comes without any warranty, to
9//  the extent permitted by applicable law. You can redistribute it
10//  and/or modify it under the terms of the Do What the Fuck You Want
11//  to Public License, Version 2, as published by the WTFPL Task Force.
12//  See http://www.wtfpl.net/ for more details.
13//
14
15#pragma once
16
17namespace lol
18{
19
20#include <lol/base/all.h>
21
22template<typename K, typename V>
23class avl_tree
24{
25public:
26    avl_tree() :
27        m_root(nullptr),
28        m_count(0)
29    {
30    }
31
32    avl_tree(avl_tree const & other) :
33        m_root(nullptr),
34        m_count(0)
35    {
36        for (auto it : other)
37            insert(it.key, it.value);
38    }
39
40    avl_tree & operator=(avl_tree const & other)
41    {
42        if (&other != this)
43        {
44            clear();
45
46            for (auto it : other)
47                insert(it.key, it.value);
48        }
49
50        return *this;
51    }
52
53    ~avl_tree()
54    {
55        clear();
56    }
57
58    bool insert(K const & key, V const & value)
59    {
60        if (!m_root)
61        {
62            m_root = new tree_node(key, value, &m_root);
63            ++m_count;
64            return true;
65        }
66
67        if (m_root->insert(key, value))
68        {
69            ++m_count;
70            return true;
71        }
72
73        return false;
74    }
75
76    bool erase(K const & key)
77    {
78        if (!m_root)
79            return false;
80
81        if (m_root->erase(key))
82        {
83            --m_count;
84            return true;
85        }
86
87        return false;
88    }
89
90    bool exists(K const & key)
91    {
92        if (!m_root)
93            return false;
94
95        return m_root->exists(key);
96    }
97
98    void clear()
99    {
100        if (m_root)
101        {
102            tree_node * node = nullptr;
103            m_root->get_min(node);
104
105            while (node)
106            {
107                tree_node * next = node->get_next();
108                delete node;
109                node = next;
110            }
111        }
112
113        m_root = nullptr;
114        m_count = 0;
115    }
116
117    bool try_get(K const & key, V * & value_ptr) const
118    {
119        if (m_root)
120            return m_root->try_get(key, value_ptr);
121
122        return false;
123    }
124
125    bool try_get_min(K const * & key_ptr, V * & value_ptr) const
126    {
127        tree_node * min_node = nullptr;
128
129        if (m_root)
130        {
131            m_root->get_min(min_node);
132            key_ptr = &min_node->get_key();
133            value_ptr = &min_node->get_value();
134
135            return true;
136        }
137
138        return false;
139    }
140
141    bool try_get_max(K const * & key_ptr, V * & value_ptr) const
142    {
143        tree_node * max_node = nullptr;
144
145        if (m_root)
146        {
147            m_root->get_max(max_node);
148            key_ptr = &max_node->get_key();
149            value_ptr = &max_node->get_value();
150
151            return true;
152        }
153
154        return false;
155    }
156
157    class iterator;
158    class const_iterator;
159
160    iterator begin()
161    {
162        tree_node * node = nullptr;
163
164        if (m_root)
165            m_root->get_min(node);
166
167        return iterator(node);
168    }
169
170    const_iterator begin() const
171    {
172        tree_node * node = nullptr;
173
174        if (m_root)
175            m_root->get_min(node);
176
177        return const_iterator(node);
178    }
179
180    int count() const
181    {
182        return m_count;
183    }
184
185    iterator end()
186    {
187        return iterator(nullptr);
188    }
189
190    const_iterator end() const
191    {
192        return const_iterator(nullptr);
193    }
194
195protected:
196
197    class tree_node
198    {
199    public:
200        tree_node(K key, V value, tree_node ** parent_slot) :
201            m_key(key),
202            m_value(value),
203            m_parent_slot(parent_slot)
204        {
205            m_child[0] = m_child[1] = nullptr;
206            m_stairs[0] = m_stairs[1] = 0;
207            m_chain[0] = m_chain[1] = nullptr;
208        }
209
210        K const & get_key()
211        {
212            return m_key;
213        }
214
215        V & get_value()
216        {
217            return m_value;
218        }
219
220        /* Insert a value in tree and return true or update an existing value for
221         * the existing key and return false */
222        bool insert(K const & key, V const & value)
223        {
224            int i = -1 + (key < m_key) + 2 * (m_key < key);
225
226            bool created = false;
227
228            if (i < 0)
229                m_value = value;
230            else if (m_child[i])
231                created = m_child[i]->insert(key, value);
232            else
233            {
234                created = true;
235
236                m_child[i] = new tree_node(key, value, &m_child[i]);
237
238                m_child[i]->m_chain[i] = m_chain[i];
239                m_child[i]->m_chain[i ? 0 : 1] = this;
240
241                if (m_chain[i])
242                    m_chain[i]->m_chain[i ? 0 : 1] = m_child[i];
243                m_chain[i] = m_child[i];
244            }
245
246            if (created)
247            {
248                rebalance_if_needed();
249            }
250
251
252            return created;
253        }
254
255        /* Erase a value in tree and return true or return false */
256        bool erase(K const & key)
257        {
258            int i = -1 + (key < m_key) + 2 * (m_key < key);
259
260            bool erased = false;
261
262            if (i < 0)
263            {
264                erase_self();
265                delete this;
266                erased = true;
267            }
268            else if (m_child[i] && m_child[i]->erase(key))
269            {
270                rebalance_if_needed();
271                erased = true;
272            }
273
274            return erased;
275        }
276
277        bool try_get(K const & key, V * & value_ptr)
278        {
279            int i = -1 + (key < m_key) + 2 * (m_key < key);
280
281            if (i < 0)
282            {
283                value_ptr = &m_value;
284                return true;
285            }
286
287            if (m_child[i])
288                return m_child[i]->try_get(key, value_ptr);
289
290            return false;
291        }
292
293        bool exists(K const & key)
294        {
295            int i = -1 + (key < m_key) + 2 * (m_key < key);
296
297            if (i < 0)
298                return true;
299
300            if (m_child[i])
301                return m_child[i]->exists(key);
302
303            return false;
304        }
305
306        void get_min(tree_node * & min_node)
307        {
308            min_node = this;
309
310            while (min_node->m_child[0])
311                min_node = min_node->m_child[0];
312        }
313
314        void get_max(tree_node * & max_node) const
315        {
316            max_node = this;
317
318            while (max_node->m_child[1])
319                max_node = max_node->m_child[1];
320        }
321
322        int get_balance() const
323        {
324            return m_stairs[1] - m_stairs[0];
325        }
326
327        tree_node * get_previous() const
328        {
329            return m_chain[0];
330        }
331
332        tree_node * get_next() const
333        {
334            return m_chain[1];
335        }
336
337    protected:
338
339        void update_balance()
340        {
341            m_stairs[0] = m_child[0] ? (m_child[0]->m_stairs[0] > m_child[0]->m_stairs[1] ? m_child[0]->m_stairs[0] : m_child[0]->m_stairs[1]) + 1 : 0;
342            m_stairs[1] = m_child[1] ? (m_child[1]->m_stairs[0] > m_child[1]->m_stairs[1] ? m_child[1]->m_stairs[0] : m_child[1]->m_stairs[1]) + 1 : 0;
343        }
344
345        void rebalance_if_needed()
346        {
347            update_balance();
348
349            int i = -1 + (get_balance() == -2) + 2 * (get_balance() == 2);
350
351            if (i != -1)
352            {
353                tree_node * replacement = nullptr;
354
355
356                if (get_balance() / 2 + m_child[i]->get_balance() == 0)
357                {
358                    replacement = m_child[i]->m_child[1 - i];
359                    tree_node * save0 = replacement->m_child[i];
360                    tree_node * save1 = replacement->m_child[1 - i];
361
362                    replacement->m_parent_slot = this->m_parent_slot;
363                    *replacement->m_parent_slot = replacement;
364
365                    replacement->m_child[i] = m_child[i];
366                    m_child[i]->m_parent_slot = &replacement->m_child[i];
367
368                    replacement->m_child[1 - i] = this;
369                    this->m_parent_slot = &replacement->m_child[1 - i];
370
371                    replacement->m_child[i]->m_child[1 - i] = save0;
372                    if (save0)
373                        save0->m_parent_slot = &replacement->m_child[i]->m_child[1 - i];
374
375                    replacement->m_child[1 - i]->m_child[i] = save1;
376                    if (save1)
377                        save1->m_parent_slot = &replacement->m_child[i]->m_child[1 - i];
378                }
379                else
380                {
381
382                    replacement = m_child[i];
383                    tree_node * save = replacement->m_child[1 - i];
384
385                    replacement->m_parent_slot = this->m_parent_slot;
386                    *replacement->m_parent_slot = replacement;
387
388                    replacement->m_child[1 - i] = this;
389                    this->m_parent_slot = &replacement->m_child[1 - i];
390
391                    this->m_child[i] = save;
392                    if (save)
393                        save->m_parent_slot = &this->m_child[i];
394                }
395
396                replacement->m_child[0]->update_balance();
397                replacement->m_child[1]->update_balance();
398                replacement->update_balance();
399            }
400        }
401
402        void erase_self()
403        {
404            int i = (get_balance() == -1);
405
406            tree_node * replacement = m_child[1 - i];
407
408            if (replacement)
409            {
410                while (replacement->m_child[i])
411                    replacement = replacement->m_child[i];
412            }
413
414            if (replacement)
415            {
416                *replacement->m_parent_slot = replacement->m_child[1 - i];
417
418                replacement->m_parent_slot = m_parent_slot;
419                *replacement->m_parent_slot = replacement;
420
421                replacement->m_child[0] = m_child[0];
422                if (replacement->m_child[0])
423                    replacement->m_child[0]->m_parent_slot = &replacement->m_child[0];
424
425                replacement->m_child[1] = m_child[1];
426                if (replacement->m_child[1])
427                    replacement->m_child[1]->m_parent_slot = &replacement->m_child[1];
428
429                if (replacement->m_child[1-i])
430                    replacement->m_child[1-i]->deep_balance(replacement->m_key);
431
432                replacement->update_balance();
433            }
434            else
435                *m_parent_slot = nullptr;
436
437            replace_chain(replacement);
438        }
439
440        void deep_balance(K const & key)
441        {
442            int i = -1 + (key < m_key) + 2 * (m_key < key);
443
444            if (i != -1 && m_child[i])
445                m_child[i]->deep_balance(key);
446
447            update_balance();
448        }
449
450        void replace_chain(tree_node * replacement)
451        {
452            if (replacement)
453            {
454                if (replacement->m_chain[0])
455                    replacement->m_chain[0]->m_chain[1] = replacement->m_chain[1];
456
457                if (replacement->m_chain[1])
458                    replacement->m_chain[1]->m_chain[0] = replacement->m_chain[0];
459
460                replacement->m_chain[0] = m_chain[0];
461                replacement->m_chain[1] = m_chain[1];
462
463                if (replacement->m_chain[0])
464                    replacement->m_chain[0]->m_chain[1] = replacement;
465                if (replacement->m_chain[1])
466                    replacement->m_chain[1]->m_chain[0] = replacement;
467            }
468            else
469            {
470                if (m_chain[0])
471                    m_chain[0]->m_chain[1] = m_chain[1];
472                if (m_chain[1])
473                    m_chain[1]->m_chain[0] = m_chain[0];
474            }
475        }
476
477        K m_key;
478        V m_value;
479
480        tree_node *m_child[2];
481        int m_stairs[2];
482
483        tree_node ** m_parent_slot;
484
485        tree_node * m_chain[2]; // Linked list used to keep order between nodes
486    };
487
488public:
489
490    /* Iterators related */
491
492    struct output_value
493    {
494        output_value(K const & _key, V & _value) :
495            key(_key),
496            value(_value)
497        {
498        }
499
500        K const & key;
501        V & value;
502    };
503
504    class iterator
505    {
506    public:
507
508        iterator(tree_node * node) :
509            m_node(node)
510        {
511        }
512
513        iterator & operator++(int)
514        {
515            m_node = m_node->get_next();
516
517            return *this;
518        }
519
520        iterator & operator--(int)
521        {
522            m_node = m_node->get_previous();
523
524            return *this;
525        }
526
527        iterator operator++()
528        {
529            tree_node * ret = m_node;
530            m_node = m_node->get_next();
531
532            return iterator(ret);
533        }
534
535        iterator operator--()
536        {
537            tree_node * ret = m_node;
538            m_node = m_node->get_previous();
539
540            return iterator(ret);
541        }
542
543        output_value operator*()
544        {
545            return output_value(m_node->get_key(), m_node->get_value());
546        }
547
548        bool operator!=(iterator const & that) const
549        {
550            return m_node != that.m_node;
551        }
552
553    protected:
554
555        tree_node * m_node;
556    };
557
558    struct const_output_value
559    {
560        const_output_value(K const & _key, V const & _value) :
561            key(_key),
562            value(_value)
563        {
564        }
565
566        K const & key;
567        V const & value;
568    };
569
570    class const_iterator
571    {
572    public:
573
574        const_iterator(tree_node * node) :
575            m_node(node)
576        {
577        }
578
579        const_iterator & operator++(int)
580        {
581            m_node = m_node->get_next();
582
583            return *this;
584        }
585
586        const_iterator & operator--(int)
587        {
588            m_node = m_node->get_previous();
589
590            return *this;
591        }
592
593        const_iterator operator++()
594        {
595            tree_node * ret = m_node;
596            m_node = m_node->get_next();
597
598            return const_iterator(ret);
599        }
600
601        const_iterator operator--()
602        {
603            tree_node * ret = m_node;
604            m_node = m_node->get_previous();
605
606            return const_iterator(ret);
607        }
608
609        const_output_value operator*()
610        {
611            return const_output_value(m_node->get_key(), m_node->get_value());
612        }
613
614        bool operator!=(const_iterator const & that) const
615        {
616            return m_node != that.m_node;
617        }
618
619    protected:
620
621        tree_node * m_node;
622    };
623
624protected:
625
626    tree_node * m_root;
627
628    int m_count;
629};
630
631}
632
Note: See TracBrowser for help on using the repository browser.