add fma implementation for x86
[musl] / src / math / fma.c
1 #include <fenv.h>
2 #include "libm.h"
3
4 #if LDBL_MANT_DIG==64 && LDBL_MAX_EXP==16384
5 union ld80 {
6         long double x;
7         struct {
8                 uint64_t m;
9                 uint16_t e : 15;
10                 uint16_t s : 1;
11                 uint16_t pad;
12         } bits;
13 };
14
15 /* exact add, assumes exponent_x >= exponent_y */
16 static void add(long double *hi, long double *lo, long double x, long double y)
17 {
18         long double r;
19
20         r = x + y;
21         *hi = r;
22         r -= x;
23         *lo = y - r;
24 }
25
26 /*
27 TODO(nsz): probably simpler mul is enough if we assume x and y are doubles
28 so last 11bits are all zeros, no subnormals etc
29 */
30 /* exact mul, assumes no over/underflow */
31 static void mul(long double *hi, long double *lo, long double x, long double y)
32 {
33         static const long double c = 1.0 + 0x1p32L;
34         long double cx, xh, xl, cy, yh, yl;
35
36         cx = c*x;
37         xh = (x - cx) + cx;
38         xl = x - xh;
39         cy = c*y;
40         yh = (y - cy) + cy;
41         yl = y - yh;
42         *hi = x*y;
43         *lo = (xh*yh - *hi) + xh*yl + xl*yh + xl*yl;
44 }
45
46 /*
47 assume (long double)(hi+lo) == hi
48 return an adjusted hi so that rounding it to double is correct
49 */
50 static long double adjust(long double hi, long double lo)
51 {
52         union ld80 uhi, ulo;
53
54         if (lo == 0)
55                 return hi;
56         uhi.x = hi;
57         if (uhi.bits.m & 0x3ff)
58                 return hi;
59         ulo.x = lo;
60         if (uhi.bits.s == ulo.bits.s)
61                 uhi.bits.m++;
62         else
63                 uhi.bits.m--;
64         return uhi.x;
65 }
66
67 static long double dadd(long double x, long double y)
68 {
69         add(&x, &y, x, y);
70         return adjust(x, y);
71 }
72
73 static long double dmul(long double x, long double y)
74 {
75         mul(&x, &y, x, y);
76         return adjust(x, y);
77 }
78
79 static int getexp(long double x)
80 {
81         union ld80 u;
82         u.x = x;
83         return u.bits.e;
84 }
85
86 double fma(double x, double y, double z)
87 {
88         long double hi, lo1, lo2, xy;
89         int round, ez, exy;
90
91         /* handle +-inf,nan */
92         if (!isfinite(x) || !isfinite(y))
93                 return x*y + z;
94         if (!isfinite(z))
95                 return z;
96         /* handle +-0 */
97         if (x == 0.0 || y == 0.0)
98                 return x*y + z;
99         round = fegetround();
100         if (z == 0.0) {
101                 if (round == FE_TONEAREST)
102                         return dmul(x, y);
103                 return x*y;
104         }
105
106         /* exact mul and add require nearest rounding */
107         /* spurious inexact exceptions may be raised */
108         fesetround(FE_TONEAREST);
109         mul(&xy, &lo1, x, y);
110         exy = getexp(xy);
111         ez = getexp(z);
112         if (ez > exy) {
113                 add(&hi, &lo2, z, xy);
114         } else if (ez > exy - 12) {
115                 add(&hi, &lo2, xy, z);
116                 if (hi == 0) {
117                         fesetround(round);
118                         /* TODO: verify that the sign of 0 is always correct */
119                         return (xy + z) + lo1;
120                 }
121         } else {
122                 /*
123                 ez <= exy - 12
124                 the 12 extra bits (1guard, 11round+sticky) are needed so with
125                         lo = dadd(lo1, lo2)
126                 elo <= ehi - 11, and we use the last 10 bits in adjust so
127                         dadd(hi, lo)
128                 gives correct result when rounded to double
129                 */
130                 hi = xy;
131                 lo2 = z;
132         }
133         fesetround(round);
134         if (round == FE_TONEAREST)
135                 return dadd(hi, dadd(lo1, lo2));
136         return hi + (lo1 + lo2);
137 }
138 #else
139 /* origin: FreeBSD /usr/src/lib/msun/src/s_fma.c */
140 /*-
141  * Copyright (c) 2005-2011 David Schultz <das@FreeBSD.ORG>
142  * All rights reserved.
143  *
144  * Redistribution and use in source and binary forms, with or without
145  * modification, are permitted provided that the following conditions
146  * are met:
147  * 1. Redistributions of source code must retain the above copyright
148  *    notice, this list of conditions and the following disclaimer.
149  * 2. Redistributions in binary form must reproduce the above copyright
150  *    notice, this list of conditions and the following disclaimer in the
151  *    documentation and/or other materials provided with the distribution.
152  *
153  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
154  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
155  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
156  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
157  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
158  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
159  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
160  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
161  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
162  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
163  * SUCH DAMAGE.
164  */
165
166 /*
167  * A struct dd represents a floating-point number with twice the precision
168  * of a double.  We maintain the invariant that "hi" stores the 53 high-order
169  * bits of the result.
170  */
171 struct dd {
172         double hi;
173         double lo;
174 };
175
176 /*
177  * Compute a+b exactly, returning the exact result in a struct dd.  We assume
178  * that both a and b are finite, but make no assumptions about their relative
179  * magnitudes.
180  */
181 static inline struct dd dd_add(double a, double b)
182 {
183         struct dd ret;
184         double s;
185
186         ret.hi = a + b;
187         s = ret.hi - a;
188         ret.lo = (a - (ret.hi - s)) + (b - s);
189         return (ret);
190 }
191
192 /*
193  * Compute a+b, with a small tweak:  The least significant bit of the
194  * result is adjusted into a sticky bit summarizing all the bits that
195  * were lost to rounding.  This adjustment negates the effects of double
196  * rounding when the result is added to another number with a higher
197  * exponent.  For an explanation of round and sticky bits, see any reference
198  * on FPU design, e.g.,
199  *
200  *     J. Coonen.  An Implementation Guide to a Proposed Standard for
201  *     Floating-Point Arithmetic.  Computer, vol. 13, no. 1, Jan 1980.
202  */
203 static inline double add_adjusted(double a, double b)
204 {
205         struct dd sum;
206         uint64_t hibits, lobits;
207
208         sum = dd_add(a, b);
209         if (sum.lo != 0) {
210                 EXTRACT_WORD64(hibits, sum.hi);
211                 if ((hibits & 1) == 0) {
212                         /* hibits += (int)copysign(1.0, sum.hi * sum.lo) */
213                         EXTRACT_WORD64(lobits, sum.lo);
214                         hibits += 1 - ((hibits ^ lobits) >> 62);
215                         INSERT_WORD64(sum.hi, hibits);
216                 }
217         }
218         return (sum.hi);
219 }
220
221 /*
222  * Compute ldexp(a+b, scale) with a single rounding error. It is assumed
223  * that the result will be subnormal, and care is taken to ensure that
224  * double rounding does not occur.
225  */
226 static inline double add_and_denormalize(double a, double b, int scale)
227 {
228         struct dd sum;
229         uint64_t hibits, lobits;
230         int bits_lost;
231
232         sum = dd_add(a, b);
233
234         /*
235          * If we are losing at least two bits of accuracy to denormalization,
236          * then the first lost bit becomes a round bit, and we adjust the
237          * lowest bit of sum.hi to make it a sticky bit summarizing all the
238          * bits in sum.lo. With the sticky bit adjusted, the hardware will
239          * break any ties in the correct direction.
240          *
241          * If we are losing only one bit to denormalization, however, we must
242          * break the ties manually.
243          */
244         if (sum.lo != 0) {
245                 EXTRACT_WORD64(hibits, sum.hi);
246                 bits_lost = -((int)(hibits >> 52) & 0x7ff) - scale + 1;
247                 if (bits_lost != 1 ^ (int)(hibits & 1)) {
248                         /* hibits += (int)copysign(1.0, sum.hi * sum.lo) */
249                         EXTRACT_WORD64(lobits, sum.lo);
250                         hibits += 1 - (((hibits ^ lobits) >> 62) & 2);
251                         INSERT_WORD64(sum.hi, hibits);
252                 }
253         }
254         return (ldexp(sum.hi, scale));
255 }
256
257 /*
258  * Compute a*b exactly, returning the exact result in a struct dd.  We assume
259  * that both a and b are normalized, so no underflow or overflow will occur.
260  * The current rounding mode must be round-to-nearest.
261  */
262 static inline struct dd dd_mul(double a, double b)
263 {
264         static const double split = 0x1p27 + 1.0;
265         struct dd ret;
266         double ha, hb, la, lb, p, q;
267
268         p = a * split;
269         ha = a - p;
270         ha += p;
271         la = a - ha;
272
273         p = b * split;
274         hb = b - p;
275         hb += p;
276         lb = b - hb;
277
278         p = ha * hb;
279         q = ha * lb + la * hb;
280
281         ret.hi = p + q;
282         ret.lo = p - ret.hi + q + la * lb;
283         return (ret);
284 }
285
286 /*
287  * Fused multiply-add: Compute x * y + z with a single rounding error.
288  *
289  * We use scaling to avoid overflow/underflow, along with the
290  * canonical precision-doubling technique adapted from:
291  *
292  *      Dekker, T.  A Floating-Point Technique for Extending the
293  *      Available Precision.  Numer. Math. 18, 224-242 (1971).
294  *
295  * This algorithm is sensitive to the rounding precision.  FPUs such
296  * as the i387 must be set in double-precision mode if variables are
297  * to be stored in FP registers in order to avoid incorrect results.
298  * This is the default on FreeBSD, but not on many other systems.
299  *
300  * Hardware instructions should be used on architectures that support it,
301  * since this implementation will likely be several times slower.
302  */
303 double fma(double x, double y, double z)
304 {
305         double xs, ys, zs, adj;
306         struct dd xy, r;
307         int oround;
308         int ex, ey, ez;
309         int spread;
310
311         /*
312          * Handle special cases. The order of operations and the particular
313          * return values here are crucial in handling special cases involving
314          * infinities, NaNs, overflows, and signed zeroes correctly.
315          */
316         if (!isfinite(x) || !isfinite(y))
317                 return (x * y + z);
318         if (!isfinite(z))
319                 return (z);
320         if (x == 0.0 || y == 0.0)
321                 return (x * y + z);
322         if (z == 0.0)
323                 return (x * y);
324
325         xs = frexp(x, &ex);
326         ys = frexp(y, &ey);
327         zs = frexp(z, &ez);
328         oround = fegetround();
329         spread = ex + ey - ez;
330
331         /*
332          * If x * y and z are many orders of magnitude apart, the scaling
333          * will overflow, so we handle these cases specially.  Rounding
334          * modes other than FE_TONEAREST are painful.
335          */
336         if (spread < -DBL_MANT_DIG) {
337 #ifdef FE_INEXACT
338                 feraiseexcept(FE_INEXACT);
339 #endif
340 #ifdef FE_UNDERFLOW
341                 if (!isnormal(z))
342                         feraiseexcept(FE_UNDERFLOW);
343 #endif
344                 switch (oround) {
345                 default: /* FE_TONEAREST */
346                         return (z);
347 #ifdef FE_TOWARDZERO
348                 case FE_TOWARDZERO:
349                         if (x > 0.0 ^ y < 0.0 ^ z < 0.0)
350                                 return (z);
351                         else
352                                 return (nextafter(z, 0));
353 #endif
354 #ifdef FE_DOWNWARD
355                 case FE_DOWNWARD:
356                         if (x > 0.0 ^ y < 0.0)
357                                 return (z);
358                         else
359                                 return (nextafter(z, -INFINITY));
360 #endif
361 #ifdef FE_UPWARD
362                 case FE_UPWARD:
363                         if (x > 0.0 ^ y < 0.0)
364                                 return (nextafter(z, INFINITY));
365                         else
366                                 return (z);
367 #endif
368                 }
369         }
370         if (spread <= DBL_MANT_DIG * 2)
371                 zs = ldexp(zs, -spread);
372         else
373                 zs = copysign(DBL_MIN, zs);
374
375         fesetround(FE_TONEAREST);
376
377         /*
378          * Basic approach for round-to-nearest:
379          *
380          *     (xy.hi, xy.lo) = x * y           (exact)
381          *     (r.hi, r.lo)   = xy.hi + z       (exact)
382          *     adj = xy.lo + r.lo               (inexact; low bit is sticky)
383          *     result = r.hi + adj              (correctly rounded)
384          */
385         xy = dd_mul(xs, ys);
386         r = dd_add(xy.hi, zs);
387
388         spread = ex + ey;
389
390         if (r.hi == 0.0) {
391                 /*
392                  * When the addends cancel to 0, ensure that the result has
393                  * the correct sign.
394                  */
395                 fesetround(oround);
396                 volatile double vzs = zs; /* XXX gcc CSE bug workaround */
397                 return (xy.hi + vzs + ldexp(xy.lo, spread));
398         }
399
400         if (oround != FE_TONEAREST) {
401                 /*
402                  * There is no need to worry about double rounding in directed
403                  * rounding modes.
404                  */
405                 fesetround(oround);
406                 adj = r.lo + xy.lo;
407                 return (ldexp(r.hi + adj, spread));
408         }
409
410         adj = add_adjusted(r.lo, xy.lo);
411         if (spread + ilogb(r.hi) > -1023)
412                 return (ldexp(r.hi + adj, spread));
413         else
414                 return (add_and_denormalize(r.hi, adj, spread));
415 }
416 #endif