regex memory corruption regression test
[libc-test] / src / math / gen / gen.c
1 /*
2 ./gen can generate testcases using an mp lib
3 ./check can test an mp lib compared to the input
4
5 input format:
6 T.<rounding>.<inputs>.<outputs>.<outputerr>.<exceptflags>.
7 where . is a sequence of separators: " \t,(){}"
8 the T prefix and rounding mode are optional (default is RN),
9 so the following are all ok and equivalent input:
10
11  1 2.0 0.1 INEXACT
12  {RN, 1, 2.0, 0.1, INEXACT},
13  T(RN, 1, 2.0, 0.1, INEXACT)
14
15 for gen only rounding and inputs are required (the rest is discarded)
16
17 gen:
18         s = getline()
19         x = scan(s)
20         xy = mpfunc(x)
21         print(xy)
22 check:
23         s = getline()
24         xy = scan(s)
25         xy' = mpfunc(x)
26         check(xy, xy')
27 */
28
29 #include <stdlib.h>
30 #include <stdio.h>
31 #include <string.h>
32 #include "gen.h"
33
34 static int scan(const char *fmt, struct t *t, char *buf);
35 static int print(const char *fmt, struct t *t, char *buf, int n);
36
37 // TODO: many output, fmt->ulp
38 struct fun;
39 static int check(struct t *want, struct t *got, struct fun *f, float ulpthres, float *abserr);
40
41 struct fun {
42         char *name;
43         int (*mpf)(struct t*);
44         char *fmt;
45 } fun[] = {
46 #define T(f,t) {#f, mp##f, #t},
47 #include "functions.h"
48 #undef T
49 };
50
51 int main(int argc, char *argv[])
52 {
53         char buf[512];
54         char *p;
55         int checkmode;
56         int i;
57         struct t t;
58         struct t tread;
59         struct fun *f = 0;
60         double ulpthres = 1.0;
61         float maxerr = 0;
62         float abserr;
63         struct t terr;
64
65         p = strrchr(argv[0], '/');
66         if (!p)
67                 p = argv[0];
68         else
69                 p++;
70         checkmode = strcmp(p, "check") == 0;
71         if (argc < 2) {
72                 fprintf(stderr, "%s func%s\n", argv[0], checkmode ? " ulpthres" : "");
73                 return 1;
74         }
75         if (argc > 2 && checkmode) {
76                 ulpthres = strtod(argv[2], &p);
77                 if (*p) {
78                         fprintf(stderr, "invalid ulperr %s\n", argv[2]);
79                         return 1;
80                 }
81         }
82         for (i = 0; i < sizeof fun/sizeof *fun; i++)
83                 if (strcmp(fun[i].name, argv[1]) == 0) {
84                         f = fun + i;
85                         break;
86                 }
87         if (f == 0) {
88                 fprintf(stderr, "unknown func: %s\n", argv[1]);
89                 return 1;
90         }
91         for (i = 1; fgets(buf, sizeof buf, stdin); i++) {
92                 dropcomm(buf);
93                 if (*buf == 0 || *buf == '\n')
94                         continue;
95                 memset(&t, 0, sizeof t);
96                 if (scan(f->fmt, &t, buf))
97                         fprintf(stderr, "error scan %s, line %d\n", f->name, i);
98                 tread = t;
99                 if (f->mpf(&t))
100                         fprintf(stderr, "error mpf %s, line %d\n", f->name, i);
101                 if (checkmode) {
102                         if (check(&tread, &t, f, ulpthres, &abserr)) {
103                                 print(f->fmt, &tread, buf, sizeof buf);
104                                 fputs(buf, stdout);
105 //                              print(f->fmt, &t, buf, sizeof buf);
106 //                              fputs(buf, stdout);
107                         }
108                         if (abserr > maxerr) {
109                                 maxerr = abserr;
110                                 terr = tread;
111                         }
112                 } else {
113                         if (print(f->fmt, &t, buf, sizeof buf))
114                                 fprintf(stderr, "error fmt %s, line %d\n", f->name, i);
115                         fputs(buf, stdout);
116                 }
117         }
118         if (checkmode && maxerr) {
119                 printf("// maxerr: %f, ", maxerr);
120                 print(f->fmt, &terr, buf, sizeof buf);
121                 fputs(buf, stdout);
122         }
123         return 0;
124 }
125
126 static int check(struct t *want, struct t *got, struct fun *f, float ulpthres, float *abserr)
127 {
128         int err = 0;
129         int m = INEXACT|UNDERFLOW; // TODO: dont check inexact and underflow for now
130
131         if ((got->e|m) != (want->e|m)) {
132                 fprintf(stdout, "//%s %s(%La,%La)==%La except: want %s",
133                         rstr(want->r), f->name, want->x, want->x2, want->y, estr(want->e));
134                 fprintf(stdout, " got %s\n", estr(got->e));
135                 err++;
136         }
137         if (isnan(got->y) && isnan(want->y))
138                 return err;
139         if (got->y != want->y || signbit(got->y) != signbit(want->y)) {
140                 char *p;
141                 int n;
142                 float d;
143
144                 p = strchr(f->fmt, '_');
145                 if (!p)
146                         return -1;
147                 p++;
148                 if (*p == 'd')
149                         n = eulp(want->y);
150                 else if (*p == 'f')
151                         n = eulpf(want->y);
152                 else if (*p == 'l')
153                         n = eulpl(want->y);
154                 else
155                         return -1;
156
157                 d = scalbnl(got->y - want->y, -n);
158                 *abserr = fabsf(d + want->dy);
159                 if (*abserr <= ulpthres)
160                         return err;
161                 fprintf(stdout, "//%s %s(%La,%La) want %La got %La ulperr %.3f = %a + %a\n",
162                         rstr(want->r), f->name, want->x, want->x2, want->y, got->y, d + want->dy, d, want->dy);
163                 err++;
164         }
165         return err;
166 }
167
168 // scan discards suffixes, this may cause rounding issues (eg scanning 0.1f as long double)
169 static int scan1(long double *x, char *s, int fmt)
170 {
171         double d;
172         float f;
173
174         if (fmt == 'd') {
175                 if (sscanf(s, "%lf", &d) != 1)
176                         return -1;
177                 *x = d;
178         } else if (fmt == 'f') {
179                 if (sscanf(s, "%f", &f) != 1)
180                         return -1;
181                 *x = f;
182         } else if (fmt == 'l') {
183                 return sscanf(s, "%Lf", x) != 1;
184         } else
185                 return -1;
186         return 0;
187 }
188
189 static int scan(const char *fmt, struct t *t, char *buf)
190 {
191         char *a[20];
192         long double *b[4];
193         long double dy, dy2;
194         char *end;
195         int n, i=0, j=0;
196
197         buf = skipstr(buf, "T \t\r\n,(){}");
198         n = splitstr(a, sizeof a/sizeof *a, buf, " \t\r\n,(){}");
199         if (n <= 0)
200                 return -1;
201         if (a[0][0] == 'R') {
202                 if (rconv(&t->r, a[i++]))
203                         return -1;
204         } else
205                 t->r = RN;
206
207         b[0] = &t->x;
208         b[1] = &t->x2;
209         b[2] = &t->x3;
210         b[3] = 0;
211         for (; *fmt && *fmt != '_'; fmt++) {
212                 if (i >= n)
213                         return -1;
214                 if (*fmt == 'i') {
215                         t->i = strtoll(a[i++], &end, 0);
216                         if (*end)
217                                 return -1;
218                 } else if (*fmt == 'd' || *fmt == 'f' || *fmt == 'l') {
219                         if (scan1(b[j++], a[i++], *fmt))
220                                 return -1;
221                 } else
222                         return -1;
223         }
224
225         b[0] = &t->y;
226         b[1] = &dy;
227         b[2] = &t->y2;
228         b[3] = &dy2;
229         j = 0;
230         fmt++;
231         for (; *fmt && i < n && j < sizeof b/sizeof *b; fmt++) {
232                 if (*fmt == 'i') {
233                         t->i = strtoll(a[i++], &end, 0);
234                         if (*end)
235                                 return -1;
236                 } else if (*fmt == 'd' || *fmt == 'f' || *fmt == 'l') {
237                         if (scan1(b[j++], a[i++], *fmt))
238                                 return -1;
239                         if (i < n && scan1(b[j++], a[i++], 'f'))
240                                 return -1;
241                 } else
242                         return -1;
243         }
244         t->dy = dy;
245         t->dy2 = dy2;
246         if (i < n)
247                 econv(&t->e, a[i]);
248         return 0;
249 }
250
251 /* assume strlen(old) == strlen(new) */
252 static void replace(char *buf, char *old, char *new)
253 {
254         int n = strlen(new);
255         char *p = buf;
256
257         while ((p = strstr(p, old)))
258                 memcpy(p, new, n);
259 }
260
261 static void fixl(char *buf)
262 {
263         replace(buf, "-infL", " -inf");
264         replace(buf, "infL", " inf");
265         replace(buf, "-nanL", " -nan");
266         replace(buf, "nanL", " nan");
267 }
268
269 static int print1(char *buf, int n, long double x, int fmt)
270 {
271         int k;
272
273         if (fmt == 'd')
274                 k = snprintf(buf, n, ",%24a", (double)x);
275         else if (fmt == 'f')
276                 k = snprintf(buf, n, ",%16a", (double)x);
277         else if (fmt == 'l') {
278 #if LDBL_MANT_DIG == 53
279                 k = snprintf(buf, n, ",%24a", (double)x);
280 #elif LDBL_MANT_DIG == 64
281                 k = snprintf(buf, n, ",%30LaL", x);
282                 fixl(buf);
283 #endif
284         } else
285                 k = -1;
286         return k;
287 }
288
289 static int print(const char *fmt, struct t *t, char *buf, int n)
290 {
291         long double a[4];
292         int k, i=0, out=0;
293
294         k = snprintf(buf, n, "T(%s", rstr(t->r));
295         if (k < 0 || k >= n)
296                 return -1;
297         n -= k;
298         buf += k;
299
300         a[0] = t->x;
301         a[1] = t->x2;
302         a[2] = t->x3;
303         for (; *fmt; fmt++) {
304                 if (*fmt == '_') {
305                         a[0] = t->y;
306                         a[1] = t->dy;
307                         a[2] = t->y2;
308                         a[3] = t->dy2;
309                         i = 0;
310                         out = 1;
311                         continue;
312                 }
313                 if (*fmt == 'i') {
314                         k = snprintf(buf, n, ", %11lld", t->i);
315                         if (k < 0 || k >= n)
316                                 return -1;
317                         n -= k;
318                         buf += k;
319                 } else {
320                         if (i >= sizeof a/sizeof *a)
321                                 return -1;
322                         k = print1(buf, n, a[i++], *fmt);
323                         if (k < 0 || k >= n)
324                                 return -1;
325                         n -= k;
326                         buf += k;
327                         if (out) {
328                                 k = print1(buf, n, a[i++], 'f');
329                                 if (k < 0 || k >= n)
330                                         return -1;
331                                 n -= k;
332                                 buf += k;
333                         }
334                 }
335         }
336         k = snprintf(buf, n, ", %s)\n", estr(t->e));
337         if (k < 0 || k >= n)
338                 return -1;
339         return 0;
340 }