diff --git a/patrickstar/core/chunk_data.py b/patrickstar/core/chunk_data.py index 3b590fa39..2a56f5e21 100644 --- a/patrickstar/core/chunk_data.py +++ b/patrickstar/core/chunk_data.py @@ -152,7 +152,11 @@ def allocate_payload(self, device): device=device, pin_memory=(device.type == "cpu"), ) - self.memory_tracer.add(device.type, self.get_payload_space()) + self.memory_tracer.add( + device.type, + self.get_payload_space(), + self.payload.is_pinned(), + ) except RuntimeError: if self._time_profile: global_timer.my_timer.finish_profile( @@ -178,7 +182,11 @@ def release_payload(self): # must delete reference of `Chunk` to self.payload self.payload = None else: - self.memory_tracer.delete(self.get_device().type, self.get_payload_space()) + self.memory_tracer.delete( + self.get_device().type, + self.get_payload_space(), + self.payload.is_pinned(), + ) del self.payload self.payload = None if profiler.started(): @@ -324,8 +332,16 @@ def move_sync(self, target_device: torch.device): self.payload = self.payload.pin_memory() self.payload = self.payload.to(target_device) - self.memory_tracer.delete(src_device.type, self.get_payload_space()) - self.memory_tracer.add(target_device.type, self.get_payload_space()) + self.memory_tracer.delete( + src_device.type, + self.get_payload_space(), + self.payload.is_pinned(), + ) + self.memory_tracer.add( + target_device.type, + self.get_payload_space(), + self.payload.is_pinned(), + ) if self._time_profile: if target_device.type == "cuda": diff --git a/patrickstar/core/client.py b/patrickstar/core/client.py index 2b49e68d0..07708173d 100644 --- a/patrickstar/core/client.py +++ b/patrickstar/core/client.py @@ -319,6 +319,7 @@ def param_fp16_chunks_max_mem_usage(self): * self.default_chunk_size * 2 / world_size + + self.default_chunk_size * 2 ) else: # non MSC has to cache work_size - 1 buffer. diff --git a/patrickstar/core/memory_cache.py b/patrickstar/core/memory_cache.py index 9fd492bea..d9e1d4f4a 100644 --- a/patrickstar/core/memory_cache.py +++ b/patrickstar/core/memory_cache.py @@ -55,7 +55,7 @@ def _new_mem(self, size, data_type, device_type, pin_memory): device=device_type, pin_memory=pin_memory, ) - self._memtracer.add(device_type.type, space_size) + self._memtracer.add(device_type.type, space_size, pin_memory) return ret def pop_or_allocate( @@ -104,8 +104,9 @@ def push(self, payload): size = payload.numel() # the cache is fulled if len(self._cached_tensors[(device_type, data_type)]) == self._capacity: + is_pinned_flag = payload.is_pinned() del payload space_size = getsizeof(data_type) * size - self._memtracer.delete(device_type.type, space_size) + self._memtracer.delete(device_type.type, space_size, is_pinned_flag) else: self._cached_tensors[(device_type, data_type)].append(payload.zero_()) diff --git a/patrickstar/core/memtracer/memtracer.py b/patrickstar/core/memtracer/memtracer.py index cee477362..33b0f3f4a 100644 --- a/patrickstar/core/memtracer/memtracer.py +++ b/patrickstar/core/memtracer/memtracer.py @@ -103,6 +103,7 @@ def __init__(self, local_rank: int = 0, config=None): self.gpu_chunk_used_mem = 0 self.cpu_chunk_used_mem = 0 + self.cpu_chunk_used_mem_pinned = 0 if config is not None: self._overall_gpu_mem_ratio = config.get("overall_gpu_mem_ratio", 0.8) @@ -165,6 +166,7 @@ def __init__(self, local_rank: int = 0, config=None): # from peak system memory. self._margin_chunk_num_for_gpu_adam = 0 self._default_chunk_size = 0 + self.max_cpu_sys_used = 0 def close_tracer(self): """ @@ -195,21 +197,23 @@ def update_margin_mem(self): logger.warning( "No gpu info collected. Maybe there are no chunk based tensors." ) - max_cpu_sys_used = 0 + self.max_cpu_sys_used = 0 else: - max_cpu_sys_used = max(self.cpu_sys_used_list) + self.max_cpu_sys_used = max(self.cpu_sys_used_list) margin_mem_size = ( self._overall_gpu_mem - max_gpu_sys_used - self._param_fp16_chunk_size ) - # 12 = 4 + 4 + 4 fp32 + m + v + # 12 = 4 + 4 + 4 (fp32 + m + v) self._margin_chunk_num_for_gpu_adam = ( (margin_mem_size) / (self._default_chunk_size * 12) * self._margin_use_ratio ) log_dist("--------------- GPU INFO AFTER BWD ----------------") log_dist(f"Max GPU System Mem (non-chunk) Used {max_gpu_sys_used / 1e6} MB") - log_dist(f"Max CPU System Mem (non-chunk) Used {max_cpu_sys_used / 1e6} MB") + log_dist( + f"Max CPU System Mem (non-chunk) Used {self.max_cpu_sys_used / 1e6} MB" + ) log_dist(f"Param FP16 Chunk Size {self._param_fp16_chunk_size / 1e6} MB") log_dist( f"Margin Mem Size {margin_mem_size / 1e6} MB, " @@ -280,7 +284,11 @@ def trace_memory(self): cpu_used = get_sys_memory_used(cpu_device) self.cpu_used_list.append(cpu_used) self.cpu_chunk_used_list.append(self.cpu_chunk_used_mem) - self.cpu_sys_used_list.append((cpu_used - self.cpu_chunk_used_mem)) + # detected cpu memory usage (already excluded pinned memory) - chunk non + # pinned memory usage = system cpu usage (non-chunk cpu memory) + self.cpu_sys_used_list.append( + cpu_used - (self.cpu_chunk_used_mem - self.cpu_chunk_used_mem_pinned) + ) # For non-warmup iter, we update the mem of index cur_mom, # and for warmup iter, we append the gpu mem to the end of the list. @@ -297,17 +305,21 @@ def trace_memory(self): self.metronome.tiktac() - def add(self, device_type: str, size_in_bytes: int): + def add(self, device_type: str, size_in_bytes: int, is_pinned: bool = False): if device_type == "cpu": self.cpu_chunk_used_mem += size_in_bytes + if is_pinned: + self.cpu_chunk_used_mem_pinned += size_in_bytes elif device_type == "cuda": self.gpu_chunk_used_mem += size_in_bytes else: raise f"device type {device_type} is not supported" - def delete(self, device_type, size_in_bytes): + def delete(self, device_type, size_in_bytes, is_pinned: bool = False): if device_type == "cpu": self.cpu_chunk_used_mem -= size_in_bytes + if is_pinned: + self.cpu_chunk_used_mem_pinned -= size_in_bytes elif device_type == "cuda": self.gpu_chunk_used_mem -= size_in_bytes else: @@ -377,7 +389,11 @@ def available_chunk_mem(self, device_type): return self._overall_gpu_mem * self.warmup_gpu_chunk_mem_ratio if device_type == "cpu": - return self._overall_cpu_mem + local_world_size = get_local_world_size() + if self.metronome.training_stage() != TrainingStage.ADAM: + return self._overall_cpu_mem - self.max_cpu_sys_used / local_world_size + else: + return self._overall_cpu_mem elif device_type == "cuda": world_size = get_world_size() if self.metronome.training_stage() == TrainingStage.ADAM: