-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic shape index #3039
Dynamic shape index #3039
Conversation
059d09a
to
7f9d604
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
7f9d604
to
daf1be0
Compare
@apbose is this ready ? |
@peri044 Right now the dynamic input cases are supported, I will add some test cases for dynamic index. Although the implementation should support it. Also, is there any example script to test it with mistral-7b and SD? I did not test it end to end with the models yet |
For mistral model, you can use the same example https://github.com/pytorch/TensorRT/blob/llm_examples_main/examples/dynamo/torch_export_llama2.py and swap the model name to |
( | ||
get_shape( | ||
ctx, | ||
target, | ||
source_ir, | ||
name + f"_transpose_tensor_shape_mult_d0_{i}", | ||
transpose_tensor, | ||
i, | ||
) | ||
if transpose_tensor_shape[i] == DYNAMIC_DIM | ||
else transpose_tensor_shape[i] | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
consider storing this in a variable for better readability
( | ||
get_shape( | ||
ctx, | ||
target, | ||
source_ir, | ||
name + f"_transpose_tensor_shape_mult_d1_{i}", | ||
transpose_tensor, | ||
i, | ||
) | ||
if transpose_tensor_shape[i] == DYNAMIC_DIM | ||
else transpose_tensor_shape[i] | ||
), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
dim_tensor_list[adv_indx_indices[adv_indx_count - 1]], | ||
name + "_dim_last", | ||
) | ||
multiplier = dim_tensor_list[adv_indx_indices[adv_indx_count - 1]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this change imply multiplier is always an ITensor since you were using get_trt_tensor
before ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes we do not need get_trt_tensor
in this, even in the previous non dynamic cases it was not required since dim_tensor_list already stores a list of dimension tensors
daf1be0
to
a5ebd76
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2024-08-16 00:09:49.859558+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2024-08-16 00:10:09.332451+00:00
@@ -532,6 +532,6 @@
with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
)
- torch.export.save(exp_program, file_path)
\ No newline at end of file
+ torch.export.save(exp_program, file_path)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2024-08-16 00:18:47.939073+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2024-08-16 00:19:11.393152+00:00
@@ -532,6 +532,6 @@
with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
)
- torch.export.save(exp_program, file_path)
\ No newline at end of file
+ torch.export.save(exp_program, file_path)
673dd7b
to
8fba0f2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2024-08-16 00:20:46.489048+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2024-08-16 00:21:12.958728+00:00
@@ -532,6 +532,6 @@
with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
)
- torch.export.save(exp_program, file_path)
\ No newline at end of file
+ torch.export.save(exp_program, file_path)
I ran the Mistral-7B-v0.3, it shows me this in the end-
Looks like it got past the dynamic index case |
8fba0f2
to
8295117
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2024-08-16 00:24:39.231642+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2024-08-16 00:25:02.079196+00:00
@@ -532,6 +532,6 @@
with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
)
- torch.export.save(exp_program, file_path)
\ No newline at end of file
+ torch.export.save(exp_program, file_path)
8295117
to
acd0abe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2024-08-22 23:54:07.789419+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_compile.py 2024-08-22 23:54:27.515379+00:00
@@ -532,6 +532,6 @@
with enable_torchbind_tracing():
exp_program = torch.export.export(
module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False
)
- torch.export.save(exp_program, file_path)
\ No newline at end of file
+ torch.export.save(exp_program, file_path)
acd0abe
to
73a2fdb
Compare
73a2fdb
to
a64c29a
Compare
|
a64c29a
to
4cafd01
Compare
4cafd01
to
5379328
Compare
No description provided.