Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamic Shape] More correctness guards #276

Merged
merged 23 commits into from
Jun 16, 2023

Conversation

Aalanli
Copy link
Collaborator

@Aalanli Aalanli commented Jun 9, 2023

Add support for dynamic shape assertions in the C++ runtime.
Further, add shape and dtype check for python runtime. With this being enabled by default.

@Aalanli
Copy link
Collaborator Author

Aalanli commented Jun 9, 2023

Moreover, there is a subtle bug where the symbol registry gets rewritten multiple times if there exists inputs which promises the same shape, but runtime inputs have different shapes.
Eg. hidet.symbol(['a', 'a']), but during runtime: hidet.randn([1, 2])

So added a check for that as well.

@Aalanli
Copy link
Collaborator Author

Aalanli commented Jun 12, 2023

I have no idea why the gpt2 test passed before, there was a similar error with the llama implementation, where the compiled model reinterpreted an int64 tensor as an int32 tensor, producing random outputs.

Copy link
Member

@yaoyaoding yaoyaoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Aalanli.

python/hidet/backend/codegen.py Outdated Show resolved Hide resolved
python/hidet/drivers/build_task.py Outdated Show resolved Hide resolved
python/hidet/ir/task.py Outdated Show resolved Hide resolved
python/hidet/ir/task.py Outdated Show resolved Hide resolved
python/hidet/ir/task.py Outdated Show resolved Hide resolved
python/hidet/option.py Outdated Show resolved Hide resolved
python/hidet/runtime/compiled_graph.py Outdated Show resolved Hide resolved
python/hidet/runtime/compiled_graph.py Outdated Show resolved Hide resolved
python/hidet/runtime/compiled_graph.py Show resolved Hide resolved
Comment on lines 154 to 177
for i, (traced, new) in enumerate(zip(self.meta_data.inputs, inputs)):
if ir.data_type(traced.dtype) != new.dtype:
raise RuntimeError(
f"dtype mismatch at arg {i} between original: {traced.dtype} and new: {new.dtype}"
)
traced_shape = traced.shape
concrete_shape = new.shape
if len(traced_shape) != len(concrete_shape):
raise RuntimeError(
f"Rank of input {i} not equal to original. ({len(concrete_shape)} vs. {len(traced_shape)})"
)
for j, (orig_shape, new_shape) in enumerate(zip(traced_shape, concrete_shape)):
if isinstance(orig_shape, int) and orig_shape != new_shape:
raise RuntimeError(
f'shape mismatch at dimension {j}, original: \
{orig_shape} vs. new: {new_shape}'
)
elif orig_shape not in symbol_map:
symbol_map[orig_shape] = new_shape
elif symbol_map[orig_shape] != new_shape:
raise RuntimeError(
f"There exists multiple instances of the same symbol {orig_shape}\
with different values in inputs (ex: {symbol_map[orig_shape]} and {new_shape})"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same logic as the CompiledGraph. Consider implement as a utility function and put it at the same module of TensorSignature.

@yaoyaoding
Copy link
Member

Thanks @Aalanli !

@yaoyaoding yaoyaoding merged commit 5b490b6 into hidet-org:main Jun 16, 2023
@Aalanli Aalanli deleted the dyn-shape-assertion branch September 27, 2023 18:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants