fix illegal goto out of cleanup context in dns lookups
[musl] / src / network / __dns.c
index cdd6429..45d9b43 100644 (file)
@@ -5,20 +5,25 @@
 #include <limits.h>
 #include <string.h>
 #include <sys/socket.h>
-#include <sys/select.h>
-#include <sys/time.h>
+#include <poll.h>
 #include <netinet/in.h>
 #include <time.h>
 #include <ctype.h>
 #include <unistd.h>
+#include <pthread.h>
 #include "__dns.h"
 #include "stdio_impl.h"
 
 #define TIMEOUT 5
-#define RETRY 1
+#define RETRY 1000
 #define PACKET_MAX 512
 #define PTR_MAX (64 + sizeof ".in-addr.arpa")
 
+static void cleanup(void *p)
+{
+       close((intptr_t)p);
+}
+
 int __dns_doqueries(unsigned char *dest, const char *name, int *rr, int rrcnt)
 {
        time_t t0 = time(0);
@@ -39,25 +44,28 @@ int __dns_doqueries(unsigned char *dest, const char *name, int *rr, int rrcnt)
        int got = 0, failed = 0;
        int errcode = EAI_AGAIN;
        int i, j;
-       struct timeval tv;
-       fd_set fds;
+       struct timespec ts;
+       struct pollfd pfd;
        int id;
+       int cs;
+
+       pthread_setcancelstate(PTHREAD_CANCEL_DISABLE, &cs);
 
        /* Construct query template - RR and ID will be filled later */
-       if (strlen(name)-1 >= 254U) return -1;
+       if (strlen(name)-1 >= 254U) return EAI_NONAME;
        q[2] = q[5] = 1;
        strcpy((char *)q+13, name);
        for (i=13; q[i]; i=j+1) {
                for (j=i; q[j] && q[j] != '.'; j++);
-               if (j-i-1u > 62u) return -1;
+               if (j-i-1u > 62u) return EAI_NONAME;
                q[i-1] = j-i;
        }
        q[i+3] = 1;
        ql = i+4;
 
        /* Make a reasonably unpredictable id */
-       gettimeofday(&tv, 0);
-       id = tv.tv_usec + tv.tv_usec/256 & 0xffff;
+       clock_gettime(CLOCK_REALTIME, &ts);
+       id = ts.tv_nsec + ts.tv_nsec/65536UL & 0xffff;
 
        /* Get nameservers from resolv.conf, fallback to localhost */
        f = __fopen_rb_ca("/etc/resolv.conf", &_f, _buf, sizeof _buf);
@@ -80,16 +88,22 @@ int __dns_doqueries(unsigned char *dest, const char *name, int *rr, int rrcnt)
                sl = sizeof sa.sin;
        }
 
+       pthread_cleanup_push(cleanup, (void *)(intptr_t)fd);
+       pthread_setcancelstate(cs, 0);
+
        /* Get local address and open/bind a socket */
        sa.sin.sin_family = family;
        fd = socket(family, SOCK_DGRAM, 0);
        if (bind(fd, (void *)&sa, sl) < 0) {
-               close(fd);
-               return -1;
+               errcode = EAI_SYSTEM;
+               goto out;
        }
        /* Nonblocking to work around Linux UDP select bug */
        fcntl(fd, F_SETFL, fcntl(fd, F_GETFL, 0) | O_NONBLOCK);
 
+       pfd.fd = fd;
+       pfd.events = POLLIN;
+
        /* Loop until we timeout; break early on success */
        for (; time(0)-t0 < TIMEOUT; ) {
 
@@ -102,11 +116,7 @@ int __dns_doqueries(unsigned char *dest, const char *name, int *rr, int rrcnt)
                }
 
                /* Wait for a response, or until time to retry */
-               FD_ZERO(&fds);
-               FD_SET(fd, &fds);
-               tv.tv_sec = RETRY;
-               tv.tv_usec = 0;
-               if (select(fd+1, &fds, 0, 0, &tv) <= 0) continue;
+               if (poll(&pfd, 1, RETRY) <= 0) continue;
 
                /* Process any and all replies */
                while (got+failed < rrcnt && (rlen = recvfrom(fd, r, 512, 0,
@@ -140,7 +150,8 @@ int __dns_doqueries(unsigned char *dest, const char *name, int *rr, int rrcnt)
                /* Check to see if we have answers to all queries */
                if (got+failed == rrcnt) break;
        }
-       close(fd);
+out:
+       pthread_cleanup_pop(1);
 
        /* Return the number of results, or an error code if none */
        if (got) return got;
@@ -257,10 +268,12 @@ int __dns_count_addrs(const unsigned char *r, int cnt)
        int found=0, res, i;
        static const int p[2][2] = { { 4, RR_A }, { 16, RR_AAAA } };
 
-       while (cnt--) for (i=0; i<2; i++) {
-               res = __dns_get_rr(0, 0, p[i][0], -1, r, p[i][1], 0);
-               if (res < 0) return res;
-               found += res;
+       while (cnt--) {
+               for (i=0; i<2; i++) {
+                       res = __dns_get_rr(0, 0, p[i][0], -1, r, p[i][1], 0);
+                       if (res < 0) return res;
+                       found += res;
+               }
                r += 512;
        }
        return found;