From db913fd5b3a96f2991d093bc96837acb6c6fcd2e Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Sun, 13 Apr 2025 13:28:54 -0500 Subject: [PATCH 01/14] Initial attempt at enriching the interrupt state machine to avoid polling --- .../cats/effect/unsafe/ParkedSignal.scala | 28 ++++ .../unsafe/WorkStealingThreadPool.scala | 39 +++-- .../cats/effect/unsafe/WorkerThread.scala | 140 +++++++++++++----- 3 files changed, 156 insertions(+), 51 deletions(-) create mode 100644 core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala b/core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala new file mode 100644 index 0000000000..6e48117a36 --- /dev/null +++ b/core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala @@ -0,0 +1,28 @@ +/* + * Copyright 2020-2025 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect.unsafe + +sealed trait ParkedSignal extends Product with Serializable + +object ParkedSignal { + case object Unparked extends ParkedSignal + + case object ParkedPolling extends ParkedSignal + case object ParkedSimple extends ParkedSignal + + case object Interrupting extends ParkedSignal +} diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala index a7139724e9..2b1f830e1d 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala @@ -47,6 +47,7 @@ import java.util.concurrent.atomic.{ AtomicReference, AtomicReferenceArray } +import java.util.concurrent.locks.LockSupport import WorkStealingThreadPool._ @@ -93,7 +94,8 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef]( new AtomicReferenceArray(threadCount) private[unsafe] val localQueues: Array[LocalQueue] = new Array(threadCount) private[unsafe] val sleepers: Array[TimerHeap] = new Array(threadCount) - private[unsafe] val parkedSignals: Array[AtomicBoolean] = new Array(threadCount) + private[unsafe] val parkedSignals: Array[AtomicReference[ParkedSignal]] = new Array( + threadCount) private[unsafe] val fiberBags: Array[WeakBag[Runnable]] = new Array(threadCount) private[unsafe] val pollers: Array[P] = new Array[AnyRef](threadCount).asInstanceOf[Array[P]] @@ -151,7 +153,7 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef]( localQueues(i) = queue val sleepersHeap = new TimerHeap() sleepers(i) = sleepersHeap - val parkedSignal = new AtomicBoolean(false) + val parkedSignal = new AtomicReference[ParkedSignal](ParkedSignal.Unparked) parkedSignals(i) = parkedSignal val index = i val fiberBag = new WeakBag[Runnable]() @@ -241,7 +243,7 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef]( if (isStackTracing) { destWorker.active = fiber - parkedSignals(dest).lazySet(false) + parkedSignals(dest).lazySet(ParkedSignal.Unparked) } fiber @@ -306,14 +308,12 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef]( val index = (from + i) % threadCount val signal = parkedSignals(index) - if (signal.getAndSet(false)) { - // Update the state so that a thread can be unparked. - // Here we are updating the 16 most significant bits, which hold the - // number of active threads, as well as incrementing the number of - // searching worker threads (unparked worker threads are implicitly - // allowed to search for work in the local queues of other worker - // threads). - state.getAndAdd(DeltaSearching) + val st = signal.get() + + if ((st eq ParkedSignal.ParkedPolling) || (st eq ParkedSignal.ParkedSimple) || signal + .compareAndSet(st, ParkedSignal.Interrupting)) { + doneSleeping() + // Fetch the latest references to the worker threads before doing the // actual unparking. There is no danger of a race condition where the // parked signal has been successfully marked as unparked but the @@ -322,7 +322,14 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef]( // point it is already unparked and entering this code region is thus // impossible. val worker = workerThreads.get(index) - system.interrupt(worker, pollers(index)) + + if (st eq ParkedSignal.ParkedPolling) { + system.interrupt(worker, pollers(index)) + } else { + LockSupport.unpark(worker) + } + signal.set(ParkedSignal.Unparked) + return true } @@ -446,6 +453,12 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef]( } private[unsafe] def doneSleeping(): Unit = { + // Update the state so that a thread can be unparked. + // Here we are updating the 16 most significant bits, which hold the + // number of active threads, as well as incrementing the number of + // searching worker threads (unparked worker threads are implicitly + // allowed to search for work in the local queues of other worker + // threads). state.getAndAdd(DeltaSearching) () } @@ -710,6 +723,8 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef]( while (i < threadCount) { val workerThread = workerThreads.get(i) if (workerThread ne currentThread) { + // we don't know which state we're in, so just try both interruptions + LockSupport.unpark(workerThread) system.interrupt(workerThread, pollers(i)) workerThread.interrupt() } diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala index 615fe8804e..9d901fd43f 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala @@ -27,7 +27,8 @@ import scala.concurrent.duration.{Duration, FiniteDuration} import java.lang.Long.MIN_VALUE import java.util.concurrent.{ArrayBlockingQueue, ThreadLocalRandom} -import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.locks.LockSupport import WorkerThread.{Metrics, TransferState} @@ -48,7 +49,7 @@ private[effect] final class WorkerThread[P <: AnyRef]( // Local queue instance with exclusive write access. private[this] var queue: LocalQueue, // The state of the `WorkerThread` (parked/unparked). - private[unsafe] var parked: AtomicBoolean, + private[unsafe] var parked: AtomicReference[ParkedSignal], // External queue used by the local queue for offloading excess fibers, as well as // for drawing fibers when the local queue is exhausted. private[this] val external: ScalQueue[AnyRef], @@ -386,7 +387,7 @@ private[effect] final class WorkerThread[P <: AnyRef]( if (isStackTracing) { _active = fiber - parked.lazySet(false) + parked.lazySet(ParkedSignal.Unparked) } // The dequeued element is a single fiber. Execute it immediately. @@ -411,11 +412,13 @@ private[effect] final class WorkerThread[P <: AnyRef]( _active = null } - parked.lazySet(true) + val needsPoll = system.needsPoll(_poller) + flagForParking(needsPoll) + // Announce that the worker thread is parking. pool.transitionWorkerToParked() // Park the thread. - if (park()) + if (park(needsPoll)) return // Work found, transition to executing fibers from the local queue. else false @@ -456,7 +459,9 @@ private[effect] final class WorkerThread[P <: AnyRef]( _active = null } - parked.lazySet(true) + val needsPoll = system.needsPoll(_poller) + flagForParking(needsPoll) + // Announce that the worker thread which was searching for work is now // parking. This checks if the parking worker thread was the last // actively searching thread. @@ -468,7 +473,7 @@ private[effect] final class WorkerThread[P <: AnyRef]( pool.notifyIfWorkPending(rnd) } // Park the thread. - if (park()) + if (park(needsPoll)) return // Work found, transition to executing fibers from the local queue. else () // Proceed to while loop. @@ -506,7 +511,7 @@ private[effect] final class WorkerThread[P <: AnyRef]( if (isStackTracing) { _active = fiber - parked.lazySet(false) + parked.lazySet(ParkedSignal.Unparked) } pool.transitionWorkerFromSearching(rnd) @@ -558,7 +563,9 @@ private[effect] final class WorkerThread[P <: AnyRef]( _active = null } - parked.lazySet(true) + val needsPoll = system.needsPoll(_poller) + flagForParking(needsPoll) + // Announce that the worker thread which was searching for work is now // parking. This checks if the parking worker thread was the last // actively searching thread. @@ -570,7 +577,7 @@ private[effect] final class WorkerThread[P <: AnyRef]( pool.notifyIfWorkPending(rnd) } // Park the thread. - if (park()) + if (park(needsPoll)) return // Work found, transition to executing fibers from the local queue. else () // loop @@ -578,13 +585,16 @@ private[effect] final class WorkerThread[P <: AnyRef]( } } + def flagForParking(needsPoll: Boolean): Unit = + parked.lazySet(if (needsPoll) ParkedSignal.ParkedPolling else ParkedSignal.ParkedSimple) + // returns whether work was found - def park(): Boolean = { + def park(needsPoll: Boolean): Boolean = { metrics.incrementParkedCount() val tt = sleepers.peekFirstTriggerTime() val workFound = if (tt == MIN_VALUE) { // no sleepers - if (parkLoop()) { + if (parkLoop(needsPoll)) { // we polled something, so go straight to local queue stuff pool.transitionWorkerFromSearching(rnd) true @@ -593,7 +603,7 @@ private[effect] final class WorkerThread[P <: AnyRef]( false } } else { - if (parkUntilNextSleeper()) { + if (parkLoopUntilNextSleeper(needsPoll)) { // we made it to the end of our sleeping/polling, so go straight to local queue stuff pool.transitionWorkerFromSearching(rnd) true @@ -627,13 +637,38 @@ private[effect] final class WorkerThread[P <: AnyRef]( acc } + def notifyDoneSleeping(): Unit = { + var st = parked.get() + + if (st ne ParkedSignal.Unparked) { + if (st eq ParkedSignal.Interrupting) { + // our state is being twiddled; wait for that to finish up + // this happens when we wake ourselves at the same moment the pool decides to wake us + + do { + st = parked.get() + } while (st eq ParkedSignal.Interrupting) + } else if (parked.compareAndSet(st, ParkedSignal.Unparked)) { + // we won the race to awaken ourselves, so we need to let the pool know + pool.doneSleeping() + } + } + } + // returns true if polled event, false if unparked - def parkLoop(): Boolean = { + def parkLoop(needsPoll: Boolean): Boolean = { while (!done.get()) { // Park the thread until further notice. val start = System.nanoTime() metrics.incrementPolledCount() - val pollResult = system.poll(_poller, -1) + + val pollResult = if (needsPoll) { + system.poll(_poller, -1) + } else { + LockSupport.park() + PollResult.Interrupted + } + now = System.nanoTime() // update now metrics.addIdleTime(now - start) @@ -641,29 +676,42 @@ private[effect] final class WorkerThread[P <: AnyRef]( if (isInterrupted()) { pool.shutdown() } else if (pollResult ne PollResult.Interrupted) { - if (parked.getAndSet(false)) - pool.doneSleeping() + notifyDoneSleeping() + // TODO, if no tasks scheduled could fastpath back to park? val _ = drainReadyEvents(pollResult, false) return true - } else if (!parked.get()) { // Spurious wakeup check. - return false - } else // loop - () + } else { + // Spurious wakeup check. + var st = parked.get() + if (st eq ParkedSignal.Unparked) { + // awakened intentionally + return false + } else if (st eq ParkedSignal.Interrupting) { + // awakened intentionally, but waiting for the state publish + // we have to block here to ensure we don't go back to sleep again too fast + do { + st = parked.get() + } while (st eq ParkedSignal.Interrupting) + } else { + // awakened spuriously; loop + () + } + } } false } // returns true if timed out or polled event, false if unparked @tailrec - def parkUntilNextSleeper(): Boolean = { + def parkLoopUntilNextSleeper(needsPoll: Boolean): Boolean = { if (done.get()) { false } else { val triggerTime = sleepers.peekFirstTriggerTime() if (triggerTime == MIN_VALUE) { // no sleeper (it was removed) - parkLoop() + parkLoop(needsPoll) } else { now = System.nanoTime() val nanos = triggerTime - now @@ -671,7 +719,14 @@ private[effect] final class WorkerThread[P <: AnyRef]( if (nanos > 0L) { val start = now metrics.incrementPolledCount() - val pollResult = system.poll(_poller, nanos) + + val pollResult = if (needsPoll) { + system.poll(_poller, nanos) + } else { + LockSupport.parkNanos(nanos) + PollResult.Interrupted + } + // we already parked and time passed, so update time again // it doesn't matter if we timed out or were awakened, the update is free-ish now = System.nanoTime() @@ -685,25 +740,33 @@ private[effect] final class WorkerThread[P <: AnyRef]( val polled = pollResult ne PollResult.Interrupted if (polled || (triggerTime - now <= 0)) { // we timed out or polled an event - if (parked.getAndSet(false)) { - pool.doneSleeping() - } + notifyDoneSleeping() + if (polled) { // TODO, if no tasks scheduled and no timers could fastpath back to park? val _ = drainReadyEvents(pollResult, false) } true } else { // we were either awakened spuriously or intentionally - if (parked.get()) // awakened spuriously, re-check next sleeper - parkUntilNextSleeper() - else // awakened intentionally, but not due to a timer or event + var st = parked.get() + if (st eq ParkedSignal.Unparked) { + // awakened intentionally, but not due to a timer or event false + } else if (st eq ParkedSignal.Interrupting) { + // awakened intentionally, but waiting for the state publish + // we have to block here to ensure we don't go back to sleep again too fast + do { + st = parked.get() + } while (st eq ParkedSignal.Interrupting) + false + } else { + // awakened spuriously, re-check next sleeper + parkLoopUntilNextSleeper(needsPoll) + } } } } else { - // a timer already expired - if (parked.getAndSet(false)) { - pool.doneSleeping() - } + // a timer already expired, we need to undo the parking + notifyDoneSleeping() true } } @@ -764,10 +827,9 @@ private[effect] final class WorkerThread[P <: AnyRef]( // we have to check for null since there's a race here when threads convert to blockers // by reading parked *after* reading state, we avoid misidentifying blockers as blocked - if (parked != null && !parked - .get() && (state == Thread.State.BLOCKED || state == Thread - .State - .WAITING || state == Thread.State.TIMED_WAITING)) { + if (parked != null && (parked + .get() eq ParkedSignal.Unparked) && (state == Thread.State.BLOCKED || + state == Thread.State.WAITING || state == Thread.State.TIMED_WAITING)) { System.err.println(mkWarning(state, thread.getStackTrace())) } } @@ -809,7 +871,7 @@ private[effect] final class WorkerThread[P <: AnyRef]( if (isStackTracing) { _active = fiber - parked.lazySet(false) + parked.lazySet(ParkedSignal.Unparked) } // The dequeued element is a single fiber. Execute it immediately. From 2eee23aa8c8d553b2b69da50549376be529fdf98 Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Sun, 13 Apr 2025 14:17:22 -0500 Subject: [PATCH 02/14] Simplified loops a bit --- .../effect/unsafe/WorkStealingThreadPool.scala | 8 +++++--- .../scala/cats/effect/unsafe/WorkerThread.scala | 15 ++++++++++----- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala index 2b1f830e1d..f64df5f98f 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala @@ -310,8 +310,10 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef]( val signal = parkedSignals(index) val st = signal.get() - if ((st eq ParkedSignal.ParkedPolling) || (st eq ParkedSignal.ParkedSimple) || signal - .compareAndSet(st, ParkedSignal.Interrupting)) { + val polling = st eq ParkedSignal.ParkedPolling + val simple = st eq ParkedSignal.ParkedSimple + + if (polling || simple || signal.compareAndSet(st, ParkedSignal.Interrupting)) { doneSleeping() // Fetch the latest references to the worker threads before doing the @@ -323,7 +325,7 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef]( // impossible. val worker = workerThreads.get(index) - if (st eq ParkedSignal.ParkedPolling) { + if (polling) { system.interrupt(worker, pollers(index)) } else { LockSupport.unpark(worker) diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala index 9d901fd43f..da4e20f03f 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala @@ -656,8 +656,11 @@ private[effect] final class WorkerThread[P <: AnyRef]( } // returns true if polled event, false if unparked + @tailrec def parkLoop(needsPoll: Boolean): Boolean = { - while (!done.get()) { + if (done.get()) { + false + } else { // Park the thread until further notice. val start = System.nanoTime() metrics.incrementPolledCount() @@ -675,31 +678,33 @@ private[effect] final class WorkerThread[P <: AnyRef]( // the only way we can be interrupted here is if it happened *externally* (probably sbt) if (isInterrupted()) { pool.shutdown() + false // we know `done` is true } else if (pollResult ne PollResult.Interrupted) { notifyDoneSleeping() // TODO, if no tasks scheduled could fastpath back to park? val _ = drainReadyEvents(pollResult, false) - return true + true } else { // Spurious wakeup check. var st = parked.get() if (st eq ParkedSignal.Unparked) { // awakened intentionally - return false + false } else if (st eq ParkedSignal.Interrupting) { // awakened intentionally, but waiting for the state publish // we have to block here to ensure we don't go back to sleep again too fast do { st = parked.get() } while (st eq ParkedSignal.Interrupting) + + false } else { // awakened spuriously; loop - () + parkLoop(needsPoll) } } } - false } // returns true if timed out or polled event, false if unparked From b09a2b78e2eecf18444cabd02cc1c0a96c309ef3 Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Sun, 13 Apr 2025 14:44:09 -0500 Subject: [PATCH 03/14] Fixed conditional and interrupt test --- .../scala/cats/effect/unsafe/WorkStealingThreadPool.scala | 4 +--- .../src/test/scala/cats/effect/IOPlatformSpecification.scala | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala index f64df5f98f..cf56712058 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala @@ -313,7 +313,7 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef]( val polling = st eq ParkedSignal.ParkedPolling val simple = st eq ParkedSignal.ParkedSimple - if (polling || simple || signal.compareAndSet(st, ParkedSignal.Interrupting)) { + if ((polling || simple) && signal.compareAndSet(st, ParkedSignal.Interrupting)) { doneSleeping() // Fetch the latest references to the worker threads before doing the @@ -725,8 +725,6 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef]( while (i < threadCount) { val workerThread = workerThreads.get(i) if (workerThread ne currentThread) { - // we don't know which state we're in, so just try both interruptions - LockSupport.unpark(workerThread) system.interrupt(workerThread, pollers(i)) workerThread.interrupt() } diff --git a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala index 9c35359f71..548d23f02f 100644 --- a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala +++ b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala @@ -591,7 +591,8 @@ trait IOPlatformSpecification extends DetectPlatform { self: BaseSpec with Scala def processReadyEvents(poller: Poller): Boolean = false - def needsPoll(poller: Poller): Boolean = false + // if we don't claim to need polling, then the worker won't bother calling it + def needsPoll(poller: Poller): Boolean = true def interrupt(targetThread: Thread, poller: Poller): Unit = { wasInterrupted.set(true) From 158dac49b905a3f88b811aadea7710a0e0d65a03 Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Sun, 13 Apr 2025 15:30:31 -0700 Subject: [PATCH 04/14] Restored Scala 3 support --- .../scala/cats/effect/unsafe/WorkerThread.scala | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala index da4e20f03f..e3018d1508 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala @@ -645,9 +645,10 @@ private[effect] final class WorkerThread[P <: AnyRef]( // our state is being twiddled; wait for that to finish up // this happens when we wake ourselves at the same moment the pool decides to wake us - do { + while ({ st = parked.get() - } while (st eq ParkedSignal.Interrupting) + st eq ParkedSignal.Interrupting + }) {} } else if (parked.compareAndSet(st, ParkedSignal.Unparked)) { // we won the race to awaken ourselves, so we need to let the pool know pool.doneSleeping() @@ -694,9 +695,10 @@ private[effect] final class WorkerThread[P <: AnyRef]( } else if (st eq ParkedSignal.Interrupting) { // awakened intentionally, but waiting for the state publish // we have to block here to ensure we don't go back to sleep again too fast - do { + while ({ st = parked.get() - } while (st eq ParkedSignal.Interrupting) + st eq ParkedSignal.Interrupting + }) {} false } else { @@ -759,9 +761,11 @@ private[effect] final class WorkerThread[P <: AnyRef]( } else if (st eq ParkedSignal.Interrupting) { // awakened intentionally, but waiting for the state publish // we have to block here to ensure we don't go back to sleep again too fast - do { + while ({ st = parked.get() - } while (st eq ParkedSignal.Interrupting) + st eq ParkedSignal.Interrupting + }) {} + false } else { // awakened spuriously, re-check next sleeper From d807be03f636a90550cadb59c3e8ad20cd75c8f4 Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Sun, 13 Apr 2025 15:46:53 -0700 Subject: [PATCH 05/14] Added test for new mixed-mode polling handling --- .../cats/effect/IOPlatformSpecification.scala | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala index 548d23f02f..abc4a96cc4 100644 --- a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala +++ b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala @@ -725,6 +725,48 @@ trait IOPlatformSpecification extends DetectPlatform { self: BaseSpec with Scala poller.wasInterrupted.get() must beTrue } + "handle mixed-mode poller/simple interruption with complex timers" in { + val delegate = unsafe.SelectorSystem() + + val (pool, poller, shutdown) = IORuntime.createWorkStealingComputeThreadPool( + threads = 1, + shutdownTimeout = 60.seconds, + pollingSystem = new PollingSystem { + type Api = delegate.Api + type Poller = delegate.Poller + def makeApi(ctx: PollingContext[Poller]) = delegate.makeApi(ctx) + def close() = delegate.close() + def makePoller() = delegate.makePoller() + def closePoller(poller: Poller) = delegate.closePoller(poller) + def poll(poller: Poller, nanos: Long) = delegate.poll(poller, nanos) + + // allows us to test what happens when some threads suspend with polling and some simple + def needsPoll(poller: Poller) = math.random() >= 0.5d + + def interrupt(thread: Thread, poller: Poller) = delegate.interrupt(thread, poller) + def metrics(poller: Poller) = delegate.metrics(poller) + def processReadyEvents(poller: Poller) = delegate.processReadyEvents(poller) + } + ) + + implicit val runtime: IORuntime = + IORuntime.builder().setCompute(pool, shutdown).addPoller(poller, () => ()).build() + + // just create a bit of chaos with timers and async completion + val sleeps = 0.until(10).map(i => IO.sleep((i * 10).millis)).toList + + val latch = IO.deferred[Unit].flatMap(d => d.complete(()).start *> d.get) + val latches = 0.until(10).map(_ => latch).toList + + val test = (sleeps ::: latches).parSequence.parReplicateA_(100) + + try { + test.unsafeRunTimed(20.seconds) must beSome + } finally { + runtime.shutdown() + } + } + if (javaMajorVersion >= 21) "block in-place on virtual threads" in real { val loomExec = classOf[Executors] From 86ed31b442696d8252f228767dd2da9c37e1249a Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Mon, 14 Apr 2025 07:47:53 -0700 Subject: [PATCH 06/14] Update core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala Co-authored-by: Arman Bilge --- core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala b/core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala index 6e48117a36..7f1bb6edf0 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala @@ -16,7 +16,7 @@ package cats.effect.unsafe -sealed trait ParkedSignal extends Product with Serializable +private sealed abstract class ParkedSignal extends Product with Serializable object ParkedSignal { case object Unparked extends ParkedSignal From 0c2cfa1a6e051d991383eb72b69e14d59c87928f Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Mon, 14 Apr 2025 07:48:02 -0700 Subject: [PATCH 07/14] Update core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala Co-authored-by: Arman Bilge --- core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala b/core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala index 7f1bb6edf0..d0e7e27bed 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/ParkedSignal.scala @@ -18,7 +18,7 @@ package cats.effect.unsafe private sealed abstract class ParkedSignal extends Product with Serializable -object ParkedSignal { +private object ParkedSignal { case object Unparked extends ParkedSignal case object ParkedPolling extends ParkedSignal From 07400e09c0b32cb070445f328c48e146ecaafacf Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Mon, 14 Apr 2025 07:49:31 -0700 Subject: [PATCH 08/14] Update core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala Co-authored-by: Arman Bilge --- .../src/main/scala/cats/effect/unsafe/WorkerThread.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala index e3018d1508..c92c66a227 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala @@ -638,17 +638,14 @@ private[effect] final class WorkerThread[P <: AnyRef]( } def notifyDoneSleeping(): Unit = { - var st = parked.get() + val st = parked.get() if (st ne ParkedSignal.Unparked) { if (st eq ParkedSignal.Interrupting) { // our state is being twiddled; wait for that to finish up // this happens when we wake ourselves at the same moment the pool decides to wake us - while ({ - st = parked.get() - st eq ParkedSignal.Interrupting - }) {} + while (parked.get() eq ParkedSignal.Interrupting) {} } else if (parked.compareAndSet(st, ParkedSignal.Unparked)) { // we won the race to awaken ourselves, so we need to let the pool know pool.doneSleeping() From 3c2b5fe5c855d114785339f943e7611b26c83eec Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Mon, 14 Apr 2025 07:50:01 -0700 Subject: [PATCH 09/14] Update core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala Co-authored-by: Arman Bilge --- .../src/main/scala/cats/effect/unsafe/WorkerThread.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala index c92c66a227..940ed387b2 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala @@ -685,17 +685,14 @@ private[effect] final class WorkerThread[P <: AnyRef]( true } else { // Spurious wakeup check. - var st = parked.get() + val st = parked.get() if (st eq ParkedSignal.Unparked) { // awakened intentionally false } else if (st eq ParkedSignal.Interrupting) { // awakened intentionally, but waiting for the state publish // we have to block here to ensure we don't go back to sleep again too fast - while ({ - st = parked.get() - st eq ParkedSignal.Interrupting - }) {} + while (parked.get() eq ParkedSignal.Interrupting) {} false } else { From 9f1505870888d93d824d4e1cfcc3205f29114544 Mon Sep 17 00:00:00 2001 From: Arman Bilge Date: Mon, 14 Apr 2025 09:55:40 -0700 Subject: [PATCH 10/14] Use `createDefaultPollingSystem` --- .../src/test/scala/cats/effect/IOPlatformSpecification.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala index abc4a96cc4..5ac728c138 100644 --- a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala +++ b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala @@ -726,7 +726,7 @@ trait IOPlatformSpecification extends DetectPlatform { self: BaseSpec with Scala } "handle mixed-mode poller/simple interruption with complex timers" in { - val delegate = unsafe.SelectorSystem() + val delegate = IORuntime.createDefaultPollingSystem() val (pool, poller, shutdown) = IORuntime.createWorkStealingComputeThreadPool( threads = 1, From 1c13c669901697d0fe391d5b982289ee999ba70c Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Mon, 14 Apr 2025 18:03:42 -0700 Subject: [PATCH 11/14] Create more chaos with the external queue --- .../scala/cats/effect/IOPlatformSpecification.scala | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala index 5ac728c138..6ed41692ee 100644 --- a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala +++ b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala @@ -752,18 +752,29 @@ trait IOPlatformSpecification extends DetectPlatform { self: BaseSpec with Scala implicit val runtime: IORuntime = IORuntime.builder().setCompute(pool, shutdown).addPoller(poller, () => ()).build() + val (scheduler, schedShut) = + IORuntime.createDefaultScheduler(threadPrefix = "complex-timer-test") + // just create a bit of chaos with timers and async completion val sleeps = 0.until(10).map(i => IO.sleep((i * 10).millis)).toList + val externalSleeps = 0.until(10).toList map { i => + IO.async_[Unit] { cb => + val _ = scheduler.sleep((i * 10 + 5).millis, () => cb(Right(()))) + () + } + } + val latch = IO.deferred[Unit].flatMap(d => d.complete(()).start *> d.get) val latches = 0.until(10).map(_ => latch).toList - val test = (sleeps ::: latches).parSequence.parReplicateA_(100) + val test = (sleeps ::: externalSleeps ::: latches).parSequence.parReplicateA_(100) try { test.unsafeRunTimed(20.seconds) must beSome } finally { runtime.shutdown() + schedShut() } } From f5ea181c9092121b2cdda4d880cc31193f8d6a9f Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Mon, 14 Apr 2025 18:08:20 -0700 Subject: [PATCH 12/14] Corrected blocked thread detection to handle interrupting state --- .../scala/cats/effect/unsafe/WorkerThread.scala | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala index 940ed387b2..32133a0daa 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala @@ -830,10 +830,19 @@ private[effect] final class WorkerThread[P <: AnyRef]( // we have to check for null since there's a race here when threads convert to blockers // by reading parked *after* reading state, we avoid misidentifying blockers as blocked - if (parked != null && (parked - .get() eq ParkedSignal.Unparked) && (state == Thread.State.BLOCKED || - state == Thread.State.WAITING || state == Thread.State.TIMED_WAITING)) { - System.err.println(mkWarning(state, thread.getStackTrace())) + if (parked != null) { + val pst = parked.get() + + val expectAlive = + (pst eq ParkedSignal.Unparked) || (pst eq ParkedSignal.Interrupting) + + val actuallyBlocked = state == Thread.State.BLOCKED || + state == Thread.State.WAITING || + state == Thread.State.TIMED_WAITING + + if (expectAlive && actuallyBlocked) { + System.err.println(mkWarning(state, thread.getStackTrace())) + } } } From 514943798cb6a87de7db5fe609033c08417a268e Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Sat, 7 Jun 2025 11:13:25 -0500 Subject: [PATCH 13/14] Run `MutexSpec` sequentially on windows --- tests/shared/src/test/scala/cats/effect/std/MutexSpec.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/shared/src/test/scala/cats/effect/std/MutexSpec.scala b/tests/shared/src/test/scala/cats/effect/std/MutexSpec.scala index caef791615..eefb80d6c0 100644 --- a/tests/shared/src/test/scala/cats/effect/std/MutexSpec.scala +++ b/tests/shared/src/test/scala/cats/effect/std/MutexSpec.scala @@ -27,6 +27,12 @@ import scala.concurrent.duration._ final class MutexSpec extends BaseSpec with DetectPlatform { + if (System.getProperty("os.name").toLowerCase.contains("windows")) { + // these tests seem oddly flaky on windows post #4377 + val _ = sequential + () + } + final override def executionTimeout = 2.minutes "ConcurrentMutex" should { From 9231cf0e3d8aff68c33d0c1ab1af2ca1832e3c10 Mon Sep 17 00:00:00 2001 From: Daniel Spiewak Date: Mon, 7 Jul 2025 14:53:39 -0400 Subject: [PATCH 14/14] SJS doesn't support getProperty --- tests/shared/src/test/scala/cats/effect/std/MutexSpec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/shared/src/test/scala/cats/effect/std/MutexSpec.scala b/tests/shared/src/test/scala/cats/effect/std/MutexSpec.scala index eefb80d6c0..9e707d03fc 100644 --- a/tests/shared/src/test/scala/cats/effect/std/MutexSpec.scala +++ b/tests/shared/src/test/scala/cats/effect/std/MutexSpec.scala @@ -27,7 +27,7 @@ import scala.concurrent.duration._ final class MutexSpec extends BaseSpec with DetectPlatform { - if (System.getProperty("os.name").toLowerCase.contains("windows")) { + if (!isJS && System.getProperty("os.name").toLowerCase.contains("windows")) { // these tests seem oddly flaky on windows post #4377 val _ = sequential ()