fix tsearch to avoid crash on oom
[musl] / src / search / tsearch_avl.c
1 #include <stdlib.h>
2 #include <search.h>
3
4 /*
5 avl tree implementation using recursive functions
6 the height of an n node tree is less than 1.44*log2(n+2)-1
7 (so the max recursion depth in case of a tree with 2^32 nodes is 45)
8 */
9
10 struct node {
11         const void *key;
12         struct node *left;
13         struct node *right;
14         int height;
15 };
16
17 static int delta(struct node *n) {
18         return (n->left ? n->left->height:0) - (n->right ? n->right->height:0);
19 }
20
21 static void updateheight(struct node *n) {
22         n->height = 0;
23         if (n->left && n->left->height > n->height)
24                 n->height = n->left->height;
25         if (n->right && n->right->height > n->height)
26                 n->height = n->right->height;
27         n->height++;
28 }
29
30 static struct node *rotl(struct node *n) {
31         struct node *r = n->right;
32         n->right = r->left;
33         r->left = n;
34         updateheight(n);
35         updateheight(r);
36         return r;
37 }
38
39 static struct node *rotr(struct node *n) {
40         struct node *l = n->left;
41         n->left = l->right;
42         l->right = n;
43         updateheight(n);
44         updateheight(l);
45         return l;
46 }
47
48 static struct node *balance(struct node *n) {
49         int d = delta(n);
50
51         if (d < -1) {
52                 if (delta(n->right) > 0)
53                         n->right = rotr(n->right);
54                 return rotl(n);
55         } else if (d > 1) {
56                 if (delta(n->left) < 0)
57                         n->left = rotl(n->left);
58                 return rotr(n);
59         }
60         updateheight(n);
61         return n;
62 }
63
64 static struct node *find(struct node *n, const void *k,
65         int (*cmp)(const void *, const void *))
66 {
67         int c;
68
69         if (!n)
70                 return 0;
71         c = cmp(k, n->key);
72         if (c == 0)
73                 return n;
74         if (c < 0)
75                 return find(n->left, k, cmp);
76         else
77                 return find(n->right, k, cmp);
78 }
79
80 static struct node *insert(struct node **n, const void *k,
81         int (*cmp)(const void *, const void *), int *new)
82 {
83         struct node *r = *n;
84         int c;
85
86         if (!r) {
87                 *n = r = malloc(sizeof **n);
88                 if (r) {
89                         r->key = k;
90                         r->left = r->right = 0;
91                         r->height = 1;
92                         *new = 1;
93                 }
94                 return r;
95         }
96         c = cmp(k, r->key);
97         if (c == 0)
98                 return r;
99         if (c < 0)
100                 r = insert(&r->left, k, cmp, new);
101         else
102                 r = insert(&r->right, k, cmp, new);
103         if (*new)
104                 *n = balance(*n);
105         return r;
106 }
107
108 static struct node *remove_rightmost(struct node *n, struct node **rightmost)
109 {
110         if (!n->right) {
111                 *rightmost = n;
112                 return n->left;
113         }
114         n->right = remove_rightmost(n->right, rightmost);
115         return balance(n);
116 }
117
118 static struct node *remove(struct node **n, const void *k,
119         int (*cmp)(const void *, const void *), struct node *parent)
120 {
121         int c;
122
123         if (!*n)
124                 return 0;
125         c = cmp(k, (*n)->key);
126         if (c == 0) {
127                 struct node *r = *n;
128                 if (r->left) {
129                         r->left = remove_rightmost(r->left, n);
130                         (*n)->left = r->left;
131                         (*n)->right = r->right;
132                         *n = balance(*n);
133                 } else
134                         *n = r->right;
135                 free(r);
136                 return parent;
137         }
138         if (c < 0)
139                 parent = remove(&(*n)->left, k, cmp, *n);
140         else
141                 parent = remove(&(*n)->right, k, cmp, *n);
142         if (parent)
143                 *n = balance(*n);
144         return parent;
145 }
146
147 void *tdelete(const void *restrict key, void **restrict rootp,
148         int(*compar)(const void *, const void *))
149 {
150         struct node *n = *rootp;
151         struct node *ret;
152         /* last argument is arbitrary non-null pointer
153            which is returned when the root node is deleted */
154         ret = remove(&n, key, compar, n);
155         *rootp = n;
156         return ret;
157 }
158
159 void *tfind(const void *key, void *const *rootp,
160         int(*compar)(const void *, const void *))
161 {
162         return find(*rootp, key, compar);
163 }
164
165 void *tsearch(const void *key, void **rootp,
166         int (*compar)(const void *, const void *))
167 {
168         int new = 0;
169         struct node *n = *rootp;
170         struct node *ret;
171         ret = insert(&n, key, compar, &new);
172         *rootp = n;
173         return ret;
174 }
175
176 static void walk(const struct node *r, void (*action)(const void *, VISIT, int), int d)
177 {
178         if (r == 0)
179                 return;
180         if (r->left == 0 && r->right == 0)
181                 action(r, leaf, d);
182         else {
183                 action(r, preorder, d);
184                 walk(r->left, action, d+1);
185                 action(r, postorder, d);
186                 walk(r->right, action, d+1);
187                 action(r, endorder, d);
188         }
189 }
190
191 void twalk(const void *root, void (*action)(const void *, VISIT, int))
192 {
193         walk(root, action, 0);
194 }