diff --git a/README.md b/README.md index 9200e3c2b..82f852df4 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,16 @@ We also evaluated PatrickStar v0.4.3 on a single node of A100 SuperPod. It is ab Detail benchmark results on WeChat AI data center as well as NVIDIA SuperPod are posted on this [Google Doc](https://docs.google.com/spreadsheets/d/136CWc_jA_2zC4h1r-6dzD4PrOvp6aw6uCDchEyQv6sE/edit?usp=sharing). + +Scale PatrickStar to multiple machine (node) on SuperPod. +We succeed to train a GPT3-175B on 32 GPU. As far as we known, it is the first work +to run GPT3 on such small GPU cluster. +Microsoft used 10,000 V100 to pertrain GPT3. +Now you can finetune it or even pretrain your own one on 32 A100 GPU, amazing! + +![alt perf](./doc/m_node_superpod.png "performance testing result on multiple Node of SuperNode") + + We've also trained the [CLUE-GPT2](https://huggingface.co/uer/gpt2-chinese-cluecorpussmall) model with PatrickStar, the loss and accuracy curve is shown below: ![CLUE-GPT2](./doc/clue-gpt2-loss-n-acc.png) diff --git a/doc/m_node_superpod.png b/doc/m_node_superpod.png new file mode 100644 index 000000000..cd516c0e9 Binary files /dev/null and b/doc/m_node_superpod.png differ diff --git a/doc/one_node_perf_a100.png b/doc/one_node_perf_a100.png index 7cdbd9202..ae5c3d4b3 100644 Binary files a/doc/one_node_perf_a100.png and b/doc/one_node_perf_a100.png differ diff --git a/examples/benchmark/process_logs.py b/examples/benchmark/process_logs.py index 22e1779a2..3896cb5d8 100644 --- a/examples/benchmark/process_logs.py +++ b/examples/benchmark/process_logs.py @@ -29,6 +29,8 @@ import os import sys +import numpy as np +from scipy.stats import t def is_run_this_file(path, file, res_dict, file_dict): @@ -48,6 +50,8 @@ def is_run_this_file(path, file, res_dict, file_dict): f = open(path + "/" + file) is_run = True + + perf_list = np.array([]) if not os.path.isdir(file): fn_list = file.split(".")[1].split("_") for i in range(len(fn_list)): @@ -62,17 +66,31 @@ def is_run_this_file(path, file, res_dict, file_dict): if "Tflops" in line and "WARM" not in line: sline = line.split() perf = float(sline[-2]) - if key not in res_dict: - res_dict[key] = perf - file_dict[key] = file - else: - if res_dict[key] < perf: - res_dict[key] = perf - file_dict[key] = file + + perf_list = np.append(perf_list, perf) + is_run = False if "RuntimeError" in line: return False + if len(perf_list) == 0: + return False + + # calculate CI of perf_list + perf_list = perf_list[1:-1] + m = perf_list.mean() + s = perf_list.std() + dof = len(perf_list) - 1 + confidence = 0.95 + t_crit = np.abs(t.ppf((1 - confidence) / 2, dof)) + ic_perf = ( + -s * t_crit / np.sqrt(len(perf_list)), + +s * t_crit / np.sqrt(len(perf_list)), + ) + + res_dict[key] = (*ic_perf, m) + file_dict[key] = file + return is_run diff --git a/examples/model_builder.py b/examples/model_builder.py index f53e75e36..74fab4700 100644 --- a/examples/model_builder.py +++ b/examples/model_builder.py @@ -176,6 +176,16 @@ def model_config(model_name): SEQ_LEN = 1024 NUM_LAYER = 96 NUM_HEAD = 96 + elif model_name == "GPT_220B": + HIDDEN_DIM = 12288 + SEQ_LEN = 1024 + NUM_LAYER = 120 + NUM_HEAD = 96 + elif model_name == "GPT_250B": + HIDDEN_DIM = 12288 + SEQ_LEN = 1024 + NUM_LAYER = 137 + NUM_HEAD = 96 elif model_name == "GPT_310B": HIDDEN_DIM = 16384 SEQ_LEN = 1024 diff --git a/examples/pretrain_bert_demo.py b/examples/pretrain_bert_demo.py index ca18a933c..a7ec08b4a 100644 --- a/examples/pretrain_bert_demo.py +++ b/examples/pretrain_bert_demo.py @@ -237,7 +237,7 @@ def test_transformer_model_helper( is_ckp=use_ckp, is_fp16=use_fp16, dist_plan=dist_plan, - num_steps=5, + num_steps=20, ) print("*" * 20 + " LOSS " + "*" * 20) print(f"{loss_list}") diff --git a/examples/run_transformers.sh b/examples/run_transformers.sh index 1c8f15624..ee6e92511 100644 --- a/examples/run_transformers.sh +++ b/examples/run_transformers.sh @@ -28,16 +28,25 @@ export MEM_PROF=${MEM_PROF:-0} # asyn memory monitor for mem sampler export AMM=${AMM:-1} # mem saving comm -export MSC=${MSC:-0} +export MSC=${MSC:-1} # mem caching comm export CACHE=${CACHE:-1} # async move export ASYNC_MOVE=${ASYNC_MOVE:-0} # linear tiling comm export TILING=${TILING:-0} +# hybrid adam +export HYB=${HYB:-1} + export LOCAL_WORLD_SIZE=${LOCAL_WORLD_SIZE:-1} export CS_SEARCH=${CS_SEARCH:-0} +export NNODES=${NNODES:-1} +export NODE_RANK=${NODE_RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} +export MASTER_PORT=${MASTER_PORT:-"12345"} +export SUFFIX=${SUFFIX:-""} + if [[ ${TILING} == 1 ]]; then TILING_FLAG="--with_tiling_linear" else @@ -109,13 +118,20 @@ else fi let CHUNK_SIZE=${CS}*1024*1024 -export HYBRID_ADAM_FLAG="--use_hybrid_adam" + +if [[ ${HYB} == 1 ]]; then + export HYBRID_ADAM_FLAG="--use_hybrid_adam" +else + export HYBRID_ADAM_FLAG="" +fi + + LOG_DIR="./logs_${MODEL_NAME}" mkdir -p ${LOG_DIR} GIT_VER=`git rev-parse --short=5 HEAD` -LOG_FILE="log.${MODEL_NAME}_gpu_${GPU_NUM}_cs_${CS}_bs_${BS}_cpueb_${CPU_EBD}_offload_${ACT_OFFLOAD}_SP_${SP}_AMM_${AMM}_MSC_${MSC}_CACHE_${CACHE}_TILING_${TILING}_${GIT_VER}" +LOG_FILE="log.${MODEL_NAME}_gpu_${GPU_NUM}_cs_${CS}_bs_${BS}_cpueb_${CPU_EBD}_hyb_${HYB}_offload_${ACT_OFFLOAD}_SP_${SP}_AMM_${AMM}_MSC_${MSC}_CACHE_${CACHE}_TILING_${TILING}_${GIT_VER}_node_${NNODES}_${SUFFIX}" is_run_flag=`python ./benchmark/is_run_this_file.py --path "${LOG_DIR}" --file "${LOG_FILE}"` echo is_run_flag $is_run_flag @@ -183,6 +199,7 @@ python -m torch.distributed.launch --nproc_per_node=1 \ done else env OMP_NUM_THREADS=${TNUM} timeout -s SIGKILL 30m python -m torch.distributed.launch --nproc_per_node=${GPU_NUM} \ +--nnodes=${NNODES} --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \ pretrain_bert_demo.py \ --default_chunk_size=${CHUNK_SIZE} \ ${cmd_opts} \ diff --git a/patrickstar/core/chunk_list.py b/patrickstar/core/chunk_list.py index c9d4dc32d..9d36ba292 100644 --- a/patrickstar/core/chunk_list.py +++ b/patrickstar/core/chunk_list.py @@ -34,7 +34,8 @@ from patrickstar.core.const import ChunkType from patrickstar.core.memtracer import RuntimeMemTracer from patrickstar.profiler import profiler -from patrickstar.utils import logger, get_rank, get_world_size +from patrickstar.utils import logger, get_rank, get_world_size, log_dist +import logging import patrickstar.utils.global_timer as global_timer from .chunk_data import Chunk from .comm import CommInfo @@ -216,23 +217,26 @@ def prepare_device(self, target_device: torch.device, need_bytes: int): target_device.type ) - logger.debug( + log_dist( f"prepare_target: device {target_device} need_bytes {need_bytes / 1e6} MB, " f"ava_chunk_mem_size {ava_chunk_mem_size / 1e6} MB, " - f"remaining_chunk_mem_size {remaining_chunk_mem_size / 1e6} MB." + f"remaining_chunk_mem_size {remaining_chunk_mem_size / 1e6} MB.", + level=logging.DEBUG, ) # TODO(jiaruifang) Situation where there is no space. # This condition is not good enough, we need to check if botn CPU and GPU # don't have enough space. if ava_chunk_mem_size < need_bytes: - logger.warning( - f"{target_device} has not enough space for {need_bytes} elements" + log_dist( + f"{target_device} has not enough space for {need_bytes} elements", + level=logging.WARNING, ) - logger.warning( + log_dist( f"{target_device} has not enough space for {need_bytes / 1e6} MB. " f"Device used Chunk Memory is {self.get_chunk_memory_used(target_device) / 1e6} MB. " - f"Avaibale Chunk Memory is {ava_chunk_mem_size / 1e6} MB" + f"Avaibale Chunk Memory is {ava_chunk_mem_size / 1e6} MB", + level=logging.WARNING, ) if self._time_profile: global_timer.my_timer.finish_profile("CHUNK_LIST_prepare_device") diff --git a/patrickstar/core/client.py b/patrickstar/core/client.py index 07708173d..f5803f927 100644 --- a/patrickstar/core/client.py +++ b/patrickstar/core/client.py @@ -74,7 +74,9 @@ def __init__(self, rank: int, default_chunk_size: int, config=None): tracer_config = default_tracer_config opt_config = default_opt_config - self.mem_tracer = RuntimeMemTracer(self.local_rank, tracer_config) + self.mem_tracer = RuntimeMemTracer( + self.local_rank, tracer_config, opt_config["with_mem_saving_comm"] + ) self.opt_config = opt_config self.chunk_eviction_strategy = LatestAccessChunkEvictionPolicy( @@ -396,6 +398,7 @@ def _fetch_remote_chunks( # If the gpu owns the chunk (local rank), access it. # If the gpu do not own the chunk (remote chunk), allocate memory. if src_rank == rank: + self.chunk_eviction_strategy.trace_access(chunk_id, compute_device) self.chunk_list.access_chunk(chunk_id, compute_device) else: self.chunk_list.try_best_allocate_payload( @@ -447,6 +450,7 @@ def _fetch_remote_chunks( # Use collective communication to achieve the most efficient communication. # However, it is memory consumping. world_size chunks on GPU simutaneously. + self.chunk_eviction_strategy.trace_access(local_chunk_id, compute_device) self.chunk_list.access_chunk(local_chunk_id, compute_device) self.chunk_list[local_chunk_id].pin() allgather_payload_buff = [] @@ -493,6 +497,7 @@ def _fetch_remote_chunks( global_timer.my_timer.finish_profile("CLIENT_fetch_remote_chunks") def _access_tensor_in_chunk(self, param, access_type, compute_device, chunk_id): + self.chunk_eviction_strategy.trace_access(chunk_id, compute_device) self.chunk_list.access_chunk(chunk_id, compute_device) # 2. Locate the param on the chunk. tensor_id = param.ps_attr.get_tensor_id(access_type) @@ -584,7 +589,7 @@ def access_dist( local_chunk_id = chunk_id # collect the time a chunk has to be placed on compute-device - self.chunk_eviction_strategy.trace_access(local_chunk_id, compute_device) + # self.chunk_eviction_strategy.trace_access(local_chunk_id, compute_device) ret = self._access_tensor_in_chunk(param, access_type, compute_device, chunk_id) if self._time_profile: @@ -640,7 +645,7 @@ def access( chunk_id = self.chunk_tensor_index.get_chunk_id(param, access_type) # collect the time a chunk has to be placed on compute-device - self.chunk_eviction_strategy.trace_access(chunk_id, compute_device) + # self.chunk_eviction_strategy.trace_access(chunk_id, compute_device) if chunk_id is None: raise RuntimeError( @@ -763,6 +768,7 @@ def release_dist( break if do_allreduce: # move the chunk_id to GPU + self.chunk_eviction_strategy.trace_access(chunk_id, self.device) self.chunk_list.access_chunk(chunk_id, self.device) if self._time_profile: global_timer.my_timer.start_profile( @@ -818,6 +824,7 @@ def release_dist( assert self.chunk_list[local_chunk_id].payload is not None input_list = [] for i in chunk_id_list: + self.chunk_eviction_strategy.trace_access(i, self.device) self.chunk_list.access_chunk(i, self.device) self.chunk_list[i].pin() input_list.append(self.chunk_list[i].payload) diff --git a/patrickstar/core/eviction_policy.py b/patrickstar/core/eviction_policy.py index 0b4f2f250..d76e7481a 100644 --- a/patrickstar/core/eviction_policy.py +++ b/patrickstar/core/eviction_policy.py @@ -32,7 +32,8 @@ from queue import PriorityQueue from patrickstar.core.memtracer import Metronome from patrickstar.core.const import ChunkState -from patrickstar.utils import logger +from patrickstar.utils import log_dist +import logging class ChunkEvictionPolicyBase(ABC): @@ -112,6 +113,8 @@ def derive_eviction_list(self, id_to_chunk_map, need_bytes, target_device): chunk.get_device() is not None and chunk.get_device().type == target_device.type and chunk.get_state() != ChunkState.COMPUTE + and chunk.get_state() != ChunkState.RELEASED + and chunk.get_state() != ChunkState.FREE and not chunk.is_pin() ): # The next moment when this chunk was accessed. @@ -133,10 +136,12 @@ def derive_eviction_list(self, id_to_chunk_map, need_bytes, target_device): # Raise error when failed to make enough room. if moved_bytes < need_bytes: - logger.warning( + log_dist( f"device {target_device} still needs {need_bytes / 1e6} MB, " f"but there is not enough space on it, only {moved_bytes / 1e6} MB available. " - f"movable_chunk_info {movable_chunk_info}" + f"movable_chunk_info {movable_chunk_info}", + [0], + logging.WARNING, ) return moved_list diff --git a/patrickstar/core/memtracer/memtracer.py b/patrickstar/core/memtracer/memtracer.py index 33b0f3f4a..0bc570078 100644 --- a/patrickstar/core/memtracer/memtracer.py +++ b/patrickstar/core/memtracer/memtracer.py @@ -37,9 +37,9 @@ log_dist, get_memory_info, get_sys_memory_used, - get_world_size, get_local_world_size, logger, + get_world_size, ) from patrickstar.core.memtracer.metronome import Metronome from concurrent.futures import ThreadPoolExecutor @@ -95,7 +95,9 @@ class RuntimeMemTracer(object): Chunkable Memry: Memory can be used to store chunk. """ - def __init__(self, local_rank: int = 0, config=None): + def __init__( + self, local_rank: int = 0, config=None, with_mem_saving_comm: bool = False + ): self.local_rank = local_rank self.metronome = Metronome() self.gpu_chunk_available_mem = 0 @@ -104,7 +106,7 @@ def __init__(self, local_rank: int = 0, config=None): self.gpu_chunk_used_mem = 0 self.cpu_chunk_used_mem = 0 self.cpu_chunk_used_mem_pinned = 0 - + self.with_mem_saving_comm = with_mem_saving_comm if config is not None: self._overall_gpu_mem_ratio = config.get("overall_gpu_mem_ratio", 0.8) self._overall_cpu_mem_ratio = config.get("overall_cpu_mem_ratio", 0.8) @@ -395,7 +397,10 @@ def available_chunk_mem(self, device_type): else: return self._overall_cpu_mem elif device_type == "cuda": - world_size = get_world_size() + if self.with_mem_saving_comm: + msc_factor = 1 + else: + msc_factor = get_world_size() if self.metronome.training_stage() == TrainingStage.ADAM: return self._overall_gpu_mem - 4 * self._default_chunk_size * 4 elif self.metronome.training_stage() == TrainingStage.FWD: @@ -409,7 +414,7 @@ def available_chunk_mem(self, device_type): ) return ( min(next_mom_ava_mem, cur_mom_ava_mem) - - world_size * 2 * self._default_chunk_size + - msc_factor * 2 * self._default_chunk_size ) elif self.metronome.training_stage() == TrainingStage.BWD: next_mom = self.metronome.next_moment() @@ -422,5 +427,5 @@ def available_chunk_mem(self, device_type): ) return ( min(next_mom_ava_mem, cur_mom_ava_mem) - - world_size * 2 * self._default_chunk_size + - msc_factor * 2 * self._default_chunk_size * msc_factor )