You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to play with host_callback in experiment, but I cannot even run the host_callback_test.py successfully. Here I provide with a minimum reproducible example on my server with 2 devices(I'm not sure if this is reproducible on other servers, as this error occurs even in running official test files):
I've also tested with 4 devices by setting x_shape=(4,2), but have the same result. The error message is:
E external/org_tensorflow/tensorflow/compiler/xla/status_macros.cc:56] Internal: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc:80) ShapeUtil::Equal(source_slices_[index].shape, output_shape) Mismatch between outfeed output buffer shape u32[2]{0} and outfeed source buffer shape s32[2]{0}
*** Begin stack trace ***
clone
*** End stack trace ***
2021-03-20 22:09:07.463390: E external/org_tensorflow/tensorflow/compiler/xla/status_macros.cc:56] Internal: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc:80) ShapeUtil::Equal(source_slices_[index].shape, output_shape) Mismatch between outfeed output buffer shape s32[2]{0} and outfeed source buffer shape u32[2]{0}
*** Begin stack trace ***
clone
*** End stack trace ***
2021-03-20 22:09:07.463428: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1886] Execution of replica 1 failed: Internal: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc:80) ShapeUtil::Equal(source_slices_[index].shape, output_shape) Mismatch between outfeed output buffer shape u32[2]{0} and outfeed source buffer shape s32[2]{0}
2021-03-20 22:09:07.463440: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:1886] Execution of replica 0 failed: Internal: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc:80) ShapeUtil::Equal(source_slices_[index].shape, output_shape) Mismatch between outfeed output buffer shape s32[2]{0} and outfeed source buffer shape u32[2]{0}
Traceback (most recent call last):
File "/home/yonghao/code/tmp/test.py", line 20, in <module>
basic_test(x)
jax._src.traceback_util.FilteredStackTrace: RuntimeError: Internal: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc:80) ShapeUtil::Equal(source_slices_[index].shape, output_shape) Mismatch between outfeed output buffer shape s32[2]{0} and outfeed source buffer shape u32[2]{0}: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/yonghao/code/tmp/test.py", line 20, in <module>
basic_test(x)
File "/home/yonghao/code/jax/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/yonghao/code/jax/jax/api.py", line 1583, in f_pmapped
global_arg_shapes=tuple(global_arg_shapes_flat))
File "/home/yonghao/code/jax/jax/core.py", line 1453, in bind
return call_bind(self, fun, *args, **params)
File "/home/yonghao/code/jax/jax/core.py", line 1385, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/home/yonghao/code/jax/jax/core.py", line 1456, in process
return trace.process_map(self, fun, tracers, params)
File "/home/yonghao/code/jax/jax/core.py", line 625, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/yonghao/code/jax/jax/interpreters/pxla.py", line 620, in xla_pmap_impl
return compiled_fun(*args)
File "/home/yonghao/code/jax/jax/interpreters/pxla.py", line 1167, in execute_replicated
out_bufs = compiled.execute_sharded_on_local_devices(input_bufs)
RuntimeError: Internal: RET_CHECK failure (external/org_tensorflow/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc:80) ShapeUtil::Equal(source_slices_[index].shape, output_shape) Mismatch between outfeed output buffer shape s32[2]{0} and outfeed source buffer shape u32[2]{0}: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).
However, in the example above, the output shape is set as x with return value(same as input) x, following the hcb_sin example in hcb example.
The same error even happens when I explicitly set the return value dtype following the host_eig example.
I wonder how can I find why this problem happens.
My jaxlib version is build from source in last Monday(3.15) and jax in last Wednesday(3.17).
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I'm trying to play with
host_callback
inexperiment
, but I cannot even run thehost_callback_test.py
successfully. Here I provide with a minimum reproducible example on my server with 2 devices(I'm not sure if this is reproducible on other servers, as this error occurs even in running official test files):I've also tested with 4 devices by setting
x_shape=(4,2)
, but have the same result. The error message is:However, in the example above, the output shape is set as
x
with return value(same as input)x
, following thehcb_sin
example in hcb example.The same error even happens when I explicitly set the return value dtype following the
host_eig
example.I wonder how can I find why this problem happens.
My jaxlib version is build from source in last Monday(3.15) and jax in last Wednesday(3.17).
Beta Was this translation helpful? Give feedback.
All reactions