Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

package org.elasticsearch.xpack.gpu.codec;

import com.nvidia.cuvs.CagraIndexParams;
import com.nvidia.cuvs.CuVSMatrix;
import com.nvidia.cuvs.CuVSResources;
import com.nvidia.cuvs.GPUInfoProvider;
import com.nvidia.cuvs.spi.CuVSProvider;

import org.elasticsearch.core.Strings;
Expand Down Expand Up @@ -47,10 +47,8 @@ public interface CuVSResourceManager {
* effect on GPU memory and compute usage to determine whether to give out
* another resource or wait for a resources to be returned before giving out another.
*/
// numVectors and dims are currently unused, but could be used along with GPU metadata,
// memory, generation, etc, when acquiring for 10M x 1536 dims, or 100,000 x 128 dims,
// to give out a resources or not.
ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException;
ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams)
throws InterruptedException;

/** Marks the resources as finished with regard to compute. */
void finishedComputation(ManagedCuVSResources resources);
Expand Down Expand Up @@ -80,31 +78,31 @@ class PoolingCuVSResourceManager implements CuVSResourceManager {
static class Holder {
static final PoolingCuVSResourceManager INSTANCE = new PoolingCuVSResourceManager(
MAX_RESOURCES,
CuVSProvider.provider().gpuInfoProvider()
new RealGPUMemoryService(CuVSProvider.provider().gpuInfoProvider())
);
}

private final ManagedCuVSResources[] pool;
private final int capacity;
private final GPUInfoProvider gpuInfoProvider;
private final GPUMemoryService gpuMemoryService;
private int createdCount;

ReentrantLock lock = new ReentrantLock();
Condition enoughResourcesCondition = lock.newCondition();

public PoolingCuVSResourceManager(int capacity, GPUInfoProvider gpuInfoProvider) {
PoolingCuVSResourceManager(int capacity, GPUMemoryService gpuMemoryService) {
if (capacity < 1 || capacity > MAX_RESOURCES) {
throw new IllegalArgumentException("Resource count must be between 1 and " + MAX_RESOURCES);
}
this.capacity = capacity;
this.gpuInfoProvider = gpuInfoProvider;
this.gpuMemoryService = gpuMemoryService;
this.pool = new ManagedCuVSResources[MAX_RESOURCES];
}

private ManagedCuVSResources getResourceFromPool() {
for (int i = 0; i < createdCount; ++i) {
var res = pool[i];
if (res.locked == false) {
if (res.isLocked() == false) {
return res;
}
}
Expand All @@ -120,43 +118,45 @@ private int numLockedResources() {
int lockedResources = 0;
for (int i = 0; i < createdCount; ++i) {
var res = pool[i];
if (res.locked) {
if (res.isLocked()) {
lockedResources++;
}
}
return lockedResources;
}

@Override
public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType) throws InterruptedException {
public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams)
throws InterruptedException {
try {
var started = System.nanoTime();
lock.lock();

boolean allConditionsMet = false;
ManagedCuVSResources res = null;

long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType, cagraIndexParams);
logger.debug(
"Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]",
numVectors,
dims,
dataType.name(),
requiredMemoryInBytes
);

while (allConditionsMet == false) {
res = getResourceFromPool();

final boolean enoughMemory;
if (res != null) {
long requiredMemoryInBytes = estimateRequiredMemory(numVectors, dims, dataType);
logger.debug(
"Estimated memory for [{}] vectors, [{}] dims of type [{}] is [{} B]",
numVectors,
dims,
dataType.name(),
requiredMemoryInBytes
);

// Check immutable constraints
long totalDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).totalDeviceMemoryInBytes();
if (requiredMemoryInBytes > totalDeviceMemoryInBytes) {
long totalMemoryInBytes = gpuMemoryService.totalMemoryInBytes(res);
if (requiredMemoryInBytes > totalMemoryInBytes) {
String message = Strings.format(
"Requested GPU memory for [%d] vectors, [%d] dims is greater than the GPU total memory [%d B]",
numVectors,
dims,
totalDeviceMemoryInBytes
totalMemoryInBytes
);
logger.error(message);
throw new IllegalArgumentException(message);
Expand All @@ -169,9 +169,9 @@ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataTyp
}

// Check resources availability
long freeDeviceMemoryInBytes = gpuInfoProvider.getCurrentInfo(res).freeDeviceMemoryInBytes();
enoughMemory = requiredMemoryInBytes <= freeDeviceMemoryInBytes;
logger.debug("Free device memory [{} B], enoughMemory[{}]", freeDeviceMemoryInBytes, enoughMemory);
long availableMemoryInBytes = gpuMemoryService.availableMemoryInBytes(res);
enoughMemory = requiredMemoryInBytes <= availableMemoryInBytes;
logger.debug("Free device memory [{} B], enoughMemory[{}]", availableMemoryInBytes, enoughMemory);
} else {
logger.debug("No resources available in pool");
enoughMemory = false;
Expand All @@ -184,19 +184,33 @@ public ManagedCuVSResources acquire(int numVectors, int dims, CuVSMatrix.DataTyp
}
var elapsed = started - System.nanoTime();
logger.debug("Resource acquired in [{}ms]", elapsed / 1_000_000.0);
res.locked = true;
gpuMemoryService.reserveMemory(requiredMemoryInBytes);
res.lock(() -> gpuMemoryService.releaseMemory(requiredMemoryInBytes));
return res;
} finally {
lock.unlock();
}
}

private long estimateRequiredMemory(int numVectors, int dims, CuVSMatrix.DataType dataType) {
private long estimateRequiredMemory(int numVectors, int dims, CuVSMatrix.DataType dataType, CagraIndexParams cagraIndexParams) {
int elementTypeBytes = switch (dataType) {
case FLOAT -> Float.BYTES;
case INT, UINT -> Integer.BYTES;
case BYTE -> Byte.BYTES;
};

if (cagraIndexParams.getCagraGraphBuildAlgo() == CagraIndexParams.CagraGraphBuildAlgo.IVF_PQ
&& cagraIndexParams.getCuVSIvfPqParams() != null
&& cagraIndexParams.getCuVSIvfPqParams().getIndexParams() != null
&& cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqDim() != 0) {
// See https://docs.rapids.ai/api/cuvs/nightly/neighbors/ivfpq/#index-device-memory
var pqDim = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqDim();
var pqBits = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getPqBits();
var numClusters = cagraIndexParams.getCuVSIvfPqParams().getIndexParams().getnLists();
var approximatedIvfBytes = numVectors * (pqDim * (pqBits / 8.0) + elementTypeBytes) + (long) numClusters * Integer.BYTES;
return (long) (GPU_COMPUTATION_MEMORY_FACTOR * approximatedIvfBytes);
}

return (long) (GPU_COMPUTATION_MEMORY_FACTOR * numVectors * dims * elementTypeBytes);
}

Expand All @@ -217,8 +231,8 @@ public void release(ManagedCuVSResources resources) {
logger.debug("Releasing resources to pool");
try {
lock.lock();
assert resources.locked;
resources.locked = false;
assert resources.isLocked();
resources.unlock();
enoughResourcesCondition.signalAll();
} finally {
lock.unlock();
Expand All @@ -238,8 +252,9 @@ public void shutdown() {
/** A managed resource. Cannot be closed. */
final class ManagedCuVSResources implements CuVSResources {

final CuVSResources delegate;
boolean locked = false;
private final CuVSResources delegate;
private static final Runnable NOT_LOCKED = () -> {};
private Runnable unlockAction = NOT_LOCKED;

ManagedCuVSResources(CuVSResources resources) {
this.delegate = resources;
Expand Down Expand Up @@ -269,5 +284,17 @@ public Path tempDirectory() {
public String toString() {
return "ManagedCuVSResources[delegate=" + delegate + "]";
}

void lock(Runnable unlockAction) {
this.unlockAction = unlockAction;
}

void unlock() {
unlockAction = NOT_LOCKED;
}

boolean isLocked() {
return unlockAction != NOT_LOCKED;
}
}
}
Loading