You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all. This is my first discussion post, so please let me know if I’ve made any mistakes. I’m running into some unexpected outputs given a matrix multiplication on the GPU. I have a few different questions, but if anyone could address even one of them I’d be very grateful. Here’s my code with five different inputs, which each produce a different output:
import jax.numpy as jnp
import jax
import jax.lax as lax
from functools import partial
from jax import random
from jax._src import config
config.update("jax_compiler_enable_remat_pass", False)
@jax.jit
def jnp_jit_matmul(a, b):
r = jnp.matmul(a, b, preferred_element_type=jnp.float16)
@partial(jax.jit, static_argnums=(1,))
def generate_arg_b(rng, dim):
c = jax.random.randint(rng, dim, 0, 2, dtype=jnp.int8)
c = (c * 2) - 1
return c
def main():
n = 46137344 #output A
# n = 46137345 #output B
# n = 48400000 #output C
# n = 48500000 #output D
# n = 50000000 #output E
key = random.PRNGKey(0)
key_a, key_b = random.split(key)
param_a = random.normal(key_a, (1, 128), dtype=jnp.float16)
param_b = generate_arg_b(key_b, (128, n)).astype(jnp.float16)
result = jnp.matmul(param_a, param_b)
print(result.shape)
if __name__ == "__main__":
main()
The graphics card this is running on is an Nvidia GeForce RTX 4090 with 24 gigabytes of memory. The environment is a fresh python venv with simply:
As I understand, a matrix multiplication has a minimum memory footprint of input_a + input_b + output. Even though my input_b is all -1’s and 1’s, I cast it to float16 so that the dtypes of the multiplication match. param_a takes up essentially no memory, so the math I’m using to calculate approximate memory usage is:
Output A behaves exactly as expected, so I’ll move to output B first. This is only a warning, and it starts when param_b takes up exactly 11 gigabytes in memory (I'm not too worried about this, but thought I'd mention it in case it's important):
2024-11-10 12:28:05.863035: I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:397] ptxas warning : Registers are spilled to local memory in function 'gemm_fusion_dot_1', 72 bytes spill stores, 72 bytes spill loads
2024-11-10 12:28:05.868270: I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:397] ptxas warning : Registers are spilled to local memory in function 'gemm_fusion_dot_1', 56 bytes spill stores, 56 bytes spill loads
2024-11-10 12:28:05.918815: I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:397] ptxas warning : Registers are spilled to local memory in function 'gemm_fusion_dot_1', 400 bytes spill stores, 380 bytes spill loads
2024-11-10 12:28:05.929907: I external/xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:397] ptxas warning : Registers are spilled to local memory in function 'gemm_fusion_dot_1', 194 bytes spill stores, 192 bytes spill loads
(1, 46137345)
Output C produces many OUT_OF_MEMORY CUDA errors, but still computes an output. My math indicates this matrix multiplication should take a little over 11.5 gigs, which falls well under my card’s capacity.
2024-11-10 12:29:17.318519: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.318594: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.318630: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.318665: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.318715: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.318748: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.318783: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.318820: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.318854: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.318886: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.318918: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.318954: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.318988: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.319021: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.319055: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.319089: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.319121: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.319154: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.319198: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.319235: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.319268: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.319302: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:17.319336: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.32MiB (96800000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
(1, 48400000)
Output D produces the strangest error, in my opinion. I have very little idea on how to start evaluating this.
2024-11-10 12:29:56.763565: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789383: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789428: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789461: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789512: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789558: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789591: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789629: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789661: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789708: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789755: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789789: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789826: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789863: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789908: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789944: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.789991: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.790024: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.790057: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.790089: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.790122: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.790155: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.790188: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
2024-11-10 12:29:56.790221: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 92.51MiB (97000000 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
Traceback (most recent call last):
File "/project_dir/ian/gpu_attack/optimized_cf_attack/./memory_err_demo.py", line 39, in <module>
main()
File "/project_dir/ian/gpu_attack/optimized_cf_attack/./memory_err_demo.py", line 35, in main
result = jnp.matmul(param_a, param_b)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/pjit.py", line 338, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/pjit.py", line 188, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **p.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/core.py", line 2803, in bind
return self.bind_with_trace(top_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/core.py", line 442, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/core.py", line 955, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1738, in _pjit_call_impl
return xc._xla.pjit(
^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1714, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1644, in _pjit_call_impl_python
).compile(compile_options)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2345, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2854, in from_hlo
xla_executable = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2666, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/compiler.py", line 434, in compile_or_get_cached
return _compile_and_write_cache(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/compiler.py", line 662, in _compile_and_write_cache
executable = backend_compile(
^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/compiler.py", line 273, in backend_compile
raise e
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/compiler.py", line 267, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: There are no algorithm candidates for computing:
%bitcast.17 = f16[1,48500000]{1,0} bitcast(f16[48500000]{0} %bitcast.16), metadata={op_name="jit(matmul)/jit(main)/dot_general" source_file="/project_dir/ian/gpu_attack/optimized_cf_attack/./memory_err_demo.py" source_line=35}
This likely means that the instruction shape is not supported by the target GPU library.
Finally, output E (and greater inputs) produces this allocation failure:
2024-11-10 12:33:34.278022: I external/xla/xla/stream_executor/cuda/cuda_executor.cc:477] failed to allocate 11.94GiB (12816777216 bytes) from device: RESOURCE_EXHAUSTED: : CUDA_ERROR_OUT_OF_MEMORY: out of memory
Traceback (most recent call last):
File "/project_dir/ian/gpu_attack/optimized_cf_attack/./memory_err_demo.py", line 39, in <module>
main()
File "/project_dir/ian/gpu_attack/optimized_cf_attack/./memory_err_demo.py", line 35, in main
result = jnp.matmul(param_a, param_b)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/pjit.py", line 338, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/pjit.py", line 188, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **p.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/core.py", line 2803, in bind
return self.bind_with_trace(top_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/core.py", line 442, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/core.py", line 955, in process_primitive
return primitive.impl(*tracers, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1738, in _pjit_call_impl
return xc._xla.pjit(
^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1714, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1644, in _pjit_call_impl_python
).compile(compile_options)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2345, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2854, in from_hlo
xla_executable = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2666, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/compiler.py", line 434, in compile_or_get_cached
return _compile_and_write_cache(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/compiler.py", line 662, in _compile_and_write_cache
executable = backend_compile(
^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/compiler.py", line 273, in backend_compile
raise e
File "/project_dir/tempenv/lib/python3.11/site-packages/jax/_src/compiler.py", line 267, in backend_compile
return backend.compile(built_c, compile_options=options)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 11.94GiB (12816777216B) on device ordinal 0
I’d mostly just like to understand what’s going on here, but if I had to ask some specific questions, they’d be these: why am I running out of memory earlier than I expect? Does my matrix multiplication input need to exist in memory twice and thus take up double the space I’m expecting? Is there a better way to perform this matrix multiplication, if I know the larger of the two inputs will be specific integers? I tried passing them natively as integers but that produced more issues, as the multiplication must be in a supported GEMM format.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi all. This is my first discussion post, so please let me know if I’ve made any mistakes. I’m running into some unexpected outputs given a matrix multiplication on the GPU. I have a few different questions, but if anyone could address even one of them I’d be very grateful. Here’s my code with five different inputs, which each produce a different output:
The graphics card this is running on is an Nvidia GeForce RTX 4090 with 24 gigabytes of memory. The environment is a fresh python venv with simply:
I have the environment variables:
As I understand, a matrix multiplication has a minimum memory footprint of input_a + input_b + output. Even though my input_b is all -1’s and 1’s, I cast it to float16 so that the dtypes of the multiplication match. param_a takes up essentially no memory, so the math I’m using to calculate approximate memory usage is:
Output A behaves exactly as expected, so I’ll move to output B first. This is only a warning, and it starts when param_b takes up exactly 11 gigabytes in memory (I'm not too worried about this, but thought I'd mention it in case it's important):
Output C produces many OUT_OF_MEMORY CUDA errors, but still computes an output. My math indicates this matrix multiplication should take a little over 11.5 gigs, which falls well under my card’s capacity.
Output D produces the strangest error, in my opinion. I have very little idea on how to start evaluating this.
Finally, output E (and greater inputs) produces this allocation failure:
I’d mostly just like to understand what’s going on here, but if I had to ask some specific questions, they’d be these: why am I running out of memory earlier than I expect? Does my matrix multiplication input need to exist in memory twice and thus take up double the space I’m expecting? Is there a better way to perform this matrix multiplication, if I know the larger of the two inputs will be specific integers? I tried passing them natively as integers but that produced more issues, as the multiplication must be in a supported GEMM format.
Beta Was this translation helpful? Give feedback.
All reactions