From 99f76f74960e0fcea5222f14389723a68bb22c84 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Tue, 3 Sep 2024 13:06:35 +0000 Subject: [PATCH] chore: cpp runtime update for min + streamable budget --- core/runtime/TRTEngine.cpp | 13 +- core/runtime/TRTEngine.h | 2 + core/runtime/register_jit_hooks.cpp | 3 +- .../runtime/_PythonTorchTensorRTModule.py | 32 ++--- .../dynamo/runtime/_TorchTensorRTModule.py | 15 +- .../runtime/_weight_streaming.py | 111 ++++++++------- .../runtime/test_004_weight_streaming.py | 132 ++++++++---------- 7 files changed, 153 insertions(+), 155 deletions(-) diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 130176de0d..1ff6e77bc7 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -91,7 +91,12 @@ TRTEngine::TRTEngine( cuda_engine = make_trt(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size())); TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine"); - if (get_min_required_device_budget() > 0) { + if (get_streamable_weights_size() > 0) { + // Scratch memory size may change based on the current weight streaming budget + // Required memory for full streaming is used to minimum weight budget + set_device_memory_budget(0); + min_required_device_budget = cuda_engine->getWeightStreamingScratchMemorySize(); + int64_t budget_bytes = get_weight_streaming_automatic_budget(); LOG_INFO("Set automatic weight streaming budget bytes " << budget_bytes); set_device_memory_budget(budget_bytes); @@ -275,10 +280,14 @@ bool TRTEngine::set_device_memory_budget(int64_t budget) { } // Returns 0 if BuilderFlag::kWEIGHT_STREAMING is unset during engine building. -int64_t TRTEngine::get_min_required_device_budget() { +int64_t TRTEngine::get_streamable_weights_size() { return cuda_engine->getStreamableWeightsSize(); } +int64_t TRTEngine::get_min_required_device_budget() { + return min_required_device_budget; +} + int64_t TRTEngine::get_weight_streaming_automatic_budget() { return cuda_engine->getWeightStreamingAutomaticBudget(); } diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index e59f8c7055..1d91797f0e 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -73,6 +73,7 @@ struct TRTEngine : torch::CustomClassHolder { void dump_engine_layer_info(); int64_t get_device_memory_budget(); bool set_device_memory_budget(int64_t budget); + int64_t get_streamable_weights_size(); int64_t get_min_required_device_budget(); int64_t get_weight_streaming_automatic_budget(); void init_context(); @@ -106,6 +107,7 @@ struct TRTEngine : torch::CustomClassHolder { std::string cuda_graph_debug_path; std::mutex mu; std::unique_ptr trt_engine_profiler; + int64_t min_required_device_budget; }; } // namespace runtime diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 58e14e619c..8f2ced2b81 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -90,8 +90,9 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = "device_memory_budget", &TRTEngine::get_device_memory_budget, &TRTEngine::set_device_memory_budget) + .def_property("streamable_weights_size", &TRTEngine::get_streamable_weights_size) .def_property("min_required_device_budget", &TRTEngine::get_min_required_device_budget) - .def("get_weight_streaming_automatic_budget", &TRTEngine::get_weight_streaming_automatic_budget) + .def_property("weight_streaming_automatic_budget", &TRTEngine::get_weight_streaming_automatic_budget) .def("init_context", &TRTEngine::init_context) .def("reset_context", &TRTEngine::reset_context) .def_pickle( diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 5cc35f1435..805592642e 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -2,9 +2,8 @@ import logging from contextlib import nullcontext -from functools import wraps from tempfile import tempdir -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple import tensorrt as trt import torch @@ -13,6 +12,9 @@ from torch_tensorrt._Device import Device from torch_tensorrt._enums import Platform, dtype from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( + recreate_context_decorator, +) from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.logging import TRT_LOGGER from torch_tensorrt.runtime._utils import ( @@ -24,22 +26,6 @@ logger = logging.getLogger(__name__) -def recreate_context_decorator(method: Callable[..., Any]) -> Callable[..., Any]: - """ - A decorator that destroys a context before a method execution and - creates it after the method execution within the same class instance. - """ - - @wraps(method) - def wrapper(self: object, *args: Any, **kwargs: Any) -> Any: - self.reset_context() - result = method(self, *args, **kwargs) - self.init_context() - return result - - return wrapper - - class PythonTorchTensorRTModule(Module): # type: ignore[misc] """PythonTorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine. @@ -142,9 +128,15 @@ def reset_context(self) -> None: def get_streamable_weights_size(self) -> Any: return self.engine.streamable_weights_size + def get_min_required_device_budget(self) -> Any: + return self.min_required_device_budget + def get_weight_streaming_budget(self) -> Any: return self.engine.weight_streaming_budget_v2 + def get_automatic_weight_streaming_budget(self) -> Any: + return self.engine.get_weight_streaming_automatic_budget() + @recreate_context_decorator def set_device_memory_budget(self, budget_bytes: int) -> int: return self._set_device_memory_budget(budget_bytes) @@ -152,7 +144,7 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: def _set_device_memory_budget(self, budget_bytes: int) -> int: # Disable weight streaming for invalid budget size if budget_bytes < 0: - budget_bytes = self.engine.streamable_weights_size + budget_bytes = self.get_streamable_weights_size() self.engine.weight_streaming_budget_v2 = budget_bytes if self.get_weight_streaming_budget() != budget_bytes: logger.error(f"Failed to set weight streaming budget to {budget_bytes}") @@ -169,7 +161,7 @@ def set_default_streaming_budget(self) -> int: self.min_required_device_budget = ( self.engine.weight_streaming_scratch_memory_size ) - budget_bytes = self.engine.get_weight_streaming_automatic_budget() + budget_bytes = self.get_automatic_weight_streaming_budget() # Set automatic weight streaming budget as default when context is created return self._set_device_memory_budget(budget_bytes) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index b95c980c81..f5cd335354 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -192,12 +192,18 @@ def init_context(self) -> None: def reset_context(self) -> None: self.engine.reset_context() + def get_streamable_weights_size(self) -> Any: + return self.engine.streamable_weights_size + def get_min_required_device_budget(self) -> Any: return self.engine.min_required_device_budget def get_weight_streaming_budget(self) -> Any: return self.engine.device_memory_budget + def get_automatic_weight_streaming_budget(self) -> Any: + return self.engine.weight_streaming_automatic_budget + @recreate_context_decorator def set_device_memory_budget(self, budget_bytes: int) -> int: return self._set_device_memory_budget(budget_bytes) @@ -205,21 +211,16 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: def _set_device_memory_budget(self, budget_bytes: int) -> int: # Disable weight streaming for invalid budget size if budget_bytes < 0: - budget_bytes = self.get_min_required_device_budget() - + budget_bytes = self.get_streamable_weights_size() self.engine.device_memory_budget = budget_bytes if self.get_weight_streaming_budget() != budget_bytes: logger.error(f"Failed to set weight streaming budget to {budget_bytes}") budget_bytes = self.get_weight_streaming_budget() - if self.engine.min_required_device_budget == budget_bytes: + if self.get_min_required_device_budget() == budget_bytes: logger.warning("Weight streaming is disabled") return budget_bytes - def set_automatic_streaming_budget(self) -> int: - budget_bytes = self.engine.get_weight_streaming_automatic_budget() - return self._set_device_memory_budget(budget_bytes) - def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. diff --git a/py/torch_tensorrt/runtime/_weight_streaming.py b/py/torch_tensorrt/runtime/_weight_streaming.py index 8fbaa4919a..304815b838 100755 --- a/py/torch_tensorrt/runtime/_weight_streaming.py +++ b/py/torch_tensorrt/runtime/_weight_streaming.py @@ -15,83 +15,94 @@ class _WeightStreamingContextManager(object): def __init__(self, module: torch.fx.GraphModule) -> None: rt_mods = [] torch_budget = 0 - trt_budget = 0 + trt_min_budget = 0 + trt_max_budget = 0 + current_trt_budget = 0 for name, rt_mod in module.named_children(): if "_run_on_acc" in name and isinstance( rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule) ): - trt_budget += rt_mod.min_required_device_budget - trt_budget += rt_mod.get_streamable_weights_size() + trt_min_budget += rt_mod.get_min_required_device_budget() + trt_max_budget += rt_mod.get_streamable_weights_size() + current_trt_budget += rt_mod.get_weight_streaming_budget() rt_mods.append((name, rt_mod)) else: torch_budget += sum( [p.numel() * p.element_size() for p in rt_mod.parameters()] ) + + trt_max_budget += trt_min_budget self.torch_budget = torch_budget self.rt_mods = rt_mods - total_device_budget = torch_budget + trt_budget - # device_budget is -1 if there is no trt module - device_budget = -1 if trt_budget == 0 else total_device_budget + device_budget = torch_budget + trt_min_budget + current_trt_budget super().__setattr__("device_budget", device_budget) - super().__setattr__("total_trt_budget", trt_budget) + super().__setattr__("trt_max_budget", trt_max_budget) + + def get_automatic_weight_streaming_budget(self) -> int: + ws_budget_bytes = self.torch_budget + for _, rt_mod in self.rt_mods: + ws_budget_bytes += rt_mod.get_automatic_weight_streaming_budget() + ws_budget_bytes += rt_mod.get_min_required_device_budget() + return ws_budget_bytes - def get_min_required_device_budget(self) -> int: + def get_required_device_budgets(self) -> tuple[int, int]: min_budget = self.torch_budget + max_budget = self.torch_budget + self.trt_max_budget for _, rt_mod in self.rt_mods: - min_budget += rt_mod.min_required_device_budget - return min_budget + min_budget += rt_mod.get_min_required_device_budget() + return min_budget, max_budget def __enter__(self) -> "_WeightStreamingContextManager": return self def __exit__(self, *args: Any) -> None: - for name, rt_mod in self.rt_mods: - streamable_budget = rt_mod.get_streamable_weights_size() - rt_mod.set_device_memory_budget(streamable_budget) + max_budget = self.torch_budget + self.trt_max_budget + if self.trt_max_budget > 0: logger.debug( - f"Disable weight streaming by setting size {streamable_budget} for {name}" + f"Disable weight streaming by applying max budget size {max_budget}" ) + self.device_budget = max_budget - def __setattr__(self, name: str, value: Any) -> None: - if name == "device_budget": - requested_budget = value - trt_engine_budget = requested_budget - self.torch_budget - value = 0 - if self.total_trt_budget == 0: - logger.error( - "Streamable bytes are zero. Was module complied with enable_weight_streaming=True option?" + def _set_streamable_weight_bytes(self, requested_budget: int) -> int: + ws_budget_bytes = self.torch_budget + trt_engine_budget = requested_budget - self.torch_budget + if self.trt_max_budget == 0: + raise RuntimeError( + "Streamable bytes are zero. Was module complied with enable_weight_streaming=True option?" + ) + elif trt_engine_budget <= 0: + raise RuntimeError( + f"Requested budget {requested_budget} is less than mininum torch budget: {self.torch_budget}" + ) + else: + # Normalized size is applied for multiple trt runtime module. + # e.g. 100B budget is applied to two modules and they have 1000B and 3000B max streamable size respectively. + # Then 25B and 75B are applied for each module. + for mod_name, rt_mod in self.rt_mods: + max_budget = ( + rt_mod.get_min_required_device_budget() + + rt_mod.get_streamable_weights_size() ) - value = -1 - elif trt_engine_budget <= 0: - logger.error( - f"Requested budget {requested_budget} is less than mininum torch budget: {self.torch_budget}" + normalized_size = ( + int(max_budget / self.trt_max_budget * trt_engine_budget) + - rt_mod.get_min_required_device_budget() ) - value = -1 - else: - # Normalized size is applied for multiple trt runtime module. - # e.g. 100B budget is applied to two modules and they have 1000B and 3000B max streamable size respectively. - # Then 25B and 75B are applied for each module. - for mod_name, rt_mod in self.rt_mods: - max_budget = ( - rt_mod.min_required_device_budget - + rt_mod.get_streamable_weights_size() - ) - normalized_size = ( - int(max_budget / self.total_trt_budget * trt_engine_budget) - - rt_mod.min_required_device_budget - ) - if normalized_size < 0: - logger.error( - f"Requested trt budget {trt_engine_budget} is less than mininum trt budget: {rt_mod.min_required_device_budget}" - ) - value = -1 - break - value += rt_mod.set_device_memory_budget(normalized_size) - value += rt_mod.min_required_device_budget - logger.debug( - f"Set weight streaming size {normalized_size} for {mod_name}" + + if normalized_size < 0: + raise RuntimeError( + f"Requested trt budget {trt_engine_budget} is less than mininum trt budget of submodule {mod_name} size={rt_mod.get_min_required_device_budget()}" ) + ws_budget_bytes += rt_mod.set_device_memory_budget(normalized_size) + ws_budget_bytes += rt_mod.get_min_required_device_budget() + logger.debug( + f"Set weight streaming size {normalized_size} for {mod_name}" + ) + return ws_budget_bytes + + def __setattr__(self, name: str, value: Any) -> None: + if name == "device_budget": + value = self._set_streamable_weight_bytes(value) super().__setattr__(name, value) diff --git a/tests/py/dynamo/runtime/test_004_weight_streaming.py b/tests/py/dynamo/runtime/test_004_weight_streaming.py index c727d12556..dc1c95488a 100644 --- a/tests/py/dynamo/runtime/test_004_weight_streaming.py +++ b/tests/py/dynamo/runtime/test_004_weight_streaming.py @@ -2,33 +2,19 @@ import torch_tensorrt as torchtrt from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests -from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule INPUT_SIZE = (64, 100) -# Helper to get current weight streaming budet in runtime module -def get_current_weight_streaming_bytes(runtime_module): - total_bytes = 0 - for name, rt_mod in runtime_module.named_children(): - if "_run_on_acc" in name and ( - isinstance(rt_mod, PythonTorchTensorRTModule) - or isinstance(rt_mod, TorchTensorRTModule) - ): - total_bytes += rt_mod.get_weight_streaming_budget() - total_bytes += rt_mod.min_required_device_budget - return total_bytes - - class SampleModel(torch.nn.Module): def __init__(self): super().__init__() self.layer1 = torch.nn.Linear(100, 128) self.layer2 = torch.nn.Linear(30, 64) self.mat1 = torch.randn((128, 32)).cuda() + self.mat2 = torch.randn((64, 512)).cuda() self.relu = torch.nn.ReLU() self.conv = torch.nn.Conv1d(64, 6, 3) - self.mat2 = torch.randn((64, 512)).cuda() def forward(self, x): out = self.layer1(x) @@ -45,7 +31,7 @@ class TestWeightStreamingPython(TestCase): @parameterized.expand( [ ("python_runtime", True), - # ("cpp_runtime", False), + ("cpp_runtime", False), ] ) def test_weight_streaming_default(self, _, use_python_runtime): @@ -60,13 +46,12 @@ def test_weight_streaming_default(self, _, use_python_runtime): min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, - debug=True, use_python_runtime=use_python_runtime, enable_weight_streaming=True, ) # Checking default weight streaming budget(automatic) is applied - current_ws_budget_bytes = get_current_weight_streaming_bytes(optimized_model) - assert current_ws_budget_bytes > 0 + with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx: + assert weight_streaming_ctx.device_budget > 0 ref = model(*input) out = optimized_model(*input) @@ -84,7 +69,7 @@ def test_weight_streaming_default(self, _, use_python_runtime): @parameterized.expand( [ ("python_runtime", True), - # ("cpp_runtime", False), + ("cpp_runtime", False), ] ) def test_weight_streaming_manual(self, _, use_python_runtime): @@ -96,48 +81,40 @@ def test_weight_streaming_manual(self, _, use_python_runtime): fx_graph, inputs=input, ir="dynamo", + min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, - min_block_size=1, - debug=True, use_python_runtime=use_python_runtime, enable_weight_streaming=True, ) # Weight streaming budget is applied manually. with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx: - min_budget = weight_streaming_ctx.get_min_required_device_budget() - current_budget = weight_streaming_ctx.device_budget - streamable_budget = current_budget - min_budget - weight_streaming_ctx.device_budget = min_budget + int( - streamable_budget * 0.7 - ) - - current_ws_budget_bytes = get_current_weight_streaming_bytes( - optimized_model - ) - assert weight_streaming_ctx.device_budget == current_ws_budget_bytes + min_budget, max_budget = weight_streaming_ctx.get_required_device_budgets() + streamable_budget = max_budget - min_budget + + requested_budget = min_budget + int(streamable_budget * 0.7) + weight_streaming_ctx.device_budget = requested_budget + assert weight_streaming_ctx.device_budget == requested_budget optimized_model(*input) # Full streaming by applying min budget weight_streaming_ctx.device_budget = min_budget - current_ws_budget_bytes = get_current_weight_streaming_bytes( - optimized_model - ) - assert weight_streaming_ctx.device_budget == current_ws_budget_bytes - - weight_streaming_ctx.device_budget = min_budget + int( - streamable_budget * 0.5 - ) - current_ws_budget_bytes = get_current_weight_streaming_bytes( - optimized_model - ) - assert weight_streaming_ctx.device_budget == current_ws_budget_bytes + assert weight_streaming_ctx.device_budget == min_budget + + # Automatic weight streaming size + val = weight_streaming_ctx.get_automatic_weight_streaming_budget() + weight_streaming_ctx.device_budget = val + assert weight_streaming_ctx.device_budget > 0 + + requested_budget = min_budget + int(streamable_budget * 0.5) + weight_streaming_ctx.device_budget = requested_budget + assert weight_streaming_ctx.device_budget == requested_budget + out = optimized_model(*input) # Weight streaming is disabled after the exit from weight streaming context - current_ws_budget_bytes = get_current_weight_streaming_bytes(optimized_model) - assert current_ws_budget_bytes == current_budget + assert weight_streaming_ctx.device_budget == max_budget ref = model(*input) torch.testing.assert_close( @@ -152,11 +129,13 @@ def test_weight_streaming_manual(self, _, use_python_runtime): @parameterized.expand( [ - ("python_runtime", True), - ("cpp_runtime", False), + ("python_runtime", True, False), + ("python_runtime_multi_rt", True, True), + ("cpp_runtime", False, False), + ("cpp_runtime_multi_rt", False, True), ] ) - def no_test_weight_streaming_invalid_usage(self, _, use_python_runtime): + def test_weight_streaming_invalid_usage(self, _, use_python_runtime, multi_rt): model = SampleModel().eval().cuda() input = [torch.randn(*INPUT_SIZE, dtype=torch.float32).cuda()] fx_graph = torch.fx.symbolic_trace(model) @@ -168,28 +147,32 @@ def no_test_weight_streaming_invalid_usage(self, _, use_python_runtime): min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, - debug=True, + torch_executed_ops=( + {"torch.ops.aten.convolution.default"} if multi_rt else {} + ), use_python_runtime=use_python_runtime, enable_weight_streaming=True, ) # Setting weight streaming context to unsupported module with torchtrt.runtime.weight_streaming(model) as weight_streaming_ctx: - current_budget = weight_streaming_ctx.device_budget - assert current_budget == -1 + min_budget, max_budget = weight_streaming_ctx.get_required_device_budgets() + assert min_budget == max_budget - # Expects weight streaming is disabled if invalid budget size is set with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx: - current_budget = weight_streaming_ctx.device_budget + min_budget, max_budget = weight_streaming_ctx.get_required_device_budgets() - # Values is larger than streamable weights size - weight_streaming_ctx.device_budget = current_budget + 1 - assert weight_streaming_ctx.device_budget == current_budget - optimized_model(*input) + # Values is larger than max budget disables weight streaming + weight_streaming_ctx.device_budget = max_budget + 1 + assert weight_streaming_ctx.device_budget == max_budget + + try: + # Runtime error if requested budget is less than mininum budget + weight_streaming_ctx.device_budget = min_budget - 1 + assert False + except RuntimeError: + assert True - # negative weight budget size - weight_streaming_ctx.device_budget = -1 - assert weight_streaming_ctx.device_budget == current_budget optimized_model(*input) torch._dynamo.reset() @@ -197,7 +180,7 @@ def no_test_weight_streaming_invalid_usage(self, _, use_python_runtime): @parameterized.expand( [ ("python_runtime", True), - # ("cpp_runtime", False), + ("cpp_runtime", False), ] ) def test_weight_streaming_multi_rt(self, _, use_python_runtime): @@ -210,7 +193,6 @@ def test_weight_streaming_multi_rt(self, _, use_python_runtime): inputs=input, ir="dynamo", min_block_size=1, - debug=True, cache_built_engines=False, reuse_cached_engines=False, torch_executed_ops={"torch.ops.aten.convolution.default"}, @@ -219,19 +201,19 @@ def test_weight_streaming_multi_rt(self, _, use_python_runtime): ) with torchtrt.runtime.weight_streaming(optimized_model) as weight_streaming_ctx: - min_budget = weight_streaming_ctx.get_min_required_device_budget() - current_budget = weight_streaming_ctx.device_budget - streamable_budget = current_budget - min_budget - for pct in [0.1, 0.2, 0.4, 0.8]: - weight_streaming_ctx.device_budget = min_budget + int( - pct * streamable_budget - ) - current_ws_budget_bytes = get_current_weight_streaming_bytes( - optimized_model - ) - assert weight_streaming_ctx.device_budget == current_ws_budget_bytes + min_budget, max_budget = weight_streaming_ctx.get_required_device_budgets() + streamable_budget = max_budget - min_budget + for pct in [0.05, 0.2, 0.4, 0.8, 1.0]: + requested_budget = min_budget + int(streamable_budget * pct) + weight_streaming_ctx.device_budget = requested_budget + + # Budget distribution to multiple submodule may result in integer differences of at most 1 + assert abs(weight_streaming_ctx.device_budget - requested_budget) <= 1 out = optimized_model(*input) + # Weight streaming is disabled after the exit from weight streaming context + assert weight_streaming_ctx.device_budget == max_budget + ref = model(*input) torch.testing.assert_close( out.cpu(),