fix wrong return value from wmemmove on forward copies
[musl] / src / string / memmem.c
1 #define _GNU_SOURCE
2 #include <string.h>
3 #include <stdlib.h>
4 #include <stdint.h>
5
6 static char *twobyte_memmem(const unsigned char *h, size_t k, const unsigned char *n)
7 {
8         uint16_t nw = n[0]<<8 | n[1], hw = h[0]<<8 | h[1];
9         for (h++, k--; k; k--, hw = hw<<8 | *++h)
10                 if (hw == nw) return (char *)h-1;
11         return 0;
12 }
13
14 static char *threebyte_memmem(const unsigned char *h, size_t k, const unsigned char *n)
15 {
16         uint32_t nw = n[0]<<24 | n[1]<<16 | n[2]<<8;
17         uint32_t hw = h[0]<<24 | h[1]<<16 | h[2]<<8;
18         for (h+=2, k-=2; k; k--, hw = (hw|*++h)<<8)
19                 if (hw == nw) return (char *)h-2;
20         return 0;
21 }
22
23 static char *fourbyte_memmem(const unsigned char *h, size_t k, const unsigned char *n)
24 {
25         uint32_t nw = n[0]<<24 | n[1]<<16 | n[2]<<8 | n[3];
26         uint32_t hw = h[0]<<24 | h[1]<<16 | h[2]<<8 | h[3];
27         for (h+=3, k-=3; k; k--, hw = hw<<8 | *++h)
28                 if (hw == nw) return (char *)h-3;
29         return 0;
30 }
31
32 #define MAX(a,b) ((a)>(b)?(a):(b))
33 #define MIN(a,b) ((a)<(b)?(a):(b))
34
35 #define BITOP(a,b,op) \
36  ((a)[(size_t)(b)/(8*sizeof *(a))] op (size_t)1<<((size_t)(b)%(8*sizeof *(a))))
37
38 static char *twoway_memmem(const unsigned char *h, const unsigned char *z, const unsigned char *n, size_t l)
39 {
40         size_t i, ip, jp, k, p, ms, p0, mem, mem0;
41         size_t byteset[32 / sizeof(size_t)] = { 0 };
42         size_t shift[256];
43
44         /* Computing length of needle and fill shift table */
45         for (i=0; i<l; i++)
46                 BITOP(byteset, n[i], |=), shift[n[i]] = i+1;
47
48         /* Compute maximal suffix */
49         ip = -1; jp = 0; k = p = 1;
50         while (jp+k<l) {
51                 if (n[ip+k] == n[jp+k]) {
52                         if (k == p) {
53                                 jp += p;
54                                 k = 1;
55                         } else k++;
56                 } else if (n[ip+k] > n[jp+k]) {
57                         jp += k;
58                         k = 1;
59                         p = jp - ip;
60                 } else {
61                         ip = jp++;
62                         k = p = 1;
63                 }
64         }
65         ms = ip;
66         p0 = p;
67
68         /* And with the opposite comparison */
69         ip = -1; jp = 0; k = p = 1;
70         while (jp+k<l) {
71                 if (n[ip+k] == n[jp+k]) {
72                         if (k == p) {
73                                 jp += p;
74                                 k = 1;
75                         } else k++;
76                 } else if (n[ip+k] < n[jp+k]) {
77                         jp += k;
78                         k = 1;
79                         p = jp - ip;
80                 } else {
81                         ip = jp++;
82                         k = p = 1;
83                 }
84         }
85         if (ip+1 > ms+1) ms = ip;
86         else p = p0;
87
88         /* Periodic needle? */
89         if (memcmp(n, n+p, ms+1)) {
90                 mem0 = 0;
91                 p = MAX(ms, l-ms-1) + 1;
92         } else mem0 = l-p;
93         mem = 0;
94
95         /* Search loop */
96         for (;;) {
97                 /* If remainder of haystack is shorter than needle, done */
98                 if (z-h < l) return 0;
99
100                 /* Check last byte first; advance by shift on mismatch */
101                 if (BITOP(byteset, h[l-1], &)) {
102                         k = l-shift[h[l-1]];
103                         if (k) {
104                                 if (mem0 && mem && k < p) k = l-p;
105                                 h += k;
106                                 mem = 0;
107                                 continue;
108                         }
109                 } else {
110                         h += l;
111                         mem = 0;
112                         continue;
113                 }
114
115                 /* Compare right half */
116                 for (k=MAX(ms+1,mem); n[k] && n[k] == h[k]; k++);
117                 if (n[k]) {
118                         h += k-ms;
119                         mem = 0;
120                         continue;
121                 }
122                 /* Compare left half */
123                 for (k=ms+1; k>mem && n[k-1] == h[k-1]; k--);
124                 if (k == mem) return (char *)h;
125                 h += p;
126                 mem = mem0;
127         }
128 }
129
130 void *memmem(const void *h0, size_t k, const void *n0, size_t l)
131 {
132         const unsigned char *h = h0, *n = n0;
133
134         /* Return immediately on empty needle */
135         if (!l) return (void *)h;
136
137         /* Return immediately when needle is longer than haystack */
138         if (k<l) return 0;
139
140         /* Use faster algorithms for short needles */
141         h = memchr(h0, *n, k);
142         if (!h || l==1) return (void *)h;
143         if (l==2) return twobyte_memmem(h, k, n);
144         if (l==3) return threebyte_memmem(h, k, n);
145         if (l==4) return fourbyte_memmem(h, k, n);
146
147         return twoway_memmem(h, h+k, n, l);
148 }