math: fix two fma issues (only affects non-nearest rounding mode, x86)
[musl] / src / math / fma.c
1 #include <fenv.h>
2 #include "libm.h"
3
4 #if LDBL_MANT_DIG==64 && LDBL_MAX_EXP==16384
5 union ld80 {
6         long double x;
7         struct {
8                 uint64_t m;
9                 uint16_t e : 15;
10                 uint16_t s : 1;
11                 uint16_t pad;
12         } bits;
13 };
14
15 /* exact add, assumes exponent_x >= exponent_y */
16 static void add(long double *hi, long double *lo, long double x, long double y)
17 {
18         long double r;
19
20         r = x + y;
21         *hi = r;
22         r -= x;
23         *lo = y - r;
24 }
25
26 /* exact mul, assumes no over/underflow */
27 static void mul(long double *hi, long double *lo, long double x, long double y)
28 {
29         static const long double c = 1.0 + 0x1p32L;
30         long double cx, xh, xl, cy, yh, yl;
31
32         cx = c*x;
33         xh = (x - cx) + cx;
34         xl = x - xh;
35         cy = c*y;
36         yh = (y - cy) + cy;
37         yl = y - yh;
38         *hi = x*y;
39         *lo = (xh*yh - *hi) + xh*yl + xl*yh + xl*yl;
40 }
41
42 /*
43 assume (long double)(hi+lo) == hi
44 return an adjusted hi so that rounding it to double (or less) precision is correct
45 */
46 static long double adjust(long double hi, long double lo)
47 {
48         union ld80 uhi, ulo;
49
50         if (lo == 0)
51                 return hi;
52         uhi.x = hi;
53         if (uhi.bits.m & 0x3ff)
54                 return hi;
55         ulo.x = lo;
56         if (uhi.bits.s == ulo.bits.s)
57                 uhi.bits.m++;
58         else {
59                 uhi.bits.m--;
60                 /* handle underflow and take care of ld80 implicit msb */
61                 if (uhi.bits.m == (uint64_t)-1/2) {
62                         uhi.bits.m *= 2;
63                         uhi.bits.e--;
64                 }
65         }
66         return uhi.x;
67 }
68
69 /* adjusted add so the result is correct when rounded to double (or less) precision */
70 static long double dadd(long double x, long double y)
71 {
72         add(&x, &y, x, y);
73         return adjust(x, y);
74 }
75
76 /* adjusted mul so the result is correct when rounded to double (or less) precision */
77 static long double dmul(long double x, long double y)
78 {
79         mul(&x, &y, x, y);
80         return adjust(x, y);
81 }
82
83 static int getexp(long double x)
84 {
85         union ld80 u;
86         u.x = x;
87         return u.bits.e;
88 }
89
90 double fma(double x, double y, double z)
91 {
92         #pragma STDC FENV_ACCESS ON
93         long double hi, lo1, lo2, xy;
94         int round, ez, exy;
95
96         /* handle +-inf,nan */
97         if (!isfinite(x) || !isfinite(y))
98                 return x*y + z;
99         if (!isfinite(z))
100                 return z;
101         /* handle +-0 */
102         if (x == 0.0 || y == 0.0)
103                 return x*y + z;
104         round = fegetround();
105         if (z == 0.0) {
106                 if (round == FE_TONEAREST)
107                         return dmul(x, y);
108                 return x*y;
109         }
110
111         /* exact mul and add require nearest rounding */
112         /* spurious inexact exceptions may be raised */
113         fesetround(FE_TONEAREST);
114         mul(&xy, &lo1, x, y);
115         exy = getexp(xy);
116         ez = getexp(z);
117         if (ez > exy) {
118                 add(&hi, &lo2, z, xy);
119         } else if (ez > exy - 12) {
120                 add(&hi, &lo2, xy, z);
121                 if (hi == 0) {
122                         /*
123                         xy + z is 0, but it should be calculated with the
124                         original rounding mode so the sign is correct, if the
125                         compiler does not support FENV_ACCESS ON it does not
126                         know about the changed rounding mode and eliminates
127                         the xy + z below without the volatile memory access
128                         */
129                         volatile double z_;
130                         fesetround(round);
131                         z_ = z;
132                         return (xy + z_) + lo1;
133                 }
134         } else {
135                 /*
136                 ez <= exy - 12
137                 the 12 extra bits (1guard, 11round+sticky) are needed so with
138                         lo = dadd(lo1, lo2)
139                 elo <= ehi - 11, and we use the last 10 bits in adjust so
140                         dadd(hi, lo)
141                 gives correct result when rounded to double
142                 */
143                 hi = xy;
144                 lo2 = z;
145         }
146         /*
147         the result is stored before return for correct precision and exceptions
148
149         one corner case is when the underflow flag should be raised because
150         the precise result is an inexact subnormal double, but the calculated
151         long double result is an exact subnormal double
152         (so rounding to double does not raise exceptions)
153
154         in nearest rounding mode dadd takes care of this: the last bit of the
155         result is adjusted so rounding sees an inexact value when it should
156
157         in non-nearest rounding mode fenv is used for the workaround
158         */
159         fesetround(round);
160         if (round == FE_TONEAREST)
161                 z = dadd(hi, dadd(lo1, lo2));
162         else {
163 #if defined(FE_INEXACT) && defined(FE_UNDERFLOW)
164                 int e = fetestexcept(FE_INEXACT);
165                 feclearexcept(FE_INEXACT);
166 #endif
167                 z = hi + (lo1 + lo2);
168 #if defined(FE_INEXACT) && defined(FE_UNDERFLOW)
169                 if (getexp(z) < 0x3fff-1022 && fetestexcept(FE_INEXACT))
170                         feraiseexcept(FE_UNDERFLOW);
171                 else if (e)
172                         feraiseexcept(FE_INEXACT);
173 #endif
174         }
175         return z;
176 }
177 #else
178 /* origin: FreeBSD /usr/src/lib/msun/src/s_fma.c */
179 /*-
180  * Copyright (c) 2005-2011 David Schultz <das@FreeBSD.ORG>
181  * All rights reserved.
182  *
183  * Redistribution and use in source and binary forms, with or without
184  * modification, are permitted provided that the following conditions
185  * are met:
186  * 1. Redistributions of source code must retain the above copyright
187  *    notice, this list of conditions and the following disclaimer.
188  * 2. Redistributions in binary form must reproduce the above copyright
189  *    notice, this list of conditions and the following disclaimer in the
190  *    documentation and/or other materials provided with the distribution.
191  *
192  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
193  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
194  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
195  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
196  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
197  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
198  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
199  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
200  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
201  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
202  * SUCH DAMAGE.
203  */
204
205 /*
206  * A struct dd represents a floating-point number with twice the precision
207  * of a double.  We maintain the invariant that "hi" stores the 53 high-order
208  * bits of the result.
209  */
210 struct dd {
211         double hi;
212         double lo;
213 };
214
215 /*
216  * Compute a+b exactly, returning the exact result in a struct dd.  We assume
217  * that both a and b are finite, but make no assumptions about their relative
218  * magnitudes.
219  */
220 static inline struct dd dd_add(double a, double b)
221 {
222         struct dd ret;
223         double s;
224
225         ret.hi = a + b;
226         s = ret.hi - a;
227         ret.lo = (a - (ret.hi - s)) + (b - s);
228         return (ret);
229 }
230
231 /*
232  * Compute a+b, with a small tweak:  The least significant bit of the
233  * result is adjusted into a sticky bit summarizing all the bits that
234  * were lost to rounding.  This adjustment negates the effects of double
235  * rounding when the result is added to another number with a higher
236  * exponent.  For an explanation of round and sticky bits, see any reference
237  * on FPU design, e.g.,
238  *
239  *     J. Coonen.  An Implementation Guide to a Proposed Standard for
240  *     Floating-Point Arithmetic.  Computer, vol. 13, no. 1, Jan 1980.
241  */
242 static inline double add_adjusted(double a, double b)
243 {
244         struct dd sum;
245         uint64_t hibits, lobits;
246
247         sum = dd_add(a, b);
248         if (sum.lo != 0) {
249                 EXTRACT_WORD64(hibits, sum.hi);
250                 if ((hibits & 1) == 0) {
251                         /* hibits += (int)copysign(1.0, sum.hi * sum.lo) */
252                         EXTRACT_WORD64(lobits, sum.lo);
253                         hibits += 1 - ((hibits ^ lobits) >> 62);
254                         INSERT_WORD64(sum.hi, hibits);
255                 }
256         }
257         return (sum.hi);
258 }
259
260 /*
261  * Compute ldexp(a+b, scale) with a single rounding error. It is assumed
262  * that the result will be subnormal, and care is taken to ensure that
263  * double rounding does not occur.
264  */
265 static inline double add_and_denormalize(double a, double b, int scale)
266 {
267         struct dd sum;
268         uint64_t hibits, lobits;
269         int bits_lost;
270
271         sum = dd_add(a, b);
272
273         /*
274          * If we are losing at least two bits of accuracy to denormalization,
275          * then the first lost bit becomes a round bit, and we adjust the
276          * lowest bit of sum.hi to make it a sticky bit summarizing all the
277          * bits in sum.lo. With the sticky bit adjusted, the hardware will
278          * break any ties in the correct direction.
279          *
280          * If we are losing only one bit to denormalization, however, we must
281          * break the ties manually.
282          */
283         if (sum.lo != 0) {
284                 EXTRACT_WORD64(hibits, sum.hi);
285                 bits_lost = -((int)(hibits >> 52) & 0x7ff) - scale + 1;
286                 if (bits_lost != 1 ^ (int)(hibits & 1)) {
287                         /* hibits += (int)copysign(1.0, sum.hi * sum.lo) */
288                         EXTRACT_WORD64(lobits, sum.lo);
289                         hibits += 1 - (((hibits ^ lobits) >> 62) & 2);
290                         INSERT_WORD64(sum.hi, hibits);
291                 }
292         }
293         return scalbn(sum.hi, scale);
294 }
295
296 /*
297  * Compute a*b exactly, returning the exact result in a struct dd.  We assume
298  * that both a and b are normalized, so no underflow or overflow will occur.
299  * The current rounding mode must be round-to-nearest.
300  */
301 static inline struct dd dd_mul(double a, double b)
302 {
303         static const double split = 0x1p27 + 1.0;
304         struct dd ret;
305         double ha, hb, la, lb, p, q;
306
307         p = a * split;
308         ha = a - p;
309         ha += p;
310         la = a - ha;
311
312         p = b * split;
313         hb = b - p;
314         hb += p;
315         lb = b - hb;
316
317         p = ha * hb;
318         q = ha * lb + la * hb;
319
320         ret.hi = p + q;
321         ret.lo = p - ret.hi + q + la * lb;
322         return (ret);
323 }
324
325 /*
326  * Fused multiply-add: Compute x * y + z with a single rounding error.
327  *
328  * We use scaling to avoid overflow/underflow, along with the
329  * canonical precision-doubling technique adapted from:
330  *
331  *      Dekker, T.  A Floating-Point Technique for Extending the
332  *      Available Precision.  Numer. Math. 18, 224-242 (1971).
333  *
334  * This algorithm is sensitive to the rounding precision.  FPUs such
335  * as the i387 must be set in double-precision mode if variables are
336  * to be stored in FP registers in order to avoid incorrect results.
337  * This is the default on FreeBSD, but not on many other systems.
338  *
339  * Hardware instructions should be used on architectures that support it,
340  * since this implementation will likely be several times slower.
341  */
342 double fma(double x, double y, double z)
343 {
344         #pragma STDC FENV_ACCESS ON
345         double xs, ys, zs, adj;
346         struct dd xy, r;
347         int oround;
348         int ex, ey, ez;
349         int spread;
350
351         /*
352          * Handle special cases. The order of operations and the particular
353          * return values here are crucial in handling special cases involving
354          * infinities, NaNs, overflows, and signed zeroes correctly.
355          */
356         if (!isfinite(x) || !isfinite(y))
357                 return (x * y + z);
358         if (!isfinite(z))
359                 return (z);
360         if (x == 0.0 || y == 0.0)
361                 return (x * y + z);
362         if (z == 0.0)
363                 return (x * y);
364
365         xs = frexp(x, &ex);
366         ys = frexp(y, &ey);
367         zs = frexp(z, &ez);
368         oround = fegetround();
369         spread = ex + ey - ez;
370
371         /*
372          * If x * y and z are many orders of magnitude apart, the scaling
373          * will overflow, so we handle these cases specially.  Rounding
374          * modes other than FE_TONEAREST are painful.
375          */
376         if (spread < -DBL_MANT_DIG) {
377 #ifdef FE_INEXACT
378                 feraiseexcept(FE_INEXACT);
379 #endif
380 #ifdef FE_UNDERFLOW
381                 if (!isnormal(z))
382                         feraiseexcept(FE_UNDERFLOW);
383 #endif
384                 switch (oround) {
385                 default: /* FE_TONEAREST */
386                         return (z);
387 #ifdef FE_TOWARDZERO
388                 case FE_TOWARDZERO:
389                         if (x > 0.0 ^ y < 0.0 ^ z < 0.0)
390                                 return (z);
391                         else
392                                 return (nextafter(z, 0));
393 #endif
394 #ifdef FE_DOWNWARD
395                 case FE_DOWNWARD:
396                         if (x > 0.0 ^ y < 0.0)
397                                 return (z);
398                         else
399                                 return (nextafter(z, -INFINITY));
400 #endif
401 #ifdef FE_UPWARD
402                 case FE_UPWARD:
403                         if (x > 0.0 ^ y < 0.0)
404                                 return (nextafter(z, INFINITY));
405                         else
406                                 return (z);
407 #endif
408                 }
409         }
410         if (spread <= DBL_MANT_DIG * 2)
411                 zs = scalbn(zs, -spread);
412         else
413                 zs = copysign(DBL_MIN, zs);
414
415         fesetround(FE_TONEAREST);
416
417         /*
418          * Basic approach for round-to-nearest:
419          *
420          *     (xy.hi, xy.lo) = x * y           (exact)
421          *     (r.hi, r.lo)   = xy.hi + z       (exact)
422          *     adj = xy.lo + r.lo               (inexact; low bit is sticky)
423          *     result = r.hi + adj              (correctly rounded)
424          */
425         xy = dd_mul(xs, ys);
426         r = dd_add(xy.hi, zs);
427
428         spread = ex + ey;
429
430         if (r.hi == 0.0) {
431                 /*
432                  * When the addends cancel to 0, ensure that the result has
433                  * the correct sign.
434                  */
435                 fesetround(oround);
436                 volatile double vzs = zs; /* XXX gcc CSE bug workaround */
437                 return xy.hi + vzs + scalbn(xy.lo, spread);
438         }
439
440         if (oround != FE_TONEAREST) {
441                 /*
442                  * There is no need to worry about double rounding in directed
443                  * rounding modes.
444                  */
445                 fesetround(oround);
446                 adj = r.lo + xy.lo;
447                 return scalbn(r.hi + adj, spread);
448         }
449
450         adj = add_adjusted(r.lo, xy.lo);
451         if (spread + ilogb(r.hi) > -1023)
452                 return scalbn(r.hi + adj, spread);
453         else
454                 return add_and_denormalize(r.hi, adj, spread);
455 }
456 #endif