@@ -116,6 +116,7 @@ void LlamaWeight::release()
116116 }
117117
118118 decoder_layer_weights.clear ();
119+ pinned_weights_.clear ();
119120
120121 // Wait for deallocations
121122 core::Context::stream ().Sync ();
@@ -127,21 +128,22 @@ void LlamaWeight::release()
127128
128129void LlamaWeight::to_device (const core::Device& device)
129130{
130- core::ContextGuard guard = context ();
131-
132- auto to_device = [&](Tensor& x) -> Tensor {
133- auto tmp = std::exchange (x, empty_like (x, device));
134- Copy (tmp, x);
135- return tmp;
136- };
137-
138- std::vector<Tensor> tmp_cpu_tensors;
131+ TM_CHECK (device.type == kCPU || device.type == kDEVICE );
132+ core::ContextGuard guard{stream_, alloca_, Allocator{kCPUpinned }};
139133
140134 auto tensor_ptr_map = get_parameters ();
141135 for (auto & [name, tensor_ptr] : tensor_ptr_map) {
142- auto tmp_tensor = to_device (*tensor_ptr);
143- if (tmp_tensor.device ().type != kDEVICE ) {
144- tmp_cpu_tensors.push_back (tmp_tensor);
136+ if (device.type == kCPU ) {
137+ if (pinned_weights_.find (name) == pinned_weights_.end ()) {
138+ pinned_weights_[name] = empty_like (*tensor_ptr, kCPUpinned );
139+ Copy (*tensor_ptr, pinned_weights_[name]);
140+ }
141+ *tensor_ptr = {};
142+ }
143+ else {
144+ TM_CHECK (pinned_weights_.find (name) != pinned_weights_.end ());
145+ *tensor_ptr = empty_like (pinned_weights_[name], kDEVICE );
146+ Copy (pinned_weights_[name], *tensor_ptr);
145147 }
146148 }
147149 core::Context::stream ().Sync ();
0 commit comments