diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala index f0089b226d5..78d05efb0c2 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuSemaphore.scala @@ -280,7 +280,7 @@ private final class SemaphoreTaskInfo(val taskAttemptId: Long) extends Logging { } } - def tryAcquire(semaphore: GpuBackingSemaphore): Boolean = synchronized { + def tryAcquire(semaphore: GpuBackingSemaphore, taskAttemptId: Long): Boolean = synchronized { val t = Thread.currentThread() if (hasSemaphore) { activeThreads.add(t) @@ -288,7 +288,7 @@ private final class SemaphoreTaskInfo(val taskAttemptId: Long) extends Logging { } else { if (blockedThreads.size() == 0) { // No other threads for this task are waiting, so we might be able to grab this directly - val ret = semaphore.tryAcquire(numPermits, lastHeld) + val ret = semaphore.tryAcquire(numPermits, lastHeld, taskAttemptId) if (ret) { hasSemaphore = true activeThreads.add(t) @@ -335,7 +335,7 @@ private final class GpuSemaphore() extends Logging { onTaskCompletion(context, completeTask) new SemaphoreTaskInfo(taskAttemptId) }) - if (taskInfo.tryAcquire(semaphore)) { + if (taskInfo.tryAcquire(semaphore, taskAttemptId)) { GpuDeviceManager.initializeFromTask() SemaphoreAcquired } else { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PrioritySemaphore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PrioritySemaphore.scala index cdee5ab1c79..dc90382d3a0 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PrioritySemaphore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/PrioritySemaphore.scala @@ -31,21 +31,26 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T]) var signaled: Boolean = false } + // use task id as tie breaker when priorities are equal (both are 0 because never hold lock) + private val priorityComp = Ordering.by[ThreadInfo, T](_.priority).reverse. + thenComparing((a, b) => a.taskId.compareTo(b.taskId)) + // We expect a relatively small number of threads to be contending for this lock at any given // time, therefore we are not concerned with the insertion/removal time complexity. private val waitingQueue: PriorityQueue[ThreadInfo] = - new PriorityQueue[ThreadInfo]( - // use task id as tie breaker when priorities are equal (both are 0 because never hold lock) - Ordering.by[ThreadInfo, T](_.priority).reverse. - thenComparing((a, b) => a.taskId.compareTo(b.taskId)) - ) + new PriorityQueue[ThreadInfo](priorityComp) - def tryAcquire(numPermits: Int, priority: T): Boolean = { + def tryAcquire(numPermits: Int, priority: T, taskAttemptId: Long): Boolean = { lock.lock() try { - if (waitingQueue.size() > 0 && ordering.gt(waitingQueue.peek.priority, priority)) { + if (waitingQueue.size() > 0 && + priorityComp.compare( + waitingQueue.peek(), + ThreadInfo(priority, null, numPermits, taskAttemptId) + ) < 0) { false - } else if (!canAcquire(numPermits)) { + } + else if (!canAcquire(numPermits)) { false } else { commitAcquire(numPermits) @@ -59,7 +64,7 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T]) def acquire(numPermits: Int, priority: T, taskAttemptId: Long): Unit = { lock.lock() try { - if (!tryAcquire(numPermits, priority)) { + if (!tryAcquire(numPermits, priority, taskAttemptId)) { val condition = lock.newCondition() val info = ThreadInfo(priority, condition, numPermits, taskAttemptId) try { diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/PrioritySemaphoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/PrioritySemaphoreSuite.scala index cd9660a5de5..7199aa55df6 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/PrioritySemaphoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/PrioritySemaphoreSuite.scala @@ -26,16 +26,16 @@ class PrioritySemaphoreSuite extends AnyFunSuite { test("tryAcquire should return true if permits are available") { val semaphore = new TestPrioritySemaphore(10) - assert(semaphore.tryAcquire(5, 0)) - assert(semaphore.tryAcquire(3, 0)) - assert(semaphore.tryAcquire(2, 0)) - assert(!semaphore.tryAcquire(1, 0)) + assert(semaphore.tryAcquire(5, 0, 0)) + assert(semaphore.tryAcquire(3, 0, 0)) + assert(semaphore.tryAcquire(2, 0, 0)) + assert(!semaphore.tryAcquire(1, 0, 0)) } test("acquire and release should work correctly") { val semaphore = new TestPrioritySemaphore(1) - assert(semaphore.tryAcquire(1, 0)) + assert(semaphore.tryAcquire(1, 0, 0)) val t = new Thread(() => { try { @@ -94,10 +94,36 @@ class PrioritySemaphoreSuite extends AnyFunSuite { // Here, there should be 5 available permits, but a thread with higher priority (2) // is waiting to acquire, therefore we should get rejected here - assert(!semaphore.tryAcquire(5, 0)) + assert(!semaphore.tryAcquire(5, 0, 0)) semaphore.release(5) t.join(1000) // After the high priority thread finishes, we can acquire with lower priority - assert(semaphore.tryAcquire(5, 0)) + assert(semaphore.tryAcquire(5, 0, 0)) + } + + // this case is described at https://github.com/NVIDIA/spark-rapids/pull/11574/files#r1795652488 + test("thread with larger task id should not surpass smaller task id in the waiting queue") { + val semaphore = new TestPrioritySemaphore(10) + semaphore.acquire(8, 0, 0) + val t = new Thread(() => { + semaphore.acquire(5, 0, 0) + semaphore.release(5) + }) + t.start() + Thread.sleep(100) + + // Here, there should be 2 available permits, and a thread with same task id (0) + // is waiting to acquire 5 permits, in this case we should succeed here + assert(semaphore.tryAcquire(2, 0, 0)) + semaphore.release(2) + + // Here, there should be 2 available permits, but a thread with smaller task id (0) + // is waiting to acquire, therefore we should get rejected here + assert(!semaphore.tryAcquire(2, 0, 1)) + + semaphore.release(8) + t.join(1000) + // After the high priority thread finishes, we can acquire with lower priority + assert(semaphore.tryAcquire(2, 0, 1)) } }