Skip to content

Commit

Permalink
addressing jason's comment (#11587)
Browse files Browse the repository at this point in the history
Signed-off-by: Hongbin Ma (Mahone) <[email protected]>
  • Loading branch information
binmahone authored Oct 11, 2024
1 parent e8ac073 commit 0ba4fd2
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,15 @@ private final class SemaphoreTaskInfo(val taskAttemptId: Long) extends Logging {
}
}

def tryAcquire(semaphore: GpuBackingSemaphore): Boolean = synchronized {
def tryAcquire(semaphore: GpuBackingSemaphore, taskAttemptId: Long): Boolean = synchronized {
val t = Thread.currentThread()
if (hasSemaphore) {
activeThreads.add(t)
true
} else {
if (blockedThreads.size() == 0) {
// No other threads for this task are waiting, so we might be able to grab this directly
val ret = semaphore.tryAcquire(numPermits, lastHeld)
val ret = semaphore.tryAcquire(numPermits, lastHeld, taskAttemptId)
if (ret) {
hasSemaphore = true
activeThreads.add(t)
Expand Down Expand Up @@ -335,7 +335,7 @@ private final class GpuSemaphore() extends Logging {
onTaskCompletion(context, completeTask)
new SemaphoreTaskInfo(taskAttemptId)
})
if (taskInfo.tryAcquire(semaphore)) {
if (taskInfo.tryAcquire(semaphore, taskAttemptId)) {
GpuDeviceManager.initializeFromTask()
SemaphoreAcquired
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,26 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T])
var signaled: Boolean = false
}

// use task id as tie breaker when priorities are equal (both are 0 because never hold lock)
private val priorityComp = Ordering.by[ThreadInfo, T](_.priority).reverse.
thenComparing((a, b) => a.taskId.compareTo(b.taskId))

// We expect a relatively small number of threads to be contending for this lock at any given
// time, therefore we are not concerned with the insertion/removal time complexity.
private val waitingQueue: PriorityQueue[ThreadInfo] =
new PriorityQueue[ThreadInfo](
// use task id as tie breaker when priorities are equal (both are 0 because never hold lock)
Ordering.by[ThreadInfo, T](_.priority).reverse.
thenComparing((a, b) => a.taskId.compareTo(b.taskId))
)
new PriorityQueue[ThreadInfo](priorityComp)

def tryAcquire(numPermits: Int, priority: T): Boolean = {
def tryAcquire(numPermits: Int, priority: T, taskAttemptId: Long): Boolean = {
lock.lock()
try {
if (waitingQueue.size() > 0 && ordering.gt(waitingQueue.peek.priority, priority)) {
if (waitingQueue.size() > 0 &&
priorityComp.compare(
waitingQueue.peek(),
ThreadInfo(priority, null, numPermits, taskAttemptId)
) < 0) {
false
} else if (!canAcquire(numPermits)) {
}
else if (!canAcquire(numPermits)) {
false
} else {
commitAcquire(numPermits)
Expand All @@ -59,7 +64,7 @@ class PrioritySemaphore[T](val maxPermits: Int)(implicit ordering: Ordering[T])
def acquire(numPermits: Int, priority: T, taskAttemptId: Long): Unit = {
lock.lock()
try {
if (!tryAcquire(numPermits, priority)) {
if (!tryAcquire(numPermits, priority, taskAttemptId)) {
val condition = lock.newCondition()
val info = ThreadInfo(priority, condition, numPermits, taskAttemptId)
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ class PrioritySemaphoreSuite extends AnyFunSuite {
test("tryAcquire should return true if permits are available") {
val semaphore = new TestPrioritySemaphore(10)

assert(semaphore.tryAcquire(5, 0))
assert(semaphore.tryAcquire(3, 0))
assert(semaphore.tryAcquire(2, 0))
assert(!semaphore.tryAcquire(1, 0))
assert(semaphore.tryAcquire(5, 0, 0))
assert(semaphore.tryAcquire(3, 0, 0))
assert(semaphore.tryAcquire(2, 0, 0))
assert(!semaphore.tryAcquire(1, 0, 0))
}

test("acquire and release should work correctly") {
val semaphore = new TestPrioritySemaphore(1)

assert(semaphore.tryAcquire(1, 0))
assert(semaphore.tryAcquire(1, 0, 0))

val t = new Thread(() => {
try {
Expand Down Expand Up @@ -94,10 +94,36 @@ class PrioritySemaphoreSuite extends AnyFunSuite {

// Here, there should be 5 available permits, but a thread with higher priority (2)
// is waiting to acquire, therefore we should get rejected here
assert(!semaphore.tryAcquire(5, 0))
assert(!semaphore.tryAcquire(5, 0, 0))
semaphore.release(5)
t.join(1000)
// After the high priority thread finishes, we can acquire with lower priority
assert(semaphore.tryAcquire(5, 0))
assert(semaphore.tryAcquire(5, 0, 0))
}

// this case is described at https://github.com/NVIDIA/spark-rapids/pull/11574/files#r1795652488
test("thread with larger task id should not surpass smaller task id in the waiting queue") {
val semaphore = new TestPrioritySemaphore(10)
semaphore.acquire(8, 0, 0)
val t = new Thread(() => {
semaphore.acquire(5, 0, 0)
semaphore.release(5)
})
t.start()
Thread.sleep(100)

// Here, there should be 2 available permits, and a thread with same task id (0)
// is waiting to acquire 5 permits, in this case we should succeed here
assert(semaphore.tryAcquire(2, 0, 0))
semaphore.release(2)

// Here, there should be 2 available permits, but a thread with smaller task id (0)
// is waiting to acquire, therefore we should get rejected here
assert(!semaphore.tryAcquire(2, 0, 1))

semaphore.release(8)
t.join(1000)
// After the high priority thread finishes, we can acquire with lower priority
assert(semaphore.tryAcquire(2, 0, 1))
}
}

0 comments on commit 0ba4fd2

Please sign in to comment.