getopt: fix null pointer arithmetic ub
[musl] / src / math / sqrtl.c
1 #include <stdint.h>
2 #include <math.h>
3 #include <float.h>
4 #include "libm.h"
5
6 #if LDBL_MANT_DIG == 53 && LDBL_MAX_EXP == 1024
7 long double sqrtl(long double x)
8 {
9         return sqrt(x);
10 }
11 #elif (LDBL_MANT_DIG == 113 || LDBL_MANT_DIG == 64) && LDBL_MAX_EXP == 16384
12 #include "sqrt_data.h"
13
14 #define FENV_SUPPORT 1
15
16 typedef struct {
17         uint64_t hi;
18         uint64_t lo;
19 } u128;
20
21 /* top: 16 bit sign+exponent, x: significand.  */
22 static inline long double mkldbl(uint64_t top, u128 x)
23 {
24         union ldshape u;
25 #if LDBL_MANT_DIG == 113
26         u.i2.hi = x.hi;
27         u.i2.lo = x.lo;
28         u.i2.hi &= 0x0000ffffffffffff;
29         u.i2.hi |= top << 48;
30 #elif LDBL_MANT_DIG == 64
31         u.i.se = top;
32         u.i.m = x.lo;
33         /* force the top bit on non-zero (and non-subnormal) results.  */
34         if (top & 0x7fff)
35                 u.i.m |= 0x8000000000000000;
36 #endif
37         return u.f;
38 }
39
40 /* return: top 16 bit is sign+exp and following bits are the significand.  */
41 static inline u128 asu128(long double x)
42 {
43         union ldshape u = {.f=x};
44         u128 r;
45 #if LDBL_MANT_DIG == 113
46         r.hi = u.i2.hi;
47         r.lo = u.i2.lo;
48 #elif LDBL_MANT_DIG == 64
49         r.lo = u.i.m<<49;
50         /* ignore the top bit: pseudo numbers are not handled. */
51         r.hi = u.i.m>>15;
52         r.hi &= 0x0000ffffffffffff;
53         r.hi |= (uint64_t)u.i.se << 48;
54 #endif
55         return r;
56 }
57
58 /* returns a*b*2^-32 - e, with error 0 <= e < 1.  */
59 static inline uint32_t mul32(uint32_t a, uint32_t b)
60 {
61         return (uint64_t)a*b >> 32;
62 }
63
64 /* returns a*b*2^-64 - e, with error 0 <= e < 3.  */
65 static inline uint64_t mul64(uint64_t a, uint64_t b)
66 {
67         uint64_t ahi = a>>32;
68         uint64_t alo = a&0xffffffff;
69         uint64_t bhi = b>>32;
70         uint64_t blo = b&0xffffffff;
71         return ahi*bhi + (ahi*blo >> 32) + (alo*bhi >> 32);
72 }
73
74 static inline u128 add64(u128 a, uint64_t b)
75 {
76         u128 r;
77         r.lo = a.lo + b;
78         r.hi = a.hi;
79         if (r.lo < a.lo)
80                 r.hi++;
81         return r;
82 }
83
84 static inline u128 add128(u128 a, u128 b)
85 {
86         u128 r;
87         r.lo = a.lo + b.lo;
88         r.hi = a.hi + b.hi;
89         if (r.lo < a.lo)
90                 r.hi++;
91         return r;
92 }
93
94 static inline u128 sub64(u128 a, uint64_t b)
95 {
96         u128 r;
97         r.lo = a.lo - b;
98         r.hi = a.hi;
99         if (a.lo < b)
100                 r.hi--;
101         return r;
102 }
103
104 static inline u128 sub128(u128 a, u128 b)
105 {
106         u128 r;
107         r.lo = a.lo - b.lo;
108         r.hi = a.hi - b.hi;
109         if (a.lo < b.lo)
110                 r.hi--;
111         return r;
112 }
113
114 /* a<<n, 0 <= n <= 127 */
115 static inline u128 lsh(u128 a, int n)
116 {
117         if (n == 0)
118                 return a;
119         if (n >= 64) {
120                 a.hi = a.lo<<(n-64);
121                 a.lo = 0;
122         } else {
123                 a.hi = (a.hi<<n) | (a.lo>>(64-n));
124                 a.lo = a.lo<<n;
125         }
126         return a;
127 }
128
129 /* a>>n, 0 <= n <= 127 */
130 static inline u128 rsh(u128 a, int n)
131 {
132         if (n == 0)
133                 return a;
134         if (n >= 64) {
135                 a.lo = a.hi>>(n-64);
136                 a.hi = 0;
137         } else {
138                 a.lo = (a.lo>>n) | (a.hi<<(64-n));
139                 a.hi = a.hi>>n;
140         }
141         return a;
142 }
143
144 /* returns a*b exactly.  */
145 static inline u128 mul64_128(uint64_t a, uint64_t b)
146 {
147         u128 r;
148         uint64_t ahi = a>>32;
149         uint64_t alo = a&0xffffffff;
150         uint64_t bhi = b>>32;
151         uint64_t blo = b&0xffffffff;
152         uint64_t lo1 = ((ahi*blo)&0xffffffff) + ((alo*bhi)&0xffffffff) + (alo*blo>>32);
153         uint64_t lo2 = (alo*blo)&0xffffffff;
154         r.hi = ahi*bhi + (ahi*blo>>32) + (alo*bhi>>32) + (lo1>>32);
155         r.lo = (lo1<<32) + lo2;
156         return r;
157 }
158
159 /* returns a*b*2^-128 - e, with error 0 <= e < 7.  */
160 static inline u128 mul128(u128 a, u128 b)
161 {
162         u128 hi = mul64_128(a.hi, b.hi);
163         uint64_t m1 = mul64(a.hi, b.lo);
164         uint64_t m2 = mul64(a.lo, b.hi);
165         return add64(add64(hi, m1), m2);
166 }
167
168 /* returns a*b % 2^128.  */
169 static inline u128 mul128_tail(u128 a, u128 b)
170 {
171         u128 lo = mul64_128(a.lo, b.lo);
172         lo.hi += a.hi*b.lo + a.lo*b.hi;
173         return lo;
174 }
175
176
177 /* see sqrt.c for detailed comments.  */
178
179 long double sqrtl(long double x)
180 {
181         u128 ix, ml;
182         uint64_t top;
183
184         ix = asu128(x);
185         top = ix.hi >> 48;
186         if (predict_false(top - 0x0001 >= 0x7fff - 0x0001)) {
187                 /* x < 0x1p-16382 or inf or nan.  */
188                 if (2*ix.hi == 0 && ix.lo == 0)
189                         return x;
190                 if (ix.hi == 0x7fff000000000000 && ix.lo == 0)
191                         return x;
192                 if (top >= 0x7fff)
193                         return __math_invalidl(x);
194                 /* x is subnormal, normalize it.  */
195                 ix = asu128(x * 0x1p112);
196                 top = ix.hi >> 48;
197                 top -= 112;
198         }
199
200         /* x = 4^e m; with int e and m in [1, 4) */
201         int even = top & 1;
202         ml = lsh(ix, 15);
203         ml.hi |= 0x8000000000000000;
204         if (even) ml = rsh(ml, 1);
205         top = (top + 0x3fff) >> 1;
206
207         /* r ~ 1/sqrt(m) */
208         static const uint64_t three = 0xc0000000;
209         uint64_t r, s, d, u, i;
210         i = (ix.hi >> 42) % 128;
211         r = (uint32_t)__rsqrt_tab[i] << 16;
212         /* |r sqrt(m) - 1| < 0x1p-8 */
213         s = mul32(ml.hi>>32, r);
214         d = mul32(s, r);
215         u = three - d;
216         r = mul32(u, r) << 1;
217         /* |r sqrt(m) - 1| < 0x1.7bp-16, switch to 64bit */
218         r = r<<32;
219         s = mul64(ml.hi, r);
220         d = mul64(s, r);
221         u = (three<<32) - d;
222         r = mul64(u, r) << 1;
223         /* |r sqrt(m) - 1| < 0x1.a5p-31 */
224         s = mul64(u, s) << 1;
225         d = mul64(s, r);
226         u = (three<<32) - d;
227         r = mul64(u, r) << 1;
228         /* |r sqrt(m) - 1| < 0x1.c001p-59, switch to 128bit */
229
230         static const u128 threel = {.hi=three<<32, .lo=0};
231         u128 rl, sl, dl, ul;
232         rl.hi = r;
233         rl.lo = 0;
234         sl = mul128(ml, rl);
235         dl = mul128(sl, rl);
236         ul = sub128(threel, dl);
237         sl = mul128(ul, sl); /* repr: 3.125 */
238         /* -0x1p-116 < s - sqrt(m) < 0x3.8001p-125 */
239         sl = rsh(sub64(sl, 4), 125-(LDBL_MANT_DIG-1));
240         /* s < sqrt(m) < s + 1 ULP + tiny */
241
242         long double y;
243         u128 d2, d1, d0;
244         d0 = sub128(lsh(ml, 2*(LDBL_MANT_DIG-1)-126), mul128_tail(sl,sl));
245         d1 = sub128(sl, d0);
246         d2 = add128(add64(sl, 1), d1);
247         sl = add64(sl, d1.hi >> 63);
248         y = mkldbl(top, sl);
249         if (FENV_SUPPORT) {
250                 /* handle rounding modes and inexact exception.  */
251                 top = predict_false((d2.hi|d2.lo)==0) ? 0 : 1;
252                 top |= ((d1.hi^d2.hi)&0x8000000000000000) >> 48;
253                 y += mkldbl(top, (u128){0});
254         }
255         return y;
256 }
257 #else
258 #error unsupported long double format
259 #endif