math: fix two fma issues (only affects non-nearest rounding mode, x86)
[musl] / src / math / fma.c
index c53f314..89def79 100644 (file)
@@ -1,3 +1,180 @@
+#include <fenv.h>
+#include "libm.h"
+
+#if LDBL_MANT_DIG==64 && LDBL_MAX_EXP==16384
+union ld80 {
+       long double x;
+       struct {
+               uint64_t m;
+               uint16_t e : 15;
+               uint16_t s : 1;
+               uint16_t pad;
+       } bits;
+};
+
+/* exact add, assumes exponent_x >= exponent_y */
+static void add(long double *hi, long double *lo, long double x, long double y)
+{
+       long double r;
+
+       r = x + y;
+       *hi = r;
+       r -= x;
+       *lo = y - r;
+}
+
+/* exact mul, assumes no over/underflow */
+static void mul(long double *hi, long double *lo, long double x, long double y)
+{
+       static const long double c = 1.0 + 0x1p32L;
+       long double cx, xh, xl, cy, yh, yl;
+
+       cx = c*x;
+       xh = (x - cx) + cx;
+       xl = x - xh;
+       cy = c*y;
+       yh = (y - cy) + cy;
+       yl = y - yh;
+       *hi = x*y;
+       *lo = (xh*yh - *hi) + xh*yl + xl*yh + xl*yl;
+}
+
+/*
+assume (long double)(hi+lo) == hi
+return an adjusted hi so that rounding it to double (or less) precision is correct
+*/
+static long double adjust(long double hi, long double lo)
+{
+       union ld80 uhi, ulo;
+
+       if (lo == 0)
+               return hi;
+       uhi.x = hi;
+       if (uhi.bits.m & 0x3ff)
+               return hi;
+       ulo.x = lo;
+       if (uhi.bits.s == ulo.bits.s)
+               uhi.bits.m++;
+       else {
+               uhi.bits.m--;
+               /* handle underflow and take care of ld80 implicit msb */
+               if (uhi.bits.m == (uint64_t)-1/2) {
+                       uhi.bits.m *= 2;
+                       uhi.bits.e--;
+               }
+       }
+       return uhi.x;
+}
+
+/* adjusted add so the result is correct when rounded to double (or less) precision */
+static long double dadd(long double x, long double y)
+{
+       add(&x, &y, x, y);
+       return adjust(x, y);
+}
+
+/* adjusted mul so the result is correct when rounded to double (or less) precision */
+static long double dmul(long double x, long double y)
+{
+       mul(&x, &y, x, y);
+       return adjust(x, y);
+}
+
+static int getexp(long double x)
+{
+       union ld80 u;
+       u.x = x;
+       return u.bits.e;
+}
+
+double fma(double x, double y, double z)
+{
+       #pragma STDC FENV_ACCESS ON
+       long double hi, lo1, lo2, xy;
+       int round, ez, exy;
+
+       /* handle +-inf,nan */
+       if (!isfinite(x) || !isfinite(y))
+               return x*y + z;
+       if (!isfinite(z))
+               return z;
+       /* handle +-0 */
+       if (x == 0.0 || y == 0.0)
+               return x*y + z;
+       round = fegetround();
+       if (z == 0.0) {
+               if (round == FE_TONEAREST)
+                       return dmul(x, y);
+               return x*y;
+       }
+
+       /* exact mul and add require nearest rounding */
+       /* spurious inexact exceptions may be raised */
+       fesetround(FE_TONEAREST);
+       mul(&xy, &lo1, x, y);
+       exy = getexp(xy);
+       ez = getexp(z);
+       if (ez > exy) {
+               add(&hi, &lo2, z, xy);
+       } else if (ez > exy - 12) {
+               add(&hi, &lo2, xy, z);
+               if (hi == 0) {
+                       /*
+                       xy + z is 0, but it should be calculated with the
+                       original rounding mode so the sign is correct, if the
+                       compiler does not support FENV_ACCESS ON it does not
+                       know about the changed rounding mode and eliminates
+                       the xy + z below without the volatile memory access
+                       */
+                       volatile double z_;
+                       fesetround(round);
+                       z_ = z;
+                       return (xy + z_) + lo1;
+               }
+       } else {
+               /*
+               ez <= exy - 12
+               the 12 extra bits (1guard, 11round+sticky) are needed so with
+                       lo = dadd(lo1, lo2)
+               elo <= ehi - 11, and we use the last 10 bits in adjust so
+                       dadd(hi, lo)
+               gives correct result when rounded to double
+               */
+               hi = xy;
+               lo2 = z;
+       }
+       /*
+       the result is stored before return for correct precision and exceptions
+
+       one corner case is when the underflow flag should be raised because
+       the precise result is an inexact subnormal double, but the calculated
+       long double result is an exact subnormal double
+       (so rounding to double does not raise exceptions)
+
+       in nearest rounding mode dadd takes care of this: the last bit of the
+       result is adjusted so rounding sees an inexact value when it should
+
+       in non-nearest rounding mode fenv is used for the workaround
+       */
+       fesetround(round);
+       if (round == FE_TONEAREST)
+               z = dadd(hi, dadd(lo1, lo2));
+       else {
+#if defined(FE_INEXACT) && defined(FE_UNDERFLOW)
+               int e = fetestexcept(FE_INEXACT);
+               feclearexcept(FE_INEXACT);
+#endif
+               z = hi + (lo1 + lo2);
+#if defined(FE_INEXACT) && defined(FE_UNDERFLOW)
+               if (getexp(z) < 0x3fff-1022 && fetestexcept(FE_INEXACT))
+                       feraiseexcept(FE_UNDERFLOW);
+               else if (e)
+                       feraiseexcept(FE_INEXACT);
+#endif
+       }
+       return z;
+}
+#else
 /* origin: FreeBSD /usr/src/lib/msun/src/s_fma.c */
 /*-
  * Copyright (c) 2005-2011 David Schultz <das@FreeBSD.ORG>
  * SUCH DAMAGE.
  */
 
-#include <fenv.h>
-#include "libm.h"
-
 /*
  * A struct dd represents a floating-point number with twice the precision
  * of a double.  We maintain the invariant that "hi" stores the 53 high-order
@@ -116,7 +290,7 @@ static inline double add_and_denormalize(double a, double b, int scale)
                        INSERT_WORD64(sum.hi, hibits);
                }
        }
-       return (ldexp(sum.hi, scale));
+       return scalbn(sum.hi, scale);
 }
 
 /*
@@ -167,6 +341,7 @@ static inline struct dd dd_mul(double a, double b)
  */
 double fma(double x, double y, double z)
 {
+       #pragma STDC FENV_ACCESS ON
        double xs, ys, zs, adj;
        struct dd xy, r;
        int oround;
@@ -178,14 +353,14 @@ double fma(double x, double y, double z)
         * return values here are crucial in handling special cases involving
         * infinities, NaNs, overflows, and signed zeroes correctly.
         */
-       if (x == 0.0 || y == 0.0)
-               return (x * y + z);
-       if (z == 0.0)
-               return (x * y);
        if (!isfinite(x) || !isfinite(y))
                return (x * y + z);
        if (!isfinite(z))
                return (z);
+       if (x == 0.0 || y == 0.0)
+               return (x * y + z);
+       if (z == 0.0)
+               return (x * y);
 
        xs = frexp(x, &ex);
        ys = frexp(y, &ey);
@@ -199,31 +374,41 @@ double fma(double x, double y, double z)
         * modes other than FE_TONEAREST are painful.
         */
        if (spread < -DBL_MANT_DIG) {
+#ifdef FE_INEXACT
                feraiseexcept(FE_INEXACT);
+#endif
+#ifdef FE_UNDERFLOW
                if (!isnormal(z))
                        feraiseexcept(FE_UNDERFLOW);
+#endif
                switch (oround) {
-               case FE_TONEAREST:
+               default: /* FE_TONEAREST */
                        return (z);
+#ifdef FE_TOWARDZERO
                case FE_TOWARDZERO:
                        if (x > 0.0 ^ y < 0.0 ^ z < 0.0)
                                return (z);
                        else
                                return (nextafter(z, 0));
+#endif
+#ifdef FE_DOWNWARD
                case FE_DOWNWARD:
                        if (x > 0.0 ^ y < 0.0)
                                return (z);
                        else
                                return (nextafter(z, -INFINITY));
-               default:        /* FE_UPWARD */
+#endif
+#ifdef FE_UPWARD
+               case FE_UPWARD:
                        if (x > 0.0 ^ y < 0.0)
                                return (nextafter(z, INFINITY));
                        else
                                return (z);
+#endif
                }
        }
        if (spread <= DBL_MANT_DIG * 2)
-               zs = ldexp(zs, -spread);
+               zs = scalbn(zs, -spread);
        else
                zs = copysign(DBL_MIN, zs);
 
@@ -249,7 +434,7 @@ double fma(double x, double y, double z)
                 */
                fesetround(oround);
                volatile double vzs = zs; /* XXX gcc CSE bug workaround */
-               return (xy.hi + vzs + ldexp(xy.lo, spread));
+               return xy.hi + vzs + scalbn(xy.lo, spread);
        }
 
        if (oround != FE_TONEAREST) {
@@ -259,12 +444,13 @@ double fma(double x, double y, double z)
                 */
                fesetround(oround);
                adj = r.lo + xy.lo;
-               return (ldexp(r.hi + adj, spread));
+               return scalbn(r.hi + adj, spread);
        }
 
        adj = add_adjusted(r.lo, xy.lo);
        if (spread + ilogb(r.hi) > -1023)
-               return (ldexp(r.hi + adj, spread));
+               return scalbn(r.hi + adj, spread);
        else
-               return (add_and_denormalize(r.hi, adj, spread));
+               return add_and_denormalize(r.hi, adj, spread);
 }
+#endif