support looking up thread-local objects with dlsym
[musl] / src / ldso / dynlink.c
index 1489f7d..a6dbaf0 100644 (file)
@@ -19,7 +19,6 @@
 #include <dlfcn.h>
 #include "pthread_impl.h"
 #include "libc.h"
-#undef libc
 
 static int errflag;
 static char errbuf[128];
@@ -90,7 +89,7 @@ struct symdef {
 void __init_ssp(size_t *);
 void *__install_initial_tls(void *);
 
-static struct dso *head, *tail, *libc, *fini_head;
+static struct dso *head, *tail, *ldso, *fini_head;
 static char *env_path, *sys_path, *r_path;
 static int ssp_used;
 static int runtime;
@@ -100,7 +99,7 @@ static jmp_buf rtld_fail;
 static pthread_rwlock_t lock;
 static struct debug debug;
 static size_t *auxv;
-static size_t tls_cnt, tls_offset, tls_start, tls_align = 4*sizeof(size_t);
+static size_t tls_cnt, tls_offset, tls_align = 4*sizeof(size_t);
 static pthread_mutex_t init_fini_lock = { ._m_type = PTHREAD_MUTEX_RECURSIVE };
 
 struct debug *_dl_debug_addr = &debug;
@@ -446,12 +445,12 @@ static struct dso *load_library(const char *name)
                        size_t l = z-name;
                        for (rp=reserved; *rp && memcmp(name+3, rp, l-3); rp+=strlen(rp)+1);
                        if (*rp) {
-                               if (!libc->prev) {
-                                       tail->next = libc;
-                                       libc->prev = tail;
-                                       tail = libc->next ? libc->next : libc;
+                               if (!ldso->prev) {
+                                       tail->next = ldso;
+                                       ldso->prev = tail;
+                                       tail = ldso->next ? ldso->next : ldso;
                                }
-                               return libc;
+                               return ldso;
                        }
                }
        }
@@ -514,7 +513,7 @@ static struct dso *load_library(const char *name)
        if (runtime && temp_dso.tls_image) {
                size_t per_th = temp_dso.tls_size + temp_dso.tls_align
                        + sizeof(void *) * (tls_cnt+3);
-               n_th = __libc.threads_minus_1 + 1;
+               n_th = libc.threads_minus_1 + 1;
                if (n_th > SSIZE_MAX / per_th) alloc_size = SIZE_MAX;
                else alloc_size += n_th * per_th;
        }
@@ -540,10 +539,16 @@ static struct dso *load_library(const char *name)
                }
                p->tls_id = ++tls_cnt;
                tls_align = MAXP2(tls_align, p->tls_align);
+#ifdef TLS_ABOVE_TP
+               p->tls_offset = tls_offset + ( (tls_align-1) &
+                       -(tls_offset + (uintptr_t)p->tls_image) );
+               tls_offset += p->tls_size;
+#else
                tls_offset += p->tls_size + p->tls_align - 1;
                tls_offset -= (tls_offset + (uintptr_t)p->tls_image)
                        & (p->tls_align-1);
                p->tls_offset = tls_offset;
+#endif
                p->new_dtv = (void *)(-sizeof(size_t) &
                        (uintptr_t)(p->name+strlen(p->name)+sizeof(size_t)));
                p->new_tls = (void *)(p->new_dtv + n_th*(tls_cnt+1));
@@ -665,7 +670,7 @@ static void do_fini()
 static void do_init_fini(struct dso *p)
 {
        size_t dyn[DYN_CNT] = {0};
-       int need_locking = __libc.threads_minus_1;
+       int need_locking = libc.threads_minus_1;
        /* Allow recursive calls that arise when a library calls
         * dlopen from one of its constructors, but block any
         * other threads until all ctors have finished. */
@@ -698,9 +703,20 @@ void *__copy_tls(unsigned char *mem)
        void **dtv = (void *)mem;
        dtv[0] = (void *)tls_cnt;
 
-       mem += __libc.tls_size - sizeof(struct pthread);
+#ifdef TLS_ABOVE_TP
+       mem += sizeof(void *) * (tls_cnt+1);
+       mem += -((uintptr_t)mem + sizeof(struct pthread)) & (tls_align-1);
+       td = (pthread_t)mem;
+       mem += sizeof(struct pthread);
+
+       for (p=head; p; p=p->next) {
+               if (!p->tls_id) continue;
+               dtv[p->tls_id] = mem + p->tls_offset;
+               memcpy(dtv[p->tls_id], p->tls_image, p->tls_len);
+       }
+#else
+       mem += libc.tls_size - sizeof(struct pthread);
        mem -= (uintptr_t)mem & (tls_align-1);
-       mem -= tls_start;
        td = (pthread_t)mem;
 
        for (p=head; p; p=p->next) {
@@ -708,6 +724,7 @@ void *__copy_tls(unsigned char *mem)
                dtv[p->tls_id] = mem - p->tls_offset;
                memcpy(dtv[p->tls_id], p->tls_image, p->tls_len);
        }
+#endif
        td->dtv = dtv;
        return td;
 }
@@ -755,9 +772,12 @@ void *__tls_get_addr(size_t *v)
 
 static void update_tls_size()
 {
-       size_t below_tp = (1+tls_cnt) * sizeof(void *) + tls_offset;
-       size_t above_tp = sizeof(struct pthread) + tls_start + tls_align;
-       __libc.tls_size = ALIGN(below_tp + above_tp, tls_align);
+       libc.tls_size = ALIGN(
+               (1+tls_cnt) * sizeof(void *) +
+               tls_offset +
+               sizeof(struct pthread) +
+               tls_align * 2,
+       tls_align);
 }
 
 void *__dynlink(int argc, char **argv)
@@ -868,9 +888,16 @@ void *__dynlink(int argc, char **argv)
        }
        if (app->tls_size) {
                app->tls_id = tls_cnt = 1;
-               tls_offset = app->tls_offset = app->tls_size;
-               tls_start = -((uintptr_t)app->tls_image + app->tls_size)
-                       & (app->tls_align-1);
+#ifdef TLS_ABOVE_TP
+               app->tls_offset = 0;
+               tls_offset = app->tls_size
+                       + ( -((uintptr_t)app->tls_image + app->tls_size)
+                       & (app->tls_align-1) );
+#else
+               tls_offset = app->tls_offset = app->tls_size
+                       + ( -((uintptr_t)app->tls_image + app->tls_size)
+                       & (app->tls_align-1) );
+#endif
                tls_align = MAXP2(tls_align, app->tls_align);
        }
        app->global = 1;
@@ -899,7 +926,7 @@ void *__dynlink(int argc, char **argv)
         * restore the initial chain in preparation for loading third
         * party libraries (preload/needed). */
        head = tail = app;
-       libc = lib;
+       ldso = lib;
        app->next = lib;
        reloc_all(lib);
        app->next = 0;
@@ -926,12 +953,12 @@ void *__dynlink(int argc, char **argv)
        update_tls_size();
        if (tls_cnt) {
                struct dso *p;
-               void *mem = mmap(0, __libc.tls_size, PROT_READ|PROT_WRITE,
+               void *mem = mmap(0, libc.tls_size, PROT_READ|PROT_WRITE,
                        MAP_ANONYMOUS|MAP_PRIVATE, -1, 0);
                if (mem==MAP_FAILED ||
                    !__install_initial_tls(__copy_tls(mem))) {
                        dprintf(2, "%s: Error getting %zu bytes thread-local storage: %m\n",
-                               argv[0], __libc.tls_size);
+                               argv[0], libc.tls_size);
                        _exit(127);
                }
        }
@@ -1053,12 +1080,17 @@ static void *do_dlsym(struct dso *p, const char *s, void *ra)
        uint32_t h = 0, gh = 0;
        Sym *sym;
        if (p == head || p == RTLD_DEFAULT || p == RTLD_NEXT) {
-               if (p == RTLD_NEXT) {
+               if (p == RTLD_DEFAULT) {
+                       p = head;
+               } else if (p == RTLD_NEXT) {
                        for (p=head; p && (unsigned char *)ra-p->map>p->map_len; p=p->next);
                        if (!p) p=head;
+                       p = p->next;
                }
-               struct symdef def = find_sym(p->next, s, 0);
+               struct symdef def = find_sym(p, s, 0);
                if (!def.sym) goto failed;
+               if ((def.sym->st_info&0xf) == STT_TLS)
+                       return __tls_get_addr((size_t []){def.dso->tls_id, def.sym->st_value});
                return def.dso->base + def.sym->st_value;
        }
        if (p->ghashtab) {
@@ -1068,6 +1100,8 @@ static void *do_dlsym(struct dso *p, const char *s, void *ra)
                h = sysv_hash(s);
                sym = sysv_lookup(s, h, p);
        }
+       if (sym && (sym->st_info&0xf) == STT_TLS)
+               return __tls_get_addr((size_t []){p->tls_id, sym->st_value});
        if (sym && sym->st_value && (1<<(sym->st_info&0xf) & OK_TYPES))
                return p->base + sym->st_value;
        if (p->deps) for (i=0; p->deps[i]; i++) {
@@ -1078,6 +1112,8 @@ static void *do_dlsym(struct dso *p, const char *s, void *ra)
                        if (!h) h = sysv_hash(s);
                        sym = sysv_lookup(s, h, p->deps[i]);
                }
+               if (sym && (sym->st_info&0xf) == STT_TLS)
+                       return __tls_get_addr((size_t []){p->deps[i]->tls_id, sym->st_value});
                if (sym && sym->st_value && (1<<(sym->st_info&0xf) & OK_TYPES))
                        return p->deps[i]->base + sym->st_value;
        }