-
Notifications
You must be signed in to change notification settings - Fork 354
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support weight streaming #3111
Conversation
@@ -109,10 +111,119 @@ def __init__( | |||
if self.serialized_engine is not None and not self.settings.lazy_engine_init: | |||
self.setup_engine() | |||
|
|||
def set_weight_streaming_budget(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@keehyuna do you need to add something similar to the C++ API?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sorry for confusion. It's dead code, all are moved to py/torch_tensorrt/runtime/_weight_streaming.py. C++ apis are updated in execute_engine.cpp
core/runtime/execute_engine.cpp
Outdated
@@ -95,11 +95,13 @@ bool _cudagraphs_validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ | |||
} | |||
|
|||
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) { | |||
compiled_engine->init_context(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will add first run latency. Why cant it run in the constructor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for advice. I added it in constructor. Latency is in compiler() context creation in forward() will be skipped when weight streaming is not used.
@@ -218,13 +219,25 @@ def set_weight_streaming_budget_v1( | |||
self.engine.minimum_weight_streaming_budget | |||
) | |||
|
|||
def reset_context(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these context resets atomically with whatever runtime settting change. Leave as much out of the forward function as we can
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I came crossed two ideas and tried #1. Please let me know if there is better way to handle it automatically.
- reset_context(delete context) and apply set_weight_streaming_budget() api. context is created at forward()
- Enqueue runtime setting change like set_weight or profile enable. Then delete context->apply pending api-> create context in forward()
assert self.engine, f"Context is used before setting up the engine" | ||
|
||
if self.context is None: | ||
self.context = self.engine.create_execution_context() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we already have a setup engine function, not sure why we need to handle this at exec time?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weight streaming needs to be set before context is created. Or TRT throw the error. engine setup was completed in compile of runtime trt module. context needs to be recreated.
engine = runtime.deserialize_cuda_engine()
engine.weight_streaming_budget_v2 = budget_bytes
engine.create_execution_context()
fine
engine = runtime.deserialize_cuda_engine()
engine.create_execution_context()
engine.weight_streaming_budget_v2 = budget_bytes
ERROR:torch_tensorrt [TensorRT Conversion Context]:ICudaEngine::setWeightStreamingBudgetV2: Error Code 3: API Usage Error (Parameter check failed, condition: mExecutionContextCounter.use_count() == 1. The weight streaming budget cannot be modified while there are active IExecutionContexts.)
@@ -109,10 +111,119 @@ def __init__( | |||
if self.serialized_engine is not None and not self.settings.lazy_engine_init: | |||
self.setup_engine() | |||
|
|||
def set_weight_streaming_budget(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sorry for confusion. It's dead code, all are moved to py/torch_tensorrt/runtime/_weight_streaming.py. C++ apis are updated in execute_engine.cpp
assert self.engine, f"Context is used before setting up the engine" | ||
|
||
if self.context is None: | ||
self.context = self.engine.create_execution_context() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weight streaming needs to be set before context is created. Or TRT throw the error. engine setup was completed in compile of runtime trt module. context needs to be recreated.
engine = runtime.deserialize_cuda_engine()
engine.weight_streaming_budget_v2 = budget_bytes
engine.create_execution_context()
fine
engine = runtime.deserialize_cuda_engine()
engine.create_execution_context()
engine.weight_streaming_budget_v2 = budget_bytes
ERROR:torch_tensorrt [TensorRT Conversion Context]:ICudaEngine::setWeightStreamingBudgetV2: Error Code 3: API Usage Error (Parameter check failed, condition: mExecutionContextCounter.use_count() == 1. The weight streaming budget cannot be modified while there are active IExecutionContexts.)
core/runtime/execute_engine.cpp
Outdated
@@ -95,11 +95,13 @@ bool _cudagraphs_validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ | |||
} | |||
|
|||
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) { | |||
compiled_engine->init_context(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for advice. I added it in constructor. Latency is in compiler() context creation in forward() will be skipped when weight streaming is not used.
@@ -218,13 +219,25 @@ def set_weight_streaming_budget_v1( | |||
self.engine.minimum_weight_streaming_budget | |||
) | |||
|
|||
def reset_context(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I came crossed two ideas and tried #1. Please let me know if there is better way to handle it automatically.
- reset_context(delete context) and apply set_weight_streaming_budget() api. context is created at forward()
- Enqueue runtime setting change like set_weight or profile enable. Then delete context->apply pending api-> create context in forward()
def get_weight_streaming_budget(self): | ||
return self.engine.streamable_weights_size | ||
|
||
def set_weight_streaming_budget(self, budget_bytes): | ||
self.reset_context() | ||
self.engine.weight_streaming_budget_v2 = budget_bytes | ||
if self.engine.weight_streaming_budget_v2 != budget_bytes: | ||
logger.error(f"Failed to set weight streaming budget to {budget_bytes}") | ||
budget_bytes = self.engine.weight_streaming_budget_v2 | ||
if self.engine.streamable_weights_size == budget_bytes: | ||
logger.warning("Weight streaming is disabled") | ||
|
||
return budget_bytes | ||
|
||
def set_automatic_streaming_budget(self): | ||
budget_bytes = self.engine.get_weight_streaming_automatic_budget() | ||
return self.set_weight_streaming_budget(budget_bytes) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This api is same as in TorchTensorRTModule class. If this interface is good to go, parent class can used to share it and other some methods.
We probably need to think about what the user flow is here: So @ compile-time:
@ runtime
|
return budget_bytes | ||
|
||
def set_automatic_streaming_budget(self): | ||
budget_bytes = self.engine.get_weight_streaming_automatic_budget() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a good default we can use in setup_engine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. I set automatic weight streaming when compiler options is set
@@ -191,6 +221,7 @@ def __del__(self) -> None: | |||
self.cudagraph.reset() | |||
|
|||
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: | |||
self.init_context() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really want to pull these calls out, It should assume that the engine is setup and error if not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood. recreation of context happens only when set_weight_streaming_budget is called
Hi @narendasan
|
enable_weight_streaming=True, | ||
) | ||
# Weight streaming budget is applied manually. | ||
ws_context = torchtrt.runtime.weight_streaming_context(optimized_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use the context manager syntax to use this?
with torch_tensorrt.runtime.weight_streaming(model) as weight_streaming_ctx:
current_budget = weight_streaming_ctx.device_budget
weight_streaming_ctx.device_budget = current_budget * 0.7 # Can add listeners to __setattr__ to trigger functions
optimized_model(*input)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thoughts :
- If we use weight streaming as default, is there any problem with perf ? assuming we don't allocate any budget or if automatic is chosen, and the model can fit on GPU memory completely
cast_layer = ctx.net.add_cast(input_val, trt_dtype) | ||
cast_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]" | ||
|
||
return cast_layer.get_output(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are currently in llm_examples_main PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rebase with main as the llm_examples PR is merged
cast_layer = ctx.net.add_cast(input_val, trt_dtype) | ||
cast_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - [{target_name}]-[{name}]" | ||
|
||
return cast_layer.get_output(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rebase with main as the llm_examples PR is merged
if ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED): | ||
promoted_type = trt_inputs[0].dtype | ||
for each_input in trt_inputs[1:]: | ||
promoted_type = _enums.dtype._from( | ||
torch.promote_types( | ||
_enums.dtype._from(promoted_type).to(torch.dtype), | ||
_enums.dtype._from(each_input.dtype).to(torch.dtype), | ||
) | ||
) | ||
|
||
trt_promoted_type = promoted_type.to(trt.DataType) | ||
trt_casted_inputs = [] | ||
for i, each_input in enumerate(trt_inputs): | ||
casted_input = cast_trt_tensor( | ||
ctx, each_input, trt_promoted_type, f"{name}_input_casted_{i}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type promotion is fine but does it needs to only happen when strong typing is enabled? Why not do this in general cases as well ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought trt can optimize the perf for relaxed precision. But it seems multiple inputs in ops are eventually casted to same type. Tested sd unet model with/without promoted types, there was no differences. I will generalize.
dtype = input.dtype if strongly_typed else None | ||
bias = to_numpy(bias, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the type of bias be always input.dtype
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If test fp16 variant of sd_unet model, bias data type is float16. It needs to be casted to run with weight streaming option.
@@ -85,6 +85,12 @@ def __init__( | |||
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) | |||
flag |= EXPLICIT_BATCH | |||
|
|||
if compilation_settings.enable_weight_streaming: | |||
STRONGLY_TYPED = 1 << (int)( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should log this at least since it affects the graph being created
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Waiting for separate compiler option to use strongly typed network.
https://github.com/pytorch/TensorRT/pull/3110/files#diff-4396607120a22430fe9fdb7d00b094ae5d55f28d0d2e3543a878ac48583ebd21R83
I will incorporate with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super minor stuff at this point, think its almost ready to go
7c9cc49
to
9842b00
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this mostly looks good to me, anything outstanding?
No pending items. I think this PR can be merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comments:
- Added a comment in the example
- Also update this example reference in the docsrc/infex.rst to get rendered.
- Rebase with main to resolve conflicts.
Overall, changes LGTM
9842b00
to
3ac5da1
Compare
3ac5da1
to
bf10495
Compare
bf10495
to
c12f76f
Compare
Description
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: