prevent invalid reads of nl_arg in printf_core
[musl] / src / stdio / vfprintf.c
index 0be7549..4555795 100644 (file)
@@ -4,6 +4,8 @@
 #include <limits.h>
 #include <string.h>
 #include <stdarg.h>
+#include <stddef.h>
+#include <stdlib.h>
 #include <wchar.h>
 #include <inttypes.h>
 #include <math.h>
@@ -13,8 +15,6 @@
 
 #define MAX(a,b) ((a)>(b) ? (a) : (b))
 #define MIN(a,b) ((a)<(b) ? (a) : (b))
-#define CONCAT2(x,y) x ## y
-#define CONCAT(x,y) CONCAT2(x,y)
 
 /* Convenient bit representation for modifier flags, which all fall
  * within 31 codepoints of the space character. */
 
 #define FLAGMASK (ALT_FORM|ZERO_PAD|LEFT_ADJ|PAD_POS|MARK_POS|GROUPED)
 
-#if UINT_MAX == ULONG_MAX
-#define LONG_IS_INT
-#endif
-
-#if SIZE_MAX != ULONG_MAX || UINTMAX_MAX != ULLONG_MAX
-#define ODD_TYPES
-#endif
-
 /* State machine to accept length modifiers + conversion specifiers.
  * Result is 0 on failure, or an argument type to pop on success. */
 
@@ -44,23 +36,9 @@ enum {
        ZTPRE, JPRE,
        STOP,
        PTR, INT, UINT, ULLONG,
-#ifndef LONG_IS_INT
        LONG, ULONG,
-#else
-#define LONG INT
-#define ULONG UINT
-#endif
        SHORT, USHORT, CHAR, UCHAR,
-#ifdef ODD_TYPES
        LLONG, SIZET, IMAX, UMAX, PDIFF, UIPTR,
-#else
-#define LLONG ULLONG
-#define SIZET ULONG
-#define IMAX LLONG
-#define UMAX ULLONG
-#define PDIFF LONG
-#define UIPTR ULONG
-#endif
        DBL, LDBL,
        NOARG,
        MAXSTATE
@@ -130,29 +108,23 @@ union arg
 
 static void pop_arg(union arg *arg, int type, va_list *ap)
 {
-       /* Give the compiler a hint for optimizing the switch. */
-       if ((unsigned)type > MAXSTATE) return;
        switch (type) {
               case PTR:        arg->p = va_arg(*ap, void *);
        break; case INT:        arg->i = va_arg(*ap, int);
        break; case UINT:       arg->i = va_arg(*ap, unsigned int);
-#ifndef LONG_IS_INT
        break; case LONG:       arg->i = va_arg(*ap, long);
        break; case ULONG:      arg->i = va_arg(*ap, unsigned long);
-#endif
        break; case ULLONG:     arg->i = va_arg(*ap, unsigned long long);
        break; case SHORT:      arg->i = (short)va_arg(*ap, int);
        break; case USHORT:     arg->i = (unsigned short)va_arg(*ap, int);
        break; case CHAR:       arg->i = (signed char)va_arg(*ap, int);
        break; case UCHAR:      arg->i = (unsigned char)va_arg(*ap, int);
-#ifdef ODD_TYPES
        break; case LLONG:      arg->i = va_arg(*ap, long long);
        break; case SIZET:      arg->i = va_arg(*ap, size_t);
        break; case IMAX:       arg->i = va_arg(*ap, intmax_t);
        break; case UMAX:       arg->i = va_arg(*ap, uintmax_t);
        break; case PDIFF:      arg->i = va_arg(*ap, ptrdiff_t);
        break; case UIPTR:      arg->i = (uintptr_t)va_arg(*ap, void *);
-#endif
        break; case DBL:        arg->f = va_arg(*ap, double);
        break; case LDBL:       arg->f = va_arg(*ap, long double);
        }
@@ -160,7 +132,7 @@ static void pop_arg(union arg *arg, int type, va_list *ap)
 
 static void out(FILE *f, const char *s, size_t l)
 {
-       __fwritex((void *)s, l, f);
+       if (!(f->flags & F_ERR)) __fwritex((void *)s, l, f);
 }
 
 static void pad(FILE *f, char c, int w, int l, int fl)
@@ -227,7 +199,7 @@ static int fmt_fp(FILE *f, long double y, int w, int p, int fl, int t)
 
        if (!isfinite(y)) {
                char *s = (t&32)?"inf":"INF";
-               if (y!=y) s=(t&32)?"nan":"NAN", pl=0;
+               if (y!=y) s=(t&32)?"nan":"NAN";
                pad(f, ' ', w, 3+pl, fl&~ZERO_PAD);
                out(f, prefix, pl);
                out(f, s, 3);
@@ -249,6 +221,7 @@ static int fmt_fp(FILE *f, long double y, int w, int p, int fl, int t)
                else re=LDBL_MANT_DIG/4-1-p;
 
                if (re) {
+                       round *= 1<<(LDBL_MANT_DIG%4);
                        while (re--) round*=16;
                        if (*prefix=='-') {
                                y=-y;
@@ -274,6 +247,8 @@ static int fmt_fp(FILE *f, long double y, int w, int p, int fl, int t)
                        if (s-buf==1 && (y||p>0||(fl&ALT_FORM))) *s++='.';
                } while (y);
 
+               if (p > INT_MAX-2-(ebuf-estr)-pl)
+                       return -1;
                if (p && s-buf-2 < p)
                        l = (p+2) + (ebuf-estr);
                else
@@ -308,13 +283,13 @@ static int fmt_fp(FILE *f, long double y, int w, int p, int fl, int t)
                        *d = x % 1000000000;
                        carry = x / 1000000000;
                }
-               if (!z[-1] && z>a) z--;
                if (carry) *--a = carry;
+               while (z>a && !z[-1]) z--;
                e2-=sh;
        }
        while (e2<0) {
                uint32_t carry=0, *b;
-               int sh=MIN(9,-e2), need=1+(p+LDBL_MANT_DIG/3+8)/9;
+               int sh=MIN(9,-e2), need=1+(p+LDBL_MANT_DIG/3U+8)/9;
                for (d=a; d<z; d++) {
                        uint32_t rm = *d & (1<<sh)-1;
                        *d = (*d>>sh) + carry;
@@ -343,9 +318,10 @@ static int fmt_fp(FILE *f, long double y, int w, int p, int fl, int t)
                x = *d % i;
                /* Are there any significant digits past j? */
                if (x || d+1!=z) {
-                       long double round = CONCAT(0x1p,LDBL_MANT_DIG);
+                       long double round = 2/LDBL_EPSILON;
                        long double small;
-                       if (*d/i & 1) round += 2;
+                       if ((*d/i & 1) || (i==1000000000 && d>a && (d[-1]&1)))
+                               round += 2;
                        if (x<i/2) small=0x0.8p0;
                        else if (x==i/2 && d+1==z) small=0x1.0p0;
                        else small=0x1.8p0;
@@ -385,17 +361,22 @@ static int fmt_fp(FILE *f, long double y, int w, int p, int fl, int t)
                                p = MIN(p,MAX(0,9*(z-r-1)+e-j));
                }
        }
+       if (p > INT_MAX-1-(p || (fl&ALT_FORM)))
+               return -1;
        l = 1 + p + (p || (fl&ALT_FORM));
        if ((t|32)=='f') {
+               if (e > INT_MAX-l) return -1;
                if (e>0) l+=e;
        } else {
                estr=fmt_u(e<0 ? -e : e, ebuf);
                while(ebuf-estr<2) *--estr='0';
                *--estr = (e<0 ? '-' : '+');
                *--estr = t;
+               if (ebuf-estr > INT_MAX-l) return -1;
                l += ebuf-estr;
        }
 
+       if (l > INT_MAX-pl) return -1;
        pad(f, ' ', w, pl+l, fl);
        out(f, prefix, pl);
        pad(f, '0', w, pl+l, fl^ZERO_PAD);
@@ -439,8 +420,10 @@ static int fmt_fp(FILE *f, long double y, int w, int p, int fl, int t)
 
 static int getint(char **s) {
        int i;
-       for (i=0; isdigit(**s); (*s)++)
-               i = 10*i + (**s-'0');
+       for (i=0; isdigit(**s); (*s)++) {
+               if (i > INT_MAX/10U || **s-'0' > INT_MAX-10*i) i = -1;
+               else i = 10*i + (**s-'0');
+       }
        return i;
 }
 
@@ -448,12 +431,12 @@ static int printf_core(FILE *f, const char *fmt, va_list *ap, union arg *nl_arg,
 {
        char *a, *z, *s=(char *)fmt;
        unsigned l10n=0, fl;
-       int w, p;
+       int w, p, xp;
        union arg arg;
        int argpos;
        unsigned st, ps;
        int cnt=0, l=0;
-       int i;
+       size_t i;
        char buf[sizeof(uintmax_t)*3+3+LDBL_MANT_DIG/4];
        const char *prefix;
        int t, pl;
@@ -461,18 +444,19 @@ static int printf_core(FILE *f, const char *fmt, va_list *ap, union arg *nl_arg,
        char mb[4];
 
        for (;;) {
+               /* This error is only specified for snprintf, but since it's
+                * unspecified for other forms, do the same. Stop immediately
+                * on overflow; otherwise %n could produce wrong results. */
+               if (l > INT_MAX - cnt) goto overflow;
+
                /* Update output count, end loop when fmt is exhausted */
-               if (cnt >= 0) {
-                       if (l > INT_MAX - cnt) {
-                               errno = EOVERFLOW;
-                               cnt = -1;
-                       } else cnt += l;
-               }
+               cnt += l;
                if (!*s) break;
 
                /* Handle literal text and %% format specifiers */
                for (a=s; *s && *s!='%'; s++);
                for (z=s; s[0]=='%' && s[1]=='%'; z++, s+=2);
+               if (z-a > INT_MAX-cnt) goto overflow;
                l = z-a;
                if (f) out(f, a, l);
                if (l) continue;
@@ -494,46 +478,53 @@ static int printf_core(FILE *f, const char *fmt, va_list *ap, union arg *nl_arg,
                if (*s=='*') {
                        if (isdigit(s[1]) && s[2]=='$') {
                                l10n=1;
-                               nl_type[s[1]-'0'] = INT;
-                               w = nl_arg[s[1]-'0'].i;
+                               if (!f) nl_type[s[1]-'0'] = INT, w = 0;
+                               else w = nl_arg[s[1]-'0'].i;
                                s+=3;
                        } else if (!l10n) {
                                w = f ? va_arg(*ap, int) : 0;
                                s++;
-                       } else return -1;
+                       } else goto inval;
                        if (w<0) fl|=LEFT_ADJ, w=-w;
-               } else if ((w=getint(&s))<0) return -1;
+               } else if ((w=getint(&s))<0) goto overflow;
 
                /* Read precision */
                if (*s=='.' && s[1]=='*') {
                        if (isdigit(s[2]) && s[3]=='$') {
-                               nl_type[s[2]-'0'] = INT;
-                               p = nl_arg[s[2]-'0'].i;
+                               if (!f) nl_type[s[2]-'0'] = INT, p = 0;
+                               else p = nl_arg[s[2]-'0'].i;
                                s+=4;
                        } else if (!l10n) {
                                p = f ? va_arg(*ap, int) : 0;
                                s+=2;
-                       } else return -1;
+                       } else goto inval;
+                       xp = (p>=0);
                } else if (*s=='.') {
                        s++;
                        p = getint(&s);
-               } else p = -1;
+                       xp = 1;
+               } else {
+                       p = -1;
+                       xp = 0;
+               }
 
                /* Format specifier state machine */
                st=0;
                do {
-                       if (OOB(*s)) return -1;
+                       if (OOB(*s)) goto inval;
                        ps=st;
                        st=states[st]S(*s++);
                } while (st-1<STOP);
-               if (!st) return -1;
+               if (!st) goto inval;
 
                /* Check validity of argument type (nl/normal) */
                if (st==NOARG) {
-                       if (argpos>=0) return -1;
+                       if (argpos>=0) goto inval;
                } else {
-                       if (argpos>=0) nl_type[argpos]=st, arg=nl_arg[argpos];
-                       else if (f) pop_arg(&arg, st, ap);
+                       if (argpos>=0) {
+                               if (!f) nl_type[argpos]=st;
+                               else arg=nl_arg[argpos];
+                       } else if (f) pop_arg(&arg, st, ap);
                        else return 0;
                }
 
@@ -572,7 +563,7 @@ static int printf_core(FILE *f, const char *fmt, va_list *ap, union arg *nl_arg,
                        if (0) {
                case 'o':
                        a = fmt_o(arg.i, z);
-                       if ((fl&ALT_FORM) && arg.i) prefix+=5, pl=1;
+                       if ((fl&ALT_FORM) && p<z-a+1) p=z-a+1;
                        } if (0) {
                case 'd': case 'i':
                        pl=1;
@@ -586,7 +577,8 @@ static int printf_core(FILE *f, const char *fmt, va_list *ap, union arg *nl_arg,
                case 'u':
                        a = fmt_u(arg.i, z);
                        }
-                       if (p>=0) fl &= ~ZERO_PAD;
+                       if (xp && p<0) goto overflow;
+                       if (xp) fl &= ~ZERO_PAD;
                        if (!arg.i && !p) {
                                a=z;
                                break;
@@ -601,9 +593,9 @@ static int printf_core(FILE *f, const char *fmt, va_list *ap, union arg *nl_arg,
                        if (1) a = strerror(errno); else
                case 's':
                        a = arg.p ? arg.p : "(null)";
-                       z = memchr(a, 0, p);
-                       if (!z) z=a+p;
-                       else p=z-a;
+                       z = a + strnlen(a, p<0 ? INT_MAX : p);
+                       if (p<0 && *z) goto overflow;
+                       p = z-a;
                        fl &= ~ZERO_PAD;
                        break;
                case 'C':
@@ -613,8 +605,9 @@ static int printf_core(FILE *f, const char *fmt, va_list *ap, union arg *nl_arg,
                        p = -1;
                case 'S':
                        ws = arg.p;
-                       for (i=l=0; i<0U+p && *ws && (l=wctomb(mb, *ws++))>=0 && l<=0U+p-i; i+=l);
+                       for (i=l=0; i<p && *ws && (l=wctomb(mb, *ws++))>=0 && l<=p-i; i+=l);
                        if (l<0) return -1;
+                       if (i > INT_MAX) goto overflow;
                        p = i;
                        pad(f, ' ', w, p, fl);
                        ws = arg.p;
@@ -625,12 +618,16 @@ static int printf_core(FILE *f, const char *fmt, va_list *ap, union arg *nl_arg,
                        continue;
                case 'e': case 'f': case 'g': case 'a':
                case 'E': case 'F': case 'G': case 'A':
+                       if (xp && p<0) goto overflow;
                        l = fmt_fp(f, arg.f, w, p, fl, t);
+                       if (l<0) goto overflow;
                        continue;
                }
 
                if (p < z-a) p = z-a;
+               if (p > INT_MAX-pl) goto overflow;
                if (w < pl+p) w = pl+p;
+               if (w > INT_MAX-cnt) goto overflow;
 
                pad(f, ' ', w, pl+p, fl);
                out(f, prefix, pl);
@@ -648,8 +645,15 @@ static int printf_core(FILE *f, const char *fmt, va_list *ap, union arg *nl_arg,
        for (i=1; i<=NL_ARGMAX && nl_type[i]; i++)
                pop_arg(nl_arg+i, nl_type[i], ap);
        for (; i<=NL_ARGMAX && !nl_type[i]; i++);
-       if (i<=NL_ARGMAX) return -1;
+       if (i<=NL_ARGMAX) goto inval;
        return 1;
+
+inval:
+       errno = EINVAL;
+       return -1;
+overflow:
+       errno = EOVERFLOW;
+       return -1;
 }
 
 int vfprintf(FILE *restrict f, const char *restrict fmt, va_list ap)
@@ -658,6 +662,7 @@ int vfprintf(FILE *restrict f, const char *restrict fmt, va_list ap)
        int nl_type[NL_ARGMAX+1] = {0};
        union arg nl_arg[NL_ARGMAX+1];
        unsigned char internal_buf[80], *saved_buf = 0;
+       int olderr;
        int ret;
 
        /* the copy allows passing va_list* even if va_list is an array */
@@ -668,13 +673,16 @@ int vfprintf(FILE *restrict f, const char *restrict fmt, va_list ap)
        }
 
        FLOCK(f);
+       olderr = f->flags & F_ERR;
+       if (f->mode < 1) f->flags &= ~F_ERR;
        if (!f->buf_size) {
                saved_buf = f->buf;
-               f->wpos = f->wbase = f->buf = internal_buf;
+               f->buf = internal_buf;
                f->buf_size = sizeof internal_buf;
-               f->wend = internal_buf + sizeof internal_buf;
+               f->wpos = f->wbase = f->wend = 0;
        }
-       ret = printf_core(f, fmt, &ap2, nl_arg, nl_type);
+       if (!f->wend && __towrite(f)) ret = -1;
+       else ret = printf_core(f, fmt, &ap2, nl_arg, nl_type);
        if (saved_buf) {
                f->write(f, 0, 0);
                if (!f->wpos) ret = -1;
@@ -682,6 +690,8 @@ int vfprintf(FILE *restrict f, const char *restrict fmt, va_list ap)
                f->buf_size = 0;
                f->wpos = f->wbase = f->wend = 0;
        }
+       if (f->flags & F_ERR) ret = -1;
+       f->flags |= olderr;
        FUNLOCK(f);
        va_end(ap2);
        return ret;