XSI search.h API implementation by Szabolcs Nagy
[musl] / src / search / tsearch_avl.c
1 #include <stdlib.h>
2 #include <search.h>
3
4 struct node {
5         const void *key;
6         struct node *left;
7         struct node *right;
8         int height;
9 };
10
11 static int delta(struct node *n) {
12         return (n->left ? n->left->height:0) - (n->right ? n->right->height:0);
13 }
14
15 static void updateheight(struct node *n) {
16         n->height = 0;
17         if (n->left && n->left->height > n->height)
18                 n->height = n->left->height;
19         if (n->right && n->right->height > n->height)
20                 n->height = n->right->height;
21         n->height++;
22 }
23
24 static struct node *rotl(struct node *n) {
25         struct node *r = n->right;
26         n->right = r->left;
27         r->left = n;
28         updateheight(n);
29         updateheight(r);
30         return r;
31 }
32
33 static struct node *rotr(struct node *n) {
34         struct node *l = n->left;
35         n->left = l->right;
36         l->right = n;
37         updateheight(n);
38         updateheight(l);
39         return l;
40 }
41
42 static struct node *balance(struct node *n) {
43         int d = delta(n);
44
45         if (d < -1) {
46                 if (delta(n->right) > 0)
47                         n->right = rotr(n->right);
48                 return rotl(n);
49         } else if (d > 1) {
50                 if (delta(n->left) < 0)
51                         n->left = rotl(n->left);
52                 return rotr(n);
53         }
54         updateheight(n);
55         return n;
56 }
57
58 static struct node *find(struct node *n, const void *k,
59         int (*cmp)(const void *, const void *))
60 {
61         int c;
62
63         if (!n)
64                 return 0;
65         c = cmp(k, n->key);
66         if (c == 0)
67                 return n;
68         if (c < 0)
69                 return find(n->left, k, cmp);
70         else
71                 return find(n->right, k, cmp);
72 }
73
74 static struct node *insert(struct node **n, const void *k,
75         int (*cmp)(const void *, const void *), int *new)
76 {
77         struct node *r = *n;
78         int c;
79
80         if (!r) {
81                 *n = r = malloc(sizeof **n);
82                 if (r) {
83                         r->key = k;
84                         r->left = r->right = 0;
85                         r->height = 1;
86                 }
87                 *new = 1;
88                 return r;
89         }
90         c = cmp(k, r->key);
91         if (c == 0)
92                 return r;
93         if (c < 0)
94                 r = insert(&r->left, k, cmp, new);
95         else
96                 r = insert(&r->right, k, cmp, new);
97         if (*new)
98                 *n = balance(*n);
99         return r;
100 }
101
102 static struct node *movr(struct node *n, struct node *r) {
103         if (!n)
104                 return r;
105         n->right = movr(n->right, r);
106         return balance(n);
107 }
108
109 static struct node *remove(struct node **n, const void *k,
110         int (*cmp)(const void *, const void *), struct node *parent)
111 {
112         int c;
113
114         if (!*n)
115                 return 0;
116         c = cmp(k, (*n)->key);
117         if (c == 0) {
118                 struct node *r = *n;
119                 *n = movr(r->left, r->right);
120                 free(r);
121                 return parent;
122         }
123         if (c < 0)
124                 parent = remove(&(*n)->left, k, cmp, *n);
125         else
126                 parent = remove(&(*n)->right, k, cmp, *n);
127         if (parent)
128                 *n = balance(*n);
129         return parent;
130 }
131
132 void *tdelete(const void *restrict key, void **restrict rootp,
133         int(*compar)(const void *, const void *))
134 {
135         /* last argument is arbitrary non-null pointer
136            which is returned when the root node is deleted */
137         return remove((void*)rootp, key, compar, *rootp);
138 }
139
140 void *tfind(const void *key, void *const *rootp,
141         int(*compar)(const void *, const void *))
142 {
143         return find(*rootp, key, compar);
144 }
145
146 void *tsearch(const void *key, void **rootp,
147         int (*compar)(const void *, const void *))
148 {
149         int new = 0;
150         return insert((void*)rootp, key, compar, &new);
151 }
152
153 static void walk(const struct node *r, void (*action)(const void *, VISIT, int), int d)
154 {
155         if (r == 0)
156                 return;
157         if (r->left == 0 && r->right == 0)
158                 action(r, leaf, d);
159         else {
160                 action(r, preorder, d);
161                 walk(r->left, action, d+1);
162                 action(r, postorder, d);
163                 walk(r->right, action, d+1);
164                 action(r, endorder, d);
165         }
166 }
167
168 void twalk(const void *root, void (*action)(const void *, VISIT, int))
169 {
170         walk(root, action, 0);
171 }