From: Rich Felker Date: Tue, 13 Dec 2022 23:39:44 +0000 (-0500) Subject: semaphores: fix missed wakes from ABA bug in waiter count logic X-Git-Url: http://nsz.repo.hu/git/?a=commitdiff_plain;h=159d1f6c02569091c7a48bdb2e2e824b844a1902;p=musl semaphores: fix missed wakes from ABA bug in waiter count logic because the has-waiters state in the semaphore value futex word is only representable when the value is zero (the special value -1 represents "0 with potential new waiters"), it's lost if intervening operations make the semaphore value positive again. this creates an ABA issue in sem_post, whereby the post uses a stale waiters count rather than re-evaluating it, skipping the futex wake if the stale count was zero. the fix here is based on a proposal by Alexey Izbyshev, with minor changes to eliminate costly new spurious wake syscalls. the basic idea is to replace the special value -1 with a sticky waiters bit (repurposing the sign bit) preserved under both wait and post. any post that takes place with the waiters bit set will perform a futex wake. to be useful, the waiters bit needs to be removable, and to remove it safely, we perform a broadcast wake instead of a normal single-task wake whenever removing the bit. this lets any un-accounted-for waiters wake and re-add the waiters bit if they still need it. there are multiple possible choices for when to perform this broadcast, but the optimal choice seems to be doing it whenever the observed waiters count is less than two (semantically, this means exactly one, but we might see a stale count of zero). in this case, the expected number of threads to be woken is one, with exactly the same cost as a non-broadcast wake. --- diff --git a/src/thread/sem_getvalue.c b/src/thread/sem_getvalue.c index d9d83071..c0b7762d 100644 --- a/src/thread/sem_getvalue.c +++ b/src/thread/sem_getvalue.c @@ -1,8 +1,9 @@ #include +#include int sem_getvalue(sem_t *restrict sem, int *restrict valp) { int val = sem->__val[0]; - *valp = val < 0 ? 0 : val; + *valp = val & SEM_VALUE_MAX; return 0; } diff --git a/src/thread/sem_post.c b/src/thread/sem_post.c index 31e3293d..5c2517f2 100644 --- a/src/thread/sem_post.c +++ b/src/thread/sem_post.c @@ -1,17 +1,21 @@ #include +#include #include "pthread_impl.h" int sem_post(sem_t *sem) { - int val, waiters, priv = sem->__val[2]; + int val, new, waiters, priv = sem->__val[2]; do { val = sem->__val[0]; waiters = sem->__val[1]; - if (val == SEM_VALUE_MAX) { + if ((val & SEM_VALUE_MAX) == SEM_VALUE_MAX) { errno = EOVERFLOW; return -1; } - } while (a_cas(sem->__val, val, val+1+(val<0)) != val); - if (val<0 || waiters) __wake(sem->__val, 1, priv); + new = val + 1; + if (waiters <= 1) + new &= ~0x80000000; + } while (a_cas(sem->__val, val, new) != val); + if (val<0) __wake(sem->__val, waiters>1 ? 1 : -1, priv); return 0; } diff --git a/src/thread/sem_timedwait.c b/src/thread/sem_timedwait.c index 58d3ebfe..aa67376c 100644 --- a/src/thread/sem_timedwait.c +++ b/src/thread/sem_timedwait.c @@ -1,4 +1,5 @@ #include +#include #include "pthread_impl.h" static void cleanup(void *p) @@ -13,14 +14,15 @@ int sem_timedwait(sem_t *restrict sem, const struct timespec *restrict at) if (!sem_trywait(sem)) return 0; int spins = 100; - while (spins-- && sem->__val[0] <= 0 && !sem->__val[1]) a_spin(); + while (spins-- && !(sem->__val[0] & SEM_VALUE_MAX) && !sem->__val[1]) + a_spin(); while (sem_trywait(sem)) { - int r; + int r, priv = sem->__val[2]; a_inc(sem->__val+1); - a_cas(sem->__val, 0, -1); + a_cas(sem->__val, 0, 0x80000000); pthread_cleanup_push(cleanup, (void *)(sem->__val+1)); - r = __timedwait_cp(sem->__val, -1, CLOCK_REALTIME, at, sem->__val[2]); + r = __timedwait_cp(sem->__val, 0x80000000, CLOCK_REALTIME, at, priv); pthread_cleanup_pop(1); if (r) { errno = r; diff --git a/src/thread/sem_trywait.c b/src/thread/sem_trywait.c index 04edf46b..beb435da 100644 --- a/src/thread/sem_trywait.c +++ b/src/thread/sem_trywait.c @@ -1,12 +1,12 @@ #include +#include #include "pthread_impl.h" int sem_trywait(sem_t *sem) { int val; - while ((val=sem->__val[0]) > 0) { - int new = val-1-(val==1 && sem->__val[1]); - if (a_cas(sem->__val, val, new)==val) return 0; + while ((val=sem->__val[0]) & SEM_VALUE_MAX) { + if (a_cas(sem->__val, val, val-1)==val) return 0; } errno = EAGAIN; return -1;