math/gen: fix ilogb, logb and modf in mp, change integer print fmt
[libc-test] / src / math / gen / gen.c
1 /*
2 ./gen can generate testcases using an mp lib
3 ./check can test an mp lib compared to the input
4
5 input format:
6 T.<rounding>.<inputs>.<outputs>.<outputerr>.<exceptflags>.
7 where . is a sequence of separators: " \t,(){}"
8 the T prefix and rounding mode are optional (default is RN),
9 so the following are all ok and equivalent input:
10
11  1 2.0 0.1 INEXACT
12  {RN, 1, 2.0, 0.1, INEXACT},
13  T(RN, 1, 2.0, 0.1, INEXACT)
14
15 for gen only rounding and inputs are required (the rest is discarded)
16
17 gen:
18         s = getline()
19         x = scan(s)
20         xy = mpfunc(x)
21         print(xy)
22 check:
23         s = getline()
24         xy = scan(s)
25         xy' = mpfunc(x)
26         check(xy, xy')
27 */
28
29 #include <stdlib.h>
30 #include <stdio.h>
31 #include <string.h>
32 #include "gen.h"
33
34 static int scan(const char *fmt, struct t *t, char *buf);
35 static int print(const char *fmt, struct t *t, char *buf, int n);
36
37 // TODO: many output, fmt->ulp
38 struct fun;
39 static int check(struct t *want, struct t *got, struct fun *f, float ulpthres);
40
41 struct fun {
42         char *name;
43         int (*mpf)(struct t*);
44         char *fmt;
45 } fun[] = {
46 #define T(f,t) {#f, mp##f, #t},
47 #include "functions.h"
48 #undef T
49 };
50
51 int main(int argc, char *argv[])
52 {
53         char buf[512];
54         char *p;
55         int checkmode;
56         int i;
57         struct t t;
58         struct t tread;
59         struct fun *f = 0;
60
61         if (argc < 2) {
62                 fprintf(stderr, "%s func\n", argv[0]);
63                 return 1;
64         }
65         p = strrchr(argv[0], '/');
66         if (!p)
67                 p = argv[0];
68         else
69                 p++;
70         checkmode = strcmp(p, "check") == 0;
71         for (i = 0; i < sizeof fun/sizeof *fun; i++)
72                 if (strcmp(fun[i].name, argv[1]) == 0) {
73                         f = fun + i;
74                         break;
75                 }
76         if (f == 0) {
77                 fprintf(stderr, "unknown func: %s\n", argv[1]);
78                 return 1;
79         }
80         for (i = 1; fgets(buf, sizeof buf, stdin); i++) {
81                 dropcomm(buf);
82                 if (*buf == 0 || *buf == '\n')
83                         continue;
84                 memset(&t, 0, sizeof t);
85                 if (scan(f->fmt, &t, buf))
86                         fprintf(stderr, "error scan %s, line %d\n", f->name, i);
87                 tread = t;
88                 if (f->mpf(&t))
89                         fprintf(stderr, "error mpf %s, line %d\n", f->name, i);
90                 if (checkmode) {
91                         if (check(&tread, &t, f, 1.0)) {
92                                 print(f->fmt, &tread, buf, sizeof buf);
93                                 fputs(buf, stdout);
94                                 print(f->fmt, &t, buf, sizeof buf);
95                                 fputs(buf, stdout);
96                         }
97                 } else {
98                         if (print(f->fmt, &t, buf, sizeof buf))
99                                 fprintf(stderr, "error fmt %s, line %d\n", f->name, i);
100                         fputs(buf, stdout);
101                 }
102         }
103         return 0;
104 }
105
106 static int check(struct t *want, struct t *got, struct fun *f, float ulpthres)
107 {
108         int err = 0;
109         int m = INEXACT|UNDERFLOW; // TODO: dont check inexact and underflow for now
110
111         if ((got->e|m) != (want->e|m)) {
112                 fprintf(stdout, "%s %s(%La,%La)==%La except: want %s",
113                         rstr(want->r), f->name, want->x, want->x2, want->y, estr(want->e));
114                 fprintf(stdout, " got %s\n", estr(got->e));
115                 err++;
116         }
117         if (isnan(got->y) && isnan(want->y))
118                 return err;
119         if (got->y != want->y || signbit(got->y) != signbit(want->y)) {
120                 char *p;
121                 int n;
122                 float d;
123
124                 p = strchr(f->fmt, '_');
125                 if (!p)
126                         return -1;
127                 p++;
128                 if (*p == 'd')
129                         n = eulp(want->y);
130                 else if (*p == 'f')
131                         n = eulpf(want->y);
132                 else if (*p == 'l')
133                         n = eulpl(want->y);
134                 else
135                         return -1;
136
137                 d = scalbnl(got->y - want->y, -n);
138                 if (fabsf(d + want->dy) <= ulpthres)
139                         return err;
140                 fprintf(stdout, "%s %s(%La,%La) want %La got %La ulperr %.3f = %a + %a\n",
141                         rstr(want->r), f->name, want->x, want->x2, want->y, got->y, d + want->dy, d, want->dy);
142                 err++;
143         }
144         return err;
145 }
146
147 // scan discards suffixes, this may cause rounding issues (eg scanning 0.1f as long double)
148 static int scan1(long double *x, char *s, int fmt)
149 {
150         double d;
151         float f;
152
153         if (fmt == 'd') {
154                 if (sscanf(s, "%lf", &d) != 1)
155                         return -1;
156                 *x = d;
157         } else if (fmt == 'f') {
158                 if (sscanf(s, "%f", &f) != 1)
159                         return -1;
160                 *x = f;
161         } else if (fmt == 'l') {
162                 return sscanf(s, "%Lf", x) != 1;
163         } else
164                 return -1;
165         return 0;
166 }
167
168 static int scan(const char *fmt, struct t *t, char *buf)
169 {
170         char *a[20];
171         long double *b[4];
172         long double dy, dy2;
173         char *end;
174         int n, i=0, j=0;
175
176         buf = skipstr(buf, "T \t\r\n,(){}");
177         n = splitstr(a, sizeof a/sizeof *a, buf, " \t\r\n,(){}");
178         if (n <= 0)
179                 return -1;
180         if (a[0][0] == 'R') {
181                 if (rconv(&t->r, a[i++]))
182                         return -1;
183         } else
184                 t->r = RN;
185
186         b[0] = &t->x;
187         b[1] = &t->x2;
188         b[2] = &t->x3;
189         b[3] = 0;
190         for (; *fmt && *fmt != '_'; fmt++) {
191                 if (i >= n)
192                         return -1;
193                 if (*fmt == 'i') {
194                         t->i = strtoll(a[i++], &end, 0);
195                         if (*end)
196                                 return -1;
197                 } else if (*fmt == 'd' || *fmt == 'f' || *fmt == 'l') {
198                         if (scan1(b[j++], a[i++], *fmt))
199                                 return -1;
200                 } else
201                         return -1;
202         }
203
204         b[0] = &t->y;
205         b[1] = &dy;
206         b[2] = &t->y2;
207         b[3] = &dy2;
208         j = 0;
209         fmt++;
210         for (; *fmt && i < n && j < sizeof b/sizeof *b; fmt++) {
211                 if (*fmt == 'i') {
212                         t->i = strtoll(a[i++], &end, 0);
213                         if (*end)
214                                 return -1;
215                 } else if (*fmt == 'd' || *fmt == 'f' || *fmt == 'l') {
216                         if (scan1(b[j++], a[i++], *fmt))
217                                 return -1;
218                         if (i < n && scan1(b[j++], a[i++], 'f'))
219                                 return -1;
220                 } else
221                         return -1;
222         }
223         t->dy = dy;
224         t->dy2 = dy2;
225         if (i < n)
226                 econv(&t->e, a[i]);
227         return 0;
228 }
229
230 /* assume strlen(old) == strlen(new) */
231 static void replace(char *buf, char *old, char *new)
232 {
233         int n = strlen(new);
234         char *p = buf;
235
236         while ((p = strstr(p, old)))
237                 memcpy(p, new, n);
238 }
239
240 static void fixl(char *buf)
241 {
242         replace(buf, "-infL", " -inf");
243         replace(buf, "infL", " inf");
244         replace(buf, "-nanL", " -nan");
245         replace(buf, "nanL", " nan");
246 }
247
248 static int print1(char *buf, int n, long double x, int fmt)
249 {
250         int k;
251
252         if (fmt == 'd')
253                 k = snprintf(buf, n, ",%24a", (double)x);
254         else if (fmt == 'f')
255                 k = snprintf(buf, n, ",%16a", (double)x);
256         else if (fmt == 'l') {
257 #if LDBL_MANT_DIG == 53
258                 k = snprintf(buf, n, ",%24a", (double)x);
259 #elif LDBL_MANT_DIG == 64
260                 k = snprintf(buf, n, ",%30LaL", x);
261                 fixl(buf);
262 #endif
263         } else
264                 k = -1;
265         return k;
266 }
267
268 static int print(const char *fmt, struct t *t, char *buf, int n)
269 {
270         long double a[4];
271         int k, i=0, out=0;
272
273         k = snprintf(buf, n, "T(%s", rstr(t->r));
274         if (k < 0 || k >= n)
275                 return -1;
276         n -= k;
277         buf += k;
278
279         a[0] = t->x;
280         a[1] = t->x2;
281         a[2] = t->x3;
282         for (; *fmt; fmt++) {
283                 if (*fmt == '_') {
284                         a[0] = t->y;
285                         a[1] = t->dy;
286                         a[2] = t->y2;
287                         a[3] = t->dy2;
288                         i = 0;
289                         out = 1;
290                         continue;
291                 }
292                 if (*fmt == 'i') {
293                         k = snprintf(buf, n, ", %11lld", t->i);
294                         if (k < 0 || k >= n)
295                                 return -1;
296                         n -= k;
297                         buf += k;
298                 } else {
299                         if (i >= sizeof a/sizeof *a)
300                                 return -1;
301                         k = print1(buf, n, a[i++], *fmt);
302                         if (k < 0 || k >= n)
303                                 return -1;
304                         n -= k;
305                         buf += k;
306                         if (out) {
307                                 k = print1(buf, n, a[i++], 'f');
308                                 if (k < 0 || k >= n)
309                                         return -1;
310                                 n -= k;
311                                 buf += k;
312                         }
313                 }
314         }
315         k = snprintf(buf, n, ", %s)\n", estr(t->e));
316         if (k < 0 || k >= n)
317                 return -1;
318         return 0;
319 }